<a href="https://colab.research.google.com/github/yilmajung/LLM_POC_Study_2025_v2/blob/main/n2_finetune_llms.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install transformers accelerate peft sentencepiece pandas pyarrow

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
import os, json, math, random
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model

In [21]:
# Set up paths
CS_CSV    = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/gss_abt_cs_full.csv"       # cross-sectional long
PANEL_CSV = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/gss_abt_panel_full.csv"    # panel long
OUT_DIR   = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/outputs_gss_multitask"          # where I write JSONLs & checkpoints
os.makedirs(OUT_DIR, exist_ok=True)

# LLM choice
BASE_MODEL_NAME = "meta-llama/llama-3.1-8b"  # or "mistralai/Mistral-7B-v0.3"

# Canonical bins (K=4)
CANON4 = ["strong_anti", "anti", "pro", "strong_pro"]
CAT2ID = {c:i for i,c in enumerate(CANON4)}
K = len(CANON4)
YEARS_CS = list(range(2006, 2025, 2))  # 2006..2024 every 2 years

In [None]:
# Load and harmonize the data
# --- Cross-sectional ---
cs = pd.read_csv(CS_CSV)
# Expect: yearid, year, abortion_att4, generation, race, gender, edu_level, wtssps
# Map the attitude to canonical
cs["att"] = cs["abortion_att4"].astype(str).str.strip()

# keep only canon categories, drop NAs
cs = cs[cs["att"].isin(CANON4)].copy()
cs["wt"] = cs.get("wtssps", pd.Series([1.0]*len(cs)))  # default 1.0 if missing

# --- Panel ---
pl = pd.read_csv(PANEL_CSV)
# Expect: id, year, abortion_att4, generation, race, gender, edu_level
pl["att"] = pl["abortion_att4"].astype(str).str.strip()
pl = pl[pl["att"].isin(CANON4)].copy()

# Define grouping keys
GROUP_COLS = ["generation","gender","race","edu_level"]
for df in (cs, pl):
    for c in GROUP_COLS:
        df[c] = df[c].astype(str).str.strip()

In [None]:
# Build cross-section margins p_cs[g,y] (weighted)
def group_key(row):
    return (row["generation"], row["gender"], row["race"], row['edu_level'])

def weighted_probs(vals, wts, cats=CANON4):
    # vals: list of category strings; wts: weights
    counts = {c:0.0 for c in cats}
    for v, w in zip(vals, wts):
        counts[v] += float(w)
    vec = np.array([counts[c] for c in cats], dtype=float)
    s = vec.sum()
    if s <= 0: return None
    return vec / s

# p_cs[(g,y)] -> np.array[K]
p_cs = {}
effN_cs = {}  # effective N for weighting samples in Task B
for y in YEARS_CS:
    sub = cs[cs["year"]==y]
    if sub.empty:
        continue
    for g_vals, df_g in sub.groupby(GROUP_COLS):
        p = weighted_probs(df_g["att"].tolist(), df_g["wt"].tolist(), CANON4)
        if p is None:
            continue
        p_cs[(g_vals, y)] = p
        effN_cs[(g_vals, y)] = float(df_g["wt"].sum())

In [None]:
# Build panel transitions C[(g,t,Δ)] from consecutive waves per yearid
# Build transitions per (g, t, Δ)
# C[(g,t,Δ)] -> KxK counts; Nfrom[(g,t,Δ)] -> row totals length K
from collections import defaultdict

def canon_index(cat):
    return CAT2ID.get(cat, None)

C = defaultdict(lambda: np.zeros((K,K), dtype=float))
Nfrom = defaultdict(lambda: np.zeros((K,), dtype=float))

# Weight=1.0
pl["w"] = 1.0

# Consecutive transitions per id
for (pid), df_id in pl.groupby("yearid"):
    df_id = df_id.sort_values("year")
    # collapse duplicates per year if any
    df_id = df_id.drop_duplicates(subset=["year"], keep="last")
    years = df_id["year"].values.tolist()
    atts  = df_id["att"].values.tolist()
    wgts  = df_id["w"].values.astype(float).tolist()
    gens  = df_id["generation"].values.tolist()
    gend  = df_id["gender"].values.tolist()
    race  = df_id["race"].values.tolist()
    edu = df_id["edu_level"].values.tolist()

    # require consistent group labels across waves for this id (common in panels)
    if not (len(set(gens))==1 and len(set(gend))==1 and len(set(race))==1 and len(set(edu))==1):
        # if you prefer, skip inconsistent cases
        continue

    g = (gens[0], gend[0], race[0], edu[0])
    for i in range(len(years)-1):
        t, t1 = int(years[i]), int(years[i+1])
        Δ = t1 - t
        if Δ not in (2,4):  # we only have 2-yr and 4-yr gaps here
            continue
        ai = canon_index(atts[i])
        aj = canon_index(atts[i+1])
        if ai is None or aj is None:
            continue
        w = float(wgts[i])  # or min(wgts[i], wgts[i+1])
        C[(g,t,Δ)][ai, aj] += w
        Nfrom[(g,t,Δ)][ai] += w


