In [None]:
# --------------------------------------------------------------
# IITB EdTech Internship 2025 – Problem ID 7
# STEP 3 & 4: Saliency Maps + Comparative Metrics
# --------------------------------------------------------------

# --------------------------------------------------------------
# 0. Mount & Imports
# --------------------------------------------------------------
from google.colab import drive
drive.mount('/content/drive')

import pandas as pd, numpy as np, os, cv2, logging, json
from pathlib import Path
from scipy.interpolate import interp1d
from scipy.stats import gaussian_kde
from scipy.ndimage import gaussian_filter
from sklearn.metrics import roc_auc_score
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")

# --------------------------------------------------------------
# 1. CONFIG
# --------------------------------------------------------------
CLEAN_ROOT = Path("/content/results/cleaned")
METRICS_ROOT = CLEAN_ROOT / "metrics"
METRICS_ROOT.mkdir(exist_ok=True)

FIX_CANVAS = CLEAN_ROOT / "fixations_canvas.csv"
FDM_NATIVE_DIR = CLEAN_ROOT / "fdm_native"
AVG_FDM_DIR = CLEAN_ROOT / "fdm_group_avg"

LOG_OUT = METRICS_ROOT / "03_step3_4.log"
logging.basicConfig(filename=LOG_OUT, level=logging.INFO,
                    format="%(asctime)s | %(levelname)s | %(message)s")
log = logging.getLogger()
log.info("=== STEP 3 & 4 STARTED ===")
print("STEP 3 & 4 – Saliency Maps + Metrics")

# --------------------------------------------------------------
# 2. Load Data
# --------------------------------------------------------------
fix_canvas = pd.read_csv(FIX_CANVAS)
fdm_files = list(FDM_NATIVE_DIR.glob("*.npy"))
avg_fdm_files = list(AVG_FDM_DIR.glob("fdm_avg_*.npy"))

log.info(f"Loaded {len(fix_canvas)} canvas fixations, {len(fdm_files)} native FDMs")

# --------------------------------------------------------------
# 3.1 Empirical FDMs (Gaussian-weighted) – Per Stimulus
# --------------------------------------------------------------
print("3.1 Building Empirical FDMs (Gaussian σ=visual angle)...")

CANVAS_W, CANVAS_H = 1024, 1024
SIGMA_PX = 40  # ~1° visual angle at 1024px canvas

fdm_canvas_list = []
for (pid, qid), grp in tqdm(fix_canvas.groupby(['pid','qid']), desc="FDM Canvas"):
    canvas = np.zeros((CANVAS_H, CANVAS_W))
    for _, row in grp.iterrows():
        x, y = int(row['x_canvas']), int(row['y_canvas'])
        dur = row['duration']
        if 0 <= x < CANVAS_W and 0 <= y < CANVAS_H:
            cv2.circle(canvas, (x, y), 1, dur, -1)
    canvas = gaussian_filter(canvas, sigma=SIGMA_PX)
    if canvas.sum() > 0:
        canvas /= canvas.sum()
    fdm_canvas_list.append({
        'pid': pid, 'qid': qid, 'difficulty': grp['difficulty'].iloc[0],
        'fdm': canvas
    })

FDM_CANVAS_DIR = METRICS_ROOT / "fdm_canvas"
FDM_CANVAS_DIR.mkdir(exist_ok=True)
for item in fdm_canvas_list:
    np.save(FDM_CANVAS_DIR / f"P{item['pid']:02d}_Q{item['qid']}.npy", item['fdm'])

log.info(f"SAVED {len(fdm_canvas_list)} canvas FDMs → {FDM_CANVAS_DIR}")

# --------------------------------------------------------------
# 3.2 Temporal Maps (Early vs Late)
# --------------------------------------------------------------
print("3.2 Temporal Maps (Early vs Late)...")

