# Model Class

> The idea here is making a PyTorch Model class that looks like Keras and enable us to train model!

In [163]:
#| default_exp model

In [164]:
#| hide
from nbdev.showdoc import *

Some Torch imports that are useful!

In [165]:
#| export
from types import SimpleNamespace
from contextlib import nullcontext

from fastprogress import progress_bar

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader

from torchmetrics import Metric, Accuracy

from capetorch.utils import ifnone, to_device

In [166]:
#| export
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [167]:
d = torch.device("cuda:0")

In [168]:
#| export
if torch.cuda.is_available():
    mixed_precision = torch.autocast("cuda")

In [194]:
class Metrics:
    def __init__(self, train, valid):
        self.train = [train, ] if not isinstance(train, list) else train
        self.valid = [valid, ] if not isinstance(valid, list) else valid
    
    @staticmethod
    def metric_name(m, suffix): 
        return suffix + (type(m).__name__).lower()
    
    def update_train(self, preds_b, labels_b):
        return {self.metric_name(m, "train_"): m(preds_b, labels_b) for m in self.train}
    
    def update_valid(self, preds_b, labels_b):
        return {self.metric_name(m, "valid_"): m(preds_b, labels_b) for m in self.valid}
    
    def compute_train(self):
        return {self.metric_name(m, "train_"): m.compute for m in self.train}
    
    def compute_valid(self):
        return {self.metric_name(m, "valid_"): m.compute for m in self.valid}
    
    def reset(self):
        for m in self.train + self.valid:
            m.reset()

In [195]:
#| export
class CapeModel:
    def __init__(self, model, train_dataloader=None, valid_dataloader=None, device=None, fp16=True, use_wandb=False):
        
        self.device = torch.device(ifnone(device, DEFAULT_DEVICE))
        self.model = model.to(self.device)
        self.fp16 = fp16 if self.device.type == "cuda" else False
        self.use_wandb = use_wandb
        
        self.config = SimpleNamespace(model_name=model.__class__, device=device)
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        
    def _defaults(self, epochs=5, lr=2e-3, wd=0.01):
        self.optimizer = AdamW(self.model.parameters(), weight_decay=wd)
        self.loss_func = nn.CrossEntropyLoss()
        self.scheduler = OneCycleLR(self.optimizer, max_lr=lr, 
                                    steps_per_epoch=len(self.train_dataloader), 
                                    epochs=epochs)
        
        self.metrics = Metrics([Accuracy(task="multiclass").to(self.device), ],
                               [Accuracy(task="multiclass").to(self.device), ])
        
    def log(self, d):
        if self.use_wandb:
            wandb.log(d)
            
    def prepare(self, epochs=5, lr=2e-3, wd=0.01, optimizer=None, 
                loss_func=None, scheduler=None, train_metrics=None, valid_metrics=None):
        self.config.epochs = epochs
        self.config.lr = lr
        self.config.wd = wd
        
        # defaults
        self._defaults(epochs, lr, wd)
        self.optimizer = ifnone(optimizer, self.optimizer)
        self.loss_func = ifnone(loss_func, self.loss_func)
        self.scheduler = ifnone(scheduler, self.scheduler)
        if train_metrics and valid_metrics:
            self.metrics = Metrics(train_metrics, valid_metrics)

        
    def train_step(self, preds_b, labels_b):
        loss = self.loss_func(preds_b, labels_b)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.scheduler.step()
        
        # metrics
        metrics = self.metrics.update_train(preds_b, labels_b)
        self.log({"train_loss": loss.item(), "learning_rate": self.scheduler.get_last_lr()[0]})
        self.log(metrics)
        return loss
    
    def valid_step(self, preds_b, labels_b):
        loss = self.loss_func(preds_b, labels_b)
        
        # metrics
        metrics = self.metrics.update_valid(preds_b, labels_b)
        return loss
    
    def switch(self, is_train=True):
        if is_train:
            self.model.train()
            dl = self.train_dataloader
            ctx = torch.enable_grad()
        else:
            self.model.eval()
            dl = self.valid_dataloader
            ctx = torch.inference_mode()
        return dl, ctx
    
    def one_epoch(self, is_train=True):
        avg_loss = 0.
        dl, ctx = self.switch(is_train)
        pbar = progress_bar(dl, leave=False)
        preds = []
        for i, b in enumerate(pbar):
            with ctx and (mixed_precision if self.fp16 else nullcontext):
                images, labels = to_device(b, self.device)
                preds_b = self.model(images)         
                preds.append(preds_b)
                if is_train:
                    loss = self.train_step(preds_b, labels)
                else:
                    loss = self.valid_step(preds_b, labels)
                avg_loss += loss
            pbar.comment = f"loss={loss.item():2.3f}"
        return torch.cat(preds, dim=0), avg_loss.mean().item()
    
    def epoch_ends(self):
        print(f"
    
    def validate(self):
        if self.valid_dataloader is not None:
            _, avg_loss = self.one_epoch(is_train=False)
            self.log({"val_loss": avg_loss})
            self.log(self.metrics.compute_valid())
    
    def get_data_tensors():
        raise NotImplementedError()
    
    def get_model_preds(self, with_inputs=False):
        preds, loss = self.one_epoch(train=False)
        if with_inputs:
            images, labels = self.get_data_tensors()
            return images, labels, preds, loss
        else:
            return preds, loss
    
    def fit(self, log_preds=False):    
        for epoch in progress_bar(range(self.config.epochs), total=self.config.epochs, leave=True):
            _  = self.one_epoch(is_train=True)
            
            self.log({"epoch":epoch})
                
            ## validation
            self.validate()
                
            self.metrics.reset()
        if self.use_wandb:
            if log_preds:
                print("Logging model predictions on validation data")
                preds, _ = self.get_model_preds()
                self.preds_logger.log(preds=preds)
            wandb.finish()

In [196]:
import timm
import torchvision as tv
import torchvision.transforms as T

In [197]:
model_name = "resnet10t"
model = timm.create_model(model_name, pretrained=False, num_classes=10, in_chans=1)

In [198]:
CM = CapeModel(model=model)

In [199]:
data_path = "."

In [200]:
train_tfms = T.Compose([
    # T.Resize((32,32)),
    T.RandomCrop(28, padding=4), 
    # T.RandAugment(num_ops=2),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.1307,), (0.3081,)),
    T.RandomErasing(),
])

val_tfms = T.Compose([
    # T.Resize((32,32)),
    T.ToTensor(),
    T.Normalize((0.1307,), (0.3081,)),
])

tfms ={"train": train_tfms, "valid":val_tfms}

In [201]:
train_ds = tv.datasets.FashionMNIST(data_path, download=True, transform=tfms["train"])
valid_ds = tv.datasets.FashionMNIST(data_path, download=True, train=False, transform=tfms["valid"])

In [202]:
bs=256
num_workers=8

In [203]:
train_dataloader = DataLoader(train_ds, batch_size=bs, shuffle=True, pin_memory=True, num_workers=num_workers)
valid_dataloader = DataLoader(valid_ds, batch_size=bs*2, shuffle=False, num_workers=num_workers)

In [204]:
CM.data(train_dataloader, valid_dataloader)

In [205]:
CM.prepare(epochs=1)

In [206]:
CM.fit()

In [16]:
n = 123123

In [22]:
print(f"{n:10d}")

    123123


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()