In [26]:
import numpy as np
import torch
from torchvision.datasets import mnist
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

from torch.nn import Module
from torch import nn
from tqdm.notebook import trange, tqdm
import time

In [7]:
class Model(Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        y = self.relu5(y)
        return y

In [8]:
BATCH_SIZE = 256
train_dataset = mnist.MNIST(root='./train', train=True, download=True, transform=ToTensor())
test_dataset = mnist.MNIST(root='./test', train=False, download=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [13]:
model = Model()
sgd = SGD(model.parameters(), lr=1e-1)
loss_fn = CrossEntropyLoss()
EPOCHS = 10

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
loss_fn = loss_fn.to(device)

In [19]:
def calculate_accuracy(y_pred, y):
    top_pred = y_pred.argmax(1, keepdim=True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

In [28]:
def train(model, iterator, optimizer, criterion, device):
    epoch_loss = 0
    epoch_acc = 0
    model.train()
    
    for (x, y) in tqdm(iterator, desc="Training", leave=False):
        
        x = x.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y)
        acc = calculate_accuracy(y_pred, y)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [29]:
def evaluate(model, iterator, criterion, device):
    
    epoch_loss = 0
    epoch_acc = 0

    model.eval()

    with torch.no_grad():

        for (x, y) in tqdm(iterator, desc="Evaluating", leave=False):

            x = x.to(device)
            y = y.to(device)

            y_pred = model(x)

            loss = criterion(y_pred, y)

            acc = calculate_accuracy(y_pred, y)

            epoch_loss += loss.item()
            epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [30]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [31]:
EPOCHS = 10

best_valid_loss = float('inf')

for epoch in trange(EPOCHS):

    start_time = time.monotonic()

    train_loss, train_acc = train(model, train_loader, sgd, loss_fn, device)

    end_time = time.monotonic()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')

  0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 01 | Epoch Time: 0m 6s
	Train Loss: 2.076 | Train Acc: 27.20%


Training:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 02 | Epoch Time: 0m 6s
	Train Loss: 0.795 | Train Acc: 71.61%


Training:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 03 | Epoch Time: 0m 6s
	Train Loss: 0.453 | Train Acc: 83.68%


Training:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 04 | Epoch Time: 0m 6s
	Train Loss: 0.353 | Train Acc: 86.79%


Training:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 05 | Epoch Time: 0m 6s
	Train Loss: 0.326 | Train Acc: 87.48%


Training:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 06 | Epoch Time: 0m 6s
	Train Loss: 0.311 | Train Acc: 87.94%


Training:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 07 | Epoch Time: 0m 6s
	Train Loss: 0.300 | Train Acc: 88.21%


Training:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 08 | Epoch Time: 0m 6s
	Train Loss: 0.292 | Train Acc: 88.41%


Training:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 09 | Epoch Time: 0m 6s
	Train Loss: 0.257 | Train Acc: 89.87%


Training:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 10 | Epoch Time: 0m 6s
	Train Loss: 0.053 | Train Acc: 98.42%


In [32]:
torch.save(model.state_dict(), 'models/lenet.pt')