<a href="https://colab.research.google.com/github/yilmajung/LLM_POC_Study_2025_v2/blob/main/1_4_train_head.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install & imports (Colab)
!pip -q install peft transformers accelerate sentencepiece

import os, json, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Dict, List
from transformers import (AutoTokenizer, AutoModelForCausalLM,
                          Trainer, TrainingArguments, PreTrainedModel)
from peft import LoraConfig, get_peft_model
from google.colab import userdata
hf_token = userdata.get('HF_TOKEN')

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Config paths
BASE_MODEL_NAME = "meta-llama/llama-3.1-8b"   # or "mistralai/Mistral-7B-v0.3"
TRAIN_JSONL = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/train_rows.jsonl"
VAL_JSONL   = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/val_rows.jsonl"
OUT_DIR     = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/tllm_fixA_head"

os.makedirs(OUT_DIR, exist_ok=True)


# Dataset & data collate
class TLLMRowDataset(torch.utils.data.Dataset):
    def __init__(self, jsonl_path, tokenizer, max_len=1024):
        self.rows = [json.loads(x) for x in open(jsonl_path, "r", encoding="utf-8")]
        self.tok = tokenizer; self.max_len = max_len
    def __len__(self): return len(self.rows)
    def __getitem__(self, i):
        r = self.rows[i]
        enc = self.tok(r["prompt_text"], return_tensors="pt", truncation=True, max_length=self.max_len)
        return {
            "input_ids": enc["input_ids"][0],
            "attention_mask": enc["attention_mask"][0],
            "to_dist": torch.tensor(r["to_dist"], dtype=torch.float),
            "weight": torch.tensor(float(r.get("weight", 1.0)), dtype=torch.float),
            "text": r["prompt_text"],
        }

def data_collate(batch):
    maxlen = max(x["input_ids"].shape[0] for x in batch)
    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
    input_ids, attn, targets, weights, texts = [], [], [], [], []
    for x in batch:
        pad = maxlen - x["input_ids"].shape[0]
        input_ids.append(torch.cat([x["input_ids"], torch.full((pad,), pad_id)]))
        attn.append(torch.cat([x["attention_mask"], torch.zeros(pad, dtype=torch.long)]))
        targets.append(x["to_dist"]); weights.append(x["weight"]); texts.append(x["text"])
    return {
        "input_ids": torch.stack(input_ids),
        "attention_mask": torch.stack(attn),
        "to_dist": torch.stack(targets),
        "weight": torch.stack(weights),
        "texts": texts,
    }


# class TLLMRowDataset(torch.utils.data.Dataset):
#     def __init__(self, jsonl_path, tokenizer, max_len=1024):
#         self.rows = [json.loads(x) for x in open(jsonl_path, "r", encoding="utf-8")]
#         self.tok = tokenizer
#         self.max_len = max_len
#     def __len__(self): return len(self.rows)
#     def __getitem__(self, i):
#         r = self.rows[i]
#         # EXACT same prompt_text you used in training
#         text = r["prompt_text"]
#         enc = self.tok(text, return_tensors="pt", truncation=True, max_length=self.max_len)
#         return {
#             "input_ids": enc["input_ids"][0],
#             "attention_mask": enc["attention_mask"][0],
#             "to_dist": torch.tensor(r["to_dist"], dtype=torch.float),
#             "weight": torch.tensor(float(r.get("weight", 1.0)), dtype=torch.float)
#         }

# def data_collate(batch):
#     # simple pad to max len
#     maxlen = max(x["input_ids"].shape[0] for x in batch)
#     pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
#     input_ids, attn, targets, weights = [], [], [], []
#     for x in batch:
#         pad = maxlen - x["input_ids"].shape[0]
#         input_ids.append(torch.cat([x["input_ids"], torch.full((pad,), pad_id)]))
#         attn.append(torch.cat([x["attention_mask"], torch.zeros(pad, dtype=torch.long)]))
#         targets.append(x["to_dist"])
#         weights.append(x["weight"])
#     return {
#         "input_ids": torch.stack(input_ids),
#         "attention_mask": torch.stack(attn),
#         "to_dist": torch.stack(targets),
#         "weight": torch.stack(weights),
#     }


# Load tokenizer & base LM
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, use_fast=True, token=hf_token)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=hf_token,
    # ensure hidden states are returned
    output_hidden_states=True
)


# Add LoRA
lora_cfg = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05,
    target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj"]
)
model = get_peft_model(base, lora_cfg)

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/826 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

In [None]:
# Small 5-way classification head
class DistHead(nn.Module):
    def __init__(self, hidden_size, K=5):
        super().__init__()
        self.out = nn.Linear(hidden_size, K)
    def forward(self, last_hidden_vec):  # [B, H]
        return self.out(last_hidden_vec) # [B, K]

# attach the head to the model
hidden_size = base.config.hidden_size
model.dist_head = DistHead(hidden_size, K=5).to(model.device)


# Find the span start at "Group:" and end at the line with "From option:"
# (We do this per-example via string positions mapped to token indices.)
def find_span_indices(text, tokenizer, enc):
    # byte offsets
    group_pos = text.find("Group:")
    from_pos  = text.find("From option:")
    if group_pos == -1 or from_pos == -1 or from_pos <= group_pos:
        # fallback: use whole sequence except the first line
        return 0, int(enc["attention_mask"][0].sum().item()) - 1
    # Get token start/end by tokenizing prefixes
    pre_group = tokenizer(text[:group_pos], return_tensors="pt")
    pre_from  = tokenizer(text[:from_pos],  return_tensors="pt")
    start = pre_group["input_ids"].shape[1]
    end   = pre_from["input_ids"].shape[1] + 8  # include a few tokens into "From option:" line
    return start, end

class HeadTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        input_ids      = inputs["input_ids"].to(model.device)
        attention_mask = inputs["attention_mask"].to(model.device)
        to_dist        = inputs["to_dist"].to(model.device)
        weight         = inputs["weight"].to(model.device)

        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)

        # Build pooled features per example over [Group: ... From option:] span
        hs = out.hidden_states[-1]  # [B, T, H]
        B, T, H = hs.shape
        feats = []
        # We need original texts to locate spans; add them to batch via dataset/collator
        texts = inputs["texts"]  # add this field in your collate (see below)
        for b in range(B):
            text_b = texts[b]
            # re-encode single example (CPU) to get token positions reliably
            enc_b = tokenizer(text_b, return_tensors="pt", truncation=True, max_length=1024)
            s, e = find_span_indices(text_b, tokenizer, enc_b)
            s = max(0, min(s, T-1)); e = max(0, min(e, T))
            if e <= s:  # fallback: use last 64 non-pad tokens
                valid_len = int(attention_mask[b].sum().item())
                s = max(0, valid_len - 64); e = valid_len
            vec = hs[b, s:e, :].mean(dim=0)  # [H]
            feats.append(vec)
        feats = torch.stack(feats, dim=0)  # [B, H]

        logits5 = model.dist_head(feats)         # [B, 5]
        p_llm   = F.softmax(logits5, dim=1)

        p_h = (to_dist / (to_dist.sum(dim=1, keepdim=True) + 1e-12)).clamp_min(1e-8)
        m   = 0.5 * (p_h + p_llm)
        kl1 = (p_h   * (p_h.add(1e-12).log() - m.add(1e-12).log())).sum(dim=1)
        kl2 = (p_llm * (p_llm.add(1e-12).log() - m.add(1e-12).log())).sum(dim=1)
        loss_vec = 0.5 * (kl1 + kl2)
        loss = (weight * loss_vec).mean()

        return loss


# # Custom Trainer with JS loss
# class HeadTrainer(Trainer):
#     def compute_loss(self, model: PreTrainedModel, inputs, return_outputs=False, num_items_in_batch=None):
#         input_ids = inputs["input_ids"].to(model.device)
#         attention_mask = inputs["attention_mask"].to(model.device)
#         to_dist = inputs["to_dist"].to(model.device)        # [B,5]
#         weight  = inputs["weight"].to(model.device)         # [B]

#         with torch.cuda.amp.autocast(dtype=torch.bfloat16):
#             out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
#             # last token index
#             last_idx = attention_mask.sum(dim=1) - 1
#             # last hidden state from last layer: out.hidden_states[-1] shape [B, T, H]
#             last_vec = out.hidden_states[-1][torch.arange(out.hidden_states[-1].size(0)), last_idx]  # [B, H]
#             logits5 = model.dist_head(last_vec)         # [B,5]
#             p_llm = F.softmax(logits5, dim=1)           # [B,5]

#         # normalize targets
#         p_h = to_dist / (to_dist.sum(dim=1, keepdim=True) + 1e-12)
#         p_h = p_h.clamp_min(1e-8)

#         # Jensen-Shannon divergence
#         m = 0.5*(p_h + p_llm)
#         kl1 = (p_h * (p_h.add(1e-12).log() - m.add(1e-12).log())).sum(dim=1)
#         kl2 = (p_llm * (p_llm.add(1e-12).log() - m.add(1e-12).log())).sum(dim=1)
#         loss_vec = 0.5*(kl1 + kl2)

#         loss = (weight * loss_vec).mean()
#         return (loss, {"logits5": logits5}) if return_outputs else loss


# Datasets
train_ds = TLLMRowDataset(TRAIN_JSONL, tokenizer, max_len=1024)
eval_ds  = TLLMRowDataset(VAL_JSONL,   tokenizer, max_len=1024)


# Training settings
args = TrainingArguments(
    output_dir=OUT_DIR,
    learning_rate=1e-4,                   # LoRA + small head
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=6,                   # give it a bit more runway
    logging_steps=50,
    eval_strategy="steps",
    eval_steps=200,
    save_steps=200,
    save_total_limit=2,
    bf16=True,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    weight_decay=0.0,
    report_to="none",
    label_names=["to_dist", "weight"]    # keep labels
)

# Ensure the head’s parameters are trainable
for n,p in model.dist_head.named_parameters():
    p.requires_grad_(True)

trainer = HeadTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=data_collate,
)

trainer.train()


# Save adapter + head + tok
model.save_pretrained(os.path.join(OUT_DIR, "lora"))
tokenizer.save_pretrained(os.path.join(OUT_DIR, "lora"))
# Save the small head weights separately
torch.save(model.dist_head.state_dict(), os.path.join(OUT_DIR, "dist_head.pt"))
print("Saved LoRA + head to:", OUT_DIR)


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


Step,Training Loss,Validation Loss
200,0.0277,0.016767
400,0.0176,0.018134
600,0.013,0.014208
800,0.0091,0.017408
1000,0.0099,0.014102
1200,0.0075,0.013593
1400,0.0076,0.01344


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


Saved LoRA + head to: /content/drive/MyDrive/LLM_POC_Study_2025_v2/tllm_fixA_head
