In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [3]:
# define hyper parameters
# size is 28 x 28
input_size = 28
num_class = 10
num_epochs = 3
# one batch will train 64 datas
batch_size = 64

# train dataset
train_dataset = datasets.MNIST(
    root="data",
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

# test dataset
test_dataset = datasets.MNIST(
    root="data",
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

# create data batch by data loader
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|████████████████████████████| 9912422/9912422 [00:01<00:00, 5419813.07it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|████████████████████████████████| 28881/28881 [00:00<00:00, 3491847.85it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|████████████████████████████| 1648877/1648877 [00:00<00:00, 5195344.10it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████| 4542/4542 [00:00<00:00, 4015710.11it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [8]:
# create model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(       # input size is (1, 28, 28) (channel, height, width)
            nn.Conv2d(
                in_channels=1,            # only one channel
                out_channels=16,          # output is 16 channel means 16 feature image (num of kernels)
                kernel_size=5,            # filter kernel size is 5 x 5
                stride=1,                 # step be 1
                padding=2,                # padding the image with 2 pixel size will become 28 + 4 = 32
            ),                            # output size is (16, 28, 28) 
            nn.ReLU(),                    # relu function
            nn.MaxPool2d(kernel_size=2),  # every 2 x 2 area will get one value so 28 / 2 = 14
        )                                 # output size is (16, 14, 14)
        self.conv2 = nn.Sequential(       # input size is (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),   # output will be (32, 14, 14)
            nn.ReLU(),                    # relu function
            nn.MaxPool2d(kernel_size=2),  # output is (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10) # fully connect layer output

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)        # flatten x, size(0) is batch size, -1 is all rest, here will be (batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output

In [9]:
# define accuracy function
def accuracy(predictions, labels):
    pred = torch.max(predictions.data, 1)[1]
    rights = pred.eq(labels.data.view_as(pred)).sum()
    return rights, len(labels)

In [11]:
# create model instance
model = CNN()
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# train loop
for epoch in range(num_epochs):
    train_acc_list = []

    for idx, (train_data, label) in enumerate(train_loader):
        model.train()
        output = model(train_data)
        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_acc = accuracy(output, label)
        train_acc_list.append(train_acc)

        if idx % 100 == 0:
            model.eval()
            val_acc_list = []

            for (val_data, label) in test_loader:
                output = model(val_data)
                val_acc = accuracy(output, label)
                val_acc_list.append(val_acc)

            # calculate the acc
            train_r = (sum([tup[0] for tup in train_acc_list]), sum([tup[1] for tup in train_acc_list]))
            val_r = (sum([tup[0] for tup in val_acc_list]), sum([tup[1] for tup in val_acc_list]))

            print(f"Epoch: {epoch}, {idx * batch_size}/{len(train_loader.dataset)}| Loss: {loss.data} | Acc of train data: {100. * train_r[0].numpy() / train_r[1]} | Acc of val data: {100. * val_r[0].numpy() / val_r[1]}")

Epoch: 0, 0/60000| Loss: 2.3169734477996826 | Acc of train data: 3.125 | Acc of val data: 9.87
Epoch: 0, 6400/60000| Loss: 0.3007182478904724 | Acc of train data: 76.39232673267327 | Acc of val data: 92.5
Epoch: 0, 12800/60000| Loss: 0.13541360199451447 | Acc of train data: 84.81032338308458 | Acc of val data: 94.9
Epoch: 0, 19200/60000| Loss: 0.14530989527702332 | Acc of train data: 88.35132890365449 | Acc of val data: 96.37
Epoch: 0, 25600/60000| Loss: 0.0494905561208725 | Acc of train data: 90.32107231920199 | Acc of val data: 97.23
Epoch: 0, 32000/60000| Loss: 0.029833775013685226 | Acc of train data: 91.59181636726547 | Acc of val data: 97.18
Epoch: 0, 38400/60000| Loss: 0.06421170383691788 | Acc of train data: 92.47868136439268 | Acc of val data: 97.95
Epoch: 0, 44800/60000| Loss: 0.026181388646364212 | Acc of train data: 93.23956847360913 | Acc of val data: 98.12
Epoch: 0, 51200/60000| Loss: 0.07353786379098892 | Acc of train data: 93.79486579275905 | Acc of val data: 98.16
Epoc