In [None]:
# --------------------------------------------------------------
# IITB EdTech Internship 2025 – Problem ID 7
# STEP 6 & 7: Final Visualizations + Report + Experimentation
# --------------------------------------------------------------

# --------------------------------------------------------------
# 0. Mount & Imports
# --------------------------------------------------------------
from google.colab import drive
drive.mount('/content/drive')

import pandas as pd, numpy as np, os, cv2, json, random
from pathlib import Path
from scipy.ndimage import gaussian_filter
from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
from matplotlib.backends.backend_pdf import PdfPages
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings("ignore")

# --------------------------------------------------------------
# 1. CONFIG
# --------------------------------------------------------------
CLEAN_ROOT = Path("/content/results/cleaned")
MODEL_ROOT = CLEAN_ROOT / "models"
METRICS_ROOT = CLEAN_ROOT / "metrics"
VIS_ROOT = CLEAN_ROOT / "final_report"
VIS_ROOT.mkdir(exist_ok=True)

FDM_CANVAS_DIR = METRICS_ROOT / "fdm_canvas"
TEMP_DIR = METRICS_ROOT / "temporal"
IMG_ROOT = Path("/content/data/stimuli")  # UPDATE IF NEEDED
FIX_CANVAS = CLEAN_ROOT / "fixations_canvas.csv"
EVAL_CSV = MODEL_ROOT / "model_evaluation.csv"

LOG_OUT = VIS_ROOT / "06_step6_7.log"
logging.basicConfig(filename=LOG_OUT, level=logging.INFO,
                    format="%(asctime)s | %(levelname)s | %(message)s")
log = logging.getLogger()
log.info("=== STEP 6 & 7 STARTED ===")
print("STEP 6 & 7 – Final Visualizations + Report")

# --------------------------------------------------------------
# 2. Load Data
# --------------------------------------------------------------
fix_canvas = pd.read_csv(FIX_CANVAS)
eval_df = pd.read_csv(EVAL_CSV) if EVAL_CSV.exists() else None

# Group-averaged FDMs (from STEP 2)
avg_fdm = {}
for diff in ['easy','medium','hard']:
    f = CLEAN_ROOT / "fdm_group_avg" / f"fdm_avg_{diff}.npy"
    if f.exists():
        mat = np.load(f)
        mat = cv2.resize(mat, (1024, 1024))
        if mat.sum() > 0: mat /= mat.sum()
        avg_fdm[diff] = mat

log.info("Loaded all data")

# --------------------------------------------------------------
# 3. Helper: Overlay FDM on Image
# --------------------------------------------------------------
def overlay_fdm(img_path, fdm, alpha=0.6, title=""):
    img = cv2.cvtColor(cv2.imread(str(img_path)), cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (1024, 1024))
    fdm_resized = cv2.resize(fdm, (1024, 1024))
    if fdm_resized.sum() > 0: fdm_resized /= fdm_resized.sum()

    plt.figure(figsize=(8,8))
    plt.imshow(img)
    plt.imshow(fdm_resized, cmap='jet', alpha=alpha)
    plt.title(title, fontsize=14)
    plt.axis('off')
    return plt.gcf()

# --------------------------------------------------------------
# 4.1 Heatmaps: Per-Difficulty FDMs on Stimuli
# --------------------------------------------------------------
print("4.1 Generating per-difficulty FDM overlays...")

pdf = PdfPages(VIS_ROOT / "final_report.pdf")

# Sample one stimulus per difficulty
sample_stim = fix_canvas.groupby('difficulty').apply(lambda x: x.sample(1, random_state=42)).reset_index(drop=True)

for _, row in sample_stim.iterrows():
    pid, qid = row['pid'], row['qid']
    img_path = IMG_ROOT / f"P{pid:02d}_Q{qid}.jpg"
    fdm_path = FDM_CANVAS_DIR / f"P{pid:02d}_Q{qid}.npy"
    if not img_path.exists() or not fdm_path.exists(): continue

    fdm = np.load(fdm_path)
    fdm = cv2.resize(fdm, (1024, 1024)); fdm /= fdm.sum()
    fig = overlay_fdm(img_path, fdm, title=f"{row['difficulty'].capitalize()} Question (P{pid:02d}_Q{qid})")
    pdf.savefig(fig, bbox_inches='tight'); plt.close()

log.info("Added per-difficulty FDM overlays")

# --------------------------------------------------------------
# 4.2 Difference Heatmaps
# --------------------------------------------------------------
print("4.2 Difference Heatmaps...")
pairs = [('hard','easy'), ('hard','medium'), ('medium','easy')]
for d1, d2 in pairs:
    if d1 not in avg_fdm or d2 not in avg_fdm: continue
    diff = avg_fdm[d1] - avg_fdm[d2]
    plt.figure(figsize=(7,6))
    sns.heatmap(diff, cmap='coolwarm', center=0, cbar=True)
    plt.title(f"FDM Difference: {d1.capitalize()} − {d2.capitalize()}")
    plt.xlabel(""); plt.ylabel("")
    pdf.savefig(plt.gcf(), bbox_inches='tight'); plt.close()

