In [1]:
%pip install pandas torch datasets tqdm transformers

Note: you may need to restart the kernel to use updated packages.


In [1]:
# -- iPython Config --
from IPython import get_ipython
if "IPython.extensions.autoreload" not in get_ipython().extension_manager.loaded:
    get_ipython().run_line_magic("load_ext", "autoreload")
else:
    get_ipython().run_line_magic("reload_ext", "autoreload")
%autoreload 2

# -- System and Path --
import os
import sys
REPO_PATH = os.path.abspath(os.path.join(".."))
if REPO_PATH not in sys.path:
    sys.path.append(REPO_PATH)
print(f"REPO_PATH: {REPO_PATH}")
import warnings
warnings.filterwarnings("ignore")

REPO_PATH: /root


In [1]:
# -- Imports --
import os
import pandas as pd
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from datasets import Dataset, DatasetDict
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# -- Configuration --
class Config:
    """Holds all configuration parameters for the script."""
    def __init__(self):
        self.repo_path = os.path.abspath(os.path.join(".."))
        self.data_dir = self.repo_path
        self.tokenized_data_dir = os.path.join(self.repo_path, "tokenized")
        self.model_name = "GSAI-ML/LLaDA-8B-Instruct"
        self.batch_size = 2
        self.lr = 1e-5
        self.num_epochs = 1
        self.seed = 42
        self.mask_token_id = 126336
        self.device = self._select_device()

    def _select_device(self):
        """Selects the best available device (CUDA, MPS, or CPU)."""
        if torch.cuda.is_available():
            device = "cuda"
        elif torch.backends.mps.is_available():
            device = "mps"
        else:
            device = "cpu"
        print(f"Using device: {device}")
        return device
config = Config()

Using device: cuda


# Training

In [3]:
# Data Loading and Preprocessing
def load_dataset_from_csv(train_file: str = None, valid_file: str = None, test_file: str = None, sample_size: int = None) -> DatasetDict:
    """Loads dataset splits from CSV files and optionally samples rows."""
    split_files = {"train": train_file, "validation": valid_file, "test": test_file}
    dct = {}
    for split, file_path in split_files.items():
        if file_path:
            df = pd.read_csv(file_path)
            if sample_size:
                df = df.sample(sample_size)
            dct[split] = Dataset.from_pandas(df)
    return DatasetDict(dct)

def format_llada_prompt(example, tokenizer):
    """Formats an example into a prompt for the LLaDA model and tokenizes it."""
    instruction = f"<start_id>user<end_id>\nสรุปข้อความต่อไปนี้\n{example['body']}<eot_id><start_id>assistant<end_id>\n{example['summary']}<EOS>"
    tokenized = tokenizer(instruction, padding="max_length", truncation=True, max_length=2048)
    prompt_end = instruction.find("<start_id>assistant<end_id>")
    prompt_tokens = tokenizer(instruction[:prompt_end])["input_ids"]
    return {"input_ids": tokenized["input_ids"], "prompt_length": len(prompt_tokens)}

def load_and_preprocess_data(config: Config, sample_size: int = 100) -> tuple:
    """Loads and preprocesses the dataset, saving tokenized data to disk."""
    train_data_file = "train-7000-1024.csv" #os.path.join("train100.csv")
    dataset_dict = load_dataset_from_csv(train_file=train_data_file, sample_size=sample_size)
    tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=True)

    os.makedirs(config.tokenized_data_dir, exist_ok=True)
    processed_data = dataset_dict["train"].map(lambda x: format_llada_prompt(x, tokenizer))
    output_path = os.path.join(config.tokenized_data_dir, "train.jsonl")
    processed_data.to_json(output_path)
    print(f"Saved tokenized data to: {output_path}")
    return processed_data, tokenizer

