In [None]:
# Install & imports (Colab)
!pip -q install peft transformers accelerate sentencepiece

import os, json, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Dict, List
from transformers import (AutoTokenizer, AutoModelForCausalLM,
                          Trainer, TrainingArguments, PreTrainedModel)
from peft import LoraConfig, get_peft_model
from google.colab import userdata
hf_token = userdata.get('HF_TOKEN', None)

In [None]:
# ============
# Config paths
# ============
BASE_MODEL_NAME = "meta-llama/llama-3.1-8b"   # or "mistralai/Mistral-7B-v0.3"
TRAIN_JSONL = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/train_rows.jsonl"
VAL_JSONL   = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/val_rows.jsonl"
OUT_DIR     = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/tllm_fixA_head"

os.makedirs(OUT_DIR, exist_ok=True)

# =====================
# Dataset & data collate
# =====================
class TLLMRowDataset(torch.utils.data.Dataset):
    def __init__(self, jsonl_path, tokenizer, max_len=1024):
        self.rows = [json.loads(x) for x in open(jsonl_path, "r", encoding="utf-8")]
        self.tok = tokenizer
        self.max_len = max_len
    def __len__(self): return len(self.rows)
    def __getitem__(self, i):
        r = self.rows[i]
        # EXACT same prompt_text you used in training
        text = r["prompt_text"]
        enc = self.tok(text, return_tensors="pt", truncation=True, max_length=self.max_len)
        return {
            "input_ids": enc["input_ids"][0],
            "attention_mask": enc["attention_mask"][0],
            "to_dist": torch.tensor(r["to_dist"], dtype=torch.float),
            "weight": torch.tensor(float(r.get("weight", 1.0)), dtype=torch.float)
        }

def data_collate(batch):
    # simple pad to max len
    maxlen = max(x["input_ids"].shape[0] for x in batch)
    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
    input_ids, attn, targets, weights = [], [], [], []
    for x in batch:
        pad = maxlen - x["input_ids"].shape[0]
        input_ids.append(torch.cat([x["input_ids"], torch.full((pad,), pad_id)]))
        attn.append(torch.cat([x["attention_mask"], torch.zeros(pad, dtype=torch.long)]))
        targets.append(x["to_dist"])
        weights.append(x["weight"])
    return {
        "input_ids": torch.stack(input_ids),
        "attention_mask": torch.stack(attn),
        "to_dist": torch.stack(targets),
        "weight": torch.stack(weights),
    }

# =======================
# Load tokenizer & base LM
# =======================
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, use_fast=True, token=hf_token)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=hf_token,
    # ensure hidden states are returned
    output_hidden_states=True
)

# =========
# Add LoRA (optional but helpful)
# =========
lora_cfg = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05,
    target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj"]
)
model = get_peft_model(base, lora_cfg)

# ==========================
# Small 5-way classification head
# ==========================
class DistHead(nn.Module):
    def __init__(self, hidden_size, K=5):
        super().__init__()
        self.out = nn.Linear(hidden_size, K)
    def forward(self, last_hidden_vec):  # [B, H]
        return self.out(last_hidden_vec) # [B, K]

# attach the head to the model
hidden_size = base.config.hidden_size
model.dist_head = DistHead(hidden_size, K=5).to(model.device)

# ==========================
# Custom Trainer with JS loss
# ==========================
class HeadTrainer(Trainer):
    def compute_loss(self, model: PreTrainedModel, inputs, return_outputs=False, num_items_in_batch=None):
        input_ids = inputs["input_ids"].to(model.device)
        attention_mask = inputs["attention_mask"].to(model.device)
        to_dist = inputs["to_dist"].to(model.device)        # [B,5]
        weight  = inputs["weight"].to(model.device)         # [B]

        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
            # last token index
            last_idx = attention_mask.sum(dim=1) - 1
            # last hidden state from last layer: out.hidden_states[-1] shape [B, T, H]
            last_vec = out.hidden_states[-1][torch.arange(out.hidden_states[-1].size(0)), last_idx]  # [B, H]
            logits5 = model.dist_head(last_vec)         # [B,5]
            p_llm = F.softmax(logits5, dim=1)           # [B,5]

        # normalize targets
        p_h = to_dist / (to_dist.sum(dim=1, keepdim=True) + 1e-12)
        p_h = p_h.clamp_min(1e-8)

        # Jensen-Shannon divergence
        m = 0.5*(p_h + p_llm)
        kl1 = (p_h * (p_h.add(1e-12).log() - m.add(1e-12).log())).sum(dim=1)
        kl2 = (p_llm * (p_llm.add(1e-12).log() - m.add(1e-12).log())).sum(dim=1)
        loss_vec = 0.5*(kl1 + kl2)

        loss = (weight * loss_vec).mean()
        return (loss, {"logits5": logits5}) if return_outputs else loss

# =========
# Datasets
# =========
train_ds = TLLMRowDataset(TRAIN_JSONL, tokenizer, max_len=1024)
eval_ds  = TLLMRowDataset(VAL_JSONL,   tokenizer, max_len=1024)

# =================
# Training settings
# =================
args = TrainingArguments(
    output_dir=OUT_DIR,
    learning_rate=1e-4,                   # LoRA + small head
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=6,                   # give it a bit more runway
    logging_steps=50,
    eval_strategy="steps",
    eval_steps=200,
    save_steps=200,
    save_total_limit=2,
    bf16=True,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    weight_decay=0.0,
    report_to="none",
    label_names=["to_dist", "weight"]    # keep labels
)

# Ensure the head’s parameters are trainable
for n,p in model.dist_head.named_parameters():
    p.requires_grad_(True)

trainer = HeadTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=data_collate,
)

trainer.train()

# =========================
# Save adapter + head + tok
# =========================
model.save_pretrained(os.path.join(OUT_DIR, "lora"))
tokenizer.save_pretrained(os.path.join(OUT_DIR, "lora"))
# Save the small head weights separately
torch.save(model.dist_head.state_dict(), os.path.join(OUT_DIR, "dist_head.pt"))
print("Saved LoRA + head to:", OUT_DIR)
