In [None]:
# -- iPython Config --
from IPython import get_ipython

if "IPython.extensions.autoreload" not in get_ipython().extension_manager.loaded:
    get_ipython().run_line_magic("load_ext", "autoreload")
else:
    get_ipython().run_line_magic("reload_ext", "autoreload")
%autoreload 2

# -- System and Path --
import os
import sys
REPO_PATH = os.path.abspath(os.path.join(".."))
if REPO_PATH not in sys.path:
    sys.path.append(REPO_PATH)
print(f"REPO_PATH: {REPO_PATH}")
import warnings
warnings.filterwarnings("ignore")


In [None]:
# -- Imports --
import os
import pandas as pd
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from datasets import Dataset, DatasetDict, load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer


In [None]:
# -- Configuration --
class Config:
    def __init__(self):
        self.REPO_PATH = REPO_PATH
        self.SEED = 42
config = Config()

# -- device
def select_device():
    device = ""
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
    print(f"Using device: {device}")
    return device
device = select_device()


In [None]:
# -- datasets
data_dir = os.path.join(config.REPO_PATH, "data", "thaisum", "raw")
train_data_file = os.path.join(data_dir, "train.csv")
# valid_data_file = os.path.join(data_dir, "valid.csv")
# test_data_file = os.path.join(data_dir, "test.csv")
def load_dataset_from_csv(
    train_file: str = None, valid_file: str = None, test_file: str = None
) -> DatasetDict:

    split_files = {"train": train_file,
                   "validation": valid_file,
                   "test": test_file}

    dct = {}
    for split in tqdm(split_files, desc="Loading CSV splits"):
        file_path = split_files[split]
        if file_path:
            # ! [Sample] the first 100 rows for demonstration
            df = pd.read_csv(file_path, nrows=100)
            dct[split] = Dataset.from_pandas(df)
    return DatasetDict(dct)
dataset_dict = load_dataset_from_csv(train_file=train_data_file)

In [None]:
# -- Tokenizer --
MODEL_NAME = "GSAI-ML/LLaDA-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

# -- Formatting --
def format_llada_prompt(example):
    instruction = f"<start_id>user<end_id>\n{example['body']}<eot_id><start_id>assistant<end_id>\n{example['summary']}<EOS>"
    tokenized = tokenizer(
        instruction, padding="max_length", truncation=True, max_length=1024
    )
    prompt_end = instruction.find("<start_id>assistant<end_id>")
    prompt_tokens = tokenizer(instruction[:prompt_end])["input_ids"]
    return {"input_ids": tokenized["input_ids"], "prompt_length": len(prompt_tokens)}

# -- Process and Save --
output_dir = os.path.join(config.REPO_PATH, "data", "thaisum", "tokenized")
os.makedirs(output_dir, exist_ok=True)

for split in tqdm(dataset_dict, desc="Processing splits"):
    print(f"Processing {split} split...")
    processed_data = dataset_dict[split].map(format_llada_prompt)

    output_path = os.path.join(output_dir, f"{split}.jsonl")
    processed_data.to_json(output_path)
    print(f"Saved: {output_path}")


In [None]:
mask_token_id = 126336
batch_size = 2
lr = 1e-5
num_epochs = 1

# -- Load tokenized dataset
train_file = os.path.join(config.REPO_PATH, "data", "thaisum", "tokenized", "train.jsonl")
train_dataset = load_dataset("json", data_files={"train": train_file})["train"]

# -- Load the model
print(f"Loading {MODEL_NAME} model...")
model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True, torch_dtype=torch.bfloat16)
print(f"{MODEL_NAME} model load successfully.")
model.to(device)
model.train()

# ==== Collate function ====
def collate_fn(batch):
    input_ids = torch.tensor([item["input_ids"] for item in batch])
    prompt_lengths = torch.tensor([item["prompt_length"] for item in batch])
    return {"input_ids": input_ids, "prompt_lengths": prompt_lengths}

# ==== DataLoader ====
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# ==== Optimizer ====
optimizer = AdamW(model.parameters(), lr=lr)

# ==== Training loop ====
for epoch in range(num_epochs):
    pbar = tqdm(dataloader, desc=f"Epoch {epoch + 1}")
    for batch in pbar:
        input_ids = batch["input_ids"].to(device)
        prompt_lengths = batch["prompt_lengths"].to(device)

        # Mask everything except the prompt
        noisy_batch = input_ids.clone()
        for i in range(noisy_batch.shape[0]):
            noisy_batch[i, prompt_lengths[i]:] = mask_token_id

        mask_index = (noisy_batch == mask_token_id)

        logits = model(input_ids=noisy_batch).logits
        p_mask = torch.ones_like(noisy_batch, dtype=torch.float32).to(device)

        token_loss = F.cross_entropy(
            logits[mask_index], input_ids[mask_index], reduction="none"
        ) / p_mask[mask_index]

        loss = token_loss.sum() / input_ids.shape[0]

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        pbar.set_postfix(loss=loss.item())
