In [None]:
# converting https://github.com/pytorch/examples/blob/master/mnist/main.py to composer

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

import composer
from composer import DataloaderSpec, Trainer
from composer.models import BaseMosaicModel

# Example algorithms to train with
from composer.algorithms import CutOut, LabelSmoothing

from torchmetrics.classification.accuracy import Accuracy

## Params

In [2]:
train_batch_size = 64
eval_batch_size = 64

epochs = 14
lr = 3e-3
gamma = 0.7 # learning rate scheduler gamma
cuda = False
seed = 1337
dry_run = False # is this supported by default in the trainer?

save_model = True


# Simple Convnet from pytorch examples

- the actual architecture doesnt matter

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

## Mosaic wrapper
notes:
- I deliberately didn't use `MosaicClassifier` here because I believe it will be helpful for adoption for people to know exactly where their code is going (pytorch lightning does this very well), especially for researchers and tinkerers we don't want to come off too "fastai" like where everything is done for you in one line of code. Also, mosaic classifier isnt a very generic class as it needs to be modified for different loss functions, auxilary tasks, and different metric tracking. I think those classes should stay around as examples but we shouldn't set the expectation that people will use them for every classification task or even most classification tasks as the name suggests.

- since we seperate the forward pass and the loss computation unlike pytorch lightning. it could be possible for us to deploy wrapped models even in cases where the full network and foward pass aren't defined outside the wrapper. We could probably also copy the forward loop into the validate loop by default, the only difference is one returns outputs one does.. though this may not be a good habit to encourage.

- metrics method should probably be optional? we could default it to reporting the loss.. which is almost always reasonable. We could also automatically wrap lists of returned metrics in `MetricContainer` 


In [4]:
class MosaicNet(BaseMosaicModel):
    def __init__(self):
        super(MosaicNet, self).__init__()
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.model = Net()
        
    def forward(self, batch):
        inputs, _ = batch
        outputs = self.model(inputs)
        return outputs

    def loss(self, outputs, batch):
        _, targets = batch
        return F.nll_loss(outputs, targets)
    
    def metrics(self, train: bool = False):
        if self.train:
            return self.train_acc
        return self.val_acc
    
    def validate(self, batch):
        inputs, targets = batch
        outputs = self.forward(batch)
        return outputs, targets

In [5]:
model = MosaicNet()

# Data

In [6]:
# Your custom train DataloaderSpec
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])


train_dataloader_spec = DataloaderSpec(
    dataset=datasets.MNIST('~/data', train=True, transform=transform, download=True),
    drop_last=False,
    shuffle=True,
)

# Your custom eval dataset
eval_dataloader_spec = DataloaderSpec(
    dataset=datasets.MNIST('~/data', train=False, transform=transform),
    drop_last=False,
    shuffle=False,
)

## Optimizer / scheduler

In [7]:
from composer.optim.scheduler import StepLRHparams
from composer.optim.optimizer_hparams import AdamWHparams

In [8]:
scheduler_hparams = StepLRHparams(step_size=1, gamma=gamma)
optimizer_hparams = AdamWHparams(lr=lr) 

# Train
notes:
- let checkpoint create directories? or at least check up front if they exist. it's annoying to get through a full epoch and run into an error because your checkpoint directory was created ahead of time.

In [20]:
# Initialize Trainer with custom model, custom train and eval datasets, and algorithms to train with

trainer = Trainer(model=model,
                  train_dataloader_spec=train_dataloader_spec,
                  eval_dataloader_spec=eval_dataloader_spec,
                  max_epochs=epochs,
                  train_batch_size=train_batch_size,
                  eval_batch_size=eval_batch_size,
                  optimizer_hparams=optimizer_hparams,
                  schedulers_hparams=scheduler_hparams,
                  checkpoint_folder='.',
                  checkpoint_interval_unit="ep",
                  checkpoint_interval=1)




In [21]:
trainer.fit()

Epoch 1: 100%|██████████| 938/938 [02:25<00:00,  6.45it/s, loss/train=0.0296]   

Epoch 1 (val):   0%|          | 0/157 [00:00<?, ?it/s]                          [A
Epoch 1 (val):   0%|          | 0/157 [00:00<?, ?it/s]                          [A
Epoch 1 (val):   2%|▏         | 3/157 [00:00<00:07, 20.93it/s]                  [A
Epoch 1 (val):   4%|▍         | 6/157 [00:00<00:07, 21.17it/s]                  [A
Epoch 1 (val):   6%|▌         | 9/157 [00:00<00:06, 21.25it/s]                  [A
Epoch 1 (val):   8%|▊         | 12/157 [00:00<00:06, 21.28it/s]                 [A
Epoch 1 (val):  10%|▉         | 15/157 [00:00<00:06, 21.27it/s]                 [A
Epoch 1 (val):  11%|█▏        | 18/157 [00:00<00:06, 21.27it/s]                 [A
Epoch 1 (val):  13%|█▎        | 21/157 [00:00<00:06, 21.28it/s]                 [A
Epoch 1 (val):  15%|█▌        | 24/157 [00:01<00:06, 21.28it/s]                 [A
Epoch 1 (val):  17%|█▋        | 27/157 [00:01<00:06, 21.28it/s]               