### CNN on MNIST

In [1]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision import datasets

from matplotlib import pyplot as plt

import pandas as pd
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# Hyper
n_epochs = 10
batch_size = 200
lr = 0.001

In [3]:
train_data = datasets.MNIST(root='./data',
                            train=True,
                            transform=transforms.ToTensor(),
                            download=True)

test_data  = datasets.MNIST(root='./data',
                            train=False,
                            transform=transforms.ToTensor(),
                            download=True)

train_load = DataLoader(train_data,
                        batch_size=batch_size,
                        shuffle=True)

test_load  = DataLoader(test_data,
                        batch_size=batch_size,
                        shuffle=True)

In [4]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        
        self.conv1 = nn.Conv2d(1, 64, 3, padding=1)   # N 64  28 28
        self.pool1 = nn.MaxPool2d(2)                  # N 64  14 14
        self.conv2 = nn.Conv2d(64, 128, 3, padding=0) # N 128 12 12
        self.pool2 = nn.MaxPool2d(2)                  # N 128  6  6
        self.lin1 = nn.Linear(128*6*6, 1024)
        self.lin2 = nn.Linear(1024, 10)

    def forward(self, x):
        out = self.relu(self.conv1(x))
        out = self.pool1(out)
        out = self.relu(self.conv2(out))
        out = self.pool2(out)
        out = self.flatten(out)
        out = self.lin1(out)
        out = self.lin2(out)
        return out

In [5]:
model = CNN().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(n_epochs):
    model.train()
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_load):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        
        predict = model(images)
        loss = criterion(predict, labels)
        loss.backward()
        
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_load)

    print(f"Epoch: [{epoch+1}/{n_epochs}], loss: {epoch_loss}")

Epoch: [1/10], loss: 0.17029022024013102
Epoch: [2/10], loss: 0.047724027594861884
Epoch: [3/10], loss: 0.03373051142552867
Epoch: [4/10], loss: 0.02725422229268588
Epoch: [5/10], loss: 0.020018086722120642
Epoch: [6/10], loss: 0.01858218543425513
Epoch: [7/10], loss: 0.012712505349967007
Epoch: [8/10], loss: 0.01278484737413237
Epoch: [9/10], loss: 0.012277075772726676
Epoch: [10/10], loss: 0.010265988356695743


In [6]:
import time

with torch.no_grad():
    model.eval()
    correct = 0
    
    for images, labels in test_load:  
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        _, preds = torch.max(outputs, 1)

        ## Plotting incorrect images
        # incorrect_indices = torch.where(preds != labels)[0].tolist()
        # 
        # if len(incorrect_indices) > 0:
        #     plt.ion()
        #     for i in incorrect_indices:
        #         
        #         plt.imshow(images[i][0].cpu(), cmap='gray')
        #         plt.title(f'Pred: {preds[i].item()}, True: {labels[i].item()}')
        #         plt.draw()
        #         plt.pause(1)
        #         plt.clf()
        #     plt.ioff()
        
        correct += (preds == labels).sum().item()

    acc = correct / len(test_load.dataset)
    print(f"Accuracy: {acc * 100:.2f}%")

Accuracy: 99.06%
