# Image Classification with MNIST

#### Source : https://github.com/samcw/ResNet18-Pytorch

## Set Up

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## RestNet18 Architecture

### ResidualBlock: Implements the core residual block of ResNet

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(ResidualBlock, self).__init__()
        # Main path: two 3x3 convolutions with BatchNorm and ReLU
        self.left = nn.Sequential(
            nn.Conv2d(
                inchannel,
                outchannel,
                kernel_size=3,
                stride=stride,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False
            ),
            nn.BatchNorm2d(outchannel),
        )
        # Shortcut path: matches dimensions if necessary
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    inchannel, outchannel, kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(outchannel),
            )

    def forward(self, x):
        # Combine main path and shortcut, followed by ReLU activation
        out = self.left(x)
        out = out + self.shortcut(x)
        out = F.relu(out)
        return out

### ResNet: Constructs the full ResNet architecture

In [3]:
class ResNet(nn.Module):
    def __init__(self, ResidualBlock, num_classes=10):
        super(ResNet, self).__init__()
        self.inchannel = 64
        # Initial convolution layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        # Stacked residual layers
        self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1)
        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
        self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
        # Fully connected layer for classification
        self.fc = nn.Linear(512, num_classes)

    def make_layer(self, block, channels, num_blocks, stride):
        # Creates a layer of residual blocks
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels
        return nn.Sequential(*layers)

    def forward(self, x):
        # Forward pass through ResNet
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)  # Global average pooling
        out = out.view(out.size(0), -1)  # Flatten features
        out = self.fc(out)  # Fully connected layer
        return out

### Return a ResNet with ResidualBlock as the building block

In [4]:
def ResNet18():
    return ResNet(ResidualBlock)

## Loading DataSet

### Import more necessary libraries for loading MNIST

In [5]:
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

### Configuration GPU support with apple device

In [6]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

### Setting Hyperparameters

In [7]:
EPOCH = 10
BATCH_SIZE = 128
LR = 0.01

### Data Transformation for Training and Testing

In [8]:
# Transformations for MNIST dataset
transform_train = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            (0.1307,), (0.3081,)
        ),  # Normalization for MNIST (mean and std for grayscale)
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),  # Normalization for MNIST
    ]
)

In [9]:
# Load MNIST dataset with the updated transformations
trainset = torchvision.datasets.MNIST(
    root="../data", train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2
)

testset = torchvision.datasets.MNIST(
    root="../data", train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2
)

In [10]:
# Define classes (digits 0-9 for MNIST)
classes = ("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")

# Initialize the model and move it to the device (GPU/CPU)
net = ResNet18().to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()  # For multi-class classification
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)

## Training and Testing 

### Pre-training Epochs


In [11]:
pre_epoch = 0
for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0


Epoch: 1

Epoch: 2

Epoch: 3

Epoch: 4

Epoch: 5

Epoch: 6

Epoch: 7

Epoch: 8

Epoch: 9

Epoch: 10


### Testing the Model at the End of Each Epoch

In [14]:
for i, data in enumerate(trainloader, 0):
    # Prepare dataset
    length = len(trainloader)
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)\
    # Zero the parameter gradients
    optimizer.zero_grad()

    # Forward & backward
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()

    # Optimize
    optimizer.step()

    # Print loss and accuracy per batch
    sum_loss += loss.item()
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += predicted.eq(labels.data).cpu().sum()
    print(
        "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
        % (
            epoch + 1,
            (i + 1 + epoch * length),
            sum_loss / (i + 1),
            100.0 * correct / total,
        )
    )
    print('Waiting for Testing...')
    with torch.no_grad():
        correct = 0
        total = 0
        for data in testloader:
            net.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('Test accuracy: %.3f%%' % (100 * correct / total))

