In [None]:
# --- iPython Config --- #
from IPython import get_ipython
if 'IPython.extensions.autoreload' not in get_ipython().extension_manager.loaded:
    get_ipython().run_line_magic('load_ext', 'autoreload')
else:
    get_ipython().run_line_magic('reload_ext', 'autoreload')
%autoreload 2

# --- System and Path --- #
import os
import sys
REPO_PATH = os.path.abspath(os.path.join('..'))
if REPO_PATH not in sys.path:
    sys.path.append(REPO_PATH)
print(f"REPO_PATH: {REPO_PATH}")
import warnings
warnings.filterwarnings("ignore")

In [None]:
# --- Imports --- #
import pandas as pd
# from datasets import load_dataset

In [None]:
import os
import pandas as pd
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer

# ==== CONFIGURATION ====
train_path = "/path/to/train.csv"
valid_path = "/path/to/valid.csv"
test_path = "/path/to/test.csv"
output_dir = "/mnt/data/thaisum_finetune"

# ==== Load tokenizer ====
tokenizer = AutoTokenizer.from_pretrained("GSAI-ML/LLaDA-8B-Instruct", trust_remote_code=True)

# ==== Load CSVs ====
train_df = pd.read_csv(train_path)
valid_df = pd.read_csv(valid_path)
test_df = pd.read_csv(test_path)

# ==== Wrap in HuggingFace Datasets ====
dataset = DatasetDict({
    "train": Dataset.from_pandas(train_df),
    "validation": Dataset.from_pandas(valid_df),
    "test": Dataset.from_pandas(test_df)
})

# ==== Formatting function ====
def format_llada_prompt(example):
    instruction = f"<start_id>user<end_id>\n{example['body']}<eot_id><start_id>assistant<end_id>\n{example['summary']}<EOS>"
    tokenized = tokenizer(instruction, padding="max_length", truncation=True, max_length=1024)
    prompt_end = instruction.find("<start_id>assistant<end_id>")
    prompt_tokens = tokenizer(instruction[:prompt_end])["input_ids"]
    return {
        "input_ids": tokenized["input_ids"],
        "prompt_length": len(prompt_tokens)
    }

# ==== Process and save ====
os.makedirs(output_dir, exist_ok=True)

for split in dataset:
    print(f"Processing {split} split...")
    processed = dataset[split].map(format_llada_prompt)
    output_path = os.path.join(output_dir, f"{split}.jsonl")
    processed.to_json(output_path)
    print(f"✅ Saved: {output_path}")


In [None]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer, AdamW
from tqdm import tqdm
import torch.nn.functional as F

# ==== CONFIG ====
train_file = "/mnt/data/thaisum_finetune/train.jsonl"
model_name = "GSAI-ML/LLaDA-8B-Instruct"
mask_token_id = 126336
batch_size = 2
lr = 1e-5
num_epochs = 1  # Adjust as needed

device = "cuda" if torch.cuda.is_available() else "cpu"

# ==== Load dataset ====
dataset = load_dataset("json", data_files={"train": train_file})["train"]

# ==== Load tokenizer/model ====
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
model.train()

# ==== Collate function ====
def collate_fn(batch):
    input_ids = torch.tensor([item["input_ids"] for item in batch])
    prompt_lengths = torch.tensor([item["prompt_length"] for item in batch])
    return {"input_ids": input_ids, "prompt_lengths": prompt_lengths}

# ==== DataLoader ====
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# ==== Optimizer ====
optimizer = AdamW(model.parameters(), lr=lr)

# ==== Training loop ====
for epoch in range(num_epochs):
    pbar = tqdm(dataloader, desc=f"Epoch {epoch + 1}")
    for batch in pbar:
        input_ids = batch["input_ids"].to(device)
        prompt_lengths = batch["prompt_lengths"].to(device)

        # Mask everything except the prompt
        noisy_batch = input_ids.clone()
        for i in range(noisy_batch.shape[0]):
            noisy_batch[i, prompt_lengths[i]:] = mask_token_id

        mask_index = (noisy_batch == mask_token_id)

        logits = model(input_ids=noisy_batch).logits
        p_mask = torch.ones_like(noisy_batch, dtype=torch.float32).to(device)

        token_loss = F.cross_entropy(
            logits[mask_index], input_ids[mask_index], reduction="none"
        ) / p_mask[mask_index]

        loss = token_loss.sum() / input_ids.shape[0]

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        pbar.set_postfix(loss=loss.item())
