In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets 
from torchvision import transforms

from tqdm import tqdm, tqdm_notebook

In [2]:
batch_size = 16
device = torch.device('cuda')

In [3]:
def train(model, device, train_loader, optimizer, criterion, n_epochs, log_interval):
    model.train()
    for epoch in tqdm_notebook(range(n_epochs)):
        for batch_idx, (X, y) in tqdm_notebook(enumerate(train_loader), total = len(train_loader)):
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            y_pred = model(X)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            if batch_idx % log_interval == 0:
                print(f'Train epoch {epoch}: [{batch_idx * len(X):5d}/{len(train_loader.dataset):5d}] Loss: {loss.item():7.4f}')

In [4]:
def evaluate(model, device, test_loader, criterion):
    model.eval()
    test_set_size = len(test_loader.dataset)
    correct_answers = 0
    sum_loss = 0
    with torch.no_grad():
        for X, y in tqdm_notebook(test_loader):
            X, y = X.to(device), y.to(device)
            y_pred = model(X)
            class_pred = y_pred.argmax(dim = 1)
            correct_answers += (y == class_pred).sum().item()
            sum_loss += criterion(y_pred, y).item()
    accuracy = correct_answers / test_set_size
    average_loss = sum_loss / len(test_loader)
    
    return accuracy, average_loss

In [5]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, 3, padding = 1)
        self.conv2 = nn.Conv2d(64, 64, 3, padding = 1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding = 1)
        self.conv4 = nn.Conv2d(128, 128, 3, padding = 1)
        self.fc1   = nn.Linear(128 * 7 * 7, 256)
        self.fc2   = nn.Linear(256, 256)
        self.fc3   = nn.Linear(256, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, kernel_size = (2, 2))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.max_pool2d(x, kernel_size = (2, 2))
        x = x.view(-1, 128 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x

In [6]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        '../data',
        train = True,
        download = True,
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ]
        )
    ),
    shuffle = True,
    batch_size = batch_size
)

In [7]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        '../data',
        train = False,
        download = True,
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ]
        )
    ),
    shuffle = False,
    batch_size = 2 * batch_size
)

In [8]:
cnn       = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.parameters())

In [9]:
train(cnn, device, train_loader, optimizer, criterion, 3, 500)

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))

Train epoch 0: [    0/60000] Loss:  2.3116
Train epoch 0: [ 8000/60000] Loss:  0.0036
Train epoch 0: [16000/60000] Loss:  0.0058
Train epoch 0: [24000/60000] Loss:  0.1645
Train epoch 0: [32000/60000] Loss:  0.0122
Train epoch 0: [40000/60000] Loss:  0.0218
Train epoch 0: [48000/60000] Loss:  0.1394
Train epoch 0: [56000/60000] Loss:  0.0301


HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))

Train epoch 1: [    0/60000] Loss:  0.0245
Train epoch 1: [ 8000/60000] Loss:  0.0362
Train epoch 1: [16000/60000] Loss:  0.0042
Train epoch 1: [24000/60000] Loss:  0.0036
Train epoch 1: [32000/60000] Loss:  0.0016
Train epoch 1: [40000/60000] Loss:  0.0078
Train epoch 1: [48000/60000] Loss:  0.0713
Train epoch 1: [56000/60000] Loss:  0.0003


HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))

Train epoch 2: [    0/60000] Loss:  0.0079
Train epoch 2: [ 8000/60000] Loss:  0.0048
Train epoch 2: [16000/60000] Loss:  0.0027
Train epoch 2: [24000/60000] Loss:  0.0005
Train epoch 2: [32000/60000] Loss:  0.1556
Train epoch 2: [40000/60000] Loss:  0.0234
Train epoch 2: [48000/60000] Loss:  0.0002
Train epoch 2: [56000/60000] Loss:  0.0001



In [10]:
accuracy, avg_loss = evaluate(cnn, device, test_loader, criterion)
print(f'Accuracy: {100 * accuracy:5.2f}%, loss: {avg_loss:7.4f}')

HBox(children=(IntProgress(value=0, max=313), HTML(value='')))


Accuracy: 99.09%, loss:  0.0296
