# U-Net Reproduction on COCO Dataset

Pipeline Steps:
1. **Data Generation**: Download COCO images and generate teacher meaning maps (using LLaVA).
2. **Training**: Train the U-Net on these (Image, Meaning Map) pairs.
3. **Evaluation**: Evaluate the U-Net on the 14 held-out Henderson & Hayes test scenes.

Prerequisites
Ensure you have run `utils/generate_public_training_data.py` and the LLaVA inference pipeline to populate `public_training_data/`.

In [None]:
import os
import time
import random
import numpy as np
import pandas as pd
from glob import glob
from PIL import Image
from tqdm import tqdm
from scipy.stats import pearsonr

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Model Architecture

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        # Encoder
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(64, 128))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(128, 256))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(256, 512))
        self.down4 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(512, 1024))
        self.down5 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(1024, 1024))

        # Decoder
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up_conv1 = DoubleConv(1024, 512)
        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_conv2 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv3 = DoubleConv(256, 128)
        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv4 = DoubleConv(128, 64)
        self.up5 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        self.up_conv5 = DoubleConv(128, 64)

        # Feature reduction layers (Custom Architecture)
        self.reduce1 = nn.Sequential(nn.Conv2d(1024, 768, kernel_size=1), nn.ReLU(inplace=True), nn.Conv2d(768, 512, kernel_size=1))
        self.reduce2 = nn.Sequential(nn.Conv2d(512, 384, kernel_size=1), nn.ReLU(inplace=True), nn.Conv2d(384, 256, kernel_size=1))
        self.reduce3 = nn.Sequential(nn.Conv2d(256, 192, kernel_size=1), nn.ReLU(inplace=True), nn.Conv2d(192, 128, kernel_size=1))
        self.reduce4 = nn.Sequential(nn.Conv2d(128, 96, kernel_size=1), nn.ReLU(inplace=True), nn.Conv2d(96, 64, kernel_size=1))
        self.reduce5 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1))

        self.final_conv = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )

        self.outc = nn.Conv2d(16, n_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x6 = self.down5(x5)

        x = self.up1(x6)
        x5_reduced = self.reduce1(x5)
        x = self.up_conv1(torch.cat([x, x5_reduced], dim=1))

        x = self.up2(x)
        x4_reduced = self.reduce2(x4)
        x = self.up_conv2(torch.cat([x, x4_reduced], dim=1))

        x = self.up3(x)
        x3_reduced = self.reduce3(x3)
        x = self.up_conv3(torch.cat([x, x3_reduced], dim=1))

        x = self.up4(x)
        x2_reduced = self.reduce4(x2)
        x = self.up_conv4(torch.cat([x, x2_reduced], dim=1))

        x = self.up5(x)
        x1_reduced = self.reduce5(x1)
        x = self.up_conv5(torch.cat([x, x1_reduced], dim=1))

        x = self.final_conv(x)
        logits = self.outc(x)
        return logits

## 2. Dataset Loading
Loads (Image, Meaning Map) pairs from the `public_training_data` directory structure.

In [None]:
class COCODistillationDataset(Dataset):
    def __init__(self, root_dir, transform=None, target_transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.target_transform = target_transform
        
        self.image_dir = os.path.join(root_dir, 'images')
        self.map_dir = os.path.join(root_dir, 'meaning_maps')
        
        # List all images
        self.image_files = sorted([f for f in os.listdir(self.image_dir) if f.endswith(('.jpg', '.png'))])
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        # Assume map has same basename
        basename = os.path.splitext(img_name)[0]
        # Check for map extension (npy, jpg, png)
        map_name = None
        for ext in ['.npy', '.jpg', '.png']:
            if os.path.exists(os.path.join(self.map_dir, basename + ext)):
                map_name = basename + ext
                break
        
        if map_name is None:
             # Fallback: return zero map if missing (or handle error)
             # For robustness in training initialization
             print(f"Warning: Map not found for {img_name}")
             map_name = basename + '.png' 

        img_path = os.path.join(self.image_dir, img_name)
        map_path = os.path.join(self.map_dir, map_name)
        
        image = Image.open(img_path).convert('RGB')
        
        # Load target map
        if map_path.endswith('.npy'):
            target_map = np.load(map_path)
            target_map = Image.fromarray((target_map * 255).astype(np.uint8))
        else:
            # If image file exists
            if os.path.exists(map_path):
               target_map = Image.open(map_path).convert('L')
            else:
               target_map = Image.new('L', image.size)

        if self.transform:
            image = self.transform(image)
        
        if self.target_transform:
            target_map = self.target_transform(target_map)
        
        return image, target_map

## 3. Training Loop
Trains the model for 20 epochs (default) using MSE Loss.

In [None]:
def train_model(data_dir='public_training_data', batch_size=8, num_epochs=20, lr=1e-4):
    # Transforms
    # Resize to 256x256 for training efficiency
    img_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
    
    map_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

    dataset = COCODistillationDataset(data_dir, transform=img_transform, target_transform=map_transform)
    
    if len(dataset) == 0:
        print("No data found! Please generate data first.")
        return

    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_ds, val_ds = random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)
    
    model = UNet(n_channels=3, n_classes=1).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for images, targets in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            images, targets = images.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(torch.sigmoid(outputs), targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, targets in val_loader:
                images, targets = images.to(device), targets.to(device)
                outputs = model(images)
                loss = criterion(torch.sigmoid(outputs), targets)
                val_loss += loss.item()
                
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        print(f'Epoch {epoch+1}: Train Loss {avg_train_loss:.4f}, Val Loss {avg_val_loss:.4f}')
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'unet_coco_best.pth')
            print("Saved best model.")
            
    return model

## 4. Evaluation on Henderson & Hayes Scenes
Evaluates the trained U-Net on the 14 held-out scenes to reproduce the correlation result.

In [None]:
def evaluate_on_test_scenes(model_path='unet_coco_best.pth', test_dir='scenes', attention_dir='attention_maps'):
    model = UNet(n_channels=3, n_classes=1).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    transform = transforms.Compose([
        transforms.Resize((256, 256)), # Inference size
        transforms.ToTensor(),
    ])

    test_images = sorted(glob(os.path.join(test_dir, '*.jpg')))
    
    correlations = []
    
    print(f"\nEvaluating on {len(test_images)} test scenes...")
    
    for img_path in test_images:
        scene_id = os.path.splitext(os.path.basename(img_path))[0]
        
        # Load Image
        img = Image.open(img_path).convert('RGB')
        original_size = img.size
        img_tensor = transform(img).unsqueeze(0).to(device)
        
        # Inference
        with torch.no_grad():
            output = model(img_tensor)
            pred_map = torch.sigmoid(output).squeeze().cpu().numpy()
        
        # Resize prediction back to original size for correlation
        pred_map_pil = Image.fromarray((pred_map * 255).astype(np.uint8))
        pred_map_resized = pred_map_pil.resize(original_size, Image.BICUBIC)
        pred_flat = np.array(pred_map_resized).flatten().astype(float)
        
        # Load Ground Truth Attention Map (if available)
        # Typically in meaning-map/result/attention_maps
        # Try several patterns
        possible_gt_paths = [
            os.path.join(attention_dir, f"{scene_id}.png"),
            os.path.join(attention_dir, f"{scene_id}.jpg"),
            os.path.join(attention_dir, f"scene_{scene_id}.png") # Common internal naming
        ]
        
        gt_path = None
        for p in possible_gt_paths:
            if os.path.exists(p):
                gt_path = p
                break
        
        if gt_path:
            gt_map = Image.open(gt_path).convert('L')
            gt_map = gt_map.resize(original_size, Image.BICUBIC)
            gt_flat = np.array(gt_map).flatten().astype(float)
            
            corr, _ = pearsonr(pred_flat, gt_flat)
            correlations.append(corr)
            print(f"{scene_id}: r = {corr:.3f}")
        else:
            print(f"{scene_id}: GT Attention Map not found in {attention_dir}")
            
    if correlations:
        print(f"\nMean Test Correlation: {np.mean(correlations):.3f} (SD={np.std(correlations):.3f})")