# Import required librairies 

In [1]:
import os
import json
import torch
import requests
import ssl
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from scipy.ndimage import gaussian_filter

In [2]:
# Fix SSL issue for downloading ImageNet labels
ssl._create_default_https_context = ssl._create_unverified_context

# Detect available device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# Define paths to dataset

# test_images_dir = "/Users/selmaakarsu/Desktop/MBB3/Deep Learning/semester_project/salicon_data/test"

test_images_dir = "/Users/nouira/Desktop/deeplearning/project/test"


# Define image transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

Using device: mps


## Define dataset class


In [3]:

import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

class SaliconTestDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = sorted(os.listdir(image_dir))  # Get all test images

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image  # No heatmap, just the image


# Load dataset

In [6]:
test_dataset = SaliconTestDataset(test_images_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True)

# Define preprocessing for feature extraction
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])



# Load pretrained ResNet model


In [7]:
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model = torch.nn.Sequential(*list(model.children())[:-1])  # Remove final FC layer
model.to(device)
model.eval()

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Con

# extracting features 


In [10]:
def extract_features(dataset, batch_size=2, device=device):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    features = []
    with torch.no_grad():
        for images in tqdm(dataloader, desc="Extracting features"):
            images = images.to(device)  # No need for .unsqueeze(0), already batched
            batch_features = model(images).view(images.size(0), -1)
            features.append(batch_features.cpu())
    return torch.cat(features, dim=0)



# Extract features
features = extract_features(test_dataset)
torch.save(features, "test_features.pt")
print(f"Extracted features shape: {features.shape}") 


Extracting features: 100%|██████████████████| 2500/2500 [01:11<00:00, 35.05it/s]


Extracted features shape: torch.Size([5000, 512])


In [11]:
print(len(features))


5000
