In [None]:

# Inference with small classification head (Fix A)
!pip -q install peft transformers accelerate sentencepiece

import os, json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from google.colab import userdata

In [None]:
# -----------------
# Config (EDIT ME)
# -----------------
BASE_MODEL_NAME = "meta-llama/llama-3.1-8b"   # or "mistralai/Mistral-7B-v0.3"
INF_DIR         = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/tllm_fixA_head"     # where you saved training outputs
LORA_DIR        = os.path.join(INF_DIR, "lora")                                    # saved LoRA+tokenizer folder
HEAD_PATH       = os.path.join(INF_DIR, "dist_head.pt")                             # saved small head weights
GSS_PATH        = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/gss_panel_2016_2020_long.parquet"

# Output CSVs
PRED_OUT   = os.path.join(INF_DIR, "gss_tllm_fixA_proj_2020_from_2016.csv")
EVAL_OUT   = os.path.join(INF_DIR, "gss_tllm_fixA_eval_2020_vs_empirical.csv")
DELTA_OUT  = os.path.join(INF_DIR, "gss_tllm_fixA_pred_deltas_2020_minus_2016.csv")

# If you need HF auth in Colab:
hf_token = userdata.get('HF_TOKEN', None)

# -----------------
# Canonical bins (5-way; UNSURE excluded at inference)
# -----------------
CANON5 = ["strong_anti", "anti", "neutral", "pro", "strong_pro"]
IDX = {k:i for i,k in enumerate(CANON5)}

# Map your GSS att5 -> 5 canonical bins (EDIT if your labels differ)
GSS_ATT5_TO_5CANON = {
    "Illegal in all cases": "strong_anti",
    "Illegal in most cases": "anti",
    "Legal in most cases": "pro",
    "Legal in all cases": "strong_pro",
    # Any other responses (DK/NA/Refused/etc.) will be dropped for the 5-bin distribution
}

def fivebin_empirical(series):
    """Return 5-bin prob vector in CANON5 order; None if no mappable answers."""
    mapped = series.map(GSS_ATT5_TO_5CANON).dropna()
    if mapped.empty:
        return None
    cnt = mapped.value_counts()
    vec = np.array([cnt.get(k, 0.0) for k in CANON5], dtype=float)
    s = vec.sum()
    if s <= 0:
        return None
    return vec / s

# -----------------
# Load model + LoRA + small head
# -----------------
tokenizer = AutoTokenizer.from_pretrained(LORA_DIR, use_fast=True, token=hf_token)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=hf_token,
    output_hidden_states=True
)
model = PeftModel.from_pretrained(base, LORA_DIR)
model.eval()

# Small 5-way classification head must match training definition
class DistHead(nn.Module):
    def __init__(self, hidden_size, K=5):
        super().__init__()
        self.out = nn.Linear(hidden_size, K)
    def forward(self, x):  # x: [B, H]
        return self.out(x) #   -> [B, K]

hidden_size = base.config.hidden_size
dist_head = DistHead(hidden_size, K=5).to(model.device)
dist_head.load_state_dict(torch.load(HEAD_PATH, map_location=model.device))
dist_head.eval()

# -----------------
# GSS data prep (2016 & 2020)
# -----------------
gss = pd.read_parquet(GSS_PATH)
gss = gss[gss["year"].isin([2016, 2020])].copy()

# We ignore education in prompts now, per your request.
# Grouping keys (adjust if you need to split race/gender differently)
GROUP_COLS = ["generation", "gender", "race"]

# -----------------
# Prompt builder (mirror your training style, but set edu_* = NA)
# -----------------
QUESTION_TEXT = "Harmonized abortion attitude across waves"

def build_transition_prompt(group_meta, from_option):
    """
    Training-style prompt you used, with edu_2016/edu_2020 present but set to NA.
    Example:
    [Task: Predict transition distribution]
    Survey: UAS
    From wave: 2016  →  To wave: 2020
    Group: edu_2016=NA; edu_2020=NA; gender=Female; generation=Baby Boomer; race=Asian
    Question: Harmonized abortion attitude across waves
    From option: anti
    """
    return (
        "[Task: Predict transition distribution]\n"
        "Survey: UAS\n"
        "From wave: 2016  \u2192  To wave: 2020\n"
        f"Group: edu_2016=NA; edu_2020=NA; "
        f"gender={group_meta['gender']}; "
        f"generation={group_meta['generation']}; "
        f"race={group_meta['race']}\n"
        f"Question: {QUESTION_TEXT}\n"
        f"From option: {from_option}\n"
    )

