In [4]:
import random
from types import SimpleNamespace
from contextlib import nullcontext

import wandb
import timm
import torchvision as tv
import torchvision.transforms as T

import pandas as pd
from fastprogress import progress_bar

from torcheval.metrics import MulticlassAccuracy, Mean

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader

from preds_logger import PredsLogger

In [5]:
def set_seed(s, reproducible=False):
    "Set random seed for `random`, `torch`, and `numpy` (where available)"
    try: torch.manual_seed(s)
    except NameError: pass
    try: torch.cuda.manual_seed_all(s)
    except NameError: pass
    try: np.random.seed(s%(2**32-1))
    except NameError: pass
    random.seed(s)
    if reproducible:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

## Get the dataset

In [6]:
WANDB_PROJECT = "fmnist_bench"
WANDB_ENTITY = "capecape"

In [11]:
config = SimpleNamespace(epochs=20, model_name="resnet10t", bs=512, seed=42)

In [12]:
set_seed(config.seed)

In [13]:
mean, std = (0.28, 0.35)

In [14]:
train_tfms = T.Compose([
    T.RandomCrop(28, padding=1), 
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean, std),
    T.RandomErasing(scale=(0.02, 0.25), value="random"),
])

val_tfms = T.Compose([
    T.ToTensor(),
    T.Normalize(mean, std),
])

tfms = {"train": train_tfms, "valid":val_tfms}

In [15]:
def to_device(t, device):
    if isinstance(t, (tuple, list)):
        return [_t.to(device) for _t in t]
    elif isinstance(t, torch.Tensor):
        return t.to(device)
    else:
        raise("Not a Tensor or list of Tensors")
    return t     

