# Convolutional Neural Networks with the MNIST handwriting dataset

Deep Neural Networks can have exceedingly large number of parameters, making them difficult to train and stabilize.

In the case of image data, there is no reason every pixel has to be connected to every other pixel. There are very good reasons for reducing the connectivity and for pooling information from groups of pixels.

In this notebook we will introduce a Convolutional Neural Network trained to recognize and classify handwritten digits from the MNIST dataset.

We'll also continue to evolve the "standard" PyTorch approach for setting up CNNs (or other complex DNNs) and training on large datasets.

## Load the MNIST dataset

As always, we start by checking the data.
In this case, we are loading the MNIST dataset as pixelated data points, without any transformation, that is, without any rescaling and renormalizing of individual pixel values.

In [None]:
from torchvision import datasets, transforms

# For visualization: no transform needed
vis_dataset = datasets.MNIST('./data', train=True, download=True, transform=None)


## Plot example data for checks

The pixel intensity values can be interpreted on a grayscale color map. These are 28 x 28 pixel maps.

In [None]:
import matplotlib.pyplot as plt

# Create a figure with subplots
fig, axes = plt.subplots(2, 5, figsize=(12, 6))
fig.suptitle('MNIST Dataset Examples', fontsize=16)

# Display 10 random samples
for i, ax in enumerate(axes.flat):
    # Get a random image and its label
    img, label = vis_dataset[i]

    # Display the image
    ax.imshow(img, cmap='gray')
    ax.set_title(f'Label: {label}')
    ax.axis('off')

plt.tight_layout()
plt.show()


## Define network model

For the first time, we implement the NN model in a separate class, made up of an initialization function and organizational function. We have got away without such a class with the simple feedforward networks, but using a class is the best practice with PyTorch, and we will need the class for more complex models.

- convolutional layers
- pooling layer
- fully connected layers

Note that the flattening happens between the last pooling layer and the first fully connected layer.

In [None]:
import torch
import torch.nn as nn

# Define the CNN architecture
# Best practice with PyTorch is to wrap this in a class.
# (It doesn't matter for feedforward networks, but it will matter
#  for more complicated networks.)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # This function just defines the layers but does not place them.
        # First convolutional layer: 1 input channel (grayscale), 32 output channels, 3x3 kernel
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        # Max pooling layer: 2x2 window
        self.pool = nn.MaxPool2d(2, 2)
        # Second convolutional layer: 32 input channels, 64 output channels, 3x3 kernel
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        # Fully connected layers
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        # Activation function generic description
        # The ReLU layers will be used in the forward function (below)
        self.relu = nn.ReLU()

    def forward(self, x):
        # This function places the layers that were defined above.
        # First conv block: conv -> relu -> pool
        x = self.pool(self.relu(self.conv1(x)))  # 28x28 -> 14x14
        # Second conv block: conv -> relu -> pool
        x = self.pool(self.relu(self.conv2(x)))  # 14x14 -> 7x7
        # Flatten for fully connected layers
        x = x.view(-1, 64 * 7 * 7)
        # First fully connected layer with ReLU
        x = self.relu(self.fc1(x))
        # Output layer (no activation, will use CrossEntropyLoss)
        x = self.fc2(x)
        return x


## Set up training and testing

In this way, we will set up the train and test functions as definitions, so that we can easily call them over and over.
Our loss function is the CrossEntropyLoss (why?), and we are taking the output with the maximum value as our prediction.

In [None]:
# Training function
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = nn.CrossEntropyLoss()(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Training Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}')

# Testing function
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.CrossEntropyLoss()(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')


##

## Loading and training

There are 3 things of note here:
- We are using the GPU if it is available. This will greatly speed up the training.
- The training and testing data are transformed through rescaling so that the values are converted to a convenient range for the input layer. We did not have to worry about this for the visualization of the data.
- The batch size is larger than we have used in the past. Of course the batch size for the testing data does not matter at all. (Does it?)

In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# For training: full transform pipeline that scales inputs 0-1 and normalizes
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
            ])

# Load MNIST dataset
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Create and load model
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train for 5 epochs
for epoch in range(1, 3):
  train(model, device, train_loader, optimizer, epoch)
  test(model, device, test_loader)


# Save the model
torch.save(model.state_dict(), 'mnist_cnn.pth')
print('Model saved to mnist_cnn.pth')

## Saving the model (CNN)

The last line saves the model (trained with many GPU cycles and energy units!). This can be useful for preserving the state of training, continuing training with more data, saving the model so that it can be used for prediction (aka inference).

How many parameters does this model have?