# Torch Runner - Example
We train a resnet50 on the CIFAR10 dataset using torch_runner. 

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googlecolab/colabtools/blob/master/notebooks/colab-github-demo.ipynb)

In [1]:
!pip install torch_runner



In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch_runner as T
from sklearn.metrics import accuracy_score

In [3]:
BATCH_SIZE = 128
EPOCHS = 5
LR = 0.001

In [4]:
class BatchCollate:
    def __call__(self, batch_list):
        ## Get output from dataloader as dict
        output = {"images": [], "labels": []}
        for image, label in batch_list:
            output["images"].append(image)
            output["labels"].append(label)
        
        output["images"] = torch.stack(output["images"])
        output["labels"] = torch.LongTensor(output["labels"])

        return output

def load_dataloaders():
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
    )

    train_dataset = torchvision.datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=transform
    )
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE,
        shuffle=True, 
        num_workers=2, 
        collate_fn=BatchCollate()
    )

    val_dataset = torchvision.datasets.CIFAR10(
        root='./data', 
        train=False,
        download=True, 
        transform=transform
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=2,
        collate_fn=BatchCollate()
    )

    return train_dataloader, val_dataloader

In [5]:
class CIFAR10Model(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        resnet = torchvision.models.resnet50(pretrained=pretrained)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.head = nn.Linear(2048, 10)
    
    def forward(self, x):
        x = self.resnet(x)
        x = x.mean((2, 3))
        x = self.head(x)
        return x

In [6]:
class Trainer(T.TrainerModule):
    def __init__(
        self, 
        model, 
        optimizer, 
        scheduler, 
        early_stop=True, 
        early_stop_params={"patience": 2, "mode": "max", "delta": 0.0},
        early_stop_metric="accuracy",
        experiment_name="model",
        device="cuda",
    ):
    
        super(Trainer, self).__init__(
            model=model, 
            optimizer=optimizer, 
            scheduler=scheduler, 
            early_stop=early_stop, 
            early_stop_params=early_stop_params, 
            early_stop_metric=early_stop_metric, 
            experiment_name=experiment_name, 
            device=device
        )
    
    def calc_metric(self, preds, targets):
        preds = torch.argmax(torch.softmax(preds, dim=-1).detach(), dim=-1).cpu().numpy()
        acc_score = accuracy_score(targets.cpu().numpy(), preds)
        return acc_score

    def loss_fct(self, preds, targets):
        criterion = nn.CrossEntropyLoss()
        loss = criterion(preds, targets)
        return loss
    
    def train_one_step(self, batch, batch_id):
        ## Must return a dict
        image = batch["images"].to(self.device)
        labels = batch["labels"].to(self.device)

        self.optimizer.zero_grad()
        outputs = self.model(image)
        loss = self.loss_fct(outputs, labels)
        loss.backward()
        self.optimizer.step()
        acc_score = self.calc_metric(outputs, labels)

        return {"loss": loss.item(), "accuracy": acc_score}
    
    def valid_one_step(self, batch, batch_id):
        ## Must return a dict
        image = batch["images"].to(self.device)
        labels = batch["labels"].to(self.device)

        outputs = self.model(image)
        loss = self.loss_fct(outputs, labels)
        acc_score = self.calc_metric(outputs, labels)
        return {"loss": loss.item(), "accuracy": acc_score}

In [7]:
device = torch.device("cuda")
model = CIFAR10Model(pretrained=True).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
train_dataloader, val_dataloader = load_dataloaders()

Files already downloaded and verified
Files already downloaded and verified


A directory with the experiment name would be created. It will contain the hyperparams as an yaml file, log file of the training and a checkpoint for the best model.

In [8]:
runner = Trainer(model, optimizer, scheduler)
runner.fit(train_dataloader, val_dataloader, batch_size=BATCH_SIZE, epochs=EPOCHS)

Epoch: 1/5


HBox(children=(FloatProgress(value=0.0, description='Training', max=391.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Validation', max=79.0, style=ProgressStyle(description_wi…


Validation score improved (-inf --> 0.7586036392405063). Saving model!
Epoch: 2/5


HBox(children=(FloatProgress(value=0.0, description='Training', max=391.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Validation', max=79.0, style=ProgressStyle(description_wi…


Validation score improved (0.7586036392405063 --> 0.8066653481012658). Saving model!
Epoch: 3/5


HBox(children=(FloatProgress(value=0.0, description='Training', max=391.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Validation', max=79.0, style=ProgressStyle(description_wi…


EarlyStopping counter: 1 out of 2, Best Score: 0.8066653481012658
Epoch: 4/5


HBox(children=(FloatProgress(value=0.0, description='Training', max=391.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Validation', max=79.0, style=ProgressStyle(description_wi…


Validation score improved (0.8066653481012658 --> 0.814181170886076). Saving model!
Epoch: 5/5


HBox(children=(FloatProgress(value=0.0, description='Training', max=391.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Validation', max=79.0, style=ProgressStyle(description_wi…


Validation score improved (0.814181170886076 --> 0.8305973101265823). Saving model!
