In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time

In [2]:
n_train = 60000
n_test  = 10000 
batch_size = 64
n_epochs = 20

In [3]:
def load_dataset(batch_size):
    """
    Loads train & test sets from MNIST with user-specified sizes.

    Args:
        n_train (int): Desired number of samples in the training set.
        n_test (int): Desired number of samples in the testing set.
        batch_size (int): Batch size for the DataLoaders.

    Returns:
        tuple: (train_loader, test_loader) where each loader is a
               torch.utils.data.DataLoader.
    """
    # Define transformations for the dataset
    transform = transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.5,), (0.5,)),
                                    transforms.Lambda(lambda img: F.interpolate(img.unsqueeze(0), size=(14, 14), 
                                        mode='bilinear', align_corners=False).squeeze(0))])

    train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

    # Create DataLoaders for training and testing sets
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

In [4]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()

        # Convolutional layer (Maps 1 input channel to 4 output channels)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=2, stride=2)

        # Calculate output size: [(W - F + 2P) / S] + 1 = [(14 - 2 + 0) / 2] + 1 = 7
        # The output of our conv layer will be 7x7x4 (4 output channels)

        # Fully connected layer to perform the final classification
        self.fc1 = nn.Linear(7 * 7 * 4, 10)  # Assuming 10 output classes

    def forward(self, x):
        x = F.relu(self.conv1(x))  # Apply convolution and ReLU activation
        x = x.view(-1, 7 * 7 * 4)  # Flatten for the fully connected layer
        x = self.fc1(x)
        x = F.softmax(x, dim=1)  # Apply softmax
        return x

model = SimpleCNN()
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")  # Use GPU if available
#model.to(device)
train_loader, test_loader = load_dataset(batch_size)

In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2)  # Reduce if val. loss doesn't improve
loss_fn = F.cross_entropy

In [6]:
for epoch in range(n_epochs):

    # ---- Training Phase ----
    running_loss = 0.0 
    correct = 0  
    total = 0
    train_losses = []
    test_losses = []
    start_time = time.time()
    for i, (images, labels) in enumerate(train_loader):
        # --- 2. Forward Pass ---
        outputs = model(images)

        # --- 3. Loss Calculation ---
        loss = loss_fn(outputs, labels)
        train_losses.append(loss.item())
        # --- 4. Backpropagation and Optimization ---
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()  
        train_losses.append(loss.item())  # Store loss for each batch

        # Accuracy calculation for the epoch
        _, predicted = torch.max(outputs.data, 1) 
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    model.eval()  # Set model to evaluation mode
    val_loss = 0.0
    val_corrects = 0
    
    with torch.no_grad():  # Disable gradients for evaluation
        val_correct = 0
        val_total = 0

        for images, labels in test_loader:
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            test_losses.append(loss.item())

            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    val_loss = sum(test_losses) / len(test_losses)
    val_acc = 100 * val_correct / val_total 

    # Print statistics for the epoch
    epoch_loss = sum(train_losses) / len(train_losses)  
    epoch_acc = 100 * correct / total 

    scheduler.step(val_loss) 

    # Periodic Evaluation and Logging
    print(f'[i] Epoch {epoch + 1}\n\tLoss: {epoch_loss:.4f}\tAccuracy: {epoch_acc :.2f}%\n\tVal Loss: {val_loss:.4f}\t Val Accuracy: {val_acc :.2f}%\n\tElapsed Time: {time.time() - start_time:.2f}s')

[i] Epoch 1
	Loss: 1.7061	Accuracy: 76.20%
	Val Loss: 1.6737	 Val Accuracy: 78.75%
	Elapsed Time: 6.65s
[i] Epoch 2
	Loss: 1.6623	Accuracy: 80.01%
	Val Loss: 1.6641	 Val Accuracy: 79.75%
	Elapsed Time: 6.34s
[i] Epoch 3
	Loss: 1.6575	Accuracy: 80.40%
	Val Loss: 1.6612	 Val Accuracy: 80.01%
	Elapsed Time: 6.53s
[i] Epoch 4
	Loss: 1.6549	Accuracy: 80.62%
	Val Loss: 1.6584	 Val Accuracy: 80.28%
	Elapsed Time: 6.57s
[i] Epoch 5
	Loss: 1.6523	Accuracy: 80.88%
	Val Loss: 1.6608	 Val Accuracy: 80.05%
	Elapsed Time: 6.55s
[i] Epoch 6
	Loss: 1.6510	Accuracy: 81.00%
	Val Loss: 1.6608	 Val Accuracy: 79.96%
	Elapsed Time: 6.52s
[i] Epoch 7
	Loss: 1.6505	Accuracy: 81.06%
	Val Loss: 1.6555	 Val Accuracy: 80.49%
	Elapsed Time: 6.69s
[i] Epoch 8
	Loss: 1.6485	Accuracy: 81.23%
	Val Loss: 1.6549	 Val Accuracy: 80.59%
	Elapsed Time: 6.63s
[i] Epoch 9
	Loss: 1.6316	Accuracy: 82.90%
	Val Loss: 1.6355	 Val Accuracy: 82.44%
	Elapsed Time: 6.54s
[i] Epoch 10
	Loss: 1.6169	Accuracy: 84.35%
	Val Loss: 1.6293	 V