In [None]:
# JSONL builders: Task A (panel rows) and Task B (margins)
def smooth_row(row_counts, alpha):
    rc = np.array(row_counts, dtype=float) + alpha
    s = rc.sum()
    if s <= 0:
        return np.ones_like(rc)/len(rc)
    return rc / s

def build_taskA_rows(C, Nfrom, p_cs, out_jsonl, alpha_small=0.05, alpha_big=0.25, n_thresh=20, cap_weight=100.0):
    out = open(out_jsonl, "w", encoding="utf-8")
    n_rows = 0
    for (g,t,Δ), mat in C.items():
        nrow = Nfrom[(g,t,Δ)]
        for i in range(K):
            alpha = alpha_small if nrow[i] >= n_thresh else alpha_big
            tgt = smooth_row(mat[i,:], alpha=alpha).tolist()
            # Build prompt
            prompt = (
                "[Task: Predict transition row]\n"
                f"From: <Y{t}> → To: <Y{t+Δ}> <DT{Δ}>\n"
                f"Group: generation={g[0]}; gender={g[1]}; race={g[2]}; edu_level={g[3]}\n"
                f"From option: {CANON4[i]}\n"
                "Answer:\n"
            )
            rec = {
                "task": "row",
                "group": {"generation": g[0], "gender": g[1], "race": g[2], "edu_level": g[3]},
                "year_t": t,
                "year_t1": t+Δ,
                "dt": Δ,
                "from_bin": CANON4[i],
                "prompt_text": prompt,
                "to_dist": tgt,
                "weight": float(min(nrow[i], cap_weight))
            }
            # Attach margins if both available (for consistency loss later)
            if ((g,t) in p_cs) and ((g,t+Δ) in p_cs):
                rec["p_curr"] = p_cs[(g,t)].tolist()
                rec["p_next"] = p_cs[(g,t+Δ)].tolist()
            out.write(json.dumps(rec) + "\n")
            n_rows += 1
    out.close()
    return n_rows

def build_taskB_rows(p_cs, effN_cs, out_jsonl, lags=(2,4), cap_weight=500.0):
    out = open(out_jsonl, "w", encoding="utf-8")
    n_rows = 0
    # group by g
    by_g = {}
    for (g,y) in p_cs.keys():
        by_g.setdefault(g, []).append(y)
    for g, years in by_g.items():
        ys = sorted(years)
        for y in ys:
            # context from previous lags
            ctx = []
            for L in lags:
                yprev = y - L
                if (g, yprev) in p_cs:
                    ctx.append((yprev, p_cs[(g, yprev)]))
            if len(ctx) == 0:
                continue
            # Build prompt
            ctx_parts = " ".join([f"<Y{yy}>[{','.join(f'{x:.4f}' for x in p)}]" for (yy,p) in ctx])
            prompt = (
                "[Task: Forecast next-wave margin]\n"
                f"Group: generation={g[0]}; gender={g[1]}; race={g[2]}; edu_level={g[3]}\n"
                f"Context: {ctx_parts}\n"
                f"Predict: <Y{y}>\n"
                "Answer:\n"
            )
            w = float(min(effN_cs.get((g,y), 1.0), cap_weight))
            rec = {
                "task": "margin",
                "group": {"generation": g[0], "gender": g[1], "race": g[2], "edu_level": g[3]},
                "year": y,
                "prompt_text": prompt,
                "to_dist": p_cs[(g,y)].tolist(),
                "weight": w
            }
            out.write(json.dumps(rec) + "\n")
            n_rows += 1
    out.close()
    return n_rows


In [26]:
TASKA_JSONL = os.path.join(OUT_DIR, "taskA_panel_rows.jsonl")
TASKB_JSONL = os.path.join(OUT_DIR, "taskB_margins.jsonl")

na = build_taskA_rows(C, Nfrom, p_cs, TASKA_JSONL)
nb = build_taskB_rows(p_cs, effN_cs, TASKB_JSONL)
print(f"Wrote Task A rows: {na}")
print(f"Wrote Task B rows: {nb}")


Wrote Task A rows: 712
Wrote Task B rows: 228


