<a href="https://colab.research.google.com/github/yilmajung/LLM_POC_Study_2025_v2/blob/main/n3_inference_backtest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os, json, numpy as np, pandas as pd, torch, torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import math

# Set up paths
BASE_MODEL_NAME = "meta-llama/llama-3.1-8b"    # or mistral
SAVE_DIR   = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/outputs_gss_multitask/final_multitask"  # from training step
HEAD_PATH  = os.path.join(SAVE_DIR, "two_head.pt")

# Canon (K=4)
CANON4 = ["strong_anti","anti","pro","strong_pro"]
K = len(CANON4)

# Load tokenizer + base + LoRA
tokenizer = AutoTokenizer.from_pretrained(SAVE_DIR, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    output_hidden_states=True
)
model = PeftModel.from_pretrained(base, SAVE_DIR)
model.eval()

# Two-head module (must match training)
import torch.nn as nn
class TwoHead(nn.Module):
    def __init__(self, hidden_size, K):
        super().__init__()
        self.head_row    = nn.Linear(hidden_size, K)   # Task A
        self.head_margin = nn.Linear(hidden_size, K)   # Task B
    def forward(self, feats):
        return self.head_row(feats), self.head_margin(feats)

hidden_size = base.config.hidden_size
two_head = TwoHead(hidden_size, K).to(model.device)
two_head.load_state_dict(torch.load(HEAD_PATH, map_location=model.device))
two_head.eval()

def pooled_features(outputs, attention_mask, tail=96):
    hs = outputs.hidden_states[-1]   # [B,T,H]
    valid = attention_mask.sum(dim=1)
    feats = []
    for b in range(hs.size(0)):
        L = int(valid[b].item())
        s = max(0, L - tail); e = L
        if e <= s: s, e = max(0, L-32), L
        feats.append(hs[b, s:e, :].mean(dim=0))
    return torch.stack(feats, dim=0)


config.json:   0%|          | 0.00/826 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

In [3]:
# Helpers: predict a row/margin/full matrix
def _to_head_dtype(x):
    return x.to(two_head.head_row.weight.dtype)

@torch.no_grad()
def predict_row_distribution(group, year_t, year_t1, from_bin, dt=None, max_len=768):
    """Task A: one row (from_bin) of the transition matrix."""
    if dt is None: dt = int(year_t1) - int(year_t)
    prompt = (
        "[Task: Predict transition row]\n"
        f"From: <Y{year_t}> → To: <Y{year_t1}> <DT{dt}>\n"
        f"Group: generation={group['generation']}; gender={group['gender']}; race={group['race']}\n"
        f"From option: {from_bin}\n"
        "Answer:\n"
    )
    enc = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_len)
    ids = enc["input_ids"].to(model.device); attn = enc["attention_mask"].to(model.device)
    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        out = model(input_ids=ids, attention_mask=attn, output_hidden_states=True)
    feats = pooled_features(out, attn, tail=96).to(model.device)
    feats = _to_head_dtype(feats)
    logits_row, _ = two_head(feats)
    p = F.softmax(logits_row, dim=1).float().cpu().numpy()[0]
    return p  # shape [K]

@torch.no_grad()
def predict_full_transition(group, year_t, year_t1, max_len=768):
    """Assemble 4×4 by calling row head for each from-bin."""
    T = np.zeros((K,K), dtype=np.float32)
    dt = int(year_t1) - int(year_t)
    for i, from_bin in enumerate(CANON4):
        T[i,:] = predict_row_distribution(group, year_t, year_t1, from_bin, dt=dt, max_len=max_len)
    # force row-stochastic
    T = np.clip(T, 1e-12, 1)
    T = T / T.sum(axis=1, keepdims=True)
    return T

