In [None]:
from importlib import reload


In [None]:
import torch
from torch import optim, nn

from model import Net, model_train, model_test,\
     summary_printer, plot_loss_n_acc
from utils import prepare_mnist_data,\
     plot_img_batch

from torchvision import transforms

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

## Data preparation

In [None]:
# Train data transformations
train_transforms = transforms.Compose([
    transforms.RandomApply([transforms.CenterCrop(22), ], p=0.1),
    transforms.Resize((28, 28)),
    transforms.RandomRotation((-15., 15.), fill=0),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    ])

# Test data transformations
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    ])

In [None]:
train_loader, test_loader = prepare_mnist_data(train_transforms, test_transforms)

In [None]:
_ = plot_img_batch(train_loader)

## Modelling

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
model = Net().to(device)
summary_printer(model)

In [None]:

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1, verbose=True)
num_epochs = 20


for epoch in range(1, num_epochs+1):
    print(f'Epoch {epoch}')
    model_train(model, device, train_loader, optimizer)
    model_test(model, device, test_loader)
    scheduler.step()

In [None]:
plot_loss_n_acc()