# Imports

In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import json

  Referenced from: <E03EDA44-89AE-3115-9796-62BA9E0E2EDE> /Users/theebankumaresan/anaconda3/lib/python3.11/site-packages/torchvision/image.so
  warn(


# Classes and Setup

In [2]:
class customCOCODataset(Dataset):
    def __init__(self, data_dir, annotations_file, transform=None):
        self.data_dir = data_dir
        self.annotations_file = annotations_file
        self.transform = transform
        self.image_list = self.load_image_list()
        self.annotations = self.load_annotations()

    def load_image_list(self):
        image_list = []
        for root, _, files in os.walk(self.data_dir):
            for file in files:
                if file.endswith('.jpg'):
                    image_list.append(os.path.join(root, file))
        return image_list

    def load_annotations(self):
        with open(self.annotations_file, 'r') as f:
            coco_data = json.load(f)

        
        annotations = {}
        for annotation in coco_data['annotations']:
            img_filename = os.path.splitext(os.path.basename(coco_data['images'][annotation['image_id']]['file_name']))[0]
            if img_filename not in annotations:
                annotations[img_filename] = []
            annotations[img_filename].append(annotation)

        return annotations

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_path = self.image_list[idx]
        image = Image.open(img_path).convert('RGB')

        img_filename = os.path.splitext(os.path.basename(img_path))[0]  

        annotations = self.annotations.get(img_filename, []) 

        if self.transform:
            image = self.transform(image)

        return image, annotations



#### Setting up the Pretrained resnet model and freezing 

In [3]:
resNet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
for param in resNet.parameters():
    param.requires_grad = False  

In [4]:
class RetrievalModel(nn.Module):
    def __init__(self):
        super(RetrievalModel, self).__init__()
        self.features = nn.Sequential(*list(resNet.children())[:-1])
        self.fc = nn.Linear(2048, 256) 

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
retrievalModel = RetrievalModel()

In [5]:
dataTransform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [6]:
dataDir = 'my_data/train'
annotationsFile = 'my_data/my_train_coco.json'
custom_dataset = customCOCODataset(dataDir, annotationsFile, transform=dataTransform)


In [7]:
dataLoader = DataLoader(custom_dataset, batch_size=32, shuffle=True)
criterion = nn.CosineEmbeddingLoss()
optimizer = torch.optim.Adam(retrievalModel.parameters(), lr=0.001)

# Training the model

In [8]:
num_epochs = 5
for epoch in range(num_epochs):
    for images, _ in dataLoader:
        anchor_images, positive_images = torch.chunk(images, 2, dim=0) 
        anchor_feat = retrievalModel(anchor_images)
        positive_feat = retrievalModel(positive_images)

        target = torch.ones(anchor_images.size(0))

        loss = criterion(anchor_feat, positive_feat, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/{num_epochs}] Loss: {loss.item():.4f}')


# torch.save(retrievalModel.state_dict(), 'retrieval_model.pth')

Epoch [1/5] Loss: 0.6588
Epoch [2/5] Loss: 0.3470
Epoch [3/5] Loss: 0.0879
Epoch [4/5] Loss: 0.0775
Epoch [5/5] Loss: 0.0516