# Model Loading
def load_model(config: Config) -> AutoModel:
    """Loads the LLaDA model and prepares it for training."""
    print(f"Loading {config.model_name} model...")
    model = AutoModel.from_pretrained(config.model_name, trust_remote_code=True, torch_dtype=torch.bfloat16)
    model.to(config.device)
    model.train()
    print(f"{config.model_name} model loaded successfully.")
    return model

# DataLoader Collate Function
def collate_fn(batch):
    """Prepares a batch for training by converting to tensors."""
    input_ids = torch.tensor([item["input_ids"] for item in batch])
    prompt_lengths = torch.tensor([item["prompt_length"] for item in batch])
    return {"input_ids": input_ids, "prompt_lengths": prompt_lengths}

# Training Function
def train_model(model: AutoModel, dataloader: DataLoader, optimizer: AdamW, num_epochs: int, device: str, mask_token_id: int):
    """Trains the model using a masked language modeling approach."""
    for epoch in range(num_epochs):
        pbar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")
        for batch in pbar:
            input_ids = batch["input_ids"].to(device)
            prompt_lengths = batch["prompt_lengths"].to(device)

            # Create noisy batch by masking post-prompt tokens
            noisy_batch = input_ids.clone()
            for i in range(noisy_batch.shape[0]):
                noisy_batch[i, prompt_lengths[i]:] = mask_token_id
            mask_index = (noisy_batch == mask_token_id)

            # Forward pass
            logits = model(input_ids=noisy_batch).logits
            p_mask = torch.ones_like(noisy_batch, dtype=torch.float32).to(device)

            # Compute loss only on masked tokens
            token_loss = F.cross_entropy(logits[mask_index], input_ids[mask_index], reduction="none") / p_mask[mask_index]
            loss = token_loss.sum() / input_ids.shape[0]

            # Backward pass
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            pbar.set_postfix(loss=loss.item())


In [4]:
# Load and preprocess data
processed_data, tokenizer = load_and_preprocess_data(config, sample_size=4000)
# Load model
model = load_model(config)
# Prepare DataLoader
dataloader = DataLoader(
    processed_data, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn
)
# Train model
optimizer = AdamW(model.parameters(), lr=config.lr)
train_model(
    model, dataloader, optimizer, config.num_epochs, config.device, config.mask_token_id
)

Map: 100%|██████████| 100/100 [00:00<00:00, 130.97 examples/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 27.28ba/s]


Saved tokenized data to: /root/tokenized/train.jsonl
Loading GSAI-ML/LLaDA-8B-Instruct model...


Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00, 43.24it/s]


GSAI-ML/LLaDA-8B-Instruct model loaded successfully.


Epoch 1/1: 100%|██████████| 50/50 [01:03<00:00,  1.28s/it, loss=0]      


# Inference

In [5]:
# Generation Functions
def add_gumbel_noise(logits, temperature):
    """Adds Gumbel noise to logits for sampling."""
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (-torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise

def get_num_transfer_tokens(mask_index, steps):
    """Calculates the number of tokens to transfer at each generation step."""
    mask_num = mask_index.sum(dim=1, keepdim=True)
    base = mask_num // steps
    remainder = mask_num % steps
    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1
    return num_transfer_tokens

@torch.no_grad()
def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
             cfg_scale=0., remasking='low_confidence', mask_id=126336):
    """Generates text using the trained model."""
    x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()
    prompt_index = (x != mask_id)

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length
    assert steps % num_blocks == 0
    steps_per_block = steps // num_blocks

    for num_block in range(num_blocks):
        start_idx = prompt.shape[1] + num_block * block_length
        end_idx = start_idx + block_length
        block_mask_index = (x[:, start_idx:end_idx] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)

        for i in range(steps_per_block):
            mask_index = (x == mask_id)
            if cfg_scale > 0:
                un_x = x.clone()
                un_x[prompt_index] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                logits = model(x_).logits
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model(x).logits

            logits_with_noise = add_gumbel_noise(logits, temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1)

            if remasking == 'low_confidence':
                p = F.softmax(logits.to(torch.float64), dim=-1)
                x0_p = torch.squeeze(torch.gather(p, dim=-1, index=x0.unsqueeze(-1)), -1)
            elif remasking == 'random':
                x0_p = torch.rand_like(x0, device=x0.device)
            else:
                raise NotImplementedError(f"Remasking strategy '{remasking}' not implemented.")

            x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -np.inf)

            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            for j in range(confidence.shape[0]):
                _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
                transfer_index[j, select_index] = True
            x[transfer_index] = x0[transfer_index]

    return x

