In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD

import torchvision

import optuna

import matplotlib.pyplot as plt
import seaborn as sns

from torch.utils.data import DataLoader, random_split
from torchvision import models, datasets, transforms
from torchsummary import summary

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#Define image transformations (including normalization)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load training and test data
full_train_data = datasets.MNIST(root=".\MNIST\MNIST_Train", train=True, download=False, transform = transform)
test_data = datasets.MNIST(root=".\MNIST\MNIST_Test", train=False, download=False, transform= transform)

# Set the seed for reproducibility
torch.manual_seed(42)  # You can choose any seed number

# Define train-validation split sizes
train_size = int(0.8 * len(full_train_data))  # 80% for training
val_size = len(full_train_data) - train_size  # 20% for validation

# Split the full training dataset
train_data, val_data = random_split(full_train_data, [train_size, val_size])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class CNNEncoder(nn.Module): #Encoder as we're only interest in the final embedding provided by the CNN 
    def __init__(self):
        super(CNNEncoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        
        self.fc = nn.Linear(256 * 7 * 7, 128)  # Output: 128-dimension embeddings

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)  # 14x14
        
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool(x)  # 7x7
        
        x = x.view(-1, 256 * 7 * 7)  # Flatten
        x = self.fc(x)  # Output embeddings
        return x

class MLPClassifier(nn.Module): # MLP projection Head 
    def __init__(self):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(128, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [4]:
class CNNModel(nn.Module):
    def __init__(self, encoder, classifier, optimizer):
        super(CNNModel, self).__init__()
        self.encoder = encoder
        self.classifier = classifier
        self.optimizer = optimizer
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.encoder(x)  # First pass through the encoder to get the embeddings
        x = self.classifier(x)  # Then pass the embeddings through the classifier
        return x

    def train_step(self, x, y):
        """Performs one training step: forward pass, loss calculation, backpropagation, and optimization"""
        self.optimizer.zero_grad()  # Clear gradients
        output = self.forward(x)    # Forward pass
        loss = self.loss_fn(output, y)  # Calculate loss
        loss.backward()             # Backpropagation
        self.optimizer.step()       # Update the model parameters
        return loss.item()          # Return the loss value for monitoring

    def test_step(self, x, y):
        """Performs one testing step: forward pass and loss calculation"""
        with torch.no_grad():  # Disable gradient calculation for testing/evaluation
            output = self.forward(x)  # Forward pass
            loss = self.loss_fn(output, y)  # Calculate loss
            predicted = torch.argmax(output, dim=1)  # Get predicted class
            correct = (predicted == y).sum().item()  # Calculate the number of correct predictions
            total = y.size(0)  # Total number of samples
        return loss.item(), correct, total  # Return loss, correct predictions, and total samples

    def train_model(self, train_loader, epochs):
        """Train the model for a number of epochs"""
        self.train()  # Set model to training mode (this does NOT call this method itself)
        for epoch in range(epochs):
            total_loss = 0
            for batch in train_loader:
                x, y = batch
                loss = self.train_step(x, y)  # Perform training step
                total_loss += loss
            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(train_loader)}")  # Print average loss per epoch

    def test_model(self, test_loader):
        """Evaluate the model on the test set"""
        self.eval()  # Set model to evaluation mode
        total_correct = 0
        total_samples = 0
        total_loss = 0
        for batch in test_loader:
            x, y = batch
            loss, correct, total = self.test_step(x, y)  # Perform testing step
            total_loss += loss
            total_correct += correct
            total_samples += total
        accuracy = total_correct / total_samples  # Calculate accuracy
        print(f"Test Loss: {total_loss / len(test_loader)}, Accuracy: {accuracy * 100}%")

In [5]:
# Hyperparameter tuning with Optuna
def objective(trial):
    # Hyperparameters to tune
    lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
    batch_size = trial.suggest_int('batch_size', 32, 128, step=32)
    
    # Create data loaders
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
    
    encoder = CNNEncoder().to(device)
    classifier = MLPClassifier().to(device)
    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=lr)
    
    trainer = CNNModel(encoder, classifier, optimizer)
    
    # Train the model
    trainer.train_model(train_loader, epochs=5)  # Use train_model instead of train
    
    # Evaluate on validation data
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            features = encoder(images)
            outputs = classifier(features)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    return accuracy

In [6]:
#model = CNNEncoder()
#summary(model, input_size=(1, 28, 28))

In [7]:
# Run Optuna study
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=2)

# Best hyperparameters
print("Best hyperparameters: ", study.best_params)

[I 2024-10-17 00:38:01,919] A new study created in memory with name: no-name-f8c848c0-09d3-4918-8d34-1e5059423917


Epoch 1/5, Loss: 0.15571689239102124
Epoch 2/5, Loss: 0.05843935274691466
Epoch 3/5, Loss: 0.045127431534383505
Epoch 4/5, Loss: 0.03903305466469646
Epoch 5/5, Loss: 0.03463148031904893


[I 2024-10-17 00:55:26,711] Trial 0 finished with value: 0.985 and parameters: {'lr': 0.001976089585879866, 'batch_size': 32}. Best is trial 0 with value: 0.985.


Epoch 1/5, Loss: 0.1969921721611172
Epoch 2/5, Loss: 0.04792178798435877
Epoch 3/5, Loss: 0.029979358142241834
Epoch 4/5, Loss: 0.02534065660322085
Epoch 5/5, Loss: 0.017295121989290542


[I 2024-10-17 01:10:48,240] Trial 1 finished with value: 0.9885 and parameters: {'lr': 0.0009359063707485683, 'batch_size': 128}. Best is trial 1 with value: 0.9885.


Best hyperparameters:  {'lr': 0.0009359063707485683, 'batch_size': 128}
