In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import json
import os
import os
import json
import torch
from torchvision import transforms
from PIL import Image  # Import PIL for image loading
import numpy as np
from torch.utils.data import Dataset, DataLoader

import ssl

#  Fix SSL issue for downloading ImageNet labels
ssl._create_default_https_context = ssl._create_unverified_context

# Detect available device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# Directories for your pre-extracted features
train_dir = "/Users/nouira/Desktop/deeplearning/project/train_features.pt"
val_dir = "/Users/nouira/Desktop/deeplearning/project/val_features.pt"
test_dir = "/Users/nouira/Desktop/deeplearning/project/test_features.pt"

train_annotations_file = "/Users/nouira/Desktop/deeplearning/project/fixations_train2014.json"
val_annotations_file = "/Users/nouira/Desktop/deeplearning/project/fixations_val2014.json"

# Load pre-extracted features
train_features = torch.load(train_dir)  # Already a tensor, no need to convert to numpy
val_features = torch.load(val_dir)
test_features = torch.load(test_dir)


# Dataset Class for loading features and corresponding heatmaps
class SaliconDataset(Dataset):
    def __init__(self, features, annotations_file, transform=None):
        self.features = features  # pre-extracted features (tensor)
        self.transform = transform
        
        # Load annotations
        with open(annotations_file, 'r') as f:
            self.annotations = json.load(f)
        
        # Image ids and mapping to indices
        self.image_ids = [img['id'] for img in self.annotations['images']]
        self.img_id_to_index = {img_id: idx for idx, img_id in enumerate(self.image_ids)}

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        ann = next(ann for ann in self.annotations['annotations'] if ann['image_id'] == img_id)
        fixations = ann['fixations']
        
        # Fetch the feature corresponding to the image ID
        features = self.features[self.img_id_to_index[img_id]]
        
        # Generate heatmap (224x224)
        heatmap = torch.zeros((224, 224))  # Adjust size to (224, 224) as per your target
        
        for row, col in fixations:
            # Ensure fixation coordinates stay within bounds of 224x224
            row = min(max(row - 1, 0), 223)
            col = min(max(col - 1, 0), 223)
            heatmap[row, col] += 1  # Increment the heatmap at the fixation location

        if self.transform:
            features = self.transform(features)

        return features, heatmap  # Return features and corresponding heatmap


# Model: LinearProbe (outputting heatmap of size 224x224)
class LinearProbe(nn.Module):
    def __init__(self, input_size):
        super(LinearProbe, self).__init__()
        self.fc = nn.Linear(input_size, 224 * 224)  # Output size should match heatmap size (224x224)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)  # Flatten the features
        x = self.fc(x)  # Fully connected layer
        x = x.view(-1, 224, 224)  # Reshape output to match heatmap size (224, 224)
        return x


In [None]:
input_size = 512  # The size of the feature vectors
model = LinearProbe(input_size).to(device)

# Loss function and optimizer
criterion = nn.MSELoss()  # Mean Squared Error loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Create dataset and dataloaders for training and validation
train_dataset = SaliconDataset(train_features, train_annotations_file)
val_dataset = SaliconDataset(val_features, val_annotations_file)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True,pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, pin_memory=True)

In [None]:
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
import numpy as np



# Training loop with progress percentage
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    num_batches = len(train_loader)
    
    for batch_idx, (features, heatmap) in enumerate(train_loader):
        features = features.to(device)
        heatmap = heatmap.to(device)

        optimizer.zero_grad()

        # Forward pass: passing features through the model
        outputs = model(features)  # Pass the features through the model

        # Compute the loss (MSE between output heatmap and target heatmap)
        loss = criterion(outputs, heatmap)
        total_loss += loss.item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Calculate and display the progress percentage
        progress = (batch_idx + 1) / num_batches * 100
        print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{num_batches} - Progress: {progress:.2f}% - Loss: {loss.item():.4f}", end='\r')

    print(f"\nEpoch {epoch+1}/{num_epochs}, Average Loss: {total_loss/len(train_loader):.4f}")