In [7]:
model.eval()
# Generate summary
prompt_text = (
    "สรุปข้อความต่อไปนี้\n ความเก่ง เกิดขึ้นได้หลายแบบไม่ว่าจะ "
    "ความหมั่นเพียร(ฝึกซ้อม), ประสบการณ์, สิ่งแวดล้อมเกื้อหนุน, มีต้นทุนบางอย่างดี "
    "เหมือนคนเกิดมาร่างกายสูงใหญ่มีโอกาสเก่งในกีฬาหลายประเภท นี่ก็ถือว่าต้นทุนดี "
    "แต่เหล่านี้เองจึงย้อนไปบั่นทอนคนที่คิดว่าตนไม่เก่ง เช่น เราขี้เกียจ-ไม่มีเวลาซ้อม, "
    "เราไม่เคยทำมาก่อน, ยังไม่พร้อม, ต้นทุนไม่ดีเหมือนเขา ส่วนหนึ่งก็ใช่ว่าผิด "
    "แต่แน่นอนไม่ถูก และกลายเป็นถ่วงอนาคตอย่างมาก"
)
messages = [{"role": "user", "content": prompt_text}]
prompt = tokenizer.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=False
)
input_ids = torch.tensor(tokenizer(prompt)["input_ids"]).to(config.device).unsqueeze(0)
output_ids = generate(
    model,
    input_ids,
    steps=128,
    gen_length=128,
    block_length=32,
    temperature=0.0,
    cfg_scale=0.0,
    remasking="low_confidence",
    mask_id=config.mask_token_id,
)
summary = tokenizer.batch_decode(
    output_ids[:, input_ids.shape[1] :], skip_special_tokens=True
)[0]
print("Generated Summary:")
print(summary)

Generated Summary:
ความเก่ง เกิดขึ้นได้หลายแบบไม่ว่าจะ ความหมั่นเพียร(ฝึกซ้อม), ประสบการณ์, สิ่งแวดล้อมเกื้อหนุน, มีต้นทุนบางอย่างดี เหมือนเขา


In [None]:
from huggingface_hub import login
login(token="")
model.push_to_hub("pupipatsk/llada-thaisum-finetuned")

model-00001-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]
model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s][A

Upload 4 LFS files:   0%|          | 0/4 [00:00<?, ?it/s][A[A


model-00003-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s][A[A[A



model-00001-of-00004.safetensors:   0%|          | 4.11M/5.00G [00:00<02:01, 41.0MB/s][A
model-00002-of-00004.safetensors:   0%|          | 3.80M/5.00G [00:00<02:11, 38.0MB/s][A


model-00003-of-00004.safetensors:   0%|          | 4.21M/5.00G [00:00<01:59, 42.0MB/s][A[A[A



model-00004-of-00004.safetensors:   0%|          | 4.00M/1.04G [00:00<00:25, 39.8MB/s][A[A[A[A
model-00002-of-00004.safetensors:   0%|          | 9.55M/5.00G [00:00<08:47, 9.47MB/s][A


model-00001-of-00004.safetensors:   0%|          | 8.22M/5.00G [00:01<11:46, 7.06MB/s][A[A[A
model-00002-of-00004.safetensors:   0%|          | 14.5M/5.00G [00:01<05:44, 14.5MB/s][A


model-00001-of-00004.saf