In [16]:
class ImageModel:
    def __init__(self, model, data_path=".", tfms=tfms, device="cuda", bs=256, use_wandb=False):
        
        self.device = device
        self.config = SimpleNamespace(device=device)
        
        self.model = model.to(device)
        
        self.train_ds = tv.datasets.FashionMNIST(data_path, download=True, 
                                                 transform=tfms["train"])
        self.valid_ds = tv.datasets.FashionMNIST(data_path, download=True, train=False, 
                                                 transform=tfms["valid"])
        
        self.train_acc = MulticlassAccuracy(device=device)
        self.valid_acc = MulticlassAccuracy(device=device)
        self.loss = Mean()
        
        self.do_validation = True
        
        self.dataloaders(bs=bs)
        self.config.tfms = tfms
        
        self.use_wandb = use_wandb
        
        # get validation data reference
        if use_wandb:
            self.preds_logger = PredsLogger(ds=self.valid_ds) 
    
    @classmethod
    def from_timm(cls, model_name, data_path=".", tfms=tfms, device="cuda", bs=256 ):
        model = timm.create_model(model_name, pretrained=False, num_classes=10, in_chans=1)
        image_model = cls(model, data_path, tfms, device, bs)
        image_model.config.model_name = model_name
        return image_model
            
    def dataloaders(self, bs=128, num_workers=8):
        self.config.bs = bs
        self.num_workers = num_workers
        self.train_dataloader = DataLoader(self.train_ds, batch_size=bs, shuffle=True, 
                                   pin_memory=True, num_workers=num_workers)
        self.valid_dataloader = DataLoader(self.valid_ds, batch_size=bs*2, shuffle=False, 
                                   num_workers=num_workers)
    
    def log(self, d):
        if self.use_wandb and (wandb.run is not None):
            wandb.log(d)
    
    def compile(self, epochs=5, lr=2e-3, wd=0.01):
        self.config.epochs = epochs
        self.config.lr = lr
        self.config.wd = wd
        
        self.optim = AdamW(self.model.parameters(), weight_decay=wd)
        self.loss_func = nn.CrossEntropyLoss()
        self.schedule = OneCycleLR(self.optim, 
                                   max_lr=lr, 
                                   total_steps=epochs*len(self.train_dataloader))
    
    def reset_metrics(self):
        self.train_acc.reset()
        self.valid_acc.reset()
        self.loss.reset()
        
    def train_step(self, loss):
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        self.schedule.step()
        return loss
        
    def one_epoch(self, train=True):
        if train: 
            self.model.train()
            dl = self.train_dataloader
        else: 
            self.model.eval()
            dl = self.valid_dataloader
        pbar = progress_bar(dl, leave=False)
        preds = []
        for i, b in enumerate(pbar):
            with (torch.inference_mode() if not train else torch.enable_grad()):
                images, labels = to_device(b, self.device)
                # with torch.autocast("cuda"):
                preds_b = self.model(images)
                loss = self.loss_func(preds_b, labels)
                self.loss.update(loss.detach().cpu(), weight=len(images))
                preds.append(preds_b)
                if train:
                    self.train_step(loss)
                    acc = self.train_acc.update(preds_b, labels)
                    self.log({"train_loss": loss.item(),
                              "train_acc": acc.compute(),
                              "learning_rate": self.schedule.get_last_lr()[0]})
                else:
                    acc = self.valid_acc.update(preds_b, labels)
            pbar.comment = f"train_loss={loss.item():2.3f}, train_acc={acc.compute():2.3f}"      
            
        return torch.cat(preds, dim=0), self.loss.compute()
    
    def log_preds(self, preds):
        if self.use_wandb:
            print("Logging model predictions on validation data")
            preds, _ = self.get_model_preds()
            self.preds_logger.log(preds=preds)
    
    def get_model_preds(self, with_inputs=False):
        preds, loss = self.one_epoch(train=False)
        if with_inputs:
            images, labels = self.get_data_tensors()
            return images, labels, preds, loss
        else:
            return preds, loss
            
    def print_metrics(self, epoch, train_loss, val_loss):
        print(f"epoch: {epoch:3}, train_loss: {train_loss:10.3f}, train_acc: {self.train_acc.compute():3.3f}   ||   val_loss: {val_loss:10.3f}, val_acc: {self.valid_acc.compute():3.3f}")
    
    def fit(self, log_preds=False):         
        for epoch in progress_bar(range(self.config.epochs), total=self.config.epochs, leave=True):
            _, train_loss = self.one_epoch(train=True)
            
            self.log({"epoch":epoch})
                
            ## validation
            if self.do_validation:
                _, val_loss = self.one_epoch(train=False)
                self.log({"val_loss": val_loss,
                          "val_acc": self.valid_acc.compute()})
            self.print_metrics(epoch, train_loss, val_loss)
            self.reset_metrics()
        if log_preds:
            self.log_preds(preds)

## Timm

In [17]:
model = ImageModel.from_timm(model_name="resnet10t", bs=512)

In [19]:
model.compile(epochs=2, lr=1e-2, wd=0.)

In [20]:
with wandb.init(project=WANDB_PROJECT, entity=WANDB_ENTITY, config=model.config):
    model.fit()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcapecape[0m. Use [1m`wandb login --relogin`[0m to force relogin


epoch:   0, train_loss:      0.736, train_acc: 0.734   ||   val_loss:      0.715, val_acc: 0.780


epoch:   1, train_loss:      0.355, train_acc: 0.869   ||   val_loss:      0.345, val_acc: 0.894


## Jeremy's Resnet

In [27]:
from jeremy_resnet import resnet

In [28]:
leak = 0.1
sub = 0.4

In [29]:
model = ImageModel(resnet(leak, sub), bs=512)
model.config.model_name = {"model_name":"jeremy", "leak":leak, "sub":sub}

In [30]:
model.compile(epochs=2, lr=1e-2, wd=0.0)

In [31]:
with wandb.init(project=WANDB_PROJECT, entity=WANDB_ENTITY, config=model.config, tags=["jeremy"]):
    model.fit()

epoch:   0, train_loss:      0.780, train_acc: 0.768   ||   val_loss:      0.736, val_acc: 0.840


epoch:   1, train_loss:      0.411, train_acc: 0.863   ||   val_loss:      0.399, val_acc: 0.888
