In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# Config
CS_PATH = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/gss_abt_cs_full.csv"
PL_PATH = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/gss_abt_panel_full.csv"
OUT_DIR = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/outputs_gss_multitask_orig_subgroup"

CANON4 = ["strong_anti","anti","pro","strong_pro"]; K = len(CANON4)
CAT2ID = {c:i for i,c in enumerate(CANON4)}

GROUP_SCHEMES = {
    "GROUP_COLS_1": ["gender"],
    "GROUP_COLS_2": ["gender","political_views"],
    "GROUP_COLS_3": ["gender","political_views","edu_level"],
    "GROUP_COLS_4": ["gender","political_views","edu_level","generation"],
    "GROUP_COLS_5": ["gender","political_views","edu_level","generation","race"],
    "GROUP_COLS_ORIG": ["gender","edu_level","generation","race"],
}
CURRENT_GROUP_SCHEME = "GROUP_COLS_ORIG"  # <<< pick a scheme

def get_group_cols(): return GROUP_SCHEMES[CURRENT_GROUP_SCHEME]
print("Active grouping:", get_group_cols())


Active grouping: ['gender', 'edu_level', 'generation', 'race']


In [3]:
import os, json, numpy as np, pandas as pd
from collections import defaultdict
os.makedirs(OUT_DIR, exist_ok=True)


In [4]:
def group_meta_from_tuple(g_tuple, cols=None):
    if cols is None: cols = get_group_cols()
    return {c:v for c,v in zip(cols, g_tuple)}

def format_group_for_prompt(meta, cols=None):
    import pandas as pd
    if cols is None: cols = get_group_cols()
    parts=[]
    for c in cols:
        v=meta.get(c,"NA")
        if isinstance(v,float) and pd.isna(v): v="NA"
        parts.append(f"{c}={v}")
    return "Group: " + "; ".join(parts)

def canon_index(x):
    if isinstance(x,(int,float)) and not pd.isna(x):
        mapping={1:"strong_anti",2:"anti",3:"pro",4:"strong_pro"}
        return CAT2ID.get(mapping.get(int(x), None), None)
    if isinstance(x,str):
        x=x.strip().lower()
        alias={"strongly oppose":"strong_anti","strong_anti":"strong_anti","oppose":"anti","anti":"anti",
               "favor":"pro","pro":"pro","strongly favor":"strong_pro","strong_pro":"strong_pro"}
        return CAT2ID.get(alias.get(x, None), None)
    return None


In [5]:
# Load data
cs = pd.read_csv(CS_PATH)
pl = pd.read_csv(PL_PATH)
for df in (cs, pl):
    if "id" in df.columns: df.rename(columns={"id":"yearid"}, inplace=True)
    df["year"] = df["year"].astype(int)

cs["att_id"] = cs["abortion_att4"].apply(canon_index)
pl["att_id"] = pl["abortion_att4"].apply(canon_index)
cs = cs.dropna(subset=["att_id"]).copy(); cs["att_id"]=cs["att_id"].astype(int)
pl = pl.dropna(subset=["att_id"]).copy(); pl["att_id"]=pl["att_id"].astype(int)
print("CS rows:", len(cs), "PL rows:", len(pl))


CS rows: 13351 PL rows: 12610


In [6]:
# Cross-sectional margins (weighted)
def weighted_margins(df, group_cols, w_col="wtssps"):
    out={}; effN={}
    for (keys,gdf) in df.groupby(group_cols+["year"], dropna=False):
        *g_vals, y = keys
        w = gdf[w_col].fillna(0.0).to_numpy(float)
        k = gdf["att_id"].to_numpy(int)
        vec = np.zeros(K, dtype=float)
        for ki, wi in zip(k, w):
            if 0 <= ki < K: vec[ki] += wi
        tot = vec.sum()
        if tot <= 0: continue
        out[(tuple(g_vals), int(y))] = vec/tot
        sw = w.sum(); sw2 = (w**2).sum()
        eff = (sw**2/sw2) if sw2>0 else len(w)
        effN[(tuple(g_vals), int(y))] = float(eff)
    return out, effN

p_cs, effN_cs = weighted_margins(cs, get_group_cols())
print("CS cells:", len(p_cs))


CS cells: 821


