In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install bert-score shap --quiet

In [None]:
import numpy as np
import pandas as pd
import torch
import json
import csv
import matplotlib.pyplot as plt
import seaborn as sns
import shap
import os
import random
import glob
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from bert_score import score as bert_score

In [None]:
# CONFIG
path_to_nlu_dir = "/kaggle/input/"
data_dir = path_to_nlu_dir #+ "nlu-data/"
path_to_res = "/kaggle/working/"
save_dir = path_to_res + "results"
os.makedirs(save_dir, exist_ok=True)

datasets = [
    #{"path": "RACE-H_v1_tst.jsonl", "name": "RACE-H"}
    {"path": "SATACT_v3_tst.jsonl", "name": "SATACT"}
]

In [None]:
model_name = "Salm00n/gpt2-xl_SATACT_v3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
tokenizer.pad_token = tokenizer.eos_token
model.eval()

In [None]:
GLOBAL_Y_TRUE = []
GLOBAL_Y_PRED = []
GLOBAL_LOGITS = []

In [None]:
def load_data(filepath):
    with open(filepath, 'r') as f:
        lines = f.readlines()
    return [json.loads(l) for l in lines]

def get_prompt(example):
    if 'prompt' in example and example['prompt']:
        return example['prompt']
    elif 'context' in example and 'question' in example:
        return example['context'] + " " + example['question']
    elif 'question' in example:
        return example['question']
    else:
        return ""

def get_options(example):
    if "options" in example and example["options"]:
        return example["options"]
    elif all(k in example for k in ["answerA", "answerB", "answerC", "answerD"]):
        return [example["answerA"], example["answerB"], example["answerC"], example["answerD"]]
    else:
        return []

def prepare_input(prompt):
    enc = tokenizer(prompt, return_tensors="pt", truncation=True, padding=False).to(device)
    if enc['input_ids'].shape[-1] > 1024:
        enc['input_ids'] = enc['input_ids'][:, -1024:]
        enc['attention_mask'] = enc['attention_mask'][:, -1024:]
    return enc

In [None]:
def predict(model, batch):
    preds = []
    logits_all_examples = []
    option_labels = ["A", "B", "C", "D"]
    for example in batch:
        prompt = get_prompt(example)
        options = get_options(example)
        if not options:
            continue
        log_probs = []
        for opt in options:
            full_input = prompt + " " + opt
            tokens = prepare_input(full_input)
            input_ids = tokens["input_ids"]
            attention_mask = tokens["attention_mask"]
            with torch.no_grad():
                output = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
            logits = output.logits
            opt_tokens = tokenizer(opt, return_tensors="pt")["input_ids"][0].to(device)
            seq_logits = logits[0, -opt_tokens.size(0)-1:-1, :]
            log_prob = sum(torch.log_softmax(seq_logits[i], dim=-1)[opt_tokens[i]].item() for i in range(opt_tokens.size(0)))
            log_probs.append(log_prob)
        pred_idx = np.argmax(log_probs)
        preds.append(pred_idx)
        logits_all_examples.append(log_probs)
    return preds, logits_all_examples

In [None]:
def get_attention_map(model, input_text):
    inputs = prepare_input(input_text)
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
    attn = outputs.attentions[-1][0]
    mean_attn = attn.mean(dim=0).cpu().numpy()
    return mean_attn, tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

def plot_attention_heatmap(attn_map, tokens, max_tokens=10):
    attn_map = attn_map[:max_tokens, :max_tokens]
    tokens = tokens[:max_tokens]
    plt.figure(figsize=(8, 6))
    sns.heatmap(attn_map, xticklabels=tokens, yticklabels=tokens, cmap="coolwarm")
    plt.title("GPT-2 XL Final Layer Attention", fontsize=18, fontweight='bold')
    plt.xticks(rotation=90, fontsize=12)
    plt.yticks(rotation=0, fontsize=12)
    plt.tight_layout()
    plt.show()

