In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision.models.resnet import ResNet50_Weights

In [2]:
class ImgEncoder_CNN(nn.Module):
    def __init__(self, projection_dim=512, hidden_dim=256, dropout_rate=0.1):
        super(ImgEncoder_CNN, self).__init__()
        
        # Load the pre-trained ResNet50 model
        base_model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        
        # Remove the final fully connected layer
        self.base_model = nn.Sequential(*list(base_model.children())[:-1])
        
        # Freeze the parameters of the base model
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        # Define the projection head
        self.projection_head = nn.Sequential(
            nn.Linear(2048, hidden_dim),  # Input dim is 2048 (output of ResNet50)
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            
            nn.Linear(hidden_dim, projection_dim)  # Output dim is 512
        )

    def forward(self, x):
        # Extract features from the base model
        h = self.base_model(x)
        
        # Flatten the features
        h = h.view(h.size(0), -1)  # Shape: [batch_size, 2048]
        
        # Pass through the projection head
        z = self.projection_head(h)
        
        # Normalize the output embeddings
        return F.normalize(z, dim=1)

In [3]:
img_encoder = ImgEncoder_CNN(projection_dim=512)
input_image = torch.randn(16, 3, 224, 224) 
output_embedding = img_encoder(input_image)
print(output_embedding.shape)  

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 181MB/s]


torch.Size([16, 512])


In [4]:
from torchinfo import summary

summary(img_encoder, input_data = input_image)

Layer (type:depth-idx)                        Output Shape              Param #
ImgEncoder_CNN                                [16, 512]                 --
├─Sequential: 1-1                             [16, 2048, 1, 1]          --
│    └─Conv2d: 2-1                            [16, 64, 112, 112]        (9,408)
│    └─BatchNorm2d: 2-2                       [16, 64, 112, 112]        (128)
│    └─ReLU: 2-3                              [16, 64, 112, 112]        --
│    └─MaxPool2d: 2-4                         [16, 64, 56, 56]          --
│    └─Sequential: 2-5                        [16, 256, 56, 56]         --
│    │    └─Bottleneck: 3-1                   [16, 256, 56, 56]         (75,008)
│    │    └─Bottleneck: 3-2                   [16, 256, 56, 56]         (70,400)
│    │    └─Bottleneck: 3-3                   [16, 256, 56, 56]         (70,400)
│    └─Sequential: 2-6                        [16, 512, 28, 28]         --
│    │    └─Bottleneck: 3-4                   [16, 512, 28, 28]      