In [None]:
# --------------------------------------------------------------
# IITB EdTech Internship 2025 – Problem ID 7
# STEP 5: Deep Gaze Modeling & Validation
# --------------------------------------------------------------

# --------------------------------------------------------------
# 0. Mount & Imports
# --------------------------------------------------------------
from google.colab import drive
drive.mount('/content/drive')

import pandas as pd, numpy as np, os, cv2, logging, json, random
from pathlib import Path
from scipy.ndimage import gaussian_filter
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import warnings
warnings.filterwarnings("ignore")

# --------------------------------------------------------------
# 1. CONFIG
# --------------------------------------------------------------
CLEAN_ROOT = Path("/content/results/cleaned")
METRICS_ROOT = CLEAN_ROOT / "metrics"
MODEL_ROOT = CLEAN_ROOT / "models"
MODEL_ROOT.mkdir(exist_ok=True)

FDM_CANVAS_DIR = METRICS_ROOT / "fdm_canvas"
IMG_ROOT = Path("/content/data/stimuli")  # UPDATE: your actual image folder
FIX_CANVAS = CLEAN_ROOT / "fixations_canvas.csv"

LOG_OUT = MODEL_ROOT / "05_step5.log"
logging.basicConfig(filename=LOG_OUT, level=logging.INFO,
                    format="%(asctime)s | %(levelname)s | %(message)s")
log = logging.getLogger()
log.info("=== STEP 5 STARTED ===")
print("STEP 5 – Deep Gaze Modeling & Validation")

# --------------------------------------------------------------
# 2. Load Stimulus List & Split
# --------------------------------------------------------------
fix_canvas = pd.read_csv(FIX_CANVAS)
stim_list = fix_canvas[['pid','qid']].drop_duplicates()

# Train/val split by stimulus (80/20)
train_stim, val_stim = train_test_split(stim_list, test_size=0.2, random_state=42, stratify=stim_list.merge(fix_canvas[['pid','qid','difficulty']].drop_duplicates(), on=['pid','qid'])['difficulty'])

train_stim.to_csv(MODEL_ROOT / "train_stimuli.csv", index=False)
val_stim.to_csv(MODEL_ROOT / "val_stimuli.csv", index=False)

log.info(f"Split: {len(train_stim)} train, {len(val_stim)} val stimuli")