In [7]:
# Panel transitions with dynamic grouping
INCONSISTENCY_STRATEGY = "anchor_t"  # "strict" | "anchor_t" | "mode"
C = defaultdict(lambda: np.zeros((K,K), dtype=float))
Nfrom = defaultdict(lambda: np.zeros((K,), dtype=float))

pl2 = pl.sort_values(["yearid","year"]).drop_duplicates(["yearid","year"], keep="last").copy()
pl2["w"] = 1.0

def group_tuple_at(df_id, t, cols):
    row = df_id.loc[df_id["year"]==t]
    if row.empty: return None
    row = row.iloc[-1]
    return tuple(row[c] for c in cols)

for pid, df_id in pl2.groupby("yearid", sort=False):
    df_id = df_id.sort_values("year")
    yrs = df_id["year"].astype(int).tolist()
    att = df_id["att_id"].astype(int).tolist()
    wg  = df_id["w"].astype(float).tolist()
    if len(yrs) < 2: continue
    for i in range(len(yrs)-1):
        t, t1 = yrs[i], yrs[i+1]; d = t1 - t
        if d not in (2,4): continue
        ai, aj = att[i], att[i+1]
        if INCONSISTENCY_STRATEGY == "anchor_t":
            g_tuple = group_tuple_at(df_id, t, get_group_cols())
        elif INCONSISTENCY_STRATEGY == "mode":
            g_tuple = tuple(df_id[c].mode(dropna=False).iloc[0] for c in get_group_cols())
        else:
            ok = all(len(set(df_id[c].tolist()))==1 for c in get_group_cols())
            g_tuple = tuple(df_id.iloc[0][c] for c in get_group_cols()) if ok else None
        if g_tuple is None: continue
        w = float(wg[i])
        C[(g_tuple, t, d)][ai, aj] += w
        Nfrom[(g_tuple, t, d)][ai]  += w

print("Panel cells:", len(C))


Panel cells: 400


In [8]:
# Write training JSONL
import json, os
TA_PATH = os.path.join(OUT_DIR, f"taskA_panel_rows_{CURRENT_GROUP_SCHEME}.jsonl")
TB_PATH = os.path.join(OUT_DIR, f"taskB_cs_margins_{CURRENT_GROUP_SCHEME}.jsonl")

def write_taskA(C,Nfrom,p_cs,out_path,alpha=0.25):
    n=0
    with open(out_path,"w",encoding="utf-8") as f:
        for (g_tuple,t,d), mat in C.items():
            rowsums=Nfrom[(g_tuple,t,d)]
            meta = group_meta_from_tuple(g_tuple)
            for k_from in range(K):
                if rowsums[k_from] <= 0: continue
                row = mat[k_from,:].astype(float)
                vec = (row + alpha) / max(row.sum() + alpha*K, 1e-12)
                prompt = (
                    "[Task: Predict transition row]\n"
                    f"From: <Y{t}> → To: <Y{t+d}> <DT{d}>\n"
                    f"{format_group_for_prompt(meta)}\n"
                    f"From option: {CANON4[k_from]}\n"
                    "Answer:\n"
                )
                rec = {"task":"row","group":meta,"year_t":int(t),"year_t1":int(t+d),"dt":int(d),
                       "from_bin": CANON4[k_from],
                       "prompt_text": prompt, "to_dist": vec.tolist(), "weight": float(rowsums[k_from])}
                p_curr = p_cs.get((g_tuple, int(t)))
                p_next = p_cs.get((g_tuple, int(t+d)))
                if p_curr is not None and p_next is not None:
                    rec["p_curr"] = [float(x) for x in p_curr]
                    rec["p_next"] = [float(x) for x in p_next]
                f.write(json.dumps(rec)+"\n"); n+=1
    return n

