In [None]:
import pandas as pd
import torch
import random
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.amp import autocast, GradScaler
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

# Custom Dataset class for ThaiSum with instruction prompting
class ThaiSumDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_input_length=512, max_target_length=128):
        self.texts = dataframe['body'].tolist()
        self.summaries = dataframe['summary'].tolist()
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length
        self.instruction = "สรุปข้อความนี้เป็นภาษาไทย:"  # Thai instruction

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        summary = str(self.summaries[idx])

        # Prepend instruction to input text and append summary as target
        input_text = f"{self.instruction} {text}"
        target_text = f"{summary}"

        # Tokenize input and target
        input_encoding = self.tokenizer(
            input_text,
            max_length=self.max_input_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        target_encoding = self.tokenizer(
            target_text,
            max_length=self.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Combine input and target for CLM (shifted labels)
        input_ids = torch.cat([input_encoding['input_ids'].squeeze(), target_encoding['input_ids'].squeeze()], dim=0)
        labels = input_ids.clone()
        labels[:input_encoding['input_ids'].shape[1]] = -100  # Ignore loss for input part

        return {
            'input_ids': input_ids,
            'labels': labels
        }

# q_t_given_0 function from LLADA (adapted for input)
def q_t_given_0(input_ids, mask_token_id, t, N, tokenizer):
    """
    Remask tokens according to q_{t|0}: mask each token with probability s = t/N.
    """
    s = t / N
    special_tokens_mask = (input_ids == tokenizer.pad_token_id)  # Only protect pad tokens
    rand_mask = torch.bernoulli(torch.full(input_ids.shape, s)).bool().to(input_ids.device)
    mask_positions = rand_mask & ~special_tokens_mask

    masked_input = input_ids.clone()
    masked_input[mask_positions] = mask_token_id
    return masked_input, mask_positions

# Training step with LLADA diffusion and mixed precision
def training_step(model, tokenizer, batch, N, scaler):
    input_ids = batch['input_ids'].to(model.device)
    labels = batch['labels'].to(model.device)

    # Apply LLADA diffusion to input (up to input part)
    input_length = input_ids.shape[0] - labels.ne(-100).sum()  # Length of input part
    input_ids_input = input_ids[:input_length]
    t = random.randint(1, N)
    masked_input_ids, _ = q_t_given_0(input_ids_input, tokenizer.pad_token_id, t, N, tokenizer)
    input_ids[:input_length] = masked_input_ids

    # Mixed precision forward pass
    with autocast(device_type='cuda'):
        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss

    # Scale loss and backpropagate
    scaler.scale(loss).backward()
    return loss

# Training loop with mixed precision
def train_summarization_llada(model, tokenizer, dataset, optimizer, epochs=3, N=10, batch_size=8):
    model.train()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    scaler = GradScaler()

    for epoch in range(epochs):
        total_loss = 0.0
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            optimizer.zero_grad()
            loss = training_step(model, tokenizer, batch, N, scaler)
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        print(f"[Epoch {epoch+1}] Avg Loss: {avg_loss:.4f}")

def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load dataset
    df = pd.read_csv('/kaggle/input/thaisum-train-10000-1024-nlpfinal/train-10000-1024.csv')

    # Initialize tokenizer and model
    model_name = 'facebook/xglm-564M'  # ~564M parameters, supports Thai
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

    # Prepare dataset
    dataset = ThaiSumDataset(df, tokenizer, max_input_length=1024, max_target_length=256)

    # Set hyperparameters
    epochs = 1
    batch_size = 2
    N = 10  # For q_t_given_0
    learning_rate = 2e-5

    # Initialize optimizer
    optimizer = AdamW(model.parameters(), lr=learning_rate)

    # Train the model
    train_summarization_llada(model, tokenizer, dataset, optimizer, epochs=epochs, N=N, batch_size=batch_size)

    # Save the model
    model.save_pretrained('./thai_summarization_llada_xglm564m')
    tokenizer.save_pretrained('./thai_summarization_llada_xglm564m')
    print("Model and tokenizer saved to './thai_summarization_llada_xglm564m'")

if __name__ == "__main__":
    main()

Using device: cuda


tokenizer_config.json:   0%|          | 0.00/433 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/4.92M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.03M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/276 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/546 [00:00<?, ?B/s]

2025-05-08 13:55:58.483575: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746712558.712118      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746712558.777264      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


pytorch_model.bin:   0%|          | 0.00/1.13G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.13G [00:00<?, ?B/s]

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.nn.functional import softmax

def infer_llada(input_text, model, tokenizer, L=128, N=10):
    # Move model to device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()

    # Prepend instruction to input text
    instruction = "สรุปข้อความนี้เป็นภาษาไทย:"
    prompt = f"{instruction} {input_text}"

    # Tokenize input text with instruction
    input_encoding = tokenizer(
        prompt,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    ).to(device)

    # Initialize fully masked sequence of length L
    masked_ids = torch.full((1, L), tokenizer.pad_token_id, dtype=torch.long).to(device)
    r_t = masked_ids.clone()

    # Sampling steps
    for t in range(N, 0, -1):
        s = t / N

        # Predict next tokens
        with torch.no_grad():
            outputs = model(input_ids=input_encoding['input_ids'], labels=r_t)
            logits = outputs.logits[:, -1, :]  # Take logits for the last token position
            probs = softmax(logits, dim=-1)
            confidences, predicted_ids = torch.max(probs, dim=-1)

        # Update r_t with predicted tokens
        r_0 = r_t.clone()
        c = torch.ones_like(r_0, dtype=torch.float).to(device)  # Confidence scores

        for i in range(L):
            if r_t[0, i] != tokenizer.pad_token_id:  # If not masked
                r_0[0, i] = r_t[0, i]
                c[0, i] = 1.0
            else:
                r_0[0, i] = predicted_ids[0]
                c[0, i] = confidences[0].item()

        # Calculate number of unmasked tokens
        n_un = int(L * (1 - s))

        # Remask the n_un least confident positions
        if n_un > 0:
            _, lowest_conf_indices = torch.topk(c[0], n_un, largest=False)
            for idx in lowest_conf_indices:
                r_0[0, idx] = tokenizer.pad_token_id

        r_t = r_0.clone()

    # Final generation
    with torch.no_grad():
        output_ids = model.generate(
            input_ids=input_encoding['input_ids'],
            max_length=L,
            num_beams=1,
            early_stopping=True,
            pad_token_id=tokenizer.pad_token_id
        )
    summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    return summary

def main():
    # Load model and tokenizer
    model_path = './thai_summarization_llada_qwen05b'
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path)

    # Example input text
    input_text = "นี่คือตัวอย่างข้อความภาษาไทยที่ยาวมากเกี่ยวกับข่าวประจำวัน ซึ่งมีรายละเอียดเกี่ยวกับเหตุการณ์สำคัญในประเทศไทย เช่น การเมือง เศรษฐกิจ และวัฒนธรรม"

    # Perform inference
    summary = infer_llada(input_text, model, tokenizer, L=128, N=10)
    print("Summary:", summary)

if __name__ == "__main__":
    main()