### Our Saliency model based on Resnet-50 and linear probing
In this notebook we're gonna go through our model and explain it step by step. 

We actually ran the train and validation in two different python scripts (see folder Train_and_Val). We wrote This notebook to go through the code explain the steps and model.

In [None]:
# import necessary librairies: 
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
import json
import numpy as np
import cv2
import os
import json
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from scipy.stats import pearsonr
from sklearn.metrics import roc_auc_score

First we loaded the Resnet 50 as a feature extractor and removed its final classification layer 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load ResNet50 as feature extractor
backbone = models.resnet50(pretrained=True)
feature_dim = backbone.fc.in_features
backbone = nn.Sequential(*list(backbone.children())[:-1])  # Remove final classification layer
backbone.to(device).eval()  # Move model to GPU/CPU and set to evaluation mode

Then we defined our Linear probing class that maps the features extracted from ResNet to saliency maps

In [None]:
class LinearProbe(nn.Module):
    def __init__(self, feature_dim, output_size=(256, 256)):
        """
        Initializes the LinearProbe module.

        Args:
            feature_dim (int): The dimensionality of the input feature vector.
            output_size (tuple): The desired output size of the saliency map (height, width).
        """
        super(LinearProbe, self).__init__()
        # Fully connected layer to map features to the desired output size
        self.fc = nn.Linear(feature_dim, output_size[0] * output_size[1])

    def forward(self, x):
        """
        Forward pass of the LinearProbe.

        Args:
            x (torch.Tensor): Input feature tensor of shape (batch_size, feature_dim, ...).

        Returns:
            torch.Tensor: A saliency map of shape (batch_size, 1, 256, 256).
        """
        # Flatten the input feature map to (batch_size, feature_dim)
        x = x.view(x.size(0), -1)
        
        # Apply the fully connected layer to map features to a vector of size (256 * 256)
        x = self.fc(x)
        
        # Reshape the output to a saliency map of size (batch_size, 1, 256, 256)
        return x.view(x.size(0), 1, 256, 256)

After that, we combined the feature extractor and linear probing class into one class that we called SaliencyPredictor

In [None]:
class SaliencyPredictor(nn.Module):
    def __init__(self, backbone, feature_dim):
        """
        Initializes the SaliencyPredictor model.

        Args:
            backbone (nn.Module): A pre-trained backbone network (ResNet) 
                                to extract features from input images.
            feature_dim (int): The dimensionality of the feature vector output by the backbone.
        """
        super(SaliencyPredictor, self).__init__()
        self.backbone = backbone  # Backbone network for feature extraction
        self.probe = LinearProbe(feature_dim)  # Linear probe to map features to saliency predictions

    def forward(self, x):
        """
        Forward pass of the model.

        Args:
            x (torch.Tensor): Input tensor representing an image or batch of images.

        Returns:
            torch.Tensor: A saliency map (or batch of saliency maps) with values in the range [0, 1].
        """
        features = self.backbone(x)  # Extract features using the backbone network
        return torch.sigmoid(self.probe(features))  # Apply the linear probe and sigmoid to get saliency map

Now we will go throught SALICON dataset and how we processed it.
SALICON provides :
- 10k train images,
- 5k validation images 
- 5K test images 
- 2 Annotation files : one for the training dataset and one for the validation dataset. 

The SaliencyDataset class takes as an input a dataset (train, val or test) and returns images and their corresponding saliency maps. 

The saliency maps are generated inside the SaliencyDataset class using the function points_to_saliency that we defined. This function takes a list of fixation points (e.g., (row, col) coordinates) from the annotation files and generates a smooth saliency map of the specified size.The saliency map highlights regions where fixations are concentrated, with Gaussian smoothing applied to create a continuous distribution.

