In [1]:
import torch
from torch import nn
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import numpy as np
from torch.nn import functional as F
import torchvision.transforms as tfs
from datetime import datetime

In [2]:
def data_tf(x):
    data_aug = tfs.Compose([
        tfs.RandomHorizontalFlip(),
        tfs.ColorJitter(brightness = 0.5, contrast = 0.5, hue = 0.2),
        tfs.ToTensor()
    ])
    x = data_aug(x)
    return x

In [3]:
train_set = MNIST('./data', train = True, transform = data_tf, download = True)
train_data = DataLoader(train_set, batch_size = 64, shuffle = True)
test_set = MNIST('./data', train = False, transform = data_tf, download = True)
test_data = DataLoader(test_set, batch_size = 128, shuffle = False)

In [4]:
class conv_net(nn.Module):
    def __init__(self):
        super(conv_net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 5, 2, 2)
        self.norm1 = nn.BatchNorm2d(16)
        self.maxpool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 8, 5, padding = 2)
        self.norm2 = nn.BatchNorm2d(8)
        self.maxpool2 = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(72, 10)
        
    
    def forward(self,x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = F.relu(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = F.relu(x)
        x = self.maxpool2(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

In [5]:
def set_learning_rate(optimizer, lr):
    for param_groups in optimizer.param_groups:
        param_groups['lr'] = lr

In [6]:
def train(net, train_data, test_data, epoch, optimizer, criterion):
    train_losses = []
    test_losses = []
    if torch.cuda.is_available():
        net = net.cuda()
    prev_time = datetime.now()
    for epoch in range(epoch):
        if epoch == 15:
            set_learning_rate(optimizer, 0.001)
        train_loss = 0
        train_acc = 0
        net = net.train()
        for im, labels in train_data:
            if torch.cuda.is_available():
                im = im.cuda()
                labels = labels.cuda()
            
            output = net(im)
            loss = criterion(output, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.data.item()
            train_acc += get_acc(output, labels)
        cur_time = datetime.now()
        h,remainder = divmod((cur_time - prev_time).seconds, 3600)
        m,s = divmod(remainder, 60)
        time_str = 'Time:%02d:%02d:%02d'%(h, m, s)
        
        test_loss = 0
        test_acc = 0
        net = net.eval()
        for im, labels in test_data:
            if torch.cuda.is_available():
                im = im.cuda()
                labels = labels.cuda()
            output = net(im)
            loss = criterion(output, labels)
            test_loss += loss.data.item()
            test_acc += get_acc(output, labels)
        epoch_str = (
                "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
                % (epoch, train_loss / len(train_data),
                   train_acc / len(train_data), test_loss / len(test_data),
                   test_acc / len(test_data)))
        prev_time = cur_time
        train_losses.append(train_loss / len(train_data))
        test_losses.append(test_loss / len(test_data))
        print(epoch_str + time_str)

In [7]:
net = conv_net()
optimizer = torch.optim.Adam(net.parameters(), lr = 0.01)
criterion = nn.CrossEntropyLoss()

In [8]:
def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().data.item()
    return num_correct / total

In [9]:
train(net, train_data, test_data, 25, optimizer, criterion)

Epoch 0. Train Loss: 0.366595, Train Acc: 0.888543, Valid Loss: 0.160965, Valid Acc: 0.952433, Time:00:00:13
Epoch 1. Train Loss: 0.180588, Train Acc: 0.946479, Valid Loss: 0.148654, Valid Acc: 0.955004, Time:00:00:14
Epoch 2. Train Loss: 0.153448, Train Acc: 0.954458, Valid Loss: 0.139718, Valid Acc: 0.953916, Time:00:00:14
Epoch 3. Train Loss: 0.139942, Train Acc: 0.958206, Valid Loss: 0.137114, Valid Acc: 0.959949, Time:00:00:14
Epoch 4. Train Loss: 0.133178, Train Acc: 0.959771, Valid Loss: 0.121486, Valid Acc: 0.964102, Time:00:00:14
Epoch 5. Train Loss: 0.125181, Train Acc: 0.963136, Valid Loss: 0.121876, Valid Acc: 0.963212, Time:00:00:14
Epoch 6. Train Loss: 0.120299, Train Acc: 0.964053, Valid Loss: 0.126054, Valid Acc: 0.962718, Time:00:00:14
Epoch 7. Train Loss: 0.115742, Train Acc: 0.965285, Valid Loss: 0.114627, Valid Acc: 0.963805, Time:00:00:14
Epoch 8. Train Loss: 0.111039, Train Acc: 0.966301, Valid Loss: 0.131807, Valid Acc: 0.958267, Time:00:00:14
Epoch 9. Train Loss