# n3_inference_backtest_multi.ipynb
Inference + backtesting for multiple subgroup schemes.


In [None]:
CS_PATH="data/GSS/gss_abt_cs_full.csv"; OUT_DIR="outputs_gss_multitask"
CANON4=["strong_anti","anti","pro","strong_pro"]; K=len(CANON4)
GROUP_SCHEMES={"GROUP_COLS_1":["gender"],
               "GROUP_COLS_2":["gender","political_views"],
               "GROUP_COLS_3":["gender","political_views","edu_level"],
               "GROUP_COLS_4":["gender","political_views","edu_level","generation"],
               "GROUP_COLS_5":["gender","political_views","edu_level","generation","race"]}
CURRENT_GROUP_SCHEME="GROUP_COLS_3"
def get_group_cols(): return GROUP_SCHEMES[CURRENT_GROUP_SCHEME]
print("Active grouping:", get_group_cols())


In [None]:
import os, json, numpy as np, pandas as pd, torch, torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer


In [None]:
SAVE_DIR=os.path.join(OUT_DIR,f"final_{CURRENT_GROUP_SCHEME}")
LM_DIR=os.path.join(SAVE_DIR,"lora"); HEAD_PATH=os.path.join(SAVE_DIR,"two_head.pt")
tokenizer=AutoTokenizer.from_pretrained(LM_DIR)
base=AutoModelForCausalLM.from_pretrained(LM_DIR, device_map="auto"); model=base
hidden_size=base.config.hidden_size
class TwoHead(nn.Module):
    def __init__(self,H,K): super().__init__(); self.trans=nn.Linear(H,K); self.margin=nn.Linear(H,K)
    def forward(self,z): return self.trans(z), self.margin(z)
two_head=TwoHead(hidden_size, K).to(next(base.parameters()).device)
two_head.load_state_dict(torch.load(HEAD_PATH, map_location="cpu")); two_head=two_head.to(next(base.parameters()).device).eval()
print("Loaded:", SAVE_DIR)


In [None]:
def span_indices(text, tokenizer):
    gpos=text.find("Group:")
    if gpos<0:
        enc=tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
        L=int(enc["input_ids"].shape[1]); return 0, max(1, L-1)
    end=text.find("\n", gpos); end=len(text) if end<0 else end
    enc_pre=tokenizer(text[:gpos], return_tensors="pt"); enc_end=tokenizer(text[:end], return_tensors="pt")
    return int(enc_pre["input_ids"].shape[1]), int(enc_end["input_ids"].shape[1])

def pooled_vec(prompt):
    enc=tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
    enc={k:v.to(model.device) for k,v in enc.items()}
    with torch.no_grad():
        out=model(**enc, output_hidden_states=True); hs=out.hidden_states[-1]
        s,e=span_indices(prompt, tokenizer); s=max(0,min(s,hs.shape[1]-1)); e=max(s+1,min(e,hs.shape[1]))
        return hs[:, s:e, :].mean(dim=1)


In [None]:
def predict_full_transition(meta, year_t, year_t1):
    rows=[]
    for from_bin in CANON4:
        prompt=("[Task: Predict transition row]\n"
                f"From: <Y{year_t}> → To: <Y{year_t1}> <DT{year_t1-year_t}>\n"
                "Group: "+"; ".join(f"{k}={v}" for k,v in meta.items())+"\n"
                f"From option: {from_bin}\n"
                "Answer:\n")
        z=pooled_vec(prompt)
        with torch.no_grad():
            lr,_=two_head(z); p=torch.softmax(lr, dim=1).cpu().numpy()[0]
        rows.append(p)
    T=np.vstack(rows); T=T/np.clip(T.sum(axis=1, keepdims=True), 1e-12, None)
    return T

def predict_next_margin(meta, context_list, target_year):
    ctx_str=" ".join([f"<Y{yy}>["+",".join(f"{x:.4f}" for x in p)+"]" for yy,p in context_list])
    prompt=("[Task: Forecast next-wave margin]\n"
            "Group: "+"; ".join(f"{k}={v}" for k,v in meta.items())+"\n"
            f"Context: {ctx_str}\n"
            f"Predict: <Y{target_year}>\n"
            "Answer:\n")
    z=pooled_vec(prompt)
    with torch.no_grad():
        _,lm=two_head(z); p=torch.softmax(lm, dim=1).cpu().numpy()[0]
    return p/np.clip(p.sum(), 1e-12, None)