In [27]:
# Datasets and Collator (multitask)
class MTJsonlDataset(torch.utils.data.Dataset):
    def __init__(self, jsonl_paths: List[str], tokenizer, max_len=768):
        self.rows = []
        for p in jsonl_paths:
            with open(p, "r", encoding="utf-8") as f:
                self.rows.extend([json.loads(x) for x in f])
        random.shuffle(self.rows)
        self.tok = tokenizer
        self.max_len = max_len

    def __len__(self): return len(self.rows)

    def __getitem__(self, i):
        r = self.rows[i]
        enc = self.tok(r["prompt_text"], return_tensors="pt", truncation=True, max_length=self.max_len)
        out = {
            "task_type": r["task"],
            "input_ids": enc["input_ids"][0],
            "attention_mask": enc["attention_mask"][0],
            "to_dist": torch.tensor(r["to_dist"], dtype=torch.float),
            "weight": torch.tensor(float(r.get("weight", 1.0)), dtype=torch.float),
        }
        # add optional consistency fields
        if r["task"]=="row" and ("p_curr" in r) and ("p_next" in r):
            out["p_curr"] = torch.tensor(r["p_curr"], dtype=torch.float)
            out["p_next"] = torch.tensor(r["p_next"], dtype=torch.float)
            out["has_consistency"] = torch.tensor(1, dtype=torch.long)
        else:
            out["has_consistency"] = torch.tensor(0, dtype=torch.long)
        return out

def mt_collate(batch):
    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
    maxlen = max(x["input_ids"].shape[0] for x in batch)

    def pad(seq, pad_val, target_len):
        pad_n = target_len - seq.shape[0]
        if pad_n <= 0: return seq
        return torch.cat([seq, torch.full((pad_n,), pad_val, dtype=seq.dtype)])

    input_ids      = torch.stack([pad(x["input_ids"], pad_id, maxlen) for x in batch])
    attention_mask = torch.stack([pad(x["attention_mask"], 0, maxlen) for x in batch])
    to_dist        = torch.stack([x["to_dist"] for x in batch])
    weight         = torch.stack([x["weight"] for x in batch])
    has_cons       = torch.stack([x["has_consistency"] for x in batch])
    # p_curr/p_next if present; else zeros
    p_curr = torch.zeros(len(batch), K, dtype=torch.float)
    p_next = torch.zeros(len(batch), K, dtype=torch.float)
    for i, x in enumerate(batch):
        if x["has_consistency"]==1:
            p_curr[i] = x["p_curr"]
            p_next[i] = x["p_next"]

    task_types = [x["task_type"] for x in batch]

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "to_dist": to_dist,
        "weight": weight,
        "has_consistency": has_cons,
        "p_curr": p_curr,
        "p_next": p_next,
        "task_types": task_types,
    }


In [None]:
# Model: LLM backbone + two small heads + last-K pooling
# Tokenizer & backbone
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    output_hidden_states=True
)

# Add LoRA to attention/MLP projections
lora_cfg = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05,
    target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj"]
)
model = get_peft_model(base, lora_cfg)

# Two task heads
class TwoHead(nn.Module):
    def __init__(self, hidden_size, K):
        super().__init__()
        self.head_row    = nn.Linear(hidden_size, K)   # Task A
        self.head_margin = nn.Linear(hidden_size, K)   # Task B

    def forward(self, feats):
        return self.head_row(feats), self.head_margin(feats)

hidden_size = base.config.hidden_size
two_head = TwoHead(hidden_size, K).to(model.device)

