In [1]:
!pip install torcheval wandb timm fastprogress

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


# Torcheval: A PyTorch metrics library

This is the code for the following Weights and Biases [report](http://wandb.me/torcheval)

## Functional 

In [2]:
import torch
from torcheval.metrics.functional import binary_accuracy

preds = torch.tensor([0.1,0.7,0.6])
truth = torch.tensor([0,1,0])

binary_accuracy(preds, truth)

tensor(0.6667)

## With Accumulation

In [3]:
from torcheval.metrics import BinaryAccuracy

accuracy = BinaryAccuracy()

preds1 = torch.tensor([0.1,0.7,0.6])
truth1 = torch.tensor([0,1,0])

accuracy.update(preds1, truth1)

accuracy.compute()

tensor(0.6667)

In [4]:
preds2 = torch.tensor([0.4,0.9,0.1])
truth2 = torch.tensor([1,1,1])

accuracy.update(preds2, truth2)

accuracy.compute()

tensor(0.5000)

that would be equivalent to do:

In [5]:
binary_accuracy(input=torch.cat([preds1, preds2]), 
                target=torch.cat([truth1, truth2]))

tensor(0.5000)

## Example usage:

Here we will show a minimal example of how to use torcheval integration on a trianing loop

In [6]:
from types import SimpleNamespace

import wandb

from fastprogress import progress_bar

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

import torchvision.transforms as T
from torchvision.datasets import FashionMNIST

import timm

from torcheval.metrics import MulticlassAccuracy, Mean

In [7]:
WANDB_PROJECT = "torcheval"
WANDB_ENTITY = None

In [8]:
config = SimpleNamespace(
    data_path = ".",
    model_name = "resnet10t",
    lr = 1e-3,
    wd = 0.0,
    bs=512,
    epochs = 20,
    num_workers=8,
    device="cuda" if torch.cuda.is_available() else "cpu",
)



train_tfms = T.Compose([
    T.RandomCrop(28, padding=1), 
    T.RandomHorizontalFlip(),
    T.ToTensor(),
])

val_tfms = T.Compose([
    T.ToTensor(),
])

config.tfms = {"train": train_tfms, "valid":val_tfms}

In [9]:
train_ds = FashionMNIST(config.data_path, download=True, transform=config.tfms["train"])
valid_ds = FashionMNIST(config.data_path, download=True, train=False, transform=config.tfms["valid"])

In [10]:
def dataloaders(train_ds, valid_ds, bs=128, num_workers=8):
    train_dataloader = DataLoader(train_ds, batch_size=bs, shuffle=True, 
                               pin_memory=True, num_workers=num_workers)
    valid_dataloader = DataLoader(valid_ds, batch_size=bs*2, shuffle=False, 
                               num_workers=num_workers)
    return train_dataloader, valid_dataloader

In [11]:
train_dl, valid_dl = dataloaders(train_ds, valid_ds, bs=config.bs, num_workers=config.num_workers)

In [12]:
model = timm.create_model(config.model_name, pretrained=False, num_classes=10, in_chans=1).to(config.device)

optimizer = AdamW(model.parameters(), weight_decay=config.wd)
loss_func = nn.CrossEntropyLoss()
scheduler = OneCycleLR(optimizer, max_lr=config.lr, total_steps=config.epochs*len(train_dl))

In [13]:
train_acc = MulticlassAccuracy(device=config.device)
valid_acc = MulticlassAccuracy(device=config.device)

# another cool trick is keep track of the loss as a metric!
train_loss = Mean(device=config.device)
valid_loss = Mean(device=config.device)

In [14]:
def reset_metrics():
    train_acc.reset()
    valid_acc.reset()
    train_loss.reset()
    valid_loss.reset()

def train_step(loss):
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    return loss

In [15]:
def to_device(t, device):
    "Put tensors on device"
    if isinstance(t, (tuple, list)):
        return [_t.to(device) for _t in t]
    return t.to(device) 

In [16]:
def train_one_epoch():
    model.train()
    pbar = progress_bar(train_dl, leave=False)
    for i, b in enumerate(pbar):
        images, labels = to_device(b, config.device)
        preds_b = model(images)
        loss = loss_func(preds_b, labels)
        train_step(loss)
        
        # update metrics
        train_loss.update(loss.detach(), weight=len(images))
        train_acc.update(preds_b, labels)
        
        # log to W&B
        wandb.log({"train_loss": loss.detach().item(),  # log at current batch
                   "learning_rate": scheduler.get_last_lr()[0]})
        pbar.comment = f"train_loss={loss.detach().item():2.3f}, train_acc={train_acc.compute():2.3f}"  
    
    # log training accuracy at the end of epoch
    wandb.log({"train_acc": train_acc.compute()})

In [17]:
def validate():
    model.eval()
    pbar = progress_bar(valid_dl, leave=False)
    for i, b in enumerate(pbar):
        images, labels = to_device(b, config.device)
        with torch.inference_mode():
            preds_b = model(images)
        
        # update metrics
        valid_loss.update(loss_func(preds_b, labels).detach(), weight=len(images))
        valid_acc.update(preds_b, labels)
        
    # log to W&B, we log at the end of the validation
    wandb.log({"valid_loss": valid_loss.compute(),
               "valid_acc": valid_acc.compute()})

In [18]:
def print_metrics(epoch):
    print((f"epoch: {epoch:3}, "
           f"train_loss: {train_loss.compute():10.3f}, "
           f"train_acc: {train_acc.compute():3.3f}"
           "  ||   "
           f"val_loss: {valid_loss.compute():10.3f}, "
           f"val_acc: {valid_acc.compute():3.3f}"
          ))

In [19]:
def fit():         
    for epoch in progress_bar(range(config.epochs), total=config.epochs, leave=True):
        train_one_epoch()

        wandb.log({"epoch":epoch})

        # validation
        validate()
        
        print_metrics(epoch)
        
        # we set metrics to zero for the new epoch
        reset_metrics()

In [20]:
with wandb.init(project=WANDB_PROJECT, entity=WANDB_ENTITY, config=config):
    fit()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcapecape[0m. Use [1m`wandb login --relogin`[0m to force relogin


epoch:   0, train_loss:      1.648, train_acc: 0.479  ||   val_loss:      1.024, val_acc: 0.690


epoch:   1, train_loss:      0.744, train_acc: 0.753  ||   val_loss:      0.517, val_acc: 0.815


epoch:   2, train_loss:      0.434, train_acc: 0.843  ||   val_loss:      0.391, val_acc: 0.856


epoch:   3, train_loss:      0.351, train_acc: 0.872  ||   val_loss:      0.356, val_acc: 0.872


epoch:   4, train_loss:      0.311, train_acc: 0.884  ||   val_loss:      0.360, val_acc: 0.865


epoch:   5, train_loss:      0.285, train_acc: 0.895  ||   val_loss:      0.342, val_acc: 0.876


epoch:   6, train_loss:      0.268, train_acc: 0.900  ||   val_loss:      0.299, val_acc: 0.892


epoch:   7, train_loss:      0.248, train_acc: 0.909  ||   val_loss:      0.301, val_acc: 0.888


epoch:   8, train_loss:      0.232, train_acc: 0.914  ||   val_loss:      0.262, val_acc: 0.903


epoch:   9, train_loss:      0.216, train_acc: 0.919  ||   val_loss:      0.270, val_acc: 0.899


epoch:  10, train_loss:      0.208, train_acc: 0.922  ||   val_loss:      0.255, val_acc: 0.909


epoch:  11, train_loss:      0.196, train_acc: 0.927  ||   val_loss:      0.239, val_acc: 0.911


epoch:  12, train_loss:      0.186, train_acc: 0.931  ||   val_loss:      0.234, val_acc: 0.919


epoch:  13, train_loss:      0.175, train_acc: 0.935  ||   val_loss:      0.227, val_acc: 0.919


epoch:  14, train_loss:      0.165, train_acc: 0.938  ||   val_loss:      0.217, val_acc: 0.919


epoch:  15, train_loss:      0.152, train_acc: 0.944  ||   val_loss:      0.210, val_acc: 0.925


epoch:  16, train_loss:      0.146, train_acc: 0.946  ||   val_loss:      0.205, val_acc: 0.929


epoch:  17, train_loss:      0.140, train_acc: 0.948  ||   val_loss:      0.202, val_acc: 0.930


epoch:  18, train_loss:      0.136, train_acc: 0.951  ||   val_loss:      0.201, val_acc: 0.930


epoch:  19, train_loss:      0.134, train_acc: 0.951  ||   val_loss:      0.201, val_acc: 0.930


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
learning_rate,▁▁▂▂▃▄▅▆▆▇███████▇▇▇▇▆▆▆▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁
train_acc,▁▅▆▇▇▇▇▇▇███████████
train_loss,█▆▄▃▃▂▂▂▂▂▁▁▂▂▂▂▁▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_acc,▁▅▆▆▆▆▇▇▇▇▇▇████████
valid_loss,█▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
epoch,19.0
learning_rate,0.0
train_acc,0.95105
train_loss,0.15139
valid_acc,0.93
valid_loss,0.20147
