# Run Model Comparison — Interactive

Run the conflict experiment with any Groq model directly from this notebook. Models with existing results are skipped automatically.

**Prerequisites:** `.env` file with `GROQ_API_KEY` set, and `data/hotpotqa/dev.json` downloaded (run `python test_setup.py` first).

In [None]:
import json
import os
import sys
import glob
import time
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
from tqdm.notebook import tqdm
from dotenv import load_dotenv

sys.path.append(os.path.abspath('..'))
load_dotenv(os.path.join('..', '.env'))

from src.data.hotpotqa_loader import HotpotQALoader
from src.data.conflict_injector import ConflictInjector
from src.inference.groq_client import GroqClient
from src.inference.prompt_templates import create_cot_prompt, extract_answer
from src.evaluation.metrics import check_answer

%matplotlib inline
plt.rcParams['figure.dpi'] = 150

RESULTS_DIR = '../outputs/results'
FIGURES_DIR = '../outputs/figures'
N_EXAMPLES = 100

## 1. Configure Models

Edit this list to add/remove models. Any model available on Groq works.

In [None]:
MODELS = [
    {"id": "llama-3.3-70b-versatile", "label": "Llama-3.3-70B"},
    {"id": "llama-3.1-8b-instant",    "label": "Llama-3.1-8B"},
    # Add more models here:
    # {"id": "gemma2-9b-it",           "label": "Gemma2-9B"},
    # {"id": "mixtral-8x7b-32768",     "label": "Mixtral-8x7B"},
]

# Check which models already have results
for m in MODELS:
    path = os.path.join(RESULTS_DIR, m['id'], 'experiment.json')
    exists = os.path.exists(path)
    status = 'DONE' if exists else 'PENDING'
    print(f"  [{status}] {m['label']} ({m['id']})")

## 2. Load Dataset

In [None]:
loader = HotpotQALoader()
loader.load(path='../data/hotpotqa/dev.json')
examples = loader.get_bridge_questions(N_EXAMPLES)
print(f'Loaded {len(examples)} bridge questions')

## 3. Run Experiments (Only for Missing Models)

This cell runs the 3-condition experiment for each model that doesn't have results yet. Models with existing results are skipped.

In [None]:
def compute_metrics(results):
    metrics = {}
    for cond in ['no_conflict', 'conflict_hop1', 'conflict_hop2']:
        n = len(results[cond])
        if n == 0:
            continue
        accuracy = sum(r['correct'] for r in results[cond]) / n
        if 'conflict' in cond:
            cfr = sum(r['followed_context'] for r in results[cond]) / n
            por = sum(r['used_parametric'] for r in results[cond]) / n
        else:
            cfr = por = 0
        metrics[cond] = {
            'n': n, 'accuracy': accuracy,
            'context_following_rate': cfr,
            'parametric_override_rate': por
        }
    return metrics


def run_single_model(model_id, examples):
    """Run 3-condition experiment for one model."""
    result_path = os.path.join(RESULTS_DIR, model_id, 'experiment.json')
    if os.path.exists(result_path):
        print(f'SKIP {model_id} — results already exist')
        return None

    print(f'\nRunning: {model_id}')
    injector = ConflictInjector()
    client = GroqClient(model=model_id)
    results = {'no_conflict': [], 'conflict_hop1': [], 'conflict_hop2': []}

    for i, example in enumerate(tqdm(examples, desc=model_id)):
        question, doc1, doc2, answer = loader.extract_supporting_facts(example)
        if not doc1 or not doc2:
            continue

        # Condition 1: No Conflict
        prompt = create_cot_prompt(question, doc1, doc2)
        response = client.generate(prompt)
        pred = extract_answer(response)
        result = check_answer(pred, answer)
        result['condition'] = 'no_conflict'
        result['question'] = question
        results['no_conflict'].append(result)

        # Condition 2: Conflict at Hop 1
        mod_doc1, mod_doc2, fake = injector.inject_conflict(question, doc1, doc2, answer, conflict_hop=1)
        prompt = create_cot_prompt(question, mod_doc1, mod_doc2)
        response = client.generate(prompt)
        pred = extract_answer(response)
        result = check_answer(pred, answer, fake)
        result['condition'] = 'conflict_hop1'
        result['question'] = question
        results['conflict_hop1'].append(result)

        # Condition 3: Conflict at Hop 2
        mod_doc1, mod_doc2, fake = injector.inject_conflict(question, doc1, doc2, answer, conflict_hop=2)
        prompt = create_cot_prompt(question, mod_doc1, mod_doc2)
        response = client.generate(prompt)
        pred = extract_answer(response)
        result = check_answer(pred, answer, fake)
        result['condition'] = 'conflict_hop2'
        result['question'] = question
        results['conflict_hop2'].append(result)

    # Save
    metrics = compute_metrics(results)
    os.makedirs(os.path.dirname(result_path), exist_ok=True)
    with open(result_path, 'w') as f:
        json.dump({'metrics': metrics, 'raw_results': results}, f, indent=2)
    print(f'Saved: {result_path}')
    return metrics


# Run all pending models
for m in MODELS:
    run_single_model(m['id'], examples)

## 4. Load All Results

In [None]:
all_models = {}
for m in MODELS:
    path = os.path.join(RESULTS_DIR, m['id'], 'experiment.json')
    if os.path.exists(path):
        with open(path, 'r') as f:
            all_models[m['id']] = {'label': m['label'], 'data': json.load(f)}
        print(f"Loaded: {m['label']}")
    else:
        print(f"Missing: {m['label']}")

