In [None]:
# Install + imports (Colab / A100)
!pip -q install peft transformers accelerate sentencepiece

import os, json, math, itertools, collections
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

torch.set_grad_enabled(False)

In [None]:
# Canonical bins (5 options; UNSURE excluded for inference)
# Must match the OPT_TOKENS order during training.
CANON5 = ["strong_anti", "anti", "neutral", "pro", "strong_pro"]
IDX = {k:i for i,k in enumerate(CANON5)}

OPT_TOKENS = [
    "<OPT_STRONG_ANTI>", "<OPT_ANTI>", "<OPT_NEUTRAL>", "<OPT_PRO>", "<OPT_STRONG_PRO>"
]

# ---- Map GSS att5 -> 5 canonical bins (EDIT these to your actual labels) ----
GSS_ATT5_TO_5CANON = {
    "Illegal in all cases": "strong_anti",
    "Illegal in most cases": "anti",
    "Legal in most cases": "pro",
    "Legal in all cases": "strong_pro",
    # Anything else (DK/NA/Refused) will be dropped when making the 5-bin empirical dist
}

def fivebin_empirical(series):
    """
    Map a Pandas Series of att5 strings to 5-bin probabilities, dropping non-mappable responses.
    Returns np.array length 5 (ordered as CANON5) or None if no countable responses.
    """
    mapped = series.map(GSS_ATT5_TO_5CANON).dropna()
    if mapped.empty:
        return None
    cnt = mapped.value_counts()
    vec = np.array([cnt.get(k, 0.0) for k in CANON5], dtype=float)
    s = vec.sum()
    if s <= 0:
        return None
    return vec / s

# =========================================
# 2) Load GSS long panel and compute modal education per group/year
#    Required columns: yearid, year, generation, edu_level, gender, race, att5
# =========================================
GSS_PATH = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/gss_panel_2016_2020_long.parquet"
gss = pd.read_parquet(GSS_PATH)
gss = gss[gss['year'].isin([2016, 2020])].copy()

GROUP_COLS = ["generation", "edu_level", "gender", "race"]  # edu_level is per-row; we'll create edu_2016 & edu_2020

# Modal education within each (group-no-edu, year)
BASE_GROUP = ["generation", "gender", "race"]

def mode_or_na(s):
    if s.dropna().empty:
        return np.nan
    return s.value_counts(dropna=True).idxmax()

edu_mode = (
    gss.groupby(BASE_GROUP + ["year"], dropna=False)["edu_level"]
       .agg(mode_or_na)
       .reset_index()
       .rename(columns={"edu_level": "edu_mode"})
)

edu_2016 = edu_mode[edu_mode["year"]==2016][BASE_GROUP+["edu_mode"]].rename(columns={"edu_mode":"edu_2016"})
edu_2020 = edu_mode[edu_mode["year"]==2020][BASE_GROUP+["edu_mode"]].rename(columns={"edu_mode":"edu_2020"})

# Build the set of groups present in data (based on BASE_GROUP) and attach edu_2016/edu_2020
groups = (
    gss[BASE_GROUP].drop_duplicates().merge(edu_2016, on=BASE_GROUP, how="left")
                                     .merge(edu_2020, on=BASE_GROUP, how="left")
)

# Also keep only groups that have at least some 2016 answers for empirical baseline
has_2016 = (
    gss[gss["year"]==2016]
    .groupby(BASE_GROUP)["att5"]
    .size()
    .reset_index(name="n2016")
)
groups = groups.merge(has_2016, on=BASE_GROUP, how="left").query("n2016 >= 1").drop(columns=["n2016"])

print("Num unique base groups with 2016 data:", len(groups))

# =========================================
# 3) Build prompts to mirror training EXACTLY (transition prompts)
#    Example from your training:
#    [Task: Predict transition distribution]
#    Survey: UAS
#    From wave: 2018  →  To wave: 2019
#    Group: edu_2018=Bachelor's Degree; edu_2019=Bachelor's Degree; gender=Female; generation=Baby Boomer; race=Asian
#    Question: Harmonized abortion attitude across waves
#    From option: anti
#    (then you appended the Options: ... line before Answer:)
# =========================================
QUESTION_TEXT = "Harmonized abortion attitude across waves"

def build_transition_prompt(group_meta, edu_2016, edu_2020, from_option):
    # IMPORTANT: keep wording/formatting consistent with training
    # Use the same Unicode arrow and semicolon spacing, and same field order (edu first).
    return (
        "[Task: Predict transition distribution]\n"
        "Survey: UAS\n"
        "From wave: 2016  \u2192  To wave: 2020\n"
        f"Group: edu_2016={edu_2016 if pd.notna(edu_2016) else 'NA'}; "
        f"edu_2020={edu_2020 if pd.notna(edu_2020) else 'NA'}; "
        f"gender={group_meta['gender']}; "
        f"generation={group_meta['generation']}; "
        f"race={group_meta['race']}\n"
        f"Question: {QUESTION_TEXT}\n"
        f"From option: {from_option}\n"
    )

SUFFIX = "Options: " + " ".join(OPT_TOKENS) + "\nAnswer:\n"

# =========================================
# 4) Load tokenizer + base + LoRA adapters
# =========================================
LORA_DIR = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/tllm_abortion_transitions_lora"
BASE_MODEL_NAME = "meta-llama/llama-3.1-8b"   # or "mistralai/Mistral-7B-v0.3"

from google.colab import userdata
hf_token = userdata.get('HF_TOKEN', None)

tokenizer = AutoTokenizer.from_pretrained(LORA_DIR, use_fast=True, token=hf_token)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=hf_token
)
model = PeftModel.from_pretrained(base, LORA_DIR)
model.eval()