# --------------------------------------------------------------
# 4.3 Temporal Ribbons (Early vs Late)
# --------------------------------------------------------------
print("4.3 Temporal Ribbons...")
for _, row in sample_stim.iterrows():
    pid, qid = row['pid'], row['qid']
    early_path = TEMP_DIR / f"P{pid:02d}_Q{qid}_early.npy"
    late_path = TEMP_DIR / f"P{pid:02d}_Q{qid}_late.npy"
    img_path = IMG_ROOT / f"P{pid:02d}_Q{qid}.jpg"
    if not all(p.exists() for p in [early_path, late_path, img_path]): continue

    img = cv2.cvtColor(cv2.imread(str(img_path)), cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (1024, 1024))
    early = np.load(early_path); early = cv2.resize(early, (1024, 1024)); early /= early.sum()
    late = np.load(late_path); late = cv2.resize(late, (1024, 1024)); late /= late.sum()

    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    axes[0].imshow(img); axes[0].set_title("Stimulus"); axes[0].axis('off')
    axes[1].imshow(img); axes[1].imshow(early, cmap='Reds', alpha=0.6); axes[1].set_title("Early Attention"); axes[1].axis('off')
    axes[2].imshow(img); axes[2].imshow(late, cmap='Blues', alpha=0.6); axes[2].set_title("Late Attention"); axes[2].axis('off')
    plt.suptitle(f"Temporal Dynamics – {row['difficulty'].capitalize()} (P{pid:02d}_Q{qid})", fontsize=16)
    pdf.savefig(fig, bbox_inches='tight'); plt.close()

# --------------------------------------------------------------
# 4.4 AOI Analysis (Simulated AOIs)
# --------------------------------------------------------------
print("4.4 AOI Chord & Violin Plots...")
# Simulate 4 AOIs per image (top-left, top-right, bottom-left, bottom-right)
def get_aoi(x, y, w=1024, h=1024):
    if x < w//2 and y < h//2: return "TL"
    if x >= w//2 and y < h//2: return "TR"
    if x < w//2 and y >= h//2: return "BL"
    return "BR"

fix_canvas['aoi'] = fix_canvas.apply(lambda r: get_aoi(r['x_canvas'], r['y_canvas']), axis=1)
aoi_dwell = fix_canvas.groupby(['pid','qid','difficulty','aoi'])['duration'].sum().reset_index()
aoi_dwell['dwell_norm'] = aoi_dwell.groupby(['pid','qid'])['duration'].transform(lambda x: x / x.sum())

# Violin plot
plt.figure(figsize=(10,6))
sns.violinplot(data=aoi_dwell, x='aoi', y='dwell_norm', hue='difficulty', split=True, inner='quartile')
plt.title("Normalized Dwell Time per AOI by Difficulty")
plt.ylabel("Proportion of Total Dwell Time")
pdf.savefig(plt.gcf(), bbox_inches='tight'); plt.close()

# Chord diagram (simplified)
aoi_trans = fix_canvas.groupby(['pid','qid']).apply(
    lambda g: pd.Series({
        'TL→TR': ((g['aoi']=='TL') & (g['aoi'].shift(-1)=='TR')).sum(),
        'TR→BR': ((g['aoi']=='TR') & (g['aoi'].shift(-1)=='BR')).sum(),
        'BR→BL': ((g['aoi']=='BR') & (g['aoi'].shift(-1)=='BL')).sum(),
        'BL→TL': ((g['aoi']=='BL') & (g['aoi'].shift(-1)=='TL')).sum(),
    })
).mean()
plt.figure(figsize=(6,6))
plt.bar(aoi_trans.index, aoi_trans.values, color=['#1f77b4','#ff7f0e','#2ca02c','#d62728'])
plt.title("Average AOI Transition Frequency")
plt.xticks(rotation=45)
pdf.savefig(plt.gcf(), bbox_inches='tight'); plt.close()

# --------------------------------------------------------------
# 4.5 Small Multiples Grid
# --------------------------------------------------------------
print("4.5 Small Multiples...")
fig, axes = plt.subplots(3, 3, figsize=(15, 12))
stim_sample = fix_canvas.groupby('difficulty').apply(lambda x: x.sample(3, random_state=42)).reset_index(drop=True)

