In [1]:
import random
from types import SimpleNamespace

import wandb
import timm
import torchvision as tv
import torchvision.transforms as T

import pandas as pd
from fastprogress import progress_bar

from torchmetrics import Accuracy

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader

In [2]:
def set_seed(s, reproducible=False):
    "Set random seed for `random`, `torch`, and `numpy` (where available)"
    try: torch.manual_seed(s)
    except NameError: pass
    try: torch.cuda.manual_seed_all(s)
    except NameError: pass
    try: np.random.seed(s%(2**32-1))
    except NameError: pass
    random.seed(s)
    if reproducible:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

## Get the dataset

In [3]:
WANDB_PROJECT = "fmnist_pt"
ENTITY = "fastai"

In [4]:
config = SimpleNamespace(epochs=20, model_name="convnext_nano", bs=128, seed=42)

In [5]:
set_seed(config.seed)

In [23]:
train_tfms = T.Compose([
    T.Resize((32, 32)), 
    T.RandAugment(num_ops=2),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
])

val_tfms = T.Compose([
    T.Resize((32,32)),
    T.ToTensor()])

tfms ={"train": train_tfms, "valid":val_tfms}

In [24]:
def to_device(t, device):
    if isinstance(t, (tuple, list)):
        return [_t.to(device) for _t in t]
    elif isinstance(t, torch.Tensor):
        return t.to(device)
    else:
        raise("Not a Tensor or list of Tensors")
    return t     