In [None]:
def true_attention_rollout(model, input_text):
    inputs = prepare_input(input_text)
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
    attentions = outputs.attentions
    rollout = torch.stack(attentions).squeeze(1).mean(dim=1).mean(dim=0)
    rollout_np = rollout.cpu().numpy()
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
    return rollout_np, tokens

def plot_attention_rollout(attn_rollout, tokens, max_tokens=10):
    attn_rollout = attn_rollout[:max_tokens, :max_tokens]
    tokens = tokens[:max_tokens]
    plt.figure(figsize=(10, 8))
    sns.heatmap(attn_rollout, xticklabels=tokens, yticklabels=tokens, cmap="coolwarm")
    plt.title("True Attention Rollout Across All Layers (Mean Heads)", fontsize=18, fontweight='bold')
    plt.xticks(rotation=90, fontsize=12)
    plt.yticks(rotation=0, fontsize=12)
    plt.tight_layout()
    plt.show()

In [None]:
def logit_lens_visual(model, input_text):
    tokens = prepare_input(input_text)
    with torch.no_grad():
        outputs = model(**tokens, output_hidden_states=True)

    logits = outputs.logits[0]  
    confidences = logits.softmax(dim=-1).max(dim=-1).values.cpu().numpy()

    return confidences

In [None]:
def model_predict_for_shap(texts):
    inputs = tokenizer(texts, return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits[:, -1, :]  
    return logits.cpu().numpy()

masker = shap.maskers.Text(tokenizer)
explainer = shap.Explainer(model_predict_for_shap, masker)

def fast_compute_shap_values(example):
    prompt = get_prompt(example)

    if not isinstance(prompt, str) or prompt.strip() == "":
        return None

    try:
        input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids[0]

        if len(input_ids) > 20:
            selected_idxs = torch.randperm(len(input_ids))[:20]
            input_ids = input_ids[selected_idxs.sort().values]

        sampled_prompt = tokenizer.decode(input_ids, skip_special_tokens=True)

        if not isinstance(sampled_prompt, str) or sampled_prompt.strip() == "" or len(sampled_prompt.split()) < 3:
            return None

        shap_values = explainer([sampled_prompt])
        return float(np.mean(np.abs(shap_values.values)))

    except:
        return None  

In [None]:
def shallow_heuristic_check(question, options):
    lower_q = question.lower()
    matches = [int(opt.lower() in lower_q or any(word in lower_q for word in opt.lower().split())) for opt in options]
    return matches

def context_fidelity_score(model, example):
    options = get_options(example)
    question = example.get("question", "")
    if 'context' in example:
        prompt_full = example['context'] + " " + question
    else:
        prompt_full = question
    prompt_no_context = question.strip()
    if not options or not question:
        return None
    try:
        pred_full = predict(model, [{"prompt": prompt_full, "options": options}])[0][0]
        pred_no_context = predict(model, [{"prompt": prompt_no_context, "options": options}])[0][0]
        return 1.0 if pred_full != pred_no_context else 0.0
    except:
        return None

In [None]:
def reasoning_score(model, input_text, options):
    try:
        tokens = tokenizer(input_text, return_tensors="pt").to(device)
        input_ids = tokens['input_ids']
        baseline_pred, _ = predict(model, [{"prompt": input_text, "options": options}])
        if not isinstance(baseline_pred, int):
            return 0
        num_tokens = input_ids.size(1)
        sample_size = min(20, num_tokens)
        sampled_idxs = torch.randperm(num_tokens)[:sample_size]
        changes = 0
        for idx in sampled_idxs:
            ablated = input_ids.clone()
            ablated[0, idx] = tokenizer.unk_token_id
            ablated_text = tokenizer.decode(ablated[0], skip_special_tokens=True)
            pred_ablated, _ = predict(model, [{"prompt": ablated_text, "options": options}])
            if pred_ablated[0] != baseline_pred:
                changes += 1
        return changes / sample_size
    except:
        return 0

In [None]:
def compute_probing_accuracy(examples):
    embeddings = []
    labels = []

    for ex in examples:
        prompt = get_prompt(ex)
        enc = prepare_input(prompt)
        with torch.no_grad():
            emb = model.transformer.wte(enc['input_ids']).mean(dim=1).cpu().numpy()
        embeddings.append(emb[0])

        # Try to get label from example
        label = ex.get('label') or ex.get('gold') or ex.get('correct') or ex.get('answer')
        if isinstance(label, str):
            label = ord(label.upper()) - ord('A')
        labels.append(label)

    if len(set(labels)) < 2:
        return None

    clf = LogisticRegression(max_iter=1000)
    clf.fit(embeddings, labels)
    probing_accuracy = clf.score(embeddings, labels)

    return probing_accuracy

In [None]:
def run_bertscore(references, candidates):
    P, R, F1 = bert_score(candidates, references, lang="en", verbose=False)
    return list(F1.numpy())

In [None]:
def contrastive_distractor_swap(example):
    prompt = get_prompt(example)
    options = get_options(example)
    if not options:
        return None, None
    swapped = options[::-1]
    pred_orig = predict(model, [{"prompt": prompt, "options": options}])[0][0]
    pred_swapped = predict(model, [{"prompt": prompt, "options": swapped}])[0][0]
    return pred_orig, pred_swapped

In [None]:
def process_dataset(dataset):
    examples = load_data(data_dir + dataset["path"])
    rows = []
    probing_acc = compute_probing_accuracy(examples)

    for idx, example in enumerate(tqdm(examples)):
        prompt = get_prompt(example)
        options = get_options(example)
        if not options:
            continue
        
        preds, logits = predict(model, [{"prompt": prompt, "options": options}])
        pred_idx = preds[0]
        gold_idx = ord(example.get("answer", "A").upper()) - ord('A')
        
        GLOBAL_Y_TRUE.append(gold_idx)
        GLOBAL_Y_PRED.append(pred_idx)
        GLOBAL_LOGITS.append(logits[0])

        try:
            heuristics = shallow_heuristic_check(example.get('question', ''), options)
            heuristic_score = sum(heuristics) / len(heuristics) if heuristics else None
        except:
            heuristic_score = None
        
        try:
            shap_val = fast_compute_shap_values(example)
        except:
            shap_val = None
        
        try:
            contrastive_orig, contrastive_swap = contrastive_distractor_swap(example)
            contrastive_orig = chr(65 + contrastive_orig) if contrastive_orig is not None else None
            contrastive_swap = chr(65 + contrastive_swap) if contrastive_swap is not None else None
            contrastive_robust = contrastive_orig == contrastive_swap
        except:
            contrastive_orig, contrastive_swap, contrastive_robust = None, None, None

        row = {
            "id": idx,
            "dataset": dataset["name"],
            "question": example.get("question", ""),
            "option_0": options[0] if len(options) > 0 else "",
            "option_1": options[1] if len(options) > 1 else "",
            "option_2": options[2] if len(options) > 2 else "",
            "option_3": options[3] if len(options) > 3 else "",
            "probing_accuracy": probing_acc,
            "logits": logits[0],
            "prediction": pred_idx,
            "actual": gold_idx,
            "bertscore_f1": run_bertscore([options[gold_idx]], [options[pred_idx]])[0],
            "heuristic_dependence_score": heuristic_score,
            "reasoning_score": reasoning_score(model, prompt, options),
            "context_fidelity_score": context_fidelity_score(model, example),
            "shap_score": shap_val,
            "contrastive_prediction_original": contrastive_orig,
            "contrastive_prediction_swapped": contrastive_swap,
            "contrastive_robust": contrastive_robust
        }

        rows.append(row)

    results_df = pd.DataFrame(rows)
    results_df.to_csv(os.path.join(save_dir, f"{dataset['name']}_results.csv"), index=False)

In [None]:
def plot_reasoning_context(save_dir, dataset_name):
    results_path = os.path.join(save_dir, f"{dataset_name}_results.csv")
    df = pd.read_csv(results_path)

    reasoning = df['reasoning_score'].dropna()
    context_fidelity = df['context_fidelity_score'].dropna()
    # heuristic_dependence = df['heuristic_dependence_score'].dropna()  
    fig, ax = plt.subplots(2, 1, figsize=(8, 10))  
    sns.histplot(reasoning, ax=ax[0], kde=False, color="steelblue", edgecolor='black', linewidth=0.6)
    ax[0].set_title("Reasoning Score", fontsize=18, fontweight='bold')
    ax[0].set_xlabel("Score", fontsize=14)
    ax[0].set_ylabel("Frequency", fontsize=14)
    ax[0].tick_params(axis='both', labelsize=12)
    ax[0].grid(False)

    sns.histplot(context_fidelity, ax=ax[1], kde=False, color="steelblue", edgecolor='black', linewidth=0.6)
    ax[1].set_title("Context Fidelity", fontsize=18, fontweight='bold')
    ax[1].set_xlabel("Score", fontsize=14)
    ax[1].set_ylabel("Frequency", fontsize=14)
    ax[1].tick_params(axis='both', labelsize=12)
    ax[1].grid(False)

    plt.tight_layout()
    plt.show()

In [None]:
def plot_softmax_distribution():
    all_softmax = [np.exp(logit) / np.sum(np.exp(logit)) for logit in GLOBAL_LOGITS]
    all_softmax_flat = np.concatenate(all_softmax)
    plt.figure(figsize=(6, 5))
    sns.histplot(all_softmax_flat, bins=50, color='steelblue')
    plt.title("Softmax Probability Distribution", fontsize=18)
    plt.xlabel("Softmax Probability", fontsize=14)
    plt.ylabel("Frequency", fontsize=14)
    plt.grid(False)
    plt.tight_layout()
    plt.show()

In [None]:
def final_confusion_matrix():
    acc = accuracy_score(GLOBAL_Y_TRUE, GLOBAL_Y_PRED)
    cm = confusion_matrix(GLOBAL_Y_TRUE, GLOBAL_Y_PRED)
    print(f"Accuracy: {acc:.3f}")
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, cmap="Blues", fmt="d")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()

