Import essential libraries for neural network training, data loading, and visualization.

In [None]:
import warnings
warnings.filterwarnings('ignore')

# Import neural network training libraries
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms

# Import basic computation and data visualization libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from sklearn.decomposition import PCA
import umap
import umap.plot
import plotly.graph_objs as go
import plotly.io as pio
pio.renderers.default = 'iframe'

# Import the custom dataset class
from mnist_dataset import MNISTDataset


Load the MNIST dataset and apply image transformations for both training and validation sets.

In [None]:
# Load data from csv
data = pd.read_csv('digit-recognizer/train.csv')
val_count = 1000

# Define common transformations for both training and validation data
default_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

# Split data into training and validation datasets
dataset = MNISTDataset(data.iloc[:-val_count], default_transform)
val_dataset = MNISTDataset(data.iloc[-val_count:], default_transform)


Set up data loaders to efficiently handle the training and validation batches.



In [None]:
# Create DataLoader for training data
trainLoader = DataLoader(
    dataset,
    batch_size=16,  # Feel free to modify the batch size
    shuffle=True,
    pin_memory=True,
    num_workers=2,
    prefetch_factor=100
)

# Create DataLoader for validation data
valLoader = DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=True,
    pin_memory=True,
    num_workers=2,
    prefetch_factor=100
)


Display examples of anchor images and corresponding contrastive images from the dataset.

In [None]:
# Function to display images with labels
def show_images(images, title=''):
    num_images = len(images)
    fig, axes = plt.subplots(1, num_images, figsize=(9, 3))
    for i in range(num_images):
        img = np.squeeze(images[i])
        axes[i].imshow(img, cmap='gray')
        axes[i].axis('off')
    fig.suptitle(title)
    plt.show()

# Visualize some examples from the training set
for batch_idx, (anchor_images, contrastive_images, distances, labels) in enumerate(trainLoader):
    # Convert tensors to numpy arrays
    anchor_images = anchor_images.numpy()
    contrastive_images = contrastive_images.numpy()
    labels = labels.numpy()

    # Display the first four samples
    show_images(anchor_images[:4], title='Anchor Image')
    show_images(contrastive_images[:4], title='+/- Example')
    
    break  # Display only one batch for demonstration


Define a neural network architecture for MNIST image representation in 64 dimensions.



In [None]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 5),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((2, 2), stride=2),
            nn.Dropout(0.3)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((2, 2), stride=2),
            nn.Dropout(0.3)
        )
        self.linear1 = nn.Sequential(
            nn.Linear(64 * 4 * 4, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 64),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.linear1(x)
        return x


Define a custom contrastive loss function using cosine similarity for comparing anchor and contrastive images.



In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self):
        super(ContrastiveLoss, self).__init__()
        self.similarity = nn.CosineSimilarity(dim=-1, eps=1e-7)

    def forward(self, anchor, contrastive, distance):
        score = self.similarity(anchor, contrastive)  # Calculate cosine similarity
        return nn.MSELoss()(score, distance)  # Minimize difference between calculated and ideal distance


Set up the network, optimizer, loss function, and learning rate scheduler.

In [None]:
net = Network()

device = 'cpu'
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    
net = net.to(device)

# Define optimizer, loss function, and learning rate scheduler
optimizer = optim.Adam(net.parameters(), lr=0.005)
loss_function = ContrastiveLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.3)


Implement the training loop to train the network for a defined number of epochs.



In [None]:
import os

# Directory to save model checkpoints
checkpoint_dir = 'checkpoints/'
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

def train_model(epoch_count=10):
    net = Network()
    lrs = []
    losses = []

    for epoch in range(epoch_count):
        epoch_loss = 0
        batches = 0
        print('epoch -', epoch)
        lrs.append(optimizer.param_groups[0]['lr'])
        print('learning rate', lrs[-1])
    
        for anchor, contrastive, distance, label in tqdm(trainLoader):
            batches += 1
            optimizer.zero_grad()
            anchor_out = net(anchor.to(device))
            contrastive_out = net(contrastive.to(device))
            distance = distance.to(torch.float32).to(device)
            loss = loss_function(anchor_out, contrastive_out, distance)
            epoch_loss += loss
            loss.backward()
            optimizer.step()
        
        losses.append(epoch_loss.cpu().detach().numpy() / batches)
        scheduler.step()
        print('epoch_loss', losses[-1])
    
        # Save model checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch}.pt')
        torch.save(net.state_dict(), checkpoint_path)

    return {"net": net, "losses": losses}


Optionally load a pre-trained model from a saved checkpoint.

In [None]:
def load_model_from_checkpoint():
    checkpoint = torch.load('checkpoints/model_epoch_99.pt')
    net = Network()
    net.load_state_dict(checkpoint)
    net.eval()
    return net

train = False  # Set to True if you want to train the model
if train:
    training_result = train_model()
    model = training_result["net"]
else:
    model = load_model_from_checkpoint()


Visualize the training loss curve after training or load the pre-trained loss curve.

In [None]:
from IPython.display import Image

if train:
    plt.plot(training_result["losses"])
    plt.show()
else:
    display(Image(filename="images/loss-curve.png", height=600, width=600))


Reduce the dimensionality of encoded training data from 64D to 3D and plot it interactively using PCA.

In [None]:
encoded_data = []
labels = []

with torch.no_grad():
    for anchor, _, _, label in tqdm(trainLoader):
        output = model(anchor.to(device))
        encoded_data.extend(output.cpu().numpy())
        labels.extend(label.cpu().numpy())

encoded_data = np.array(encoded_data)
labels = np.array(labels)

# Apply PCA for dimensionality reduction (64D to 3D)
pca = PCA(n_components=3)
encoded_data_3d = pca.fit_transform(encoded_data)

# Plot in 3D
scatter = go.Scatter3d(
    x=encoded_data_3d[:, 0],
    y=encoded_data_3d[:, 1],
    z=encoded_data_3d[:, 2],
    mode='markers',
    marker=dict(size=4, color=labels, colorscale='Viridis', opacity=0.8),
    text=labels, 
    hoverinfo='text',
)

layout = go.Layout(
    title="MNIST Dataset - Encoded and PCA Reduced 3D Scatter Plot",
    scene=dict(
        xaxis=dict(title="PC1"),
        yaxis=dict(title="PC2"),
        zaxis=dict(title="PC3"),
    ),
    width=1000, 
    height=750,
)

fig = go.Figure(data=[scatter], layout=layout)
fig.show()


Use UMAP for dimensionality reduction and visualize the data in 2D.

In [None]:
mapper = umap.UMAP(random_state=42, metric='cosine').fit(encoded_data)
umap.plot.points(mapper, labels=labels)

# Use Euclidean metric for UMAP visualization
mapper = umap.UMAP(random_state=42).fit(encoded_data) 
umap.plot.points(mapper, labels=labels)