temporal_maps = []
for item in fdm_canvas_list:
    pid, qid = item['pid'], item['qid']
    grp = fix_canvas[(fix_canvas['pid']==pid) & (fix_canvas['qid']==qid)].copy()
    if len(grp) < 4: continue
    grp = grp.sort_values('start')
    mid = len(grp) // 2
    early = grp.iloc[:mid]
    late = grp.iloc[mid:]

    def make_map(sub):
        m = np.zeros((CANVAS_H, CANVAS_W))
        for _, r in sub.iterrows():
            x, y = int(r['x_canvas']), int(r['y_canvas'])
            if 0 <= x < CANVAS_W and 0 <= y < CANVAS_H:
                cv2.circle(m, (x, y), 1, r['duration'], -1)
        m = gaussian_filter(m, sigma=SIGMA_PX)
        if m.sum() > 0: m /= m.sum()
        return m

    early_map = make_map(early)
    late_map = make_map(late)
    temporal_maps.append({
        'pid': pid, 'qid': qid, 'difficulty': item['difficulty'],
        'early': early_map, 'late': late_map
    })

TEMP_DIR = METRICS_ROOT / "temporal"
TEMP_DIR.mkdir(exist_ok=True)
for t in temporal_maps:
    np.save(TEMP_DIR / f"P{t['pid']:02d}_Q{t['qid']}_early.npy", t['early'])
    np.save(TEMP_DIR / f"P{t['pid']:02d}_Q{t['qid']}_late.npy", t['late'])

log.info(f"SAVED {len(temporal_maps)} temporal maps")

# --------------------------------------------------------------
# 4.1 Load Group-Averaged FDMs
# --------------------------------------------------------------
print("4.1 Loading Group-Averaged FDMs...")
avg_fdms = {}
for f in avg_fdm_files:
    diff = f.stem.split('_')[-1]
    mat = np.load(f)
    # Resize to 1024x1024
    h, w = mat.shape
    target_h, target_w = 1024, 1024
    mat_resized = cv2.resize(mat, (target_w, target_h))
    avg_fdms[diff] = mat_resized / mat_resized.sum()

# --------------------------------------------------------------
# 4.2 Metric Functions
# --------------------------------------------------------------
from scipy.stats import wasserstein_distance
from scipy.spatial.distance import jensenshannon
from sklearn.metrics import roc_auc_score
import numpy as np

def similarity(m1, m2):
    return np.sum(np.sqrt(m1 * m2))

def kl_divergence(p, q):
    p, q = p.ravel() + 1e-12, q.ravel() + 1e-12
    return np.sum(p * np.log(p / q))

def nss(saliency, fixations):
    sal_flat = (saliency - saliency.mean()) / (saliency.std() + 1e-8)
    return sal_flat[fixations[:,1].astype(int), fixations[:,0].astype(int)].mean()

def cc(m1, m2):
    return np.corrcoef(m1.ravel(), m2.ravel())[0,1]

def auc_judd(saliency, fixations, jitter=True):
    saliency_flat = saliency.ravel()
    labels = np.zeros(saliency.size, dtype=int)
    labels[fixations[:,1].astype(int) * CANVAS_W + fixations[:,0].astype(int)] = 1
    if jitter:
        saliency_flat += np.random.rand(*saliency_flat.shape) * 1e-8
    return roc_auc_score(labels, saliency_flat)

# --------------------------------------------------------------
# 4.3 Compute Comparisons: Easy vs Medium vs Hard
# --------------------------------------------------------------
print("4.3 Computing Comparative Metrics...")

metrics_df = []
pairs = [('easy','medium'), ('easy','hard'), ('medium','hard')]

for diff1, diff2 in pairs:
    m1 = avg_fdms[diff1]
    m2 = avg_fdms[diff2]
    metrics_df.append({
        'comparison': f"{diff1}_vs_{diff2}",
        'SIM': similarity(m1, m2),
        'KL': kl_divergence(m1, m2),
        'JS': jensenshannon(m1.ravel(), m2.ravel()),
        'CC': cc(m1, m2),
        'EMD': wasserstein_distance(m1.ravel(), m2.ravel())
    })

metrics_df = pd.DataFrame(metrics_df)
metrics_df.to_csv(METRICS_ROOT / "group_comparison_metrics.csv", index=False)
log.info("SAVED group comparison metrics")

# --------------------------------------------------------------
# 4.4 Point/Scanpath-Level Metrics (NSS, AUC-Judd)
# --------------------------------------------------------------
print("4.4 Point/Scanpath-Level Metrics...")