def write_taskB(p_cs, effN_cs, out_path, max_ctx=3):
    n=0
    by_group={}
    for (g,y), _ in p_cs.items(): by_group.setdefault(g, []).append(int(y))
    with open(out_path,"w",encoding="utf-8") as f:
        for g, ys in by_group.items():
            ys = sorted(ys)
            for i in range(1, len(ys)):
                target = ys[i]; prevs = ys[max(0, i-max_ctx):i]
                ctx_pairs = [(yy, p_cs[(g, yy)]) for yy in prevs if (g,yy) in p_cs]
                if not ctx_pairs: continue
                meta = group_meta_from_tuple(g)
                ctx_str = " ".join([f"<Y{yy}>[" + ",".join(f"{x:.4f}" for x in vec) + "]" for yy, vec in ctx_pairs])
                prompt = (
                    "[Task: Forecast next-wave margin]\n"
                    f"{format_group_for_prompt(meta)}\n"
                    f"Context: {ctx_str}\n"
                    f"Predict: <Y{target}>\n"
                    "Answer:\n"
                )
                rec = {"task":"margin","group":meta,"year":int(target),
                       "prompt_text":prompt,"to_dist":[float(x) for x in p_cs[(g,target)]],
                       "weight": float(min(effN_cs.get((g,target), 200.0), 1000.0))}
                f.write(json.dumps(rec)+"\n"); n+=1
    return n

nA = write_taskA(C,Nfrom,p_cs,TA_PATH); nB = write_taskB(p_cs,effN_cs,TB_PATH)
print("Wrote TaskA:", nA, "TaskB:", nB)
print("Paths:", TA_PATH, TB_PATH)


Wrote TaskA: 1062 TaskB: 675
Paths: /content/drive/MyDrive/LLM_POC_Study_2025_v2/outputs_gss_multitask_orig_subgroup/taskA_panel_rows_GROUP_COLS_ORIG.jsonl /content/drive/MyDrive/LLM_POC_Study_2025_v2/outputs_gss_multitask_orig_subgroup/taskB_cs_margins_GROUP_COLS_ORIG.jsonl