# Simple pooled features: mean of last K tokens (tail window)
def pooled_features(outputs, attention_mask, tail=96):
    hs = outputs.hidden_states[-1]     # [B,T,H]
    B, T, H = hs.shape
    valid_lens = attention_mask.sum(dim=1)  # [B]
    feats = []
    for b in range(B):
        L = int(valid_lens[b].item())
        s = max(0, L - tail); e = L
        if e <= s: s, e = max(0, L-32), L
        feats.append(hs[b, s:e, :].mean(dim=0))
    feats = torch.stack(feats, dim=0)
    return feats


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [31]:
# Trainer with multitask losses (KL + optional consistency)
class MTTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        self.two_head = kwargs.pop("two_head")
        self.lambda_A = kwargs.pop("lambda_A", 1.0)
        self.lambda_B = kwargs.pop("lambda_B", 1.0)
        self.lambda_C = kwargs.pop("lambda_C", 0.5)
        super().__init__(*args, **kwargs)

    def compute_loss(
        self,
        model,
        inputs,
        return_outputs: bool = False,
        num_items_in_batch: Optional[int] = None,  # <-- accept the kwarg
    ):
        input_ids      = inputs["input_ids"].to(model.device)
        attention_mask = inputs["attention_mask"].to(model.device)
        to_dist        = inputs["to_dist"].to(model.device)     # [B,K]
        weight         = inputs["weight"].to(model.device)      # [B]
        has_cons       = inputs["has_consistency"].to(model.device)  # [B]
        p_curr         = inputs["p_curr"].to(model.device)      # [B,K]
        p_next         = inputs["p_next"].to(model.device)      # [B,K]
        task_types     = inputs["task_types"]                   # list[str]

        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        feats = pooled_features(out, attention_mask, tail=96).to(model.device)  # [B,H]

        logits_row, logits_margin = self.two_head(feats)   # [B,K], [B,K]
        p_row_hat    = F.softmax(logits_row, dim=1)
        p_margin_hat = F.softmax(logits_margin, dim=1)

        # Build masks by task
        is_row    = torch.tensor([t == "row"    for t in task_types], device=model.device, dtype=torch.bool)
        is_margin = torch.tensor([t == "margin" for t in task_types], device=model.device, dtype=torch.bool)

        eps = 1e-8
        def fwd_kl(p, q):
            p = p.clamp_min(eps); q = q.clamp_min(eps)
            return (p * (p.log() - q.log())).sum(dim=1)

        loss = torch.tensor(0.0, device=model.device)

        # Task A: panel rows
        if is_row.any():
            L_A = fwd_kl(to_dist[is_row], p_row_hat[is_row])
            loss = loss + self.lambda_A * (weight[is_row] * L_A).mean()

        # Task B: margins
        if is_margin.any():
            L_B = fwd_kl(to_dist[is_margin], p_margin_hat[is_margin])
            loss = loss + self.lambda_B * (weight[is_margin] * L_B).mean()

        # Consistency (disabled unless you wire full-row assembly):
        # cons_mask = is_row & (has_cons==1)
        # if cons_mask.any() and self.lambda_C > 0:
        #     ...

        if return_outputs:
            return loss, {"logits_row": logits_row, "logits_margin": logits_margin, "labels": to_dist}
        else:
            return loss



In [33]:
# Training setup
tokenizer.padding_side = "left"  # (helps with some backbones; optional)

train_ds = MTJsonlDataset([TASKA_JSONL, TASKB_JSONL], tokenizer, max_len=768)
# For a quick start you can split train/val here:
val_frac = 0.1
n_val = int(len(train_ds) * val_frac)
indices = list(range(len(train_ds)))
random.seed(0); random.shuffle(indices)
val_idx  = set(indices[:n_val])
train_idx= set(indices[n_val:])

class SubsetDS(torch.utils.data.Dataset):
    def __init__(self, base, keep_idx):
        self.base = base
        self.keep = sorted(list(keep_idx))
    def __len__(self): return len(self.keep)
    def __getitem__(self, i): return self.base[self.keep[i]]

ds_train = SubsetDS(train_ds, train_idx)
ds_val   = SubsetDS(train_ds, val_idx)

args = TrainingArguments(
    output_dir=os.path.join(OUT_DIR, "ckpt"),
    learning_rate=1e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=1,
    num_train_epochs=20,
    logging_steps=50,
    eval_strategy="steps",
    eval_steps=200,
    save_steps=200,
    save_total_limit=2,
    bf16=True,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    warmup_steps=300,
    weight_decay=0.0,
    report_to="none",
    remove_unused_columns=False,  # keep our custom fields
)

trainer = MTTrainer(
    model=model,
    args=args,
    train_dataset=ds_train,
    eval_dataset=ds_val,
    data_collator=mt_collate,
    two_head=two_head,
    lambda_A=1.0,
    lambda_B=1.0,
    lambda_C=0.0,   # keep 0.0 for now; we’ll add strict consistency later
)

trainer.train()

# Save adapter + heads
save_dir = os.path.join(OUT_DIR, "final_multitask")
os.makedirs(save_dir, exist_ok=True)
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
torch.save(two_head.state_dict(), os.path.join(save_dir, "two_head.pt"))
print("Saved to:", save_dir)


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


Step,Training Loss,Validation Loss
200,1.185,No log
400,1.8792,No log
600,1.6729,No log
800,1.2961,No log
1000,1.1278,No log
1200,1.079,No log
1400,0.7647,No log
1600,0.4841,No log
1800,0.3148,No log
2000,0.2587,No log


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


Saved to: /content/drive/MyDrive/LLM_POC_Study_2025_v2/outputs_gss_multitask/final_multitask
