# Fashion MNIST: A Classification Problem

Let's train the best possible model on the Fashion MNIST dataset. 
- For 5 and 20 epochs

In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW

import torchvision.transforms as T
from torchvision.datasets import FashionMNIST

import timm

from torcheval.metrics import MulticlassAccuracy, Mean

from utils import to_device

Setup

In [2]:
mean, std = (0.28, 0.35)  # computed from FashionMNIST train set

train_tfms = T.Compose([
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean, std),
])

val_tfms = T.Compose([
    T.ToTensor(),
    T.Normalize(mean, std),
])

tfms = {"train": train_tfms, "valid":val_tfms}

Data

In [21]:
data_path="."
tfms=tfms
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "mps"
num_workers = 6 
bs=512

In [22]:
train_ds = FashionMNIST(data_path, download=True, transform=tfms["train"])
valid_ds = FashionMNIST(data_path, download=True, train=False, transform=tfms["valid"])

train_dataloader = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=num_workers)
valid_dataloader = DataLoader(valid_ds, batch_size=bs*2, shuffle=False, num_workers=num_workers)

A model

In [23]:
model_name = "resnet18"

model = timm.create_model(model_name, pretrained=False, num_classes=10, in_chans=1).to(device)

In [24]:
lr = 1e-3
wd = 0.01

optim = AdamW(model.parameters(), lr=lr, weight_decay=wd)

A Training loop

In [25]:
train_acc = MulticlassAccuracy(device=device)
valid_acc = MulticlassAccuracy(device=device)
metric_loss = Mean()

epochs = 5

In [26]:
def train_step(optim, loss):
    optim.zero_grad()
    loss.backward()
    optim.step()
    return loss

def do_one_epoch(dl, train=True):
    if train:
        model.train()
    else:
        model.eval()
    preds = []
    for b in dl:
        with (torch.inference_mode() if not train else torch.enable_grad()):
            # grab a batch
            images, labels = to_device(b, device)

            # compute preds on batch
            preds_b = model(images)
            preds.append(preds_b)

            # compute loss
            loss = F.cross_entropy(preds_b, labels)

            # store value on metric
            metric_loss.update(loss.detach().cpu(), weight=len(images))
            
            if train:
                # update weights
                train_step(optim, loss)

                # update metrics
                train_acc.update(preds_b, labels)
                
            else:
                valid_acc.update(preds_b, labels)        
    
    return torch.cat(preds, dim=0), metric_loss.compute()

In [27]:
preds, loss = do_one_epoch(train_dataloader, train=True)

In [28]:
def print_metrics(epoch, train_loss, val_loss):
    print(f"epoch: {epoch:3}, train_loss: {train_loss:10.3f}, train_acc: {train_acc.compute():3.3f}   ||   val_loss: {val_loss:10.3f}, val_acc: {valid_acc.compute():3.3f}")

def reset_metrics():
    train_acc.reset()
    valid_acc.reset()
    metric_loss.reset()

def fit(epochs):         
    for epoch in range(epochs):
        _, train_loss = do_one_epoch(train_dataloader, train=True)
        
        ## validation
        _, val_loss = do_one_epoch(valid_dataloader, train=False)
        print_metrics(epoch, train_loss, val_loss)
        reset_metrics()

In [30]:
fit(5)

epoch:   0, train_loss:      0.241, train_acc: 0.910   ||   val_loss:      0.249, val_acc: 0.892
epoch:   1, train_loss:      0.223, train_acc: 0.917   ||   val_loss:      0.232, val_acc: 0.897
epoch:   2, train_loss:      0.203, train_acc: 0.925   ||   val_loss:      0.217, val_acc: 0.888
epoch:   3, train_loss:      0.196, train_acc: 0.925   ||   val_loss:      0.213, val_acc: 0.884
epoch:   4, train_loss:      0.180, train_acc: 0.933   ||   val_loss:      0.196, val_acc: 0.894
