In [1]:
import torch
from torchvision import datasets, transforms
from torch import nn, optim
from tqdm import tqdm

train = datasets.MNIST('.', train=True, download=True,
                       transform=transforms.ToTensor())
loader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True)

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Dropout(0.25),
        )
    def forward(self, x):
        return self.conv(x).view(x.size(0), -1)  # (B,64)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = Encoder().to(device)
optimizer = optim.Adam(encoder.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# simple supervised training
classifier = nn.Linear(64, 10).to(device)

# Early stopping parameters
patience = 10
best_loss = float('inf')
patience_counter = 0

for epoch in tqdm(range(200)):
    running_loss = 0.0
    total_batches = 0
    
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        feat = encoder(x)
        loss = criterion(classifier(feat), y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        total_batches += 1
    
    avg_loss = running_loss / total_batches
    # print(f"Epoch {epoch}, Loss: {avg_loss:.4f}")
    
    # Early stopping logic
    if avg_loss < best_loss:
        best_loss = avg_loss
        patience_counter = 0
        # Save the best model
        torch.save({
            'encoder': encoder.state_dict(),
            'classifier': classifier.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, 'mnist_encoder.pth')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}. Best loss: {best_loss:.4f}")
            break

 13%|█▎        | 26/200 [01:56<12:56,  4.46s/it]


KeyboardInterrupt: 

In [None]:
# save the encoder
torch.save(encoder.state_dict(), 'mnist_encoder.pth')

In [None]:
# visualize some embeddings
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
def plot_embeddings(embeddings, labels, title=None):
    tsne = TSNE(n_components=2, perplexity=30, n_iter=300)
    reduced = tsne.fit_transform(embeddings)
    plt.figure(figsize=(8, 8))
    scatter = plt.scatter(reduced[:, 0], reduced[:, 1], c=labels, cmap='tab10', alpha=0.7)
    plt.legend(*scatter.legend_elements(), title="Digits")
    if title:
        plt.title(title)
    plt.show()
# get some embeddings
encoder.eval()
with torch.no_grad():
    all_feats = []
    all_labels = []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        feat = encoder(x)
        all_feats.append(feat.cpu())
        all_labels.append(y.cpu())
    all_feats = torch.cat(all_feats)
    all_labels = torch.cat(all_labels)
plot_embeddings(all_feats.numpy(), all_labels.numpy(), title="MNIST Embeddings")