In [None]:
class SaliencyDataset(Dataset):
    def __init__(self, image_dir, annotation_file, transform=None):

        self.image_dir = image_dir  # Directory where images are stored
        self.transform = transform  # Optional transformations (e.g., resizing, normalization)
        
        # Load annotations from the JSON file
        with open(annotation_file, 'r') as f:
            self.annotations = json.load(f)
        
        # Organize fixations by image ID
        self.fixations = {}
        for ann in self.annotations['annotations']:
            image_id = ann['image_id']
            if image_id not in self.fixations:
                self.fixations[image_id] = []
            self.fixations[image_id].extend(ann['fixations'])  # Add fixations for the image
        
        # Map image IDs to file names and original dimensions
        self.image_id_to_file = {
            img['id']: (img['file_name'], img['height'], img['width']) 
            for img in self.annotations['images']
        }
        
        # List of all image IDs
        self.image_ids = list(self.fixations.keys())

    def __len__(self):
        """
        Returns the number of images in the dataset.
        """
        return len(self.image_ids)

    def __getitem__(self, idx):

        # Get image ID and corresponding file name and dimensions
        image_id = self.image_ids[idx]
        image_filename, original_height, original_width = self.image_id_to_file[image_id]
        
        # Load the image
        image_path = os.path.join(self.image_dir, image_filename)
        image = Image.open(image_path).convert("RGB")  # Ensure image is in RGB format
        
        # Get fixation points for the image
        points = self.fixations.get(image_id, [])
        
        # Convert fixation points to a saliency map
        saliency_map = points_to_saliency(
            points, 
            image_size=(256, 256), 
            original_image_height=original_height, 
            original_image_width=original_width
        )
        
        # Convert saliency map to a PIL image
        saliency_map = Image.fromarray((saliency_map * 255).astype(np.uint8))
        
        # Apply transformations if specified
        if self.transform:
            image = self.transform(image)
            saliency_map = self.transform(saliency_map)
        
        return image, saliency_map

# Convert Fixation Points to Saliency Map
def points_to_saliency(points, image_size=(256, 256), original_image_height=None, original_image_width=None, sigma=10):

    # Initialize an empty saliency map
    saliency_map = np.zeros(image_size, dtype=np.float32)
    
    # Scale fixation points to the target image size
    for (row, col) in points:
        row = int((row - 1) * (image_size[0] / original_image_height))
        col = int((col - 1) * (image_size[1] / original_image_width))
        
        # Ensure points are within bounds
        row = min(max(row, 0), image_size[0] - 1)
        col = min(max(col, 0), image_size[1] - 1)
        
        # Increment the saliency map at the fixation point
        saliency_map[row, col] += 1.0
    
    # Apply Gaussian smoothing to the saliency map
    saliency_map = cv2.GaussianBlur(saliency_map, (0, 0), sigmaX=sigma, sigmaY=sigma)
    
    # Normalize the saliency map to [0, 1]
    if saliency_map.max() > 0:
        saliency_map /= saliency_map.max()
    
    return saliency_map


### Training : Defining the training loop 
- 10 Epochs
- batch size = 32
- learning rate = 0.0001
- Optimizer : Adaptive Moment Estimation (Adam)
- Criterion : Mean Squared Error (MSE) Loss

We also applied transformations to the images (see transform variable)

In [None]:
# Training Function
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, saliency_maps in dataloader:
        images, saliency_maps = images.to(device), saliency_maps.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, saliency_maps)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(dataloader)