# -----------------
# Head-based predictor
# -----------------
def predict_probs_with_head(texts, max_length=768, batch_size=32):
    """
    Given a list of prompts (same style as training), return [N,5] probs over CANON5.
    """
    out = np.zeros((len(texts), 5), dtype=np.float32)
    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            chunk = texts[i:i+batch_size]
            enc = tokenizer(
                chunk, return_tensors="pt",
                padding=True, truncation=True, max_length=max_length
            )
            ids = enc["input_ids"].to(model.device)
            attn = enc["attention_mask"].to(model.device)
            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                m_out = model(input_ids=ids, attention_mask=attn, output_hidden_states=True)
            last_idx = attn.sum(dim=1) - 1
            # Last hidden of last layer
            last_vec = m_out.hidden_states[-1][torch.arange(m_out.hidden_states[-1].size(0)), last_idx]  # [B,H]
            logits5 = dist_head(last_vec)  # [B,5]
            p = F.softmax(logits5, dim=1).float().cpu().numpy()
            out[i:i+batch_size] = p
    return out

# -----------------
# Build subgroup list (only those present in GSS)
# -----------------
subgroups = gss[GROUP_COLS].drop_duplicates().reset_index(drop=True)
subgroup_dicts = subgroups.to_dict(orient="records")

# -----------------
# Build prompts for each subgroup × from_option (5 prompts per subgroup)
# -----------------
all_prompts = []
index_map = []  # (group_idx, from_idx)
for gi, g in enumerate(subgroup_dicts):
    for k, from_opt in enumerate(CANON5):
        all_prompts.append(build_transition_prompt(g, from_opt))
        index_map.append((gi, k))

print(f"Total prompts to score: {len(all_prompts)}  (groups={len(subgroup_dicts)} × 5 from-options)")

# -----------------
# Predict transitions (5x5 per subgroup)
# -----------------
all_probs = predict_probs_with_head(all_prompts, max_length=768, batch_size=32)  # [G*5, 5]

G = len(subgroup_dicts)
T_mats = [np.zeros((5,5), dtype=np.float32) for _ in range(G)]
for (gi, from_k), row_prob in zip(index_map, all_probs):
    T_mats[gi][from_k, :] = row_prob

# -----------------
# Empirical 2016 for each subgroup; project to 2020: p2020 = p2016 @ T
# Also (optional) empirical 2020 for evaluation
# -----------------
def mask_group(df, g):
    m = (df["generation"]==g["generation"]) & (df["gender"]==g["gender"]) & (df["race"]==g["race"])
    return m

rows_pred, rows_eval = [], []
for gi, g in enumerate(subgroup_dicts):
    m2016 = mask_group(gss, g) & (gss["year"]==2016)
    emp2016 = fivebin_empirical(gss.loc[m2016, "att5"])
    if emp2016 is None:
        # skip groups with no mappable 2016 answers
        continue

    T = T_mats[gi]  # [5,5]
    pred2020 = emp2016 @ T

    rec = {**g}
    for j, cat in enumerate(CANON5):
        rec[f"emp2016_{cat}"] = float(emp2016[j])
        rec[f"pred2020_{cat}"] = float(pred2020[j])
    rows_pred.append(rec)

    # Optional evaluation vs empirical 2020
    m2020 = mask_group(gss, g) & (gss["year"]==2020)
    emp2020 = fivebin_empirical(gss.loc[m2020, "att5"])
    if emp2020 is not None:
        def rmse(a,b): return float(np.sqrt(np.mean((a-b)**2)))
        def jsd(p,q,eps=1e-9):
            p = np.clip(p,eps,1); q = np.clip(q,eps,1)
            p/=p.sum(); q/=q.sum()
            m = 0.5*(p+q)
            def kl(x,y): return float(np.sum(x*(np.log(x+eps)-np.log(y+eps))))
            return 0.5*kl(p,m)+0.5*kl(q,m)

        ev = {**g, "n2016": int(m2016.sum()), "n2020": int(m2020.sum())}
        ev["RMSE"] = rmse(emp2020, pred2020)
        ev["JSD"]  = jsd(emp2020, pred2020)
        for j, cat in enumerate(CANON5):
            ev[f"emp2020_{cat}"]  = float(emp2020[j])
            ev[f"pred2020_{cat}"] = float(pred2020[j])
        rows_eval.append(ev)

# -----------------
# Save outputs
# -----------------
df_pred = pd.DataFrame(rows_pred).sort_values(GROUP_COLS).reset_index(drop=True)
df_pred.to_csv(PRED_OUT, index=False)
print("Saved projections:", PRED_OUT)

if len(rows_eval) > 0:
    df_eval = pd.DataFrame(rows_eval).sort_values(GROUP_COLS).reset_index(drop=True)
    df_eval.to_csv(EVAL_OUT, index=False)
    print("Saved eval:", EVAL_OUT)
else:
    print("No empirical 2020 rows available for evaluation (check GSS mapping).")

# Year-over-year deltas (predicted) per subgroup
if not df_pred.empty:
    wide = df_pred[[*GROUP_COLS, *[f"emp2016_{c}" for c in CANON5], *[f"pred2020_{c}" for c in CANON5]]].copy()
    # compute 2020-2016 deltas for each category
    for c in CANON5:
        wide[f"delta_{c}"] = wide[f"pred2020_{c}"] - wide[f"emp2016_{c}"]
    wide.to_csv(DELTA_OUT, index=False)
    print("Saved deltas:", DELTA_OUT)