# Ensure option tokens exist (usually already in the saved tokenizer)
missing = [t for t in OPT_TOKENS if t not in tokenizer.get_vocab()]
if missing:
    tokenizer.add_special_tokens({"additional_special_tokens": missing})
    model.resize_token_embeddings(len(tokenizer))

opt_ids = torch.tensor([tokenizer.convert_tokens_to_ids(t) for t in OPT_TOKENS], device=model.device)

# =========================================
# 5) Batch predictor over the five option tokens (same as training head)
# =========================================
def predict_probs_for_texts(texts, max_length=768, batch_size=32):
    out_probs = np.zeros((len(texts), len(OPT_TOKENS)), dtype=np.float32)

    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            chunk = texts[i:i+batch_size]
            enc = tokenizer(
                [t + SUFFIX for t in chunk],
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=max_length
            )
            input_ids = enc["input_ids"].to(model.device)
            attn = enc["attention_mask"].to(model.device)

            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                out = model(input_ids=input_ids, attention_mask=attn)

            last_idx = attn.sum(dim=1) - 1
            last_logits = out.logits[torch.arange(out.logits.size(0), device=model.device), last_idx]  # [B,V]
            opt_logits = last_logits[:, opt_ids]  # [B,5]
            probs = F.softmax(opt_logits, dim=1).float().cpu().numpy()
            out_probs[i:i+batch_size] = probs

    return out_probs

# =========================================
# 6) Build all prompts per (group, from_option) and predict transitions
#    For each group we need 5 prompts (one per From option) -> 5x5 transition
# =========================================
group_dicts = groups[BASE_GROUP + ["edu_2016","edu_2020"]].to_dict(orient="records")

all_prompts = []
index_map = []  # (group_idx, from_idx)
for gi, g in enumerate(group_dicts):
    for k, from_opt in enumerate(CANON5):
        p = build_transition_prompt(g, g["edu_2016"], g["edu_2020"], from_opt)
        all_prompts.append(p)
        index_map.append((gi, k))

print("Total prompts:", len(all_prompts))

all_probs = predict_probs_for_texts(all_prompts, max_length=768, batch_size=32)  # [G*5, 5]

# Reassemble into a transition matrix per group (5x5 each)
G = len(group_dicts)
T_mats = [np.zeros((5,5), dtype=np.float32) for _ in range(G)]
for (gi, from_k), row_prob in zip(index_map, all_probs):
    T_mats[gi][from_k, :] = row_prob

# =========================================
# 7) Empirical 2016 distribution per group (5 bins, UNSURE dropped)
#    Project to 2020 via p2020 = p2016 @ T
# =========================================
def sub_mask(df, g):
    m = (df["generation"]==g["generation"]) & (df["gender"]==g["gender"]) & (df["race"]==g["race"])
    return m

rows_pred = []
rows_eval = []  # optional, add empirical 2020 if available

for gi, g in enumerate(group_dicts):
    # empirical 2016 (5 bins)
    m2016 = sub_mask(gss, g) & (gss["year"]==2016)
    emp2016 = fivebin_empirical(gss.loc[m2016, "att5"])
    if emp2016 is None:
        continue  # skip groups with no mappable 2016 answers

    # project to 2020
    T = T_mats[gi]  # 5x5
    pred2020 = emp2016 @ T

    rec = {
        "generation": g["generation"],
        "gender": g["gender"],
        "race": g["race"],
        "edu_2016": g["edu_2016"],
        "edu_2020": g["edu_2020"],
    }
    for j, cat in enumerate(CANON5):
        rec[f"emp2016_{cat}"] = float(emp2016[j])
        rec[f"pred2020_{cat}"] = float(pred2020[j])
    rows_pred.append(rec)

    # optional: empirical 2020 to compare
    m2020 = sub_mask(gss, g) & (gss["year"]==2020)
    emp2020 = fivebin_empirical(gss.loc[m2020, "att5"])
    if emp2020 is not None:
        ev = {
            **{k: rec[k] for k in ["generation","gender","race","edu_2016","edu_2020"]},
            "n2016": int(m2016.sum()),
            "n2020": int(m2020.sum())
        }
        # metrics
        def rmse(a,b): return float(np.sqrt(np.mean((a-b)**2)))
        def jsd(p,q,eps=1e-9):
            p = np.clip(p,eps,1); q = np.clip(q,eps,1)
            p/=p.sum(); q/=q.sum()
            m = 0.5*(p+q)
            def kl(x,y): return float(np.sum(x*(np.log(x+eps)-np.log(y+eps))))
            return 0.5*kl(p,m)+0.5*kl(q,m)

        ev["RMSE"] = rmse(emp2020, pred2020)
        ev["JSD"]  = jsd(emp2020, pred2020)
        for j, cat in enumerate(CANON5):
            ev[f"emp2020_{cat}"]  = float(emp2020[j])
            ev[f"pred2020_{cat}"] = float(pred2020[j])
        rows_eval.append(ev)

df_pred = pd.DataFrame(rows_pred).sort_values(["generation","gender","race"]).reset_index(drop=True)
df_eval = pd.DataFrame(rows_eval).sort_values(["generation","gender","race"]).reset_index(drop=True)

PRED_OUT = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/gss_tllm_proj_2020_from_2016.csv"
df_pred.to_csv(PRED_OUT, index=False)
print("Saved projections:", PRED_OUT)

if not df_eval.empty:
    EVAL_OUT = "/content/drive/MyDrive/LLM_POC_Study_2025_v2/gss_tllm_eval_2020_vs_empirical.csv"
    df_eval.to_csv(EVAL_OUT, index=False)
    print("Saved eval:", EVAL_OUT)
else:
    print("No empirical 2020 (5-bin) rows available for evaluation; check GSS mapping.")