# Main Script
if __name__ == "__main__":
    batch_size = 32
    learning_rate = 0.0001
    num_epochs = 10
    image_dir = "../train"  #Define the path to the train dataset 
    annotation_file = ".../fixations_train2014.json" # Define the path to the annotation file

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

    dataset = SaliencyDataset(image_dir, annotation_file, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = SaliencyPredictor(backbone, feature_dim).to(device)
    criterion = nn.MSELoss()
 
    for epoch in range(num_epochs):
        epoch_loss = train(model, dataloader, criterion, optimizer, device)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

    # ✅ Save Trained Model
    torch.save(model.state_dict(), "/resnet_saliency.pth") # Define a path and save the model 
    print("✅ Model saved!") 


### Validation : Defining the validation loop 
We're gonna use the 5K validation images in the loop, and we're gonna save the output of the model (the predicted saliency maps) and save the visualizations of the predicted heatmap + ground truth heatmap + original image of every validation image in a folder. 

We're gonna also calculate some metrics such as MSE loss, pearson correlation and AUC to save them later in a file so we can analyse the performance of the model (see the notebook Performance_analysis)

In [None]:
# Defining validation Data Paths and load the data

val_image_dir = "/Users/nouira/Desktop/deeplearning/project/val"
val_annotation_file = "/Users/nouira/Desktop/deeplearning/project/fixations_val2014.json"

val_dataset = SaliencyDataset(val_image_dir, val_annotation_file, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4, pin_memory=True)

model.eval() # put model on evaluation mode 

In [None]:
#  Evaluation Metrics
mse_loss = torch.nn.MSELoss()
mse_total = 0
pearson_corrs = [] # a list to save the pearson correlation values of every image
auc_scores = [] # a list to save the AUC values of every image

# Directory to Save Visualizations
output_dir = ".../Predicted_saliency_maps" # Define the path of the output for to save saliency maps in validation
os.makedirs(output_dir, exist_ok=True)

# Dictionary to store metrics for each image
image_metrics = {}

The Validation Loop :

In [None]:
# ✅ Run Model on Validation Set
print("🔹 Running validation...")
with torch.no_grad():
    for images, targets, image_ids in val_loader:
        images, targets = images.to(device), targets.to(device)

        # ✅ Use full model instead of just the probe
        predictions = model(images)

        # Compute MSE Loss
        mse = mse_loss(predictions, targets)
        mse_total += mse.item()

        # ✅ Save Visualizations and Compute Metrics for Each Image
        for i in range(len(image_ids)):
            # Get the image ID
            image_id = image_ids[i].item()

            # Extract the ground truth and predicted saliency maps
            gt_map = targets[i].cpu().squeeze().numpy()
            pred_map = predictions[i].cpu().squeeze().numpy()

            # Compute Pearson Correlation
            pred_flat = pred_map.flatten()
            target_flat = gt_map.flatten()
            if np.any(target_flat > 0):  # Avoid division by zero
                pearson_corr, _ = pearsonr(pred_flat, target_flat)
            else:
                pearson_corr = np.nan  # Avoid division by zero

            # Compute AUC Score
            auc = roc_auc_score((target_flat > 0.5).astype(int), pred_flat)

            # Store metrics in the dictionary
            image_metrics[image_id] = {
                "mse": float(np.mean((pred_flat - target_flat) ** 2)),  # MSE for this image
                "pearson_corr": float(pearson_corr),  # Pearson Correlation for this image
                "auc": float(auc),  # AUC for this image
            }

            # Save Visualization
            img = images[i].cpu().permute(1, 2, 0).numpy()
            fig, axes = plt.subplots(1, 3, figsize=(12, 4))
            axes[0].imshow(img)
            axes[0].set_title("Input Image")
            axes[0].axis("off")

            axes[1].imshow(gt_map, cmap="jet")
            axes[1].set_title("Ground Truth Saliency")
            axes[1].axis("off")

            axes[2].imshow(pred_map, cmap="jet")
            axes[2].set_title("Predicted Saliency")
            axes[2].axis("off")

            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, f"saliency_{image_id}.png"))
            plt.close()


Calculate the average final metrics and save the results in a file

In [None]:
# ✅ Compute Final Metrics (averages)
avg_mse = mse_total / len(val_loader)
avg_pearson = np.mean([metrics["pearson_corr"] for metrics in image_metrics.values()])
avg_auc = np.mean([metrics["auc"] for metrics in image_metrics.values()])

# ✅ Save Results to File
results_path = "/../validation_results_resnet.json" # Define a path 
with open(results_path, "w") as f:
    json.dump({
        "average_metrics": {
            "mse": avg_mse,
            "pearson_corr": avg_pearson,
            "auc": avg_auc,
        },
        "image_metrics": image_metrics,  # Metrics for each image
    }, f, indent=4)

print(f"✅ Validation results saved to {results_path}")