@torch.no_grad()
def predict_next_margin(group, context, target_year, max_len=768):
    """
    Task B: context = list of (year, prob_vector length K) tuples.
    """
    ctx_parts = " ".join([f"<Y{yy}>[{','.join(f'{x:.4f}' for x in p)}]" for (yy,p) in context])
    prompt = (
        "[Task: Forecast next-wave margin]\n"
        f"Group: generation={group['generation']}; gender={group['gender']}; race={group['race']}\n"
        f"Context: {ctx_parts}\n"
        f"Predict: <Y{target_year}>\n"
        "Answer:\n"
    )
    enc = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_len)
    ids = enc["input_ids"].to(model.device); attn = enc["attention_mask"].to(model.device)
    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        out = model(input_ids=ids, attention_mask=attn, output_hidden_states=True)
    feats = pooled_features(out, attn, tail=96).to(model.device)
    feats = _to_head_dtype(feats)
    _, logits_margin = two_head(feats)
    p = F.softmax(logits_margin, dim=1).float().cpu().numpy()[0]
    return p


In [4]:
# Using real margins and forecasting
# Load cross-sectional CSV and build p_cs[(g,y)]
cs = pd.read_csv("/content/drive/MyDrive/LLM_POC_Study_2025_v2/gss_abt_cs_full.csv")
# Map to 4 bins if numeric:
def map_att(v):
    if isinstance(v, (int, np.integer, float)) and not pd.isna(v):
        return {1:"strong_anti",2:"anti",3:"pro",4:"strong_pro"}.get(int(v), None)
    return str(v).strip()
cs["att"] = cs["abortion_att4"].apply(map_att)
cs = cs[cs["att"].isin(CANON4)].copy()
cs["wt"] = cs.get("wtssps", pd.Series([1.0]*len(cs)))

GROUP_COLS = ["generation","gender","race"]
for c in GROUP_COLS: cs[c] = cs[c].astype(str).str.strip()

def weighted_probs(vals, wts):
    d = {c:0.0 for c in CANON4}
    for v,w in zip(vals,wts): d[v]+=float(w)
    vec = np.array([d[c] for c in CANON4],dtype=float)
    s = vec.sum();
    return vec/s if s>0 else None

p_cs = {}
for (gvals, df_g) in cs.groupby(GROUP_COLS):
    for y, df_y in df_g.groupby("year"):
        p = weighted_probs(df_y["att"].tolist(), df_y["wt"].tolist())
        if p is not None:
            p_cs[(gvals, int(y))] = p


In [5]:
# Forecast 2024 from 2022 (per subgroup)
# choose subgroups present in 2022
groups_2022 = sorted({g for (g,y) in p_cs.keys() if y==2022})

rows = []
alpha = 0.5  # ensemble weight between transition-based and margin-head forecast

for g in groups_2022:
    group = {"generation": g[0], "gender": g[1], "race": g[2]}
    # context margins e.g., 2018, 2020, 2022 if available
    ctx = []
    for yy in [2018, 2020, 2022]:
        if (g, yy) in p_cs:
            ctx.append((yy, p_cs[(g,yy)]))
    if len(ctx)==0 or (g,2022) not in p_cs:
        continue

    p_2022 = p_cs[(g,2022)]
    # Task A: transition matrix 2022->2024
    T_22_24 = predict_full_transition(group, 2022, 2024)
    p_2024_trans = p_2022 @ T_22_24

    # Task B: direct next margin
    p_2024_margin = predict_next_margin(group, ctx, 2024)

    # Ensemble (optional)
    p_2024_hat = alpha * p_2024_trans + (1 - alpha) * p_2024_margin
    p_2024_hat = p_2024_hat / p_2024_hat.sum()

    # Observed 2024 for backtest if available
    p_2024_obs = p_cs.get((g,2024), None)

    rec = {
        "generation": g[0], "gender": g[1], "race": g[2],
        **{f"p2022_{c}": float(p_2022[i]) for i,c in enumerate(CANON4)},
        **{f"p2024_trans_{c}": float(p_2024_trans[i]) for i,c in enumerate(CANON4)},
        **{f"p2024_margin_{c}": float(p_2024_margin[i]) for i,c in enumerate(CANON4)},
        **{f"p2024_hat_{c}": float(p_2024_hat[i]) for i,c in enumerate(CANON4)},
    }
    if p_2024_obs is not None:
        rec.update({f"p2024_obs_{c}": float(p_2024_obs[i]) for i,c in enumerate(CANON4)})

        # quick metrics
        def jsd(p,q,eps=1e-9):
            p = np.clip(p,eps,1); q = np.clip(q,eps,1)
            p/=p.sum(); q/=q.sum(); m=0.5*(p+q)
            return 0.5*np.sum(p*np.log(p/m)) + 0.5*np.sum(q*np.log(q/m))
        def rmse(p,q): return float(np.sqrt(np.mean((p-q)**2)))
        rec["JSD_hat_vs_obs"] = float(jsd(p_2024_hat, p_2024_obs))
        rec["RMSE_hat_vs_obs"] = rmse(p_2024_hat, p_2024_obs)
        rec["JSD_trans_vs_obs"] = float(jsd(p_2024_trans, p_2024_obs))
        rec["JSD_margin_vs_obs"]= float(jsd(p_2024_margin, p_2024_obs))

    rows.append(rec)

