<a href="https://colab.research.google.com/github/yilmajung/LLM_POC_Study_2025_v2/blob/main/finetune_TLLM_abortion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch, json, random
from huggingface_hub import login

from google.colab import userdata
hf_token = userdata.get('HF_TOKEN')
login(hf_token)

OPT_TOKENS = [
  "<OPT_STRONG_ANTI>", "<OPT_ANTI>", "<OPT_NEUTRAL>", "<OPT_PRO>", "<OPT_STRONG_PRO>"
]
opt2id = {}

model_name = "meta-llama/llama-3.1-8b" # or "mistralai/Mistral-7B-v0.3" or "QWEN 2.5"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, token = hf_token)
tokenizer.add_special_tokens({"additional_special_tokens": OPT_TOKENS})

model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.bfloat16, device_map="auto", token=hf_token
)
model.resize_token_embeddings(len(tokenizer))

# Optional QLoRA
# model = prepare_model_for_kbit_training(model)
lora_cfg = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05,
    target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj"]
)
model = get_peft_model(model, lora_cfg)

# map special tokens to ids AFTER resize
opt2id = {opt: tokenizer.convert_tokens_to_ids(opt) for opt in OPT_TOKENS}
opt_ids = torch.tensor([opt2id[o] for o in OPT_TOKENS], device=model.device)

# --- Dataset ---
class TLLMRowDataset(torch.utils.data.Dataset):
    def __init__(self, jsonl_path, shuffle_opts=True):
        self.recs = [json.loads(x) for x in open(jsonl_path)]
        self.shuffle_opts = shuffle_opts
    def __len__(self): return len(self.recs)
    def __getitem__(self, i):
        r = self.recs[i]
        # Expect fields: prompt_text, to_dist (list float), weight (float)
        prompt = r["prompt_text"]
        # Randomize option order to reduce position bias
        order = list(range(len(OPT_TOKENS)))
        if self.shuffle_opts:
            random.shuffle(order)
        # Reorder target distribution accordingly
        to_dist = [r["to_dist"][j] for j in order]
        opt_tokens_ordered = [OPT_TOKENS[j] for j in order]

        prompt_with_opts = prompt + "Options: " + " ".join(opt_tokens_ordered) + "\nAnswer:\n"
        enc = tokenizer(prompt_with_opts, return_tensors="pt", truncation=True, max_length=1024)
        return {
            "input_ids": enc["input_ids"][0],
            "attention_mask": enc["attention_mask"][0],
            "to_dist": torch.tensor(to_dist, dtype=torch.float),
            "weight": torch.tensor(r.get("weight", 1.0), dtype=torch.float),
            "order": torch.tensor(order, dtype=torch.long),
        }

def data_collate(batch):
    # pad
    maxlen = max(x["input_ids"].shape[0] for x in batch)
    input_ids = []
    attn = []
    for x in batch:
        pad = maxlen - x["input_ids"].shape[0]
        input_ids.append(torch.cat([x["input_ids"], torch.full((pad,), tokenizer.pad_token_id or tokenizer.eos_token_id)]))
        attn.append(torch.cat([x["attention_mask"], torch.zeros(pad, dtype=torch.long)]))
    return {
        "input_ids": torch.stack(input_ids),
        "attention_mask": torch.stack(attn),
        "to_dist": torch.stack([x["to_dist"] for x in batch]),
        "weight": torch.stack([x["weight"] for x in batch]),
        "order": torch.stack([x["order"] for x in batch]),
    }

# --- Custom Trainer with KL loss over option tokens ---
import torch.nn.functional as F
class KLTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        to_dist = inputs["to_dist"].to(model.device)   # [B, K]
        weight = inputs["weight"].to(model.device)     # [B]
        order  = inputs["order"].to(model.device)      # [B, K]

        out = model(input_ids=input_ids.to(model.device),
                    attention_mask=attention_mask.to(model.device))
        # last token position per sequence
        last_idx = attention_mask.sum(dim=1) - 1  # [B]
        last_hidden = out.logits[torch.arange(out.logits.size(0)), last_idx]  # [B, V]

        # fetch logits for option tokens in the *ordered* list
        # remap global opt_ids according to per-example 'order'
        opt_ids_ordered = torch.stack([opt_ids[o] for o in order])  # [B,K]
        opt_logits = last_hidden.gather(1, opt_ids_ordered)          # [B,K]
        p_llm = F.softmax(opt_logits, dim=1)                         # [B,K]

        # forward KL: sum p_human * (log p_human - log p_llm)
        p_h = (to_dist / (to_dist.sum(dim=1, keepdim=True) + 1e-12)).clamp_min(1e-8)
        loss_vec = (p_h * (p_h.log() - (p_llm + 1e-12).log())).sum(dim=1)  # [B]
        # weight by n_from etc.
        loss = (weight * loss_vec).mean()

        return (loss, out) if return_outputs else loss

