In [43]:
import wandb
import timm
import torchvision as tv
import torchvision.transforms as T

from types import SimpleNamespace

from fastprogress import progress_bar

from torchmetrics import Accuracy

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader

## Get the dataset

In [44]:
config = SimpleNamespace(epochs=20, model_name="convnext_nano", bs=128)

In [45]:
train_tfms = T.Compose([
    T.Resize((32, 32)), 
    T.RandAugment(num_ops=3),
    T.ToTensor()])

val_tfms = T.Compose([
    T.Resize((32,32)),
    T.ToTensor()])

tfms ={"train": train_tfms, "valid":val_tfms}

In [46]:
def to_device(t, device):
    if isinstance(t, (tuple, list)):
        return [_t.to(device) for _t in t]
    elif isinstance(t, torch.Tensor):
        return t.to(device)
    else:
        raise("Not a Tensor or list of Tensors")
    return t     

In [51]:
class ImageModel:
    def __init__(self, data_path=".", tfms=tfms, model_name="convnext_nano", device="cuda"):
        
        self.device = device
        self.config = SimpleNamespace(model_name=model_name, device=device)
        
        self.model = timm.create_model(model_name, pretrained=False, num_classes=10, in_chans=1).to(device)
        
        self.train_ds = tv.datasets.FashionMNIST(data_path, download=True, 
                                                 transform=tfms["train"])
        self.valid_ds = tv.datasets.FashionMNIST(data_path, download=True, train=False, 
                                                 transform=tfms["valid"])
        
        self.train_acc = Accuracy(task="multiclass", num_classes=10).to(device)
        self.valid_acc = Accuracy(task="multiclass", num_classes=10).to(device)
        
        self.do_validation = True
        
        self.dataloaders()
                 
        
    
    def dataloaders(self, bs=128, num_workers=8):
        self.config.bs = bs
        self.num_workers = num_workers
        self.train_dataloader = DataLoader(self.train_ds, batch_size=bs, shuffle=True, 
                                   pin_memory=True, num_workers=num_workers)
        self.valid_dataloader = DataLoader(self.valid_ds, batch_size=bs*2, shuffle=False, 
                                   num_workers=num_workers)

    def compile(self, epochs=5, lr=2e-3, wd=0.01, num_workers=8):
        self.config.epochs = epochs
        self.config.lr = lr
        self.config.wd = wd
        
        self.optim = AdamW(self.model.parameters(), weight_decay=wd)
        self.loss_func = nn.CrossEntropyLoss()
        self.schedule = OneCycleLR(self.optim, max_lr=lr, 
                                   steps_per_epoch=len(self.train_dataloader), 
                                   epochs=epochs)
    
    def reset(self):
        self.train_acc.reset()
        self.valid_acc.reset()
        
    def train_step(self, loss):
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        self.schedule.step()
        return loss
        
    def one_epoch(self, train=True, use_wandb=False):
        avg_loss = 0.
        if train: 
            self.model.train()
            dl = self.train_dataloader
        else: 
            self.model.eval()
            dl = self.valid_dataloader
        pbar = progress_bar(dl, leave=False)
        for i, b in enumerate(pbar):
            with torch.autocast("cuda") and (torch.inference_mode() if not train else torch.enable_grad()):
                images, labels = to_device(b, self.device)
                preds = self.model(images)
                loss = self.loss_func(preds, labels)
                avg_loss += loss
                if train:
                    self.train_step(loss)
                    acc = self.train_acc(preds, labels)
                    if use_wandb: 
                        wandb.log({"train_loss": loss.item(),
                                   "train_acc": acc,
                                   "learning_rate": self.schedule.get_last_lr()[0]})
                else:
                    acc = self.valid_acc(preds, labels)
            pbar.comment = f"train_loss={loss.item():2.3f}, train_acc={acc:2.3f}"      
            
        return avg_loss.mean().item(), acc
    
    def fit(self, use_wandb=False):
        if use_wandb:
            run = wandb.init(project="fmnist_pytorch", entity="fastai", config=self.config)
            
        for epoch in progress_bar(range(self.config.epochs), total=self.config.epochs, leave=True):
            _  = self.one_epoch(train=True, use_wandb=use_wandb)
            
            if use_wandb:
                wandb.log({"epoch":epoch})
                
            ## validation
            if self.do_validation:
                avg_loss, acc = self.one_epoch(train=False, use_wandb=use_wandb)
                if use_wandb:
                    wandb.log({"val_loss": avg_loss,
                               "val_acc": self.valid_acc.compute()})
            self.reset()
        if use_wandb:
            wandb.finish()

In [52]:
model = ImageModel()

In [53]:
model.compile(epochs=20, lr=2e-3)

In [54]:
model.fit(use_wandb=True)

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
learning_rate,▁▁▂▂▃▄▅▆▆▇███████▇▇▇▇▆▆▅▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁
train_acc,▂▁▃▆▅▅▄▄▆▆▆▃▅▇▅▇▆▆▇▆▇▆▆▆▆▆▇▆▆▆▇█▆▇▇▇▇█▇█
train_loss,██▆▄▅▅▅▄▃▃▄▆▃▃▃▃▃▃▃▃▂▃▃▂▃▃▂▂▂▂▂▁▃▂▂▂▂▁▁▂
val_acc,▁▃▃▄▄▃▅▅▆▆▆▇▇▇▇▇████
val_loss,█▆▅▅▅▅▃▃▃▃▃▂▂▂▂▁▁▁▁▁

0,1
epoch,19.0
learning_rate,0.0
train_acc,0.90625
train_loss,0.20611
val_acc,0.919
val_loss,8.84315
