In [133]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms, models
from torchvision import transforms
from torch.nn.functional import normalize
from torchvision.models import resnet50, ResNet50_Weights


In [134]:
# class ProjectionHead(nn.Module):
#     def __init__(self, input_dim, hidden_dim, output_dim):
#         super(ProjectionHead, self).__init__()
#         self.fc1 = nn.Linear(input_dim, hidden_dim)
#         self.fc2 = nn.Linear(hidden_dim, output_dim)
#         self.relu = nn.ReLU()
#         self.bn = nn.BatchNorm1d(hidden_dim)

#     def forward(self, x):
#         x = self.fc1(x)
#         x = self.bn(x)
#         x = self.relu(x)
#         x = self.fc2(x)
#         return x

In [135]:

class ProjectionHead(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=512, output_dim=128, dropout_rate=0.1):
   
        super(ProjectionHead, self).__init__()
        
        # Define the layers
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(dropout_rate)
        
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
     
        # First layer
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        # Second layer
        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        # Third layer 
        x = self.fc3(x)
        
        # Normalize the output embeddings to lie on a unit hypersphere
        x = nn.functional.normalize(x, p=2, dim=1)
        
        return x

In [136]:
class ImgEncoder_CNN(nn.Module):
    def __init__(self, projection_dim=128):
        super(ImgEncoder_CNN, self).__init__()
        
        base_model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        self.base_model = nn.Sequential(*list(base_model.children())[:-1])
        
        # Freezeing the parameters of the base model
        for param in self.base_model.parameters():
            param.requires_grad = False  # Corrected attribute name
        
        # Define the projection head
        self.projection_head = ProjectionHead(2048, 256, projection_dim)  # Corrected input_dim

    def forward(self, x):
    
        # Extract features from the base model
        h = self.base_model(x).squeeze()  # Shape: [batch_size, 2048]
        
        # Pass through the projection head
        z = self.projection_head(h)  # Shape: [batch_size, projection_dim]
        
        # Normalize the output embeddings
        return normalize(z, dim=1)

In [137]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



model = ImgEncoder_CNN(projection_dim=512).to(device)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 156MB/s]


In [138]:
from torchinfo import summary

summary(model)


Layer (type:depth-idx)                        Param #
ImgEncoder_CNN                                --
├─Sequential: 1-1                             --
│    └─Conv2d: 2-1                            (9,408)
│    └─BatchNorm2d: 2-2                       (128)
│    └─ReLU: 2-3                              --
│    └─MaxPool2d: 2-4                         --
│    └─Sequential: 2-5                        --
│    │    └─Bottleneck: 3-1                   (75,008)
│    │    └─Bottleneck: 3-2                   (70,400)
│    │    └─Bottleneck: 3-3                   (70,400)
│    └─Sequential: 2-6                        --
│    │    └─Bottleneck: 3-4                   (379,392)
│    │    └─Bottleneck: 3-5                   (280,064)
│    │    └─Bottleneck: 3-6                   (280,064)
│    │    └─Bottleneck: 3-7                   (280,064)
│    └─Sequential: 2-7                        --
│    │    └─Bottleneck: 3-8                   (1,512,448)
│    │    └─Bottleneck: 3-9                   (1,1