args = TrainingArguments(
    output_dir="/content/drive/MyDrive/LLM_POC_Study_2025_v2/tllm_abortion_transitions",
    learning_rate=2e-4,  # LoRA can take a higher LR
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=3,
    logging_steps=50,
    eval_strategy="steps",
    eval_steps=500,
    save_steps=500,
    bf16=True,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    weight_decay=0.0,
    save_total_limit=2,
    report_to="none",
    label_names=["to_dist", "weight", "order"] # Add label_names to prevent removal
)

train_ds = TLLMRowDataset("/content/drive/MyDrive/LLM_POC_Study_2025/train_rows.jsonl", shuffle_opts=True)
eval_ds  = TLLMRowDataset("/content/drive/MyDrive/LLM_POC_Study_2025/val_rows.jsonl",   shuffle_opts=False)

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/826 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [None]:
# Check if model can be trained

# Count train/eval rows and expected steps
from math import ceil
print("Train rows:", len(train_ds), "Eval rows:", len(eval_ds))

B = 8  # per_device_train_batch_size
GA = 2 # gradient_accumulation_steps
E = 3  # num_train_epochs
steps_per_epoch = ceil(len(train_ds) / (B))
opt_steps_per_epoch = ceil(len(train_ds) / (B * GA))
print("Steps/epoch (forward passes):", steps_per_epoch)
print("Optimizer steps/epoch:", opt_steps_per_epoch)
print("Total optimizer steps:", opt_steps_per_epoch * E)

# Confirm trainable parameters (LoRA) are actually on
trainable, total = 0, 0
for n, p in model.named_parameters():
    total += p.numel()
    if p.requires_grad:
        trainable += p.numel()
print(f"Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.3f}% expected small for LoRA)")

Train rows: 4075 Eval rows: 2240
Steps/epoch (forward passes): 510
Optimizer steps/epoch: 255
Total optimizer steps: 765
Trainable params: 32,505,856 / 8,062,808,064 (0.403% expected small for LoRA)


In [None]:
trainer = KLTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=data_collate,
)
trainer.train()
model.save_pretrained("/content/drive/MyDrive/LLM_POC_Study_2025/tllm_abortion_transitions_lora")
tokenizer.save_pretrained("/content/drive/MyDrive/LLM_POC_Study_2025/tllm_abortion_transitions_lora")

Step,Training Loss,Validation Loss
500,0.2029,0.057673




('/content/drive/MyDrive/LLM_POC_Study_2025/tllm_abortion_transitions_lora/tokenizer_config.json',
 '/content/drive/MyDrive/LLM_POC_Study_2025/tllm_abortion_transitions_lora/special_tokens_map.json',
 '/content/drive/MyDrive/LLM_POC_Study_2025/tllm_abortion_transitions_lora/tokenizer.json')

In [None]:
print("Global step:", trainer.state.global_step)
print("Epoch:", trainer.state.epoch)


Global step: 765
Epoch: 3.0


In [None]:
# Use Mistral
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch, json, random
from huggingface_hub import login

from google.colab import userdata
hf_token = userdata.get('HF_TOKEN')
login(hf_token)

OPT_TOKENS = [
  "<OPT_STRONG_ANTI>", "<OPT_ANTI>", "<OPT_NEUTRAL>", "<OPT_PRO>", "<OPT_STRONG_PRO>"
]
opt2id = {}

model_name = "mistralai/Mistral-7B-v0.3" # "meta-llama/llama-3.1-8b" # or "mistralai/Mistral-7B-v0.3" or "QWEN 2.5"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, token = hf_token)
tokenizer.add_special_tokens({"additional_special_tokens": OPT_TOKENS})

model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.bfloat16, device_map="auto", token=hf_token
)
model.resize_token_embeddings(len(tokenizer))

# Optional QLoRA
# model = prepare_model_for_kbit_training(model)
lora_cfg = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05,
    target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj"]
)
model = get_peft_model(model, lora_cfg)

# map special tokens to ids AFTER resize
opt2id = {opt: tokenizer.convert_tokens_to_ids(opt) for opt in OPT_TOKENS}
opt_ids = torch.tensor([opt2id[o] for o in OPT_TOKENS], device=model.device)

# --- Dataset ---
class TLLMRowDataset(torch.utils.data.Dataset):
    def __init__(self, jsonl_path, shuffle_opts=True):
        self.recs = [json.loads(x) for x in open(jsonl_path)]
        self.shuffle_opts = shuffle_opts
    def __len__(self): return len(self.recs)
    def __getitem__(self, i):
        r = self.recs[i]
        # Expect fields: prompt_text, to_dist (list float), weight (float)
        prompt = r["prompt_text"]
        # Randomize option order to reduce position bias
        order = list(range(len(OPT_TOKENS)))
        if self.shuffle_opts:
            random.shuffle(order)
        # Reorder target distribution accordingly
        to_dist = [r["to_dist"][j] for j in order]
        opt_tokens_ordered = [OPT_TOKENS[j] for j in order]

        prompt_with_opts = prompt + "Options: " + " ".join(opt_tokens_ordered) + "\nAnswer:\n"
        enc = tokenizer(prompt_with_opts, return_tensors="pt", truncation=True, max_length=1024)
        return {
            "input_ids": enc["input_ids"][0],
            "attention_mask": enc["attention_mask"][0],
            "to_dist": torch.tensor(to_dist, dtype=torch.float),
            "weight": torch.tensor(r.get("weight", 1.0), dtype=torch.float),
            "order": torch.tensor(order, dtype=torch.long),
        }