df_forecast = pd.DataFrame(rows)
out_csv = os.path.join(SAVE_DIR, "/content/drive/MyDrive/LLM_POC_Study_2025_v2/forecast_2024_from_2022_by_group.csv")
df_forecast.to_csv(out_csv, index=False)
out_csv


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


'/content/drive/MyDrive/LLM_POC_Study_2025_v2/forecast_2024_from_2022_by_group.csv'

In [6]:
def backtest_interval(year_t, year_t1, alpha=0.5):
    results = []
    groups = sorted({g for (g,y) in p_cs.keys() if y==year_t})
    for g in groups:
        if (g,year_t) not in p_cs or (g,year_t1) not in p_cs:
            continue
        group = {"generation": g[0], "gender": g[1], "race": g[2]}
        p_t = p_cs[(g,year_t)]
        # context: use up to two previous margins if present
        ctx_years = [year_t-4, year_t-2, year_t]
        ctx = [(yy, p_cs[(g,yy)]) for yy in ctx_years if (g,yy) in p_cs]
        T = predict_full_transition(group, year_t, year_t1)
        p_next_trans = p_t @ T
        p_next_margin = predict_next_margin(group, ctx, year_t1) if len(ctx)>0 else p_next_trans
        p_hat = alpha*p_next_trans + (1-alpha)*p_next_margin
        p_obs = p_cs[(g,year_t1)]
        def jsd(p,q,eps=1e-9):
            p = np.clip(p,eps,1); q = np.clip(q,eps,1)
            p/=p.sum(); q/=q.sum(); m=0.5*(p+q)
            return 0.5*np.sum(p*np.log(p/m)) + 0.5*np.sum(q*np.log(q/m))
        def rmse(p,q): return float(np.sqrt(np.mean((p-q)**2)))
        results.append({
            "year_t": year_t, "year_t1": year_t1,
            "generation": g[0], "gender": g[1], "race": g[2],
            "JSD_hat": float(jsd(p_hat, p_obs)),
            "JSD_trans": float(jsd(p_next_trans, p_obs)),
            "JSD_margin": float(jsd(p_next_margin, p_obs)),
            "RMSE_hat": float(rmse(p_hat, p_obs)),
        })
    return pd.DataFrame(results)

bt_all = []
for y0 in range(2008, 2024, 2):  # 2008→2010 … 2022→2024
    bt_all.append(backtest_interval(y0, y0+2, alpha=0.5))
bt = pd.concat(bt_all, ignore_index=True)
bt_out = os.path.join(SAVE_DIR, "/content/drive/MyDrive/LLM_POC_Study_2025_v2/backtest_by_interval.csv")
bt.to_csv(bt_out, index=False)
bt_out, bt.groupby(["year_t","year_t1"]).JSD_hat.mean()

  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


('/content/drive/MyDrive/LLM_POC_Study_2025_v2/backtest_by_interval.csv',
 year_t  year_t1
 2008    2010       0.034290
 2010    2012       0.041960
 2012    2014       0.069244
 2014    2016       0.036331
 2016    2018       0.057928
 2022    2024       0.055560
 Name: JSD_hat, dtype: float64)