In [9]:
!pip -q install transformers peft accelerate bitsandbytes


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.3/61.3 MB[0m [31m42.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [12]:
import torch, torch.nn as nn, torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model

MODEL_NAME = "meta-llama/llama-3.1-8b"
USE_BF16 = torch.cuda.is_available()

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16 if USE_BF16 else torch.float16, device_map="auto")
hidden_size = base.config.hidden_size

tokenizer.padding_side = "right"
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
    base.config.pad_token_id = tokenizer.eos_token_id

class TwoHead(nn.Module):
    def __init__(self,H,K):
        super().__init__(); self.trans=nn.Linear(H,K); self.margin=nn.Linear(H,K)
    def forward(self,z): return self.trans(z), self.margin(z)

two_head = TwoHead(hidden_size, K).to(next(base.parameters()).device)

lora_cfg = LoraConfig(r=16,lora_alpha=32,lora_dropout=0.05,
                      target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj"])
model = get_peft_model(base, lora_cfg); model.train()

def span_indices(text, tokenizer):
    gpos = text.find("Group:")
    if gpos < 0:
        enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
        L = int(enc["input_ids"].shape[1]); return 0, max(1, L-1)
    end = text.find("\n", gpos); end = len(text) if end < 0 else end
    enc_pre = tokenizer(text[:gpos], return_tensors="pt")
    enc_end = tokenizer(text[:end],  return_tensors="pt")
    return int(enc_pre["input_ids"].shape[1]), int(enc_end["input_ids"].shape[1])

class JsonlDistDataset(torch.utils.data.Dataset):
    def __init__(self, paths, tokenizer, max_len=1024):
        self.rows=[]
        for p in (paths if isinstance(paths,(list,tuple)) else [paths]):
            self.rows.extend([json.loads(x) for x in open(p,"r",encoding="utf-8")])
        self.tok=tokenizer; self.max_len=max_len
    def __len__(self): return len(self.rows)
    def __getitem__(self,i):
        r=self.rows[i]
        enc=self.tok(r["prompt_text"], return_tensors="pt", truncation=True, max_length=self.max_len)
        # pool over Group: line
        gpos=r["prompt_text"].find("Group:")
        if gpos<0:
            s,e=0,int(enc["attention_mask"].sum().item())-1
        else:
            end=r["prompt_text"].find("\n", gpos); end=len(r["prompt_text"]) if end<0 else end
            s = self.tok(r["prompt_text"][:gpos], return_tensors="pt")["input_ids"].shape[1]
            e = self.tok(r["prompt_text"][:end],  return_tensors="pt")["input_ids"].shape[1]
        return {"input_ids":enc["input_ids"][0], "attention_mask":enc["attention_mask"][0],
                "span":torch.tensor([s,e]), "to_dist":torch.tensor(r["to_dist"],dtype=torch.float32),
                "weight":torch.tensor(r.get("weight",1.0),dtype=torch.float32),
                "task":0 if r["task"]=="row" else 1}

def collate(batch):
    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
    maxT = max(x["input_ids"].shape[0] for x in batch)
    input_ids=[]; attn=[]; spans=[]; y=[]; w=[]; task=[]
    for x in batch:
        pad=maxT-x["input_ids"].shape[0]
        input_ids.append(torch.cat([x["input_ids"], torch.full((pad,), pad_id)]))
        attn.append(torch.cat([x["attention_mask"], torch.zeros(pad, dtype=torch.long)]))
        spans.append(x["span"]); y.append(x["to_dist"]); w.append(x["weight"]); task.append(x["task"])
    return {"input_ids":torch.stack(input_ids),"attention_mask":torch.stack(attn),
            "span":torch.stack(spans),"to_dist":torch.stack(y),"weight":torch.stack(w),
            "task":torch.tensor(task)}

def js_divergence(p,q,eps=1e-12):
    p=torch.clamp(p,min=eps); q=torch.clamp(q,min=eps)
    p=p/p.sum(dim=1,keepdim=True); q=q/q.sum(dim=1,keepdim=True)
    m=0.5*(p+q);
    return 0.5*torch.sum(p*(torch.log(p)-torch.log(m)),dim=1) + 0.5*torch.sum(q*(torch.log(q)-torch.log(m)),dim=1)

from transformers import Trainer, TrainingArguments

class MTTrainer(Trainer):
    def __init__(self, *args, two_head=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.two_head = two_head

    # Accept any future kwargs (e.g., num_items_in_batch)
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        input_ids = inputs["input_ids"].to(model.device)
        attn      = inputs["attention_mask"].to(model.device)
        span      = inputs["span"].to(model.device)
        y         = inputs["to_dist"].to(model.device)
        w         = inputs["weight"].to(model.device)
        task      = inputs["task"].to(model.device)

        with torch.cuda.amp.autocast(enabled=True):
            out = model(input_ids=input_ids, attention_mask=attn, output_hidden_states=True)
            hs = out.hidden_states[-1]  # [B, T, H]
            B, T, H = hs.shape

            feats = []
            for b in range(B):
                s, e = int(span[b, 0]), int(span[b, 1])
                s = max(0, min(s, T - 1))
                e = max(s + 1, min(e, T))
                feats.append(hs[b, s:e, :].mean(dim=0))
            z = torch.stack(feats, dim=0)

            lr, lm = self.two_head(z)
            p_row  = torch.softmax(lr, dim=1)
            p_mrg  = torch.softmax(lm, dim=1)
            p_pred = torch.where(task.unsqueeze(1) == 0, p_row, p_mrg)

            loss_vec = js_divergence(y, p_pred)
            loss = torch.mean(w * loss_vec)

        return (loss, out) if return_outputs else loss


train_ds = JsonlDistDataset([TA_PATH, TB_PATH], tokenizer)
args = TrainingArguments(output_dir=os.path.join(OUT_DIR,f"runs_{CURRENT_GROUP_SCHEME}"),
                         per_device_train_batch_size=8, gradient_accumulation_steps=2,
                         learning_rate=1e-4, num_train_epochs=3,
                         logging_steps=50, save_steps=400, save_total_limit=2, bf16=USE_BF16, report_to="none",
                         remove_unused_columns=False)
trainer = MTTrainer(model=model, args=args, train_dataset=train_ds, data_collator=collate, two_head=two_head)
trainer.train()

SAVE_DIR = os.path.join(OUT_DIR, f"final_{CURRENT_GROUP_SCHEME}")
os.makedirs(SAVE_DIR, exist_ok=True)
model.save_pretrained(os.path.join(SAVE_DIR,"lora"))
tokenizer.save_pretrained(os.path.join(SAVE_DIR,"lora"))
torch.save(two_head.state_dict(), os.path.join(SAVE_DIR,"two_head.pt"))
print("Saved:", SAVE_DIR)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=True):


Step,Training Loss
50,0.6367
100,0.5799
150,0.555
200,0.5306
250,0.5534
300,0.5082


Saved: /content/drive/MyDrive/LLM_POC_Study_2025_v2/outputs_gss_multitask_orig_subgroup/final_GROUP_COLS_ORIG
