In [1]:
from torch import nn, optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import transforms, datasets
from multiprocessing import cpu_count

In [2]:
import sys
sys.path.append('../')
import olympic

In [3]:
transform = transforms.Compose([
       transforms.ToTensor(),
       transforms.Normalize((0.1307,), (0.3081,))
])

train = datasets.MNIST('', train=True, transform=transform, download=True)
val = datasets.MNIST('', train=False, transform=transform, download=True)

train_loader = DataLoader(train, batch_size=128, num_workers=cpu_count())
val_loader = DataLoader(val, batch_size=128, num_workers=cpu_count())

(60000, 10000)

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [5]:
model = Net()
optimiser = optim.SGD(model.parameters(), lr=0.1)
loss_fn = nn.CrossEntropyLoss()

In [6]:
callbacks = [
    # Evaluates every epoch on val_loader
    olympic.callbacks.Evaluate(val_loader),
    # Saves model with best val_accuracy
    olympic.callbacks.ModelCheckpoint('model.pt', save_best_only=True, monitor='val_accuracy'),
    # Logs all metrics
    olympic.callbacks.CSVLogger('log.csv')
]

In [7]:
olympic.fit(
    model,
    optimiser,
    loss_fn,
    dataloader=train_loader,
    epochs=10,
    metrics=['accuracy'],
    callbacks=callbacks
)

Epoch 1:   0%|          | 0/469 [00:00<?, ?it/s]

Begin training...


Epoch 1: 100%|██████████| 469/469 [00:13<00:00, 34.70it/s, loss=0.659, accuracy=0.789] 
Epoch 2: 100%|██████████| 469/469 [00:16<00:00, 28.83it/s, loss=0.305, accuracy=0.909] 
Epoch 3: 100%|██████████| 469/469 [00:18<00:00, 26.01it/s, loss=0.247, accuracy=0.927] 
Epoch 4: 100%|██████████| 469/469 [00:14<00:00, 33.43it/s, loss=0.214, accuracy=0.938] 
Epoch 5: 100%|██████████| 469/469 [00:13<00:00, 34.98it/s, loss=0.199, accuracy=0.942] 
Epoch 6: 100%|██████████| 469/469 [00:13<00:00, 11.95it/s, loss=0.182, accuracy=0.948] 
Epoch 7: 100%|██████████| 469/469 [00:13<00:00, 34.47it/s, loss=0.169, accuracy=0.949] 
Epoch 8: 100%|██████████| 469/469 [00:13<00:00, 33.50it/s, loss=0.164, accuracy=0.953] 
Epoch 9: 100%|██████████| 469/469 [00:13<00:00, 34.68it/s, loss=0.155, accuracy=0.956] 
Epoch 10: 100%|██████████| 469/469 [00:14<00:00, 31.64it/s, loss=0.153, accuracy=0.955] 

Finished.