conditions = ['no_conflict', 'conflict_hop1', 'conflict_hop2']
cond_labels = {'no_conflict': 'No Conflict', 'conflict_hop1': 'Conflict@Hop1', 'conflict_hop2': 'Conflict@Hop2'}

## 5. Comparison Table

In [None]:
rows = []
for model_id, entry in all_models.items():
    metrics = entry['data']['metrics']
    for cond in conditions:
        if cond not in metrics:
            continue
        m = metrics[cond]
        rows.append({
            'Model': entry['label'],
            'Condition': cond_labels[cond],
            'N': m['n'],
            'Accuracy': f"{m['accuracy']:.1%}",
            'CFR': f"{m['context_following_rate']:.1%}" if m['context_following_rate'] > 0 else '-',
            'POR': f"{m['parametric_override_rate']:.1%}" if m['parametric_override_rate'] > 0 else '-',
        })

pd.DataFrame(rows)

## 6. Grouped Bar Chart

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))

labels = ['No Conflict\n(Baseline)', 'Conflict at\nHop 1', 'Conflict at\nHop 2']
x = np.arange(len(labels))
n_models = len(all_models)
width = 0.7 / n_models
colors = ['#3498db', '#e67e22', '#2ecc71', '#9b59b6']

for idx, (model_id, entry) in enumerate(all_models.items()):
    metrics = entry['data']['metrics']
    accs = [metrics[c]['accuracy'] * 100 if c in metrics else 0 for c in conditions]
    offset = (idx - (n_models - 1) / 2) * width
    bars = ax.bar(x + offset, accs, width, label=entry['label'],
                  color=colors[idx % len(colors)], edgecolor='black', linewidth=1)
    for bar in bars:
        h = bar.get_height()
        ax.annotate(f'{h:.1f}%', xy=(bar.get_x() + bar.get_width()/2, h),
                    xytext=(0, 3), textcoords='offset points',
                    ha='center', va='bottom', fontsize=9, fontweight='bold')

ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('Model Comparison: Impact of Knowledge Conflicts', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(labels, fontsize=11)
ax.set_ylim(0, 100)
ax.legend(fontsize=10)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()

os.makedirs(FIGURES_DIR, exist_ok=True)
plt.savefig(os.path.join(FIGURES_DIR, 'model_comparison.png'), dpi=300, bbox_inches='tight')
plt.show()
print('Saved to outputs/figures/model_comparison.png')

## 7. Statistical Comparison

In [None]:
model_ids = list(all_models.keys())

# Compare all pairs
for i in range(len(model_ids)):
    for j in range(i + 1, len(model_ids)):
        m1, m2 = model_ids[i], model_ids[j]
        l1, l2 = all_models[m1]['label'], all_models[m2]['label']
        met1, met2 = all_models[m1]['data']['metrics'], all_models[m2]['data']['metrics']

        print(f'\n=== {l1} vs {l2} ===')
        stat_rows = []
        for cond in conditions:
            if cond not in met1 or cond not in met2:
                continue
            a1, n1 = met1[cond]['accuracy'], met1[cond]['n']
            a2, n2 = met2[cond]['accuracy'], met2[cond]['n']
            ct = [[int(a1*n1), n1 - int(a1*n1)],
                  [int(a2*n2), n2 - int(a2*n2)]]
            chi2, p_val, _, _ = stats.chi2_contingency(ct)
            stat_rows.append({
                'Condition': cond_labels[cond],
                l1: f'{a1:.1%}',
                l2: f'{a2:.1%}',
                'Chi2': f'{chi2:.2f}',
                'p-value': f'{p_val:.4f}',
                'Significant': 'Yes' if p_val < 0.05 else 'No'
            })

        display(pd.DataFrame(stat_rows))

## 8. Behavior Breakdown (CFR / POR / Other)

In [None]:
fig, axes = plt.subplots(1, len(all_models), figsize=(6 * len(all_models), 5), sharey=True)
if len(all_models) == 1:
    axes = [axes]

for ax, (model_id, entry) in zip(axes, all_models.items()):
    metrics = entry['data']['metrics']
    conflict_conds = ['conflict_hop1', 'conflict_hop2']
    bar_labels = ['Hop 1', 'Hop 2']

    cfr = [metrics[c]['context_following_rate'] * 100 for c in conflict_conds]
    por = [metrics[c]['parametric_override_rate'] * 100 for c in conflict_conds]
    other = [100 - cfr[i] - por[i] for i in range(len(conflict_conds))]

    xp = np.arange(len(bar_labels))
    w = 0.5
    ax.bar(xp, cfr, w, label='Followed Context (Wrong)', color='#e74c3c')
    ax.bar(xp, por, w, bottom=cfr, label='Parametric Override (Correct)', color='#2ecc71')
    ax.bar(xp, other, w, bottom=[cfr[i]+por[i] for i in range(len(conflict_conds))],
           label='Other/Hallucination', color='#95a5a6')

    ax.set_title(entry['label'], fontsize=12, fontweight='bold')
    ax.set_xticks(xp)
    ax.set_xticklabels(bar_labels)
    ax.set_ylim(0, 100)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

axes[0].set_ylabel('Percentage (%)')
axes[-1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)
plt.suptitle('Model Behavior Under Knowledge Conflict', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(os.path.join(FIGURES_DIR, 'behavior_comparison.png'), dpi=300, bbox_inches='tight')
plt.show()