In [25]:
class ImageModel:
    def __init__(self, data_path=".", tfms=tfms, model_name="convnext_nano", device="cuda"):
        
        self.device = device
        self.config = SimpleNamespace(model_name=model_name, device=device)
        
        self.model = timm.create_model(model_name, pretrained=False, num_classes=10, in_chans=1).to(device)
        
        self.train_ds = tv.datasets.FashionMNIST(data_path, download=True, 
                                                 transform=tfms["train"])
        self.valid_ds = tv.datasets.FashionMNIST(data_path, download=True, train=False, 
                                                 transform=tfms["valid"])
        
        self.train_acc = Accuracy(task="multiclass", num_classes=10).to(device)
        self.valid_acc = Accuracy(task="multiclass", num_classes=10).to(device)
        
        self.do_validation = True
        
        self.dataloaders()        
        self.val_data_at = f'{ENTITY}/{WANDB_PROJECT}/validation_data:latest'
    
    def dataloaders(self, bs=128, num_workers=8):
        self.config.bs = bs
        self.num_workers = num_workers
        self.train_dataloader = DataLoader(self.train_ds, batch_size=bs, shuffle=True, 
                                   pin_memory=True, num_workers=num_workers)
        self.valid_dataloader = DataLoader(self.valid_ds, batch_size=bs*2, shuffle=False, 
                                   num_workers=num_workers)

    def compile(self, epochs=5, lr=2e-3, wd=0.01):
        self.config.epochs = epochs
        self.config.lr = lr
        self.config.wd = wd
        
        self.optim = AdamW(self.model.parameters(), weight_decay=wd)
        self.loss_func = nn.CrossEntropyLoss()
        self.schedule = OneCycleLR(self.optim, max_lr=lr, 
                                   steps_per_epoch=len(self.train_dataloader), 
                                   epochs=epochs)
    
    def get_data_tensors(self, dl="valid"):
        "stack images and labels as fat tensors"
        dl = self.valid_dataloader if dl=="valid" else self.train_dataloader
        images=[]
        labels=[]
        for img_b, labels_b in dl:
            images.append(img_b)
            labels.append(labels_b)
        images = torch.cat(images, dim=0)
        labels = torch.cat(labels, dim=0)
        return images, labels
    
    def reset(self):
        self.train_acc.reset()
        self.valid_acc.reset()
        
    def train_step(self, loss):
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        self.schedule.step()
        return loss
        
    def one_epoch(self, train=True, use_wandb=False):
        avg_loss = 0.
        if train: 
            self.model.train()
            dl = self.train_dataloader
        else: 
            self.model.eval()
            dl = self.valid_dataloader
        pbar = progress_bar(dl, leave=False)
        preds = []
        for i, b in enumerate(pbar):
            with torch.autocast("cuda") and (torch.inference_mode() if not train else torch.enable_grad()):
                images, labels = to_device(b, self.device)
                preds_b = self.model(images)
                loss = self.loss_func(preds_b, labels)
                avg_loss += loss
                preds.append(preds_b)
                if train:
                    self.train_step(loss)
                    acc = self.train_acc(preds_b, labels)
                    if use_wandb: 
                        wandb.log({"train_loss": loss.item(),
                                   "train_acc": acc,
                                   "learning_rate": self.schedule.get_last_lr()[0]})
                else:
                    acc = self.valid_acc(preds_b, labels)
            pbar.comment = f"train_loss={loss.item():2.3f}, train_acc={acc:2.3f}"      
            
        return torch.cat(preds, dim=0), avg_loss.mean().item()
    
    def get_preds(self, with_inputs=False):
        preds, loss = self.one_epoch(train=False)
        if with_inputs:
            images, labels = self.get_data_tensors()
            return images, labels, preds, loss
        else:
            return preds, loss
    
    def validation_data(self, n=None):
        val_table = wandb.Table(columns = ["idx", "image", "label"])
        images, labels = self.get_data_tensors()
        pbar = progress_bar(zip(images, labels), total=len(images), leave=False)
        pbar.comment = "Creating W&B Table with validation DL"
        for i, (img, lbl) in enumerate(pbar):
            if n is not None and i>=n:
                break
            val_table.add_data(i, wandb.Image(img, mode="L"), lbl.item())
        return val_table
    
    def create_table(self, n=None):
        artifact = wandb.use_artifact(self.val_data_at, type='data')
        val_table = artifact.get("val_table")
        preds, _ = self.get_preds()
        df = pd.DataFrame(data=torch.concat([torch.arange(len(preds), device=preds.device)[:,None], 
                                             preds.argmax(dim=1)[:,None], 
                                             preds], dim=-1).cpu().numpy(),
                          columns=["idx", "preds"]+[f"prob_{i}" for i in range(10)])
        df = df.iloc[:n]
        preds_table = wandb.Table(dataframe=df)
        return wandb.JoinedTable(val_table, preds_table, "idx")
    
    
    def log_preds(self, n=None):
        preds_table = self.create_table(n=n)
        wandb.log({"preds_table":preds_table})
    
    def fit(self, use_wandb=False, log_preds=False, n_preds=None):
        if use_wandb:
            run = wandb.init(project=WANDB_PROJECT, entity=ENTITY, config=self.config)
            
        for epoch in progress_bar(range(self.config.epochs), total=self.config.epochs, leave=True):
            _  = self.one_epoch(train=True, use_wandb=use_wandb)
            
            if use_wandb:
                wandb.log({"epoch":epoch})
                
            ## validation
            if self.do_validation:
                _, avg_loss = self.one_epoch(train=False, use_wandb=use_wandb)
                if use_wandb:
                    wandb.log({"val_loss": avg_loss,
                               "val_acc": self.valid_acc.compute()})
            self.reset()
        if use_wandb:
            if log_preds:
                print("Logging model predictions on validation data")
                self.log_preds(n=n_preds)
            wandb.finish()

## Log val dataset

In [26]:
model = ImageModel(model_name="resnet14t")

In [27]:
# n=None

# with wandb.init(project=WANDB_PROJECT, entity=ENTITY, job_type="log_data"):
#     val_at = wandb.Artifact("validation_data", type="data")
#     val_at.add(model.validation_data(n=n), "val_table")
#     wandb.log_artifact(val_at)

In [29]:
model.compile(epochs=20, lr=3e-3)

In [None]:
model.fit(use_wandb=True, log_preds=True)