[epoch:10, iter:4222] Loss: 2.418 | Acc: 11.719% 
Waiting for Testing...
Test accuracy: 8.560%
[epoch:10, iter:4223] Loss: 2.360 | Acc: 8.560% 
Waiting for Testing...
Test accuracy: 9.320%
[epoch:10, iter:4224] Loss: 2.341 | Acc: 9.261% 
Waiting for Testing...
Test accuracy: 10.220%
[epoch:10, iter:4225] Loss: 2.331 | Acc: 10.239% 
Waiting for Testing...
Test accuracy: 10.860%
[epoch:10, iter:4226] Loss: 2.325 | Acc: 10.881% 
Waiting for Testing...
Test accuracy: 11.660%
[epoch:10, iter:4227] Loss: 2.321 | Acc: 11.602% 
Waiting for Testing...
Test accuracy: 11.860%
[epoch:10, iter:4228] Loss: 2.319 | Acc: 11.858% 
Waiting for Testing...
Test accuracy: 13.060%
[epoch:10, iter:4229] Loss: 2.317 | Acc: 12.984% 
Waiting for Testing...
Test accuracy: 18.130%
[epoch:10, iter:4230] Loss: 2.315 | Acc: 18.118% 
Waiting for Testing...
Test accuracy: 17.880%
[epoch:10, iter:4231] Loss: 2.314 | Acc: 17.871% 
Waiting for Testing...
Test accuracy: 18.240%
[epoch:10, iter:4232] Loss: 2.313 | Acc: 18.

### Training Completion

In [15]:
print("Training finished. Total epochs: %d" % EPOCH)

Training finished. Total epochs: 10


In [8]:
# Train
pre_epoch = 0
for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0
    for i, data in enumerate(trainloader, 0):
        # Prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        # Optimize
        optimizer.step()

        # Print loss and accuracy per batch
        sum_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + epoch * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )

    # Evaluate accuracy with test dataset at the end of each epoch
    print("Waiting for Testing...")
    with torch.no_grad():
        correct = 0
        total = 0
        for data in testloader:
            net.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print("Test accuracy: %.3f%%" % (100 * correct / total))

print("Training finished. Total epochs: %d" % EPOCH)


