### Multi-modal RAG
- CLIP embeddings
    - Shared embedding space of text and images
- Multi-model prompting
- Tool Calling

### Motivation 
Looking through possible options
 - Option 1 - Simple binary classification model 
 - Option 2 - Transfer Learning
 The above are still designed around prediction of 1 or 0
 ``` 
 Solution : Contrastive learning using Saimese Networks as it 
  learns to map noth inpus to shared embedding space
  - If the distance b/w the embeddings is LOW, they are similar.
  - If distance b/w the embeddings is HIGH, they are dissimilar.

  *Constrastive loass to help train such a model.
L = (1- y) . D^2 + y.max(0, (margin - D))^2

y is true label
D is distance b/w two embeddings
margin is a hyperparameter. typically greater than 1.
 ```


In [3]:
# Siamese Networks in face unlock
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST
from torch import optim


In [6]:
mnist_train = MNIST(root='./data', train=True, download=True)
mnist_test = MNIST(root='./data', train=False, download=True)

transform = transforms.Compose([transforms.ToTensor()])

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:34<00:00, 283648.67it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 62966.00it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:09<00:00, 167651.54it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2536012.88it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [11]:
class SiameseDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)  # Corrected to return the length of the dataset

    def __getitem__(self, index):
    
        imgA, labelA = self.data[index]
            
        same_class_flag = random.randint(0, 1) # pair with same class?
        
        if same_class_flag: # yes, pair with same class
            labelB = -1
            while labelB != labelA:
                imgB, labelB = random.choice(self.data)
                
        else: # no, pair with different class
            labelB = labelA
            while labelB == labelA:
                imgB, labelB = random.choice(self.data)

        if self.transform:
            imgA = self.transform(imgA)
            imgB = self.transform(imgB)
            
        pair_label = torch.tensor([1.0 if labelA != labelB else 0.0], dtype=torch.float32)  # Clarified pair_label logic
            
        return imgA, imgB, pair_label

In [13]:
siamese_train = SiameseDataset(mnist_train, transform=transform)
siamese_test = SiameseDataset(mnist_test, transform=transform)



## Defining the network

In [15]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        # CNN sub-network
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),

            nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2)
        )
        # fully connected layer
        self.fc = nn.Sequential(
            nn.Linear(256 * 3 * 3, 1024),
            nn.ReLU(inplace=True),

            nn.Linear(1024, 256),
            nn.ReLU(inplace=True),

            nn.Linear(256, 2)
        )

    def forward_once(self, x):
        output = self.cnn(x)
        output = output.view(output.size()[0], -1)
        output = self.fc(output)
        return output

    def forward(self, inputA, inputB):
        outputA = self.forward_once(inputA)
        outputB = self.forward_once(inputB)
        return outputA, outputB

### Defining the constrastive loss

In [19]:
class ContrastiveLoss(torch.nn.Module):
    
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, outputA, outputB, y):
        euclidean_distance = F.pairwise_distance(outputA, outputB, keepdim = True)

        same_class_loss = (1-y) * (euclidean_distance**2)
        diff_class_loss = (y) * (torch.clamp(self.margin - euclidean_distance, min=0.0)**2)
    
        return torch.mean(same_class_loss + diff_class_loss)

In [28]:
train_dataloader = DataLoader(siamese_train, shuffle=True, num_workers=0, batch_size=64)
model = SiameseNetwork()
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001)

In [29]:


for epoch in range(5):
    total_loss = 0
    
    for imgA, imgB, label in train_dataloader:
        # Move tensors to device if using GPU
        # imgA, imgB, label = imgA.to(device), imgB.to(device), label.to(device)
        
        optimizer.zero_grad()
        outputA, outputB = model(imgA, imgB)
        loss_contrastive = criterion(outputA, outputB, label)
        loss_contrastive.backward()

        total_loss += loss_contrastive.item()
        optimizer.step()

    print(f"Epoch {epoch}; Loss {total_loss}")

Epoch 0; Loss 276.25927489995956
Epoch 1; Loss 107.70143886096776
Epoch 2; Loss 58.102975176647305
Epoch 3; Loss 44.54895405960269
Epoch 4; Loss 36.63022816204466