def data_collate(batch):
    # pad
    maxlen = max(x["input_ids"].shape[0] for x in batch)
    input_ids = []
    attn = []
    for x in batch:
        pad = maxlen - x["input_ids"].shape[0]
        input_ids.append(torch.cat([x["input_ids"], torch.full((pad,), tokenizer.pad_token_id or tokenizer.eos_token_id)]))
        attn.append(torch.cat([x["attention_mask"], torch.zeros(pad, dtype=torch.long)]))
    return {
        "input_ids": torch.stack(input_ids),
        "attention_mask": torch.stack(attn),
        "to_dist": torch.stack([x["to_dist"] for x in batch]),
        "weight": torch.stack([x["weight"] for x in batch]),
        "order": torch.stack([x["order"] for x in batch]),
    }

# --- Custom Trainer with KL loss over option tokens ---
import torch.nn.functional as F
class KLTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        to_dist = inputs["to_dist"].to(model.device)   # [B, K]
        weight = inputs["weight"].to(model.device)     # [B]
        order  = inputs["order"].to(model.device)      # [B, K]

        out = model(input_ids=input_ids.to(model.device),
                    attention_mask=attention_mask.to(model.device))
        # last token position per sequence
        last_idx = attention_mask.sum(dim=1) - 1  # [B]
        last_hidden = out.logits[torch.arange(out.logits.size(0)), last_idx]  # [B, V]

        # fetch logits for option tokens in the *ordered* list
        # remap global opt_ids according to per-example 'order'
        opt_ids_ordered = torch.stack([opt_ids[o] for o in order])  # [B,K]
        opt_logits = last_hidden.gather(1, opt_ids_ordered)          # [B,K]
        p_llm = F.softmax(opt_logits, dim=1)                         # [B,K]

        # forward KL: sum p_human * (log p_human - log p_llm)
        p_h = (to_dist / (to_dist.sum(dim=1, keepdim=True) + 1e-12)).clamp_min(1e-8)
        loss_vec = (p_h * (p_h.log() - (p_llm + 1e-12).log())).sum(dim=1)  # [B]
        # weight by n_from etc.
        loss = (weight * loss_vec).mean()

        return (loss, out) if return_outputs else loss

args = TrainingArguments(
    output_dir="/content/drive/MyDrive/LLM_POC_Study_2025/tllm_abortion_transitions_mistral",
    learning_rate=2e-4,  # LoRA can take a higher LR
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=3,
    logging_steps=50,
    eval_strategy="steps",
    eval_steps=500,
    save_steps=500,
    bf16=True,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    weight_decay=0.0,
    save_total_limit=2,
    report_to="none",
    label_names=["to_dist", "weight", "order"] # Add label_names to prevent removal
)

train_ds = TLLMRowDataset("/content/drive/MyDrive/LLM_POC_Study_2025/train_rows.jsonl", shuffle_opts=True)
eval_ds  = TLLMRowDataset("/content/drive/MyDrive/LLM_POC_Study_2025/val_rows.jsonl",   shuffle_opts=False)

trainer = KLTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=data_collate,
)
trainer.train()
model.save_pretrained("/content/drive/MyDrive/LLM_POC_Study_2025/tllm_abortion_transitions_mistral_lora")
tokenizer.save_pretrained("/content/drive/MyDrive/LLM_POC_Study_2025/tllm_abortion_transitions_mistral_lora")

tokenizer_config.json:   0%|          | 0.00/137k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/587k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/601 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.55G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Step,Training Loss,Validation Loss
500,0.2029,0.057849




('/content/drive/MyDrive/LLM_POC_Study_2025/tllm_abortion_transitions_mistral_lora/tokenizer_config.json',
 '/content/drive/MyDrive/LLM_POC_Study_2025/tllm_abortion_transitions_mistral_lora/special_tokens_map.json',
 '/content/drive/MyDrive/LLM_POC_Study_2025/tllm_abortion_transitions_mistral_lora/tokenizer.model',
 '/content/drive/MyDrive/LLM_POC_Study_2025/tllm_abortion_transitions_mistral_lora/added_tokens.json',
 '/content/drive/MyDrive/LLM_POC_Study_2025/tllm_abortion_transitions_mistral_lora/tokenizer.json')