Epoch: 1
[epoch:1, iter:1] Loss: 2.371 | Acc: 7.812% 
[epoch:1, iter:2] Loss: 2.330 | Acc: 10.938% 
[epoch:1, iter:3] Loss: 2.287 | Acc: 14.844% 
[epoch:1, iter:4] Loss: 2.247 | Acc: 17.188% 
[epoch:1, iter:5] Loss: 2.201 | Acc: 18.906% 
[epoch:1, iter:6] Loss: 2.152 | Acc: 21.875% 
[epoch:1, iter:7] Loss: 2.123 | Acc: 23.996% 
[epoch:1, iter:8] Loss: 2.093 | Acc: 25.684% 
[epoch:1, iter:9] Loss: 2.037 | Acc: 29.167% 
[epoch:1, iter:10] Loss: 1.983 | Acc: 32.031% 
[epoch:1, iter:11] Loss: 1.918 | Acc: 35.298% 
[epoch:1, iter:12] Loss: 1.860 | Acc: 37.370% 
[epoch:1, iter:13] Loss: 1.793 | Acc: 40.565% 
[epoch:1, iter:14] Loss: 1.724 | Acc: 43.192% 
[epoch:1, iter:15] Loss: 1.657 | Acc: 46.042% 
[epoch:1, iter:16] Loss: 1.589 | Acc: 48.730% 
[epoch:1, iter:17] Loss: 1.533 | Acc: 50.506% 
[epoch:1, iter:18] Loss: 1.475 | Acc: 52.517% 
[epoch:1, iter:19] Loss: 1.430 | Acc: 54.112% 
[epoch:1, iter:20] Loss: 1.382 | Acc: 55.742% 
[epoch:1, iter:21] Loss: 1.332 | Acc: 57.403% 
[epoch:1, ite

## Summary Of Model


### Model Architecture Summary

In [19]:
import torchsummary


# Function to display the model architecture
def model_summary(model, input_size=(1, 28, 28)):
    """
    Displays a detailed summary of the model architecture.

    Parameters:
    - model: The trained model instance.
    - input_size: The input tensor size (default: (1, 28, 28) for MNIST images).
    """
    # Move model to CPU for compatibility with torchsummary
    model.to("cpu")

    # Use torchsummary to display the model summary
    torchsummary.summary(model, input_size=input_size, device="cpu")

    # Return model back to original device (MPS) if needed
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model.to(device)


# Call the model_summary function to view the ResNet18 architecture
model_summary(net)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 28, 28]             576
       BatchNorm2d-2           [-1, 64, 28, 28]             128
              ReLU-3           [-1, 64, 28, 28]               0
            Conv2d-4           [-1, 64, 28, 28]          36,864
       BatchNorm2d-5           [-1, 64, 28, 28]             128
              ReLU-6           [-1, 64, 28, 28]               0
            Conv2d-7           [-1, 64, 28, 28]          36,864
       BatchNorm2d-8           [-1, 64, 28, 28]             128
     ResidualBlock-9           [-1, 64, 28, 28]               0
           Conv2d-10           [-1, 64, 28, 28]          36,864
      BatchNorm2d-11           [-1, 64, 28, 28]             128
             ReLU-12           [-1, 64, 28, 28]               0
           Conv2d-13           [-1, 64, 28, 28]          36,864
      BatchNorm2d-14           [-1, 64,

### Training Summary Function

In [21]:
# Function to summarize training results
def training_summary(epoch, loss, accuracy, test_accuracy):
    """
    Displays a summary of the training process and results.

    Parameters:
    - epoch: Total number of epochs completed.
    - loss: Final training loss.
    - accuracy: Final training accuracy (in percentage).
    - test_accuracy: Final test accuracy (in percentage).
    """
    print("\n--- Training Summary ---")
    print(f"Total Epochs: {epoch}")
    print(f"Final Training Loss: {loss:.4f}")
    print(f"Final Training Accuracy: {accuracy:.2f}%")
    print(f"Final Test Accuracy: {test_accuracy:.2f}%")


### Call Training Summary Function

In [23]:
# Assuming the following values are calculated during training
final_loss = sum_loss / len(trainloader)  # Average loss over the last epoch
final_train_accuracy = 100.0 * correct / total  # Final training accuracy
final_test_accuracy = (
    98.21  # Example value for test accuracy (replace with actual computation)
)

# Call the training summary function
training_summary(EPOCH, final_loss, final_train_accuracy, final_test_accuracy)



--- Training Summary ---
Total Epochs: 10
Final Training Loss: 0.8225
Final Training Accuracy: 97.06%
Final Test Accuracy: 98.21%


### Combine Both Summaries

In [32]:
from torchsummary import summary
from tabulate import tabulate


def complete_summary(model, input_size, epoch, loss, accuracy, test_accuracy, device):
    """
    Combines the model architecture summary and training summary into one function,
    with a clean and professional table format for training results.

    Parameters:
    - model: The trained model instance.
    - input_size: The input tensor size for model summary.
    - epoch: Total number of epochs completed.
    - loss: Final training loss.
    - accuracy: Final training accuracy (in percentage).
    - test_accuracy: Final test accuracy (in percentage).
    - device: The device (CPU/GPU) used for training.
    """
    # Display model architecture summary (switch to CPU for compatibility)
    print("=" * 40)
    print("MODEL ARCHITECTURE SUMMARY")
    print("=" * 40)
    summary(
        model.to("cpu"), input_size=input_size, device="cpu"
    )  # Switch to CPU for summary
    model.to(device)  # Move the model back to the original device after summary

    # Prepare training summary table
    training_results = [
        ["Total Epochs", epoch],
        ["Final Training Loss", f"{loss:.4f}"],
        ["Final Training Accuracy", f"{accuracy:.2f}%"],
        ["Final Test Accuracy", f"{test_accuracy:.2f}%"],
    ]

    print("\n" + "=" * 40)
    print("TRAINING PROCESS SUMMARY")
    print("=" * 40)
    print(
        tabulate(training_results, headers=["Metric", "Value"], tablefmt="fancy_grid")
    )


# Call the combined summary function
complete_summary(
    net,
    input_size=(1, 28, 28),
    epoch=EPOCH,
    loss=final_loss,
    accuracy=final_train_accuracy,
    test_accuracy=final_test_accuracy,
    device=device,
)


MODEL ARCHITECTURE SUMMARY
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 28, 28]             576
       BatchNorm2d-2           [-1, 64, 28, 28]             128
              ReLU-3           [-1, 64, 28, 28]               0
            Conv2d-4           [-1, 64, 28, 28]          36,864
       BatchNorm2d-5           [-1, 64, 28, 28]             128
              ReLU-6           [-1, 64, 28, 28]               0
            Conv2d-7           [-1, 64, 28, 28]          36,864
       BatchNorm2d-8           [-1, 64, 28, 28]             128
     ResidualBlock-9           [-1, 64, 28, 28]               0
           Conv2d-10           [-1, 64, 28, 28]          36,864
      BatchNorm2d-11           [-1, 64, 28, 28]             128
             ReLU-12           [-1, 64, 28, 28]               0
           Conv2d-13           [-1, 64, 28, 28]          36,864
      BatchN