for i, (_, row) in enumerate(stim_sample.iterrows()):
    ax = axes[i//3, i%3]
    pid, qid = row['pid'], row['qid']
    img_path = IMG_ROOT / f"P{pid:02d}_Q{qid}.jpg"
    fdm_path = FDM_CANVAS_DIR / f"P{pid:02d}_Q{qid}.npy"
    if not img_path.exists() or not fdm_path.exists(): continue
    img = cv2.cvtColor(cv2.imread(str(img_path)), cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (340, 340))
    fdm = np.load(fdm_path); fdm = cv2.resize(fdm, (340, 340)); fdm /= fdm.sum()
    ax.imshow(img)
    ax.imshow(fdm, cmap='jet', alpha=0.5)
    ax.set_title(f"{row['difficulty'][:1].upper()} | P{pid:02d}_Q{qid}", fontsize=10)
    ax.axis('off')
plt.suptitle("Small Multiples: FDMs × Difficulty", fontsize=16)
pdf.savefig(fig, bbox_inches='tight'); plt.close()

# --------------------------------------------------------------
# 5. Model Performance Summary
# --------------------------------------------------------------
if eval_df is not None:
    summary = eval_df.groupby('difficulty').mean()[['NSS_emp','NSS_cb','AUC_emp']]
    summary_plot = summary[['NSS_emp','NSS_cb']].plot(kind='bar', figsize=(8,5), color=['#1f77b4','#ff7f0e'])
    plt.title("Model Performance (NSS) by Difficulty")
    plt.ylabel("NSS")
    plt.xticks(rotation=0)
    pdf.savefig(plt.gcf(), bbox_inches='tight'); plt.close()

# --------------------------------------------------------------
# 6. Experiment: Gaussian σ Sensitivity
# --------------------------------------------------------------
print("6. Experiment: Gaussian σ...")
sigmas = [20, 40, 60, 80]
nss_by_sigma = []

for sigma in sigmas:
    nss_list = []
    for _, row in sample_stim.iterrows():
        pid, qid = row['pid'], row['qid']
        grp = fix_canvas[(fix_canvas['pid']==pid) & (fix_canvas['qid']==qid)]
        canvas = np.zeros((1024,1024))
        for _, r in grp.iterrows():
            x, y = int(r['x_canvas']), int(r['y_canvas'])
            cv2.circle(canvas, (x,y), 1, r['duration'], -1)
        canvas = gaussian_filter(canvas, sigma=sigma)
        if canvas.sum() > 0: canvas /= canvas.sum()
        pts = grp[['x_canvas','y_canvas']].values.astype(int)
        sal_norm = (canvas - canvas.mean()) / (canvas.std() + 1e-8)
        nss = sal_norm[pts[:,1], pts[:,0]].mean()
        nss_list.append(nss)
    nss_by_sigma.append({'sigma': sigma, 'NSS': np.mean(nss_list)})

sigma_df = pd.DataFrame(nss_by_sigma)
plt.figure(figsize=(7,5))
sns.lineplot(data=sigma_df, x='sigma', y='NSS', marker='o')
plt.title("NSS vs Gaussian σ (Visual Angle)")
plt.xlabel("σ (pixels @ 1024px)")
pdf.savefig(plt.gcf(), bbox_inches='tight'); plt.close()

# --------------------------------------------------------------
# 7. Finalize PDF Report
# --------------------------------------------------------------
# Cover Page
fig, ax = plt.subplots(figsize=(8.5, 11))
ax.text(0.5, 0.7, "IITB EdTech Internship 2025", ha='center', va='center', fontsize=20, fontweight='bold')
ax.text(0.5, 0.6, "Problem ID-7: Visual Attention Mapping Across Task Types", ha='center', va='center', fontsize=14)
ax.text(0.5, 0.5, "TEAM_DRS | Rohan Sandip Nikam", ha='center', va='center', fontsize=12)
ax.text(0.5, 0.4, "Mentor: Mrs. Murnal M Wakarekar", ha='center', va='center', fontsize=12)
ax.text(0.5, 0.3, f"Generated: {pd.Timestamp.now().strftime('%Y-%m-%d')}", ha='center', va='center', fontsize=10)
ax.axis('off')
pdf.savefig(fig); plt.close()

pdf.close()
print(f"\nFINAL REPORT GENERATED: {VIS_ROOT / 'final_report.pdf'}")

# --------------------------------------------------------------
# 8. Summary
# --------------------------------------------------------------
print("\nSTEP 6 & 7 COMPLETE!")
print(f"Outputs → {VIS_ROOT}")
print(f"   • final_report.pdf (full visual report)")
print(f"   • All visualizations included:")
print(f"     - Per-difficulty FDM overlays")
print(f"     - Difference heatmaps")
print(f"     - Temporal ribbons")
print(f"     - AOI violin + chord")
print(f"     - Small multiples")
print(f"     - Model performance")
print(f"     - Gaussian σ experiment")

log.info("STEP 6 & 7 FINISHED")