## Readings

modules, losses, activations, etc. :
 - https://pytorch.org/docs/stable/nn.html
 
optimizers and schedulers:
 - https://pytorch.org/docs/stable/optim.html
 
examples:
 - https://github.com/pytorch/examples/blob/master/vae/main.py
 - https://github.com/pytorch/examples/blob/master/mnist/main.py

In [7]:
from importlib import reload

import torch
from torch import nn
from torch import optim
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose

from Models.mlp_classifier import Classifier
import utils

## Model

In [18]:
import Models.mlp_classifier
reload(Models.mlp_classifier)
from Models.mlp_classifier import Classifier

cls = Classifier(input_size=784, num_classes=10, hidden_layers=[100,50])
print(cls)
# cls.cuda()

Classifier(
  (net): Sequential(
    (0): Linear(in_features=784, out_features=100, bias=True)
    (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=100, out_features=50, bias=True)
    (4): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=50, out_features=10, bias=True)
    (7): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): Softmax(dim=None)
  )
)


In [19]:
utils.get_n_params(cls)

84380

## Dataset

In [27]:
class Flatten(object):

    def __call__(self, sample):
        return sample.view(-1)
    
final_transform = Compose([ToTensor(), Flatten()])

In [28]:
mnist = MNIST(root='data', train=False, transform=final_transform, download=True)
len(mnist)

10000

In [29]:
train_data, test_data = random_split(mnist, [7000,3000])
len(train_data), len(test_data)

(7000, 3000)

In [30]:
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=True)

## Training

In [31]:
optimizer = optim.Adam(cls.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

In [33]:
reload(utils)

utils.train(cls, train_loader, test_loader, criterion, optimizer, device='cpu', epochs=10, verbose=1)

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

1 train_loss:1.8788471351970326 train_acc:0.7916666666666666
val_loss:1.862539139199764 val_acc:0.8928571428571429

2 train_loss:1.8368669336492365 train_acc:0.875
val_loss:1.830928457544205 val_acc:0.875

3 train_loss:1.8131488940932534 train_acc:0.7916666666666666
val_loss:1.8110796390695776 val_acc:0.8214285714285714

4 train_loss:1.7945027882402593 train_acc:0.875
val_loss:1.8034629365231127 val_acc:0.9285714285714286

5 train_loss:1.781535104188052 train_acc:0.8333333333333334
val_loss:1.7884985157783995 val_acc:0.9107142857142857

6 train_loss:1.770592923597856 train_acc:0.7916666666666666
val_loss:1.7763321830871257 val_acc:0.9107142857142857

7 train_loss:1.7584263649853793 train_acc:0.9166666666666666
val_loss:1.7726412225276866 val_acc:0.9285714285714286

8 train_loss:1.7478667183355852 train_acc:0.7916666666666666
val_loss:1.7628133626694376 val_acc:0.9107142857142857

9 train_loss:1.7399583621458574 train_acc:0.9166666666666666
val_loss:1.7565453356884895 val_acc:0.91071428