In [19]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time

In [20]:
n_train = 6000
n_test  = 1000
batch_size = 256
n_epochs = 20

lr = 0.01
gamma = 0.985

In [21]:
def load_dataset(n_train, n_test, batch_size):
    """
    Loads train & test sets from MNIST with user-specified sizes.

    Args:
        n_train (int): Desired number of samples in the training set.
        n_test (int): Desired number of samples in the testing set.
        batch_size (int): Batch size for the DataLoaders.

    Returns:
        tuple: (train_loader, test_loader) where each loader is a
               torch.utils.data.DataLoader.
    """
    # Define transformations for the dataset
    transform = transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.5,), (0.5,)),
                                    transforms.Lambda(lambda img: F.interpolate(img.unsqueeze(0), size=(14, 14), 
                                        mode='bilinear', align_corners=False).squeeze(0))])

    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    # Subset the datasets to the desired number of samples
    train_subset = torch.utils.data.Subset(train_dataset, range(n_train))
    test_subset = torch.utils.data.Subset(test_dataset, range(n_test))

    # Create DataLoaders for training and testing sets
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, pin_memory=True)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False, pin_memory=True)

    print("Number of training samples:", len(train_subset))
    print("Number of test samples:", len(test_subset))

    return train_loader, test_loader

In [22]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()

        # Convolutional layer (Maps 1 input channel to 4 output channels)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=2, stride=2)

        # Calculate output size: [(W - F + 2P) / S] + 1 = [(14 - 2 + 0) / 2] + 1 = 7
        # The output of our conv layer will be 7x7x4 (4 output channels)

        # Fully connected layer to perform the final classification
        self.fc1 = nn.Linear(7 * 7 * 4, 10)  # Assuming 10 output classes

    def forward(self, x):
        x = F.relu(self.conv1(x))  # Apply convolution and ReLU activation
        x = x.view(-1, 7 * 7 * 4)  # Flatten for the fully connected layer
        x = self.fc1(x)
        x = F.softmax(x, dim=1)  # Apply softmax
        return x

model = SimpleCNN()
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")  # Use GPU if available
#model.to(device)
train_loader, test_loader = load_dataset(n_train, n_test, batch_size)

Number of training samples: 6000
Number of test samples: 1000


In [23]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) 
loss_fn = F.cross_entropy

In [24]:
for epoch in range(n_epochs):

    # ---- Training Phase ----
    running_loss = 0.0 
    correct = 0  
    total = 0
    train_losses = []
    test_losses = []
    start_time = time.time()
    for i, (images, labels) in enumerate(train_loader):
        # --- 2. Forward Pass ---
        outputs = model(images)

        # --- 3. Loss Calculation ---
        loss = loss_fn(outputs, labels)
        train_losses.append(loss.item())
        # --- 4. Backpropagation and Optimization ---
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()  
        train_losses.append(loss.item())  # Store loss for each batch

        # Accuracy calculation for the epoch
        _, predicted = torch.max(outputs.data, 1) 
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    scheduler.step() 
    model.eval()  # Set model to evaluation mode
    val_loss = 0.0
    val_corrects = 0
    
    with torch.no_grad():  # Disable gradients for evaluation
        val_correct = 0
        val_total = 0

        for images, labels in test_loader:
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            test_losses.append(loss.item())

            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    val_loss = sum(test_losses) / len(test_losses)
    val_acc = 100 * val_correct / val_total 

    # Print statistics for the epoch
    epoch_loss = sum(train_losses) / len(train_losses)  
    epoch_acc = 100 * correct / total 

    # Periodic Evaluation and Logging
    print(f'[i] Epoch {epoch + 1}\n\tLoss: {epoch_loss:.4f}\tAccuracy: {epoch_acc :.2f}%\n\tVal Loss: {val_loss:.4f}\t Val Accuracy: {val_acc :.2f}%\n\tElapsed Time: {time.time() - start_time:.2f}s')

[i] Epoch 1
	Loss: 2.1576	Accuracy: 34.98%
	Val Loss: 1.9451	 Val Accuracy: 60.40%
	Elapsed Time: 0.69s
[i] Epoch 2
	Loss: 1.7972	Accuracy: 72.13%
	Val Loss: 1.7466	 Val Accuracy: 75.00%
	Elapsed Time: 0.64s
[i] Epoch 3
	Loss: 1.6972	Accuracy: 79.07%
	Val Loss: 1.7063	 Val Accuracy: 77.60%
	Elapsed Time: 0.69s
[i] Epoch 4
	Loss: 1.6678	Accuracy: 81.25%
	Val Loss: 1.6869	 Val Accuracy: 79.00%
	Elapsed Time: 0.68s
[i] Epoch 5
	Loss: 1.6538	Accuracy: 82.40%
	Val Loss: 1.6815	 Val Accuracy: 79.60%
	Elapsed Time: 0.57s
[i] Epoch 6
	Loss: 1.6427	Accuracy: 83.20%
	Val Loss: 1.6749	 Val Accuracy: 79.50%
	Elapsed Time: 0.64s
[i] Epoch 7
	Loss: 1.6371	Accuracy: 83.77%
	Val Loss: 1.6644	 Val Accuracy: 80.60%
	Elapsed Time: 0.71s
[i] Epoch 8
	Loss: 1.6283	Accuracy: 84.17%
	Val Loss: 1.6607	 Val Accuracy: 80.70%
	Elapsed Time: 0.71s
[i] Epoch 9
	Loss: 1.6284	Accuracy: 84.25%
	Val Loss: 1.6581	 Val Accuracy: 81.00%
	Elapsed Time: 0.87s
[i] Epoch 10
	Loss: 1.6260	Accuracy: 84.37%
	Val Loss: 1.6630	 V