In [None]:
def datamap_confidence_scatter():
    correct = np.array([int(p == y) for p, y in zip(GLOBAL_Y_PRED, GLOBAL_Y_TRUE)])
    conf = np.max(np.array(GLOBAL_LOGITS), axis=1)
    plt.figure(figsize=(8, 6))
    sns.scatterplot(x=conf, y=correct, hue=correct, palette="coolwarm")
    plt.xlabel("Max Confidence")
    plt.ylabel("Correct Prediction")
    plt.title("DataMap Scatter: Confidence vs Accuracy")
    plt.show()

In [None]:
example = load_data(data_dir + datasets[0]["path"])[0]
prompt = get_prompt(example)
attn_map, tokens = get_attention_map(model, prompt)
plot_attention_heatmap(attn_map, tokens)
rollout_map, rollout_tokens = true_attention_rollout(model, prompt)
plot_attention_rollout(rollout_map, rollout_tokens)

In [None]:
for ds in datasets:
    process_dataset(ds)

In [None]:
plot_reasoning_context(save_dir, dataset_name="SATACT")

In [None]:
plot_softmax_distribution()

In [None]:
final_confusion_matrix()

In [None]:
datamap_confidence_scatter()

In [None]:
# random example
example = random.choice(load_data(data_dir + datasets[0]["path"]))
prompt = get_prompt(example)
confidences = logit_lens_visual(model, prompt)

plt.figure(figsize=(8, 6))
plt.plot(confidences, marker='o')
plt.title("Logit Lens: Token-wise Confidence (Random Example)", fontsize=18, fontweight='bold')
plt.xlabel("Token Position", fontsize=14)
plt.ylabel("Max Softmax Confidence", fontsize=14)
plt.grid(True)
plt.tight_layout()
plt.show()