# --------------------------------------------------------------
# 3. Dataset Class
# --------------------------------------------------------------
class EyeGazeDataset(Dataset):
    def __init__(self, stim_df, img_root, fdm_dir, canvas_size=1024, sigma=40):
        self.stim_df = stim_df
        self.img_root = Path(img_root)
        self.fdm_dir = Path(fdm_dir)
        self.canvas_size = canvas_size
        self.sigma = sigma
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self): return len(self.stim_df)

    def __getitem__(self, idx):
        row = self.stim_df.iloc[idx]
        pid, qid = row['pid'], row['qid']
        img_path = self.img_root / f"P{pid:02d}_Q{qid}.jpg"  # UPDATE extension if needed
        fdm_path = self.fdm_dir / f"P{pid:02d}_Q{qid}.npy"

        # Load image
        img = cv2.imread(str(img_path))
        if img is None: raise FileNotFoundError(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_tensor = self.transform(img)

        # Load ground-truth FDM
        fdm = np.load(fdm_path)
        fdm = cv2.resize(fdm, (self.canvas_size, self.canvas_size))
        if fdm.sum() > 0: fdm /= fdm.sum()
        fdm_tensor = torch.tensor(fdm, dtype=torch.float32).unsqueeze(0)

        return img_tensor, fdm_tensor, pid, qid

# --------------------------------------------------------------
# 4. Baselines
# --------------------------------------------------------------
print("4. Building Baselines...")

def center_bias(canvas_size=1024, sigma=200):
    x = np.linspace(0, canvas_size-1, canvas_size)
    y = np.linspace(0, canvas_size-1, canvas_size)
    X, Y = np.meshgrid(x, y)
    center_x, center_y = canvas_size // 2, canvas_size // 2
    cb = np.exp(-((X - center_x)**2 + (Y - center_y)**2) / (2 * sigma**2))
    cb /= cb.sum()
    return cb

CB_MAP = center_bias()
np.save(MODEL_ROOT / "center_bias.npy", CB_MAP)

def blur_baseline(fdm_native):
    return gaussian_filter(fdm_native, sigma=80)

# --------------------------------------------------------------
# 5. DeepGaze II Model (Pretrained + Fine-tune)
# --------------------------------------------------------------
print("5. Loading & Fine-tuning DeepGaze II...")

!pip install -q deepgaze

from deepgaze import DeepGazeII

# Load pretrained model
model = DeepGazeII(pretrained=True)
model.eval()

# Freeze all but last layer
for param in model.parameters():
    param.requires_grad = False
for param in model.centerbias_layer.parameters():
    param.requires_grad = True

# Fine-tune setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.KLDivLoss(reduction='batchmean')
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

train_dataset = EyeGazeDataset(train_stim, IMG_ROOT, FDM_CANVAS_DIR)
val_dataset = EyeGazeDataset(val_stim, IMG_ROOT, FDM_CANVAS_DIR)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

# --------------------------------------------------------------
# 6. Training Loop
# --------------------------------------------------------------
print("6. Fine-tuning DeepGaze II...")

EPOCHS = 10
best_val_loss = float('inf')
history = []

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    for imgs, fdms, _, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
        imgs, fdms = imgs.to(device), fdms.to(device)
        log_pred = model(imgs).log()
        loss = criterion(log_pred, fdms)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, fdms, _, _ in val_loader:
            imgs, fdms = imgs.to(device), fdms.to(device)
            log_pred = model(imgs).log()
            loss = criterion(log_pred, fdms)
            val_loss += loss.item()

    avg_train = train_loss / len(train_loader)
    avg_val = val_loss / len(val_loader)
    history.append({'epoch': epoch+1, 'train_loss': avg_train, 'val_loss': avg_val})

    if avg_val < best_val_loss:
        best_val_loss = avg_val
        torch.save(model.state_dict(), MODEL_ROOT / "deepgaze_finetuned.pth")
        log.info(f"New best model saved at epoch {epoch+1}")

    print(f"Epoch {epoch+1} | Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f}")

history_df = pd.DataFrame(history)
history_df.to_csv(MODEL_ROOT / "training_history.csv", index=False)

# --------------------------------------------------------------
# 7. Evaluation Metrics (Same as STEP 4)
# --------------------------------------------------------------
from scipy.stats import wasserstein_distance
from scipy.spatial.distance import jensenshannon

def similarity(m1, m2): return np.sum(np.sqrt(m1 * m2))
def kl_div(p, q): return np.sum(p * np.log(p / (q + 1e-12) + 1e-12))
def cc(m1, m2): return np.corrcoef(m1.ravel(), m2.ravel())[0,1]
def nss(sal, pts):
    sal_norm = (sal - sal.mean()) / (sal.std() + 1e-8)
    return sal_norm[pts[:,1].astype(int), pts[:,0].astype(int)].mean()

def auc_judd(sal, pts):
    labels = np.zeros(sal.size, dtype=int)
    idx = pts[:,1].astype(int) * 1024 + pts[:,0].astype(int)
    labels[idx] = 1
    sal_flat = sal.ravel() + np.random.rand(sal.size) * 1e-8
    return roc_auc_score(labels, sal_flat)

# --------------------------------------------------------------
# 8. Run Inference & Evaluate
# --------------------------------------------------------------
print("8. Evaluating Models...")

model.load_state_dict(torch.load(MODEL_ROOT / "deepgaze_finetuned.pth"))
model.eval()

results = []
eval_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

with torch.no_grad():
    for img_tensor, gt_fdm, pid, qid in tqdm(eval_loader, desc="Evaluating"):
        pid, qid = pid.item(), qid.item()
        img_tensor = img_tensor.to(device)
        pred = model(img_tensor).cpu().numpy().squeeze()
        pred = cv2.resize(pred, (1024, 1024))
        if pred.sum() > 0: pred /= pred.sum()

        gt = gt_fdm.numpy().squeeze()
        cb = CB_MAP.copy()

        # Get fixations
        fix = fix_canvas[(fix_canvas['pid']==pid) & (fix_canvas['qid']==qid)]
        pts = fix[['x_canvas','y_canvas']].values.astype(int)

        row = {
            'pid': pid, 'qid': qid,
            'difficulty': fix['difficulty'].iloc[0],
            'SIM_emp': similarity(gt, pred),
            'KL_emp': kl_div(gt, pred),
            'CC_emp': cc(gt, pred),
            'NSS_emp': nss(pred, pts),
            'AUC_emp': auc_judd(pred, pts),
            'SIM_cb': similarity(gt, cb),
            'KL_cb': kl_div(gt, cb),
            'CC_cb': cc(gt, cb),
            'NSS_cb': nss(cb, pts),
            'AUC_cb': auc_judd(cb, pts)
        }
        results.append(row)

results_df = pd.DataFrame(results)
results_df.to_csv(MODEL_ROOT / "model_evaluation.csv", index=False)

# --------------------------------------------------------------
# 9. Stratified Summary
# --------------------------------------------------------------
summary = results_df.groupby('difficulty').mean()
summary.to_csv(MODEL_ROOT / "summary_by_difficulty.csv")
log.info(f"DeepGaze outperforms center-bias: NSS Δ = {summary['NSS_emp'].mean() - summary['NSS_cb'].mean():.3f}")

# --------------------------------------------------------------
# 10. Visualizations
# --------------------------------------------------------------
VIS_DIR = MODEL_ROOT / "visualizations"
VIS_DIR.mkdir(exist_ok=True)

# Sample predictions
sample = results_df.sample(3)
for _, row in sample.iterrows():
    pid, qid = row['pid'], row['qid']
    img = cv2.cvtColor(cv2.imread(str(IMG_ROOT / f"P{pid:02d}_Q{qid}.jpg")), cv2.COLOR_BGR2RGB)
    gt = np.load(FDM_CANVAS_DIR / f"P{pid:02d}_Q{qid}.npy")
    gt = cv2.resize(gt, (1024, 1024)); gt /= gt.sum()
    pred = model(transforms.ToTensor()(img).unsqueeze(0).to(device)).cpu().numpy().squeeze()
    pred = cv2.resize(pred, (1024, 1024)); pred /= pred.sum()

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(img); axes[0].set_title("Stimulus")
    sns.heatmap(gt, ax=axes[1], cmap='viridis', cbar=True); axes[1].set_title("Empirical FDM")
    sns.heatmap(pred, ax=axes[2], cmap='viridis', cbar=True); axes[2].set_title("DeepGaze Pred")
    for ax in axes: ax.axis('off')
    plt.tight_layout()
    plt.savefig(VIS_DIR / f"pred_P{pid:02d}_Q{qid}.png")
    plt.close()

# Performance bar plot
plot_df = summary[['NSS_emp','NSS_cb']].reset_index().melt(id_vars='difficulty', var_name='Model', value_name='NSS')
plt.figure(figsize=(8,5))
sns.barplot(data=plot_df, x='difficulty', y='NSS', hue='Model', order=['easy','medium','hard'])
plt.title("NSS: DeepGaze vs Center-Bias by Difficulty")
plt.savefig(VIS_DIR / "nss_comparison.png")
plt.close()

log.info(f"Visualizations → {VIS_DIR}")

# --------------------------------------------------------------
# FINAL SUMMARY
# --------------------------------------------------------------
print("\nSTEP 5 COMPLETE!")
print(f"Outputs → {MODEL_ROOT}")
print(f"   • deepgaze_finetuned.pth")
print(f"   • model_evaluation.csv")
print(f"   • summary_by_difficulty.csv")
print(f"   • visualizations/ (preds + NSS plot)")

log.info("STEP 5 FINISHED")