In [None]:
import torch
from torch import nn
import torchvision

| Layers | Hyper-parameters |
| :--- | :--- |
| Covolution 1 | Kernel size $= (5, 5, 32)$, SAME padding. Followed by BatchNorm and ReLU. |
| Pooling 1 | Average operation. Kernel size $= (2, 2)$. Stride $= 2$. Padding $= 0$. |
| Covolution 2 | Kernel size $= (5, 5, 32)$, SAME padding. Followed by BatchNorm and ReLU. |
| Pooling 2 | Average operation. Kernel size $= (2, 2)$. Stride $= 2$. Padding $= 0$. |
| Covolution 3 | Kernel size $= (5, 5, 64)$, SAME padding. Followed by BatchNorm and ReLU. |
| Pooling 3 | Average operation. Kernel size $= (2, 2)$. Stride $= 2$. Padding $= 0$. |
| Fully Connected 1 | Output channels $= 64$. Followed by BatchNorm and ReLU. |
| Fully Connected 2 | Output channels $= 10$. Followed by Softmax. |

In [None]:
use_cda = torch.cuda.is_available()
device = torch.device("cuda" if use_cda else "cpu")


class DigitClassifcation(nn.Module):
    def __init__(self):
        super(DigitClassifcation, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(1, 32, 5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.AvgPool2d(2, 2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.AvgPool2d(2, 2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.AvgPool2d(2, 2),
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 3 * 3, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 10),
            nn.Softmax(1),
        )

    def forward(self, x):
        x = self.seq(x)
        x = self.fc(x.view(-1, 64 * 3 * 3))
        return x


model = DigitClassifcation().to(device)

In [None]:
from torch.utils.data import DataLoader

normalize = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(lambda x: (x - 0.5) * 2),
    ]
)

train_set = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=normalize
)
test_set = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=normalize
)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

In [None]:
import matplotlib.pyplot as plt


def train():
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
    criterion = nn.CrossEntropyLoss()
    num_epochs = 10
    train_losses = []
    test_losses = []
    test_accuracies = []
    for epoch in range(num_epochs):
        losses_train = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            losses_train += loss.item() * inputs.size(0)
        print("Epoch %d/(up to) %d" % (epoch + 1, num_epochs))
        train_losses.append(losses_train / len(train_loader.dataset))

        model.eval()
        correct = 0
        total = 0
        losses_test = 0.0
        with torch.no_grad():
            for (
                inputs,
                labels,
            ) in test_loader:
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += len(predicted)
                correct += (predicted == labels).sum().item()
                losses_test += criterion(outputs, labels) * inputs.size(0)

        test_losses.append(losses_test / len(test_loader.dataset))
        test_accuracies.append(100 * correct / total)
        if test_accuracies[-1] >= 99:
            num_epochs = epoch + 1
            break

    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(1, num_epochs + 1), train_losses, label="Train Loss")
    plt.plot(range(1, num_epochs + 1), test_losses, label="Test Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training and Testing Loss")
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(
        range(1, num_epochs + 1), test_accuracies, label="Test Accuracy", color="orange"
    )
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.title("Test Accuracy")
    plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
train()

In [None]:
m = torch.jit.script(model)
torch.jit.save(m, "mnist_model.pth")