In [None]:
def load_cs_margins(path, group_cols):
    df=pd.read_csv(path)
    if "id" in df.columns: df.rename(columns={"id":"yearid"}, inplace=True)
    df["att_id"]=df["abortion_att4"].astype(int)
    out={}; effN={}
    for (keys,gdf) in df.groupby(group_cols+["year"], dropna=False):
        *g_vals, y = keys
        w=gdf["wtssps"].fillna(0.0).to_numpy(float); k=gdf["att_id"].to_numpy(int)
        vec=np.zeros(K)
        for ki,wi in zip(k,w):
            if 1<=ki<=4: vec[ki-1]+=wi
        tot=vec.sum()
        if tot<=0: continue
        out[(tuple(g_vals), int(y))]=vec/tot
        sw=w.sum(); sw2=(w**2).sum(); eff=(sw**2/sw2) if sw2>0 else len(w)
        effN[(tuple(g_vals), int(y))]=float(eff)
    return out, effN

p_cs, effN_cs = load_cs_margins(CS_PATH, get_group_cols())
print("Loaded CS margins:", len(p_cs))


In [None]:
def forecast_every_year_from_previous(p_cs, alpha=0.5, context_lags=(4,2,0)):
    rows=[]; by_group={}
    for (g,y), _ in p_cs.items():
        by_group.setdefault(g, []).append(y)
    for g, ys in by_group.items():
        ys=sorted(ys)
        for i in range(1,len(ys)):
            y_prev, y = ys[i-1], ys[i]
            meta = {c:v for c,v in zip(get_group_cols(), g)}
            p_prev = p_cs[(g, y_prev)]
            T = predict_full_transition(meta, y_prev, y)
            p_trans = p_prev @ T
            ctx=[]
            for L in context_lags:
                yy = y_prev - L
                if (g, yy) in p_cs: ctx.append((yy, p_cs[(g, yy)]))
            if not ctx: ctx=[(y_prev, p_prev)]
            p_margin = predict_next_margin(meta, ctx, y)
            p_hat = alpha*p_trans + (1-alpha)*p_margin
            p_hat = p_hat/np.clip(p_hat.sum(), 1e-12, None)

            p_obs = p_cs.get((g, y))
            def jsd(a,b,eps=1e-12):
                import numpy as np
                a=np.clip(a,eps,1); b=np.clip(b,eps,1); a/=a.sum(); b/=b.sum(); m=0.5*(a+b)
                return 0.5*np.sum(a*np.log(a/m)) + 0.5*np.sum(b*np.log(b/m))
            rec={c:v for c,v in zip(get_group_cols(), g)}
            rec.update({"year_t":y_prev,"year_t1":y,
                        **{f"p_prev_{CANON4[i]}":float(p_prev[i]) for i in range(K)},
                        **{f"p_trans_{CANON4[i]}":float(p_trans[i]) for i in range(K)},
                        **{f"p_margin_{CANON4[i]}":float(p_margin[i]) for i in range(K)},
                        **{f"p_hat_{CANON4[i]}":float(p_hat[i]) for i in range(K)}})
            if p_obs is not None:
                rec.update({f"p_obs_{CANON4[i]}":float(p_obs[i]) for i in range(K)})
                rec["JSD_hat"]=float(jsd(p_hat, p_obs)); rec["JSD_trans"]=float(jsd(p_trans, p_obs)); rec["JSD_margin"]=float(jsd(p_margin, p_obs))
            rows.append(rec)
    import pandas as pd
    return pd.DataFrame(rows)

bt = forecast_every_year_from_previous(p_cs, alpha=0.5)
BT_PATH = os.path.join(OUT_DIR, f"backtest_{CURRENT_GROUP_SCHEME}.csv")
bt.to_csv(BT_PATH, index=False)
print("Backtest saved:", BT_PATH)
bt.head(3)