point_metrics = []
for item in fdm_canvas_list:
    pid, qid, diff = item['pid'], item['qid'], item['difficulty']
    fdm = item['fdm']
    fix = fix_canvas[(fix_canvas['pid']==pid) & (fix_canvas['qid']==qid)]
    pts = fix[['x_canvas','y_canvas']].values.astype(int)
    if len(pts) == 0: continue
    point_metrics.append({
        'pid': pid, 'qid': qid, 'difficulty': diff,
        'NSS': nss(fdm, pts),
        'AUC_Judd': auc_judd(fdm, pts)
    })

point_df = pd.DataFrame(point_metrics)
point_df.to_csv(METRICS_ROOT / "point_level_metrics.csv", index=False)

# Aggregate
agg = point_df.groupby('difficulty').mean()[['NSS','AUC_Judd']]
agg.to_csv(METRICS_ROOT / "point_metrics_by_difficulty.csv")
log.info("SAVED point-level metrics")

# --------------------------------------------------------------
# 4.5 Statistical Testing
# --------------------------------------------------------------
from scipy.stats import ttest_ind

stats = []
for metric in ['NSS', 'AUC_Judd']:
    easy = point_df[point_df['difficulty']=='easy'][metric]
    hard = point_df[point_df['difficulty']=='hard'][metric]
    t, p = ttest_ind(easy, hard)
    stats.append({'metric': metric, 't_stat': t, 'p_value': p})

stats_df = pd.DataFrame(stats)
stats_df.to_csv(METRICS_ROOT / "statistical_tests.csv", index=False)
log.info(f"Stats: NSS p={stats[0]['p_value']:.4f}, AUC p={stats[1]['p_value']:.4f}")

# --------------------------------------------------------------
# 5. Visualizations
# --------------------------------------------------------------
print("5. Visualizing Results...")

VIS_DIR = METRICS_ROOT / "visualizations"
VIS_DIR.mkdir(exist_ok=True)

# Group FDMs
plt.figure(figsize=(15,5))
for i, diff in enumerate(['easy','medium','hard']):
    plt.subplot(1,3,i+1)
    sns.heatmap(avg_fdms[diff], cmap='viridis', cbar=True)
    plt.title(f"Avg FDM – {diff.capitalize()}")
plt.tight_layout()
plt.savefig(VIS_DIR / "group_avg_fdms.png")
plt.close()

# Difference Heatmaps
for diff1, diff2 in pairs:
    diff_map = avg_fdms[diff1] - avg_fdms[diff2]
    plt.figure(figsize=(6,5))
    sns.heatmap(diff_map, cmap='coolwarm', center=0, cbar=True)
    plt.title(f"FDM Diff: {diff1} - {diff2}")
    plt.savefig(VIS_DIR / f"diff_{diff1}_vs_{diff2}.png")
    plt.close()

# Metric Bar Plots
fig, ax = plt.subplots(1,2, figsize=(12,5))
sns.barplot(data=point_df, x='difficulty', y='NSS', ax=ax[0], order=['easy','medium','hard'])
ax[0].set_title("NSS by Difficulty")
sns.barplot(data=point_df, x='difficulty', y='AUC_Judd', ax=ax[1], order=['easy','medium','hard'])
ax[1].set_title("AUC-Judd by Difficulty")
plt.tight_layout()
plt.savefig(VIS_DIR / "point_metrics_bar.png")
plt.close()

log.info(f"Visualizations saved → {VIS_DIR}")

# --------------------------------------------------------------
# FINAL SUMMARY
# --------------------------------------------------------------
print("\nSTEP 3 & 4 COMPLETE!")
print(f"Outputs → {METRICS_ROOT}")
print(f"   • fdm_canvas/                  : {len(fdm_canvas_list)} Gaussian FDMs")
print(f"   • temporal/                    : early/late maps")
print(f"   • group_comparison_metrics.csv : SIM, KL, CC, EMD")
print(f"   • point_level_metrics.csv      : NSS, AUC-Judd per trial")
print(f"   • visualizations/              : heatmaps, diffs, bars")

log.info("STEP 3 & 4 FINISHED")