In [9]:
#Import All libraries 
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import optuna
import torch.nn as nn
import torch.nn.functional as F
from pytorch_metric_learning.losses import SupConLoss
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator


In [10]:
# Define image transformations (including normalization)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the full MNIST training set
full_train_data = datasets.MNIST(root="D:\MNIST\MNIST_Train", train=True, download=False, transform=transform)
test_data = datasets.MNIST(root="D:\MNIST\MNIST_Test", train=False, download=False, transform=transform)

In [11]:
#Making the Train-Test Split 
torch.manual_seed(42) 

train_size = int(0.8 * len(full_train_data))  
val_size = len(full_train_data) - train_size  

# Split the full training dataset
train_data, val_data = random_split(full_train_data, [train_size, val_size])

# Load the training and validation data
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)

# Test data
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

In [12]:
class CNNEncoder(nn.Module): #Encoder as we're only interest in the final embedding provided by the CNN 
    def __init__(self):
        super(CNNEncoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        
        self.fc = nn.Linear(256 * 7 * 7, 128)  # Output: 128-dimension embeddings

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)  # 14x14
        
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool(x)  # 7x7
        
        x = x.view(-1, 256 * 7 * 7)  # Flatten
        x = self.fc(x)  # Output embeddings
        return x

In [17]:
class MLPClassifier(nn.Module):
    def __init__(self):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(128, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [16]:
# Define the Supervised Contrastive Loss from the library
contrastive_loss = SupConLoss()

# Initialize accuracy calculator
ccuracy_calculator = AccuracyCalculator(include=("precision_at_1",), knn_func=None)

class ContrastiveTrainer:
    def __init__(self, encoder, classifier, criterion, optimizer):
        self.encoder = encoder
        self.classifier = classifier
        self.criterion = criterion
        self.optimizer = optimizer

    def train_contrastive(self, train_loader, epochs=10):
        self.encoder.train()
        for epoch in range(epochs):
            total_loss = 0
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                self.optimizer.zero_grad()

                # Forward pass through the encoder
                features = self.encoder(images)

                # Compute supervised contrastive loss
                loss = self.criterion(features, labels)
                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
            print(f'Epoch [{epoch + 1}/{epochs}], Contrastive Loss: {total_loss / len(train_loader):.4f}')
    
    def train_classifier(self, train_loader, epochs=5):
        self.encoder.eval()  # Freeze the encoder for classifier training
        self.classifier.train()
        
        for epoch in range(epochs):
            total_loss = 0
            correct = 0
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                self.optimizer.zero_grad()

                # Extract frozen features
                with torch.no_grad():
                    features = self.encoder(images)

                # Forward pass through the classifier
                outputs = self.classifier(features)
                loss = F.cross_entropy(outputs, labels)
                loss.backward()
                self.optimizer.step()

                total_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()

            accuracy = correct / len(train_loader.dataset)
            print(f'Epoch [{epoch + 1}/{epochs}], Classification Loss: {total_loss / len(train_loader):.4f}, Accuracy: {accuracy * 100:.2f}%')

    def evaluate_embeddings(self, test_loader):
        self.encoder.eval()
        all_embeddings = []
        all_labels = []

        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                features = self.encoder(images)
                
                all_embeddings.append(features.cpu())
                all_labels.append(labels.cpu())

        all_embeddings = torch.cat(all_embeddings)
        all_labels = torch.cat(all_labels)

        print(accuracy_calculator.get_accuracy(all_embeddings, all_labels, all_embeddings, all_labels))

NameError: name 'faiss' is not defined

In [None]:
# Initialize encoder, classifier, and optimizer
contrastive_loss = SupConLoss().to(device)
encoder = CNNEncoder().to(device)
classifier = MLPClassifier().to(device)

# Use Adam optimizer for both contrastive learning and classification
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=0.001)

# Initialize trainer
trainer = ContrastiveTrainer(encoder, classifier, contrastive_loss, optimizer)

# 1. Train the encoder with Supervised Contrastive Learning
trainer.train_contrastive(train_loader, epochs=10)

# 2. Fine-tune the classifier on top of the frozen encoder
trainer.train_classifier(train_loader, epochs=5)