In [1]:
# setting
from pathlib import Path

# load data
import pickle
import gzip

# data prep
import torch
from torch.utils.data import TensorDataset, DataLoader

# modeling
import torch.nn as nn
import torch.nn.functional as F

# training & testing
from torch import optim

In [2]:
DATA_PATH = Path("../data")
PATH = DATA_PATH / "mnist" / "mnist.pkl.gz"

In [3]:
with gzip.open((PATH).as_posix(), "rb") as f:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

In [4]:
batch_size=32

In [5]:
def x_transform(x):
    return x.view(-1, 1, 28, 28)

def y_transform(y):
    pass

In [6]:
class WrappedDataLoader:
    
    def __transform(self, x, func):
        x_transformed = torch.tensor(x)
        if func:
            x_transformed = func(x_transformed)
            
        return x_transformed
    
    def __init__(self, x, y, x_trans=None, y_trans=None, **kwargs):
        x_transformed = self.__transform(x, x_trans)
        y_transformed = self.__transform(y, y_trans)
        self.dl = DataLoader(TensorDataset(x_transformed, y_transformed), **kwargs)
        
    def __len__(self):
        return len(self.dl)
    
    def __iter__(self):
        return iter(self.dl)

In [7]:
train_dl = WrappedDataLoader(x_train, y_train, x_transform, batch_size=batch_size, shuffle=True)
valid_dl = WrappedDataLoader(x_valid, y_valid, x_transform, batch_size=batch_size * 2)

In [8]:
class CNN(nn.Module):
    
    def __init__(self, in_channel):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channel, 8, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(16 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.maxpool(x)
        x = self.relu(self.conv2(x))
        x = self.maxpool(x)
        x = x.view(x.size()[0], -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x

In [9]:
def run_epoch(dl, model, loss_func, opt=None):
    total_loss = 0
    total_size = 0
    cnt_true = 0
    for x, y in dl:
        y_pred = model(x)
        loss = loss_func(y_pred, y)
        total_loss += loss * x.size()[0]
        total_size += x.size()[0]
        cnt_true += torch.sum(torch.argmax(y_pred, dim=1) == y).item()
        if opt:
            loss.backward()
            opt.step()
            opt.zero_grad()

    return total_loss / total_size, cnt_true / total_size


def fit(train_dl, valid_dl, model, loss_func, opt, epochs):
    for epoch in range(epochs):
        model.train()
        train_loss, _ = run_epoch(train_dl, model, loss_func, opt)
        
        model.eval()
        with torch.no_grad():
            valid_loss, accuracy = run_epoch(valid_dl, model, loss_func)
            
        print(f'Epoch {epoch} -  training loss: {train_loss}   validation loss: {valid_loss}   accuracy: {accuracy}')

In [10]:
model = CNN(1)
loss_func = F.cross_entropy
opt = optim.Adam(
    model.parameters(),
    lr=0.001
)

In [11]:
fit(train_dl, valid_dl, model, loss_func, opt, 10)

Epoch 0 -  training loss: 0.253691166639328   validation loss: 0.09074848890304565   accuracy: 0.9734
Epoch 1 -  training loss: 0.07913655042648315   validation loss: 0.06637964397668839   accuracy: 0.9809
Epoch 2 -  training loss: 0.05328388512134552   validation loss: 0.054564137011766434   accuracy: 0.9839
Epoch 3 -  training loss: 0.04223530367016792   validation loss: 0.048925142735242844   accuracy: 0.9857
Epoch 4 -  training loss: 0.03290123865008354   validation loss: 0.056081052869558334   accuracy: 0.9833
Epoch 5 -  training loss: 0.02563486620783806   validation loss: 0.05559058114886284   accuracy: 0.9849
Epoch 6 -  training loss: 0.021003657951951027   validation loss: 0.04349739849567413   accuracy: 0.9882
Epoch 7 -  training loss: 0.0171432476490736   validation loss: 0.04357772320508957   accuracy: 0.9885
Epoch 8 -  training loss: 0.014722688123583794   validation loss: 0.051204029470682144   accuracy: 0.9873
Epoch 9 -  training loss: 0.011063211597502232   validation l