In [12]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torchvision import datasets
import torchvision.transforms as T

import matplotlib.pyplot as plt

import composer
from composer import Trainer
from composer.algorithms import ChannelsLast, CutMix, LabelSmoothing, BlurPool, RandAugment, MixUp
from composer.models import mnist_model

from composer.loggers import WandBLogger

torch.manual_seed(42) # For replicability

<torch._C.Generator at 0x7f0a5019b2d0>

In [13]:
PROJECT = "fmnist_bench"
ENTITY = "capecape"

In [14]:
data_directory = "."

bs = 256
lr = 1e-3
wd = 1e-3
epochs = 20

In [15]:
wandb_logger = WandBLogger(project=PROJECT, entity=ENTITY, tags=["composer"])

In [16]:
train_tfms = T.Compose([
    T.RandomCrop(28, padding=4), 
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.1307,), (0.3081,)),
])
val_tfms = T.Compose([
    # T.Resize((32,32)),
    T.ToTensor(),
    T.Normalize((0.1307,), (0.3081,)),
])

tfms = {"train": train_tfms, "valid":val_tfms}

In [17]:
train_dataset = datasets.FashionMNIST(data_directory, download=True, train=True, transform=tfms["train"])
eval_dataset = datasets.FashionMNIST(data_directory, download=True, train=False, transform=tfms["valid"])

train_dataloader = DataLoader(train_dataset, batch_size=bs, num_workers=8, pin_memory=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=bs, num_workers=8)

## Timm

In [18]:
import timm
from composer.models import ComposerClassifier

model_name = "resnet10t"

timm_model = timm.create_model(model_name, pretrained=False, num_classes=10, in_chans=1)
model = ComposerClassifier(timm_model)

In [19]:
optimizer = AdamW(model.parameters(), weight_decay=wd)
scheduler = OneCycleLR(optimizer, max_lr=lr, 
                       steps_per_epoch=len(train_dataloader), 
                       epochs=epochs)

In [20]:
train_epochs = f"{epochs}ep" # Train for 3 epochs because we're assuming Colab environment and hardware
device = "gpu" if torch.cuda.is_available() else "cpu" # select the device

trainer = composer.trainer.Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    max_duration=train_epochs,
    optimizers=optimizer,
    schedulers=scheduler,
    device=device,
    precision='amp',
    loggers=wandb_logger,
)

0,1
epoch,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
loss/train/total,█▆▅▄▃▃▃▃▃▂▂▂▂▂▂▁▂▂▂▁▁▂▂▁▂▁▂▁▂▂▂▁▁▂▁▂▁▂▁▁
metrics/eval/Accuracy,▁▃▄▅▆▆▆▇▇▇▇▇▇▇██████
metrics/eval/CrossEntropy,█▅▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁
metrics/train/Accuracy,▁▃▃▅▅▅▅▆▅▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇▇▇▇▇▇██▇██▇▇▇▇
trainer/batch_idx,▃▅▁▆▂▇▃▅▁▆▂▇▃▇▁▆▂▇▃▇▁▆▂▇▃▇▃▆▂▆▃▇▃▆▂▆▂▇▃█
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/grad_accum,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,20.0
loss/train/total,0.2643
metrics/eval/Accuracy,0.872
metrics/eval/CrossEntropy,0.33687
metrics/train/Accuracy,0.89583
trainer/batch_idx,234.0
trainer/global_step,4700.0
trainer/grad_accum,1.0


In [21]:
trainer.fit()

******************************
Config:
num_gpus_per_node: 1
num_nodes: 1
rank_zero_seed: 114437740

******************************
train          Epoch   0:  100%|| 235/235 [00:06<00:00, 37.83ba/s, loss/train/total=1.0973]                                                                                                                          
eval           Epoch   0:  100%|| 40/40 [00:00<00:00, 56.59ba/s, metrics/eval/Accuracy=0.6912]                                                                                                                       
train          Epoch   1:  100%|| 235/235 [00:06<00:00, 36.79ba/s, loss/train/total=0.7663]                                                                                                                          
eval           Epoch   1:  100%|| 40/40 [00:00<00:00, 56.82ba/s, metrics/eval/Accuracy=0.7515]                                                                                                                       
train        

0,1
epoch,▁▂▂▃▃▄▅▅▆▆▇▇█
loss/train/total,█▆▅▄▄▃▃▂▃▃▃▂▂▂▂▂▂▂▂▂▁▂▂▁▂▂▁▁▂▂▁▂▁▁▁▁▂▁▁▁
metrics/eval/Accuracy,▁▄▅▆▆▇▇▇████
metrics/eval/CrossEntropy,█▅▄▃▂▂▂▂▁▁▁▁
metrics/train/Accuracy,▁▂▄▅▆▆▆▇▆▆▆▇▇▆▇▇▇▇▇▇▇▇▇█▇▇██▇▇█▇████▇███
trainer/batch_idx,▂▄▆█▃▅▇▂▅▇▂▅▇▁▃▆▁▃▆█▃▅█▂▄▇▂▄▆▂▄▆▁▃▅▇▃▅▇▃
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/grad_accum,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,12.0
loss/train/total,0.42729
metrics/eval/Accuracy,0.854
metrics/eval/CrossEntropy,0.38717
metrics/train/Accuracy,0.84375
trainer/batch_idx,90.0
trainer/global_step,2910.0
trainer/grad_accum,1.0


train          Epoch  12:   39%|| 91/235 [00:08<00:53,  2.69ba/s, loss/train/total=0.4273]                                                                                                                           

Error: You must call wandb.init() before wandb.log()