## Functional 

In [1]:
import torch
from torcheval.metrics.functional import binary_accuracy

preds = torch.tensor([0.1,0.7,0.6])
truth = torch.tensor([0,1,0])

binary_accuracy(preds, truth)

tensor(0.6667)

## With Accumulation

In [2]:
from torcheval.metrics import BinaryAccuracy

accuracy = BinaryAccuracy()

preds1 = torch.tensor([0.1,0.7,0.6])
truth1 = torch.tensor([0,1,0])

accuracy.update(preds1, truth1)

accuracy.compute()

tensor(0.6667)

In [3]:
preds2 = torch.tensor([0.4,0.9,0.1])
truth2 = torch.tensor([1,1,1])

accuracy.update(preds2, truth2)

accuracy.compute()

tensor(0.5000)

that would be equivalent to do:

In [5]:
binary_accuracy(input=torch.cat([preds1, preds2]), 
                target=torch.cat([truth1, truth2]))

tensor(0.5000)

## Example usage:

In [66]:
from types import SimpleNamespace

import wandb

from fastprogress import progress_bar

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

import torchvision.transforms as T
from torchvision.datasets import FashionMNIST

import timm

from torcheval.metrics import MulticlassAccuracy, Mean

In [67]:
WANDB_PROJECT = "torcheval"
WANDB_ENTITY = "capecape"

In [68]:
config = SimpleNamespace(
    data_path = ".",
    model_name = "resnet10t",
    lr = 1e-3,
    wd = 0.0,
    bs=512,
    epochs = 20,
    num_workers=8,
    device="cuda" if torch.cuda.is_available() else "cpu",
)



train_tfms = T.Compose([
    T.RandomCrop(28, padding=1), 
    T.RandomHorizontalFlip(),
    T.ToTensor(),
])

val_tfms = T.Compose([
    T.ToTensor(),
])

config.tfms = {"train": train_tfms, "valid":val_tfms}

In [69]:
train_ds = FashionMNIST(config.data_path, download=True, transform=config.tfms["train"])
valid_ds = FashionMNIST(config.data_path, download=True, train=False, transform=config.tfms["valid"])

In [70]:
def dataloaders(train_ds, valid_ds, bs=128, num_workers=8):
    train_dataloader = DataLoader(train_ds, batch_size=bs, shuffle=True, 
                               pin_memory=True, num_workers=num_workers)
    valid_dataloader = DataLoader(valid_ds, batch_size=bs*2, shuffle=False, 
                               num_workers=num_workers)
    return train_dataloader, valid_dataloader

In [71]:
train_dl, valid_dl = dataloaders(train_ds, valid_ds, bs=config.bs, num_workers=config.num_workers)

In [72]:
model = timm.create_model(config.model_name, pretrained=False, num_classes=10, in_chans=1).to(config.device)

optimizer = AdamW(model.parameters(), weight_decay=config.wd)
loss_func = nn.CrossEntropyLoss()
scheduler = OneCycleLR(optimizer, max_lr=config.lr, total_steps=config.epochs*len(train_dl))

In [73]:
train_acc = MulticlassAccuracy(device=config.device)
valid_acc = MulticlassAccuracy(device=config.device)

# another cool trick is keep track of the loss as a metric!
train_loss = Mean(device=config.device)
valid_loss = Mean(device=config.device)

In [74]:
def reset_metrics():
    train_acc.reset()
    valid_acc.reset()
    train_loss.reset()
    valid_loss.reset()

def train_step(loss):
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    return loss

In [75]:
def to_device(t, device):
    "Put tensors on device"
    if isinstance(t, (tuple, list)):
        return [_t.to(device) for _t in t]
    return t.to(device) 

In [76]:
def train_one_epoch():
    model.train()
    pbar = progress_bar(train_dl, leave=False)
    for i, b in enumerate(pbar):
        images, labels = to_device(b, config.device)
        preds_b = model(images)
        loss = loss_func(preds_b, labels)
        train_step(loss)
        
        # update metrics
        train_loss.update(loss.detach(), weight=len(images))
        train_acc.update(preds_b, labels)
        
        # log to W&B
        wandb.log({"train_loss": train_loss.compute(),
                  "train_acc": train_acc.compute(),
                  "learning_rate": scheduler.get_last_lr()[0]})
        pbar.comment = f"train_loss={train_loss.compute():2.3f}, train_acc={train_acc.compute():2.3f}"  

In [77]:
def validate():
    model.eval()
    pbar = progress_bar(valid_dl, leave=False)
    for i, b in enumerate(pbar):
        images, labels = to_device(b, config.device)
        with torch.inference_mode():
            preds_b = model(images)
        
        # update metrics
        valid_loss.update(loss_func(preds_b, labels).detach(), weight=len(images))
        valid_acc.update(preds_b, labels)
        
    # log to W&B, we log at the end of the validation
    wandb.log({"valid_loss": valid_loss.compute(),
              "train_acc": valid_acc.compute()})

In [78]:
def print_metrics(epoch):
    print((f"epoch: {epoch:3}, "
           f"train_loss: {train_loss.compute():10.3f}, "
           f"train_acc: {train_acc.compute():3.3f}"
           "  ||   "
           f"val_loss: {valid_loss.compute():10.3f}, "
           f"val_acc: {valid_acc.compute():3.3f}"
          ))

In [79]:
def fit():         
    for epoch in progress_bar(range(config.epochs), total=config.epochs, leave=True):
        train_one_epoch()

        wandb.log({"epoch":epoch})

        # validation
        validate()
        
        print_metrics(epoch)
        
        # we set metrics to zero for the new epoch
        reset_metrics()

In [80]:
with wandb.init(project=WANDB_PROJECT, entity=WANDB_ENTITY, config=config):
    fit()

epoch:   0, train_loss:      1.629, train_acc: 0.489||   val_loss:      0.976, val_acc: 0.707


epoch:   1, train_loss:      0.714, train_acc: 0.761||   val_loss:      0.512, val_acc: 0.814


epoch:   2, train_loss:      0.423, train_acc: 0.847||   val_loss:      0.383, val_acc: 0.859


epoch:   3, train_loss:      0.341, train_acc: 0.876||   val_loss:      0.396, val_acc: 0.848


epoch:   4, train_loss:      0.304, train_acc: 0.886||   val_loss:      0.432, val_acc: 0.833


epoch:   5, train_loss:      0.280, train_acc: 0.896||   val_loss:      0.297, val_acc: 0.891


epoch:   6, train_loss:      0.261, train_acc: 0.903||   val_loss:      0.349, val_acc: 0.870


epoch:   7, train_loss:      0.243, train_acc: 0.910||   val_loss:      0.267, val_acc: 0.901


epoch:   8, train_loss:      0.229, train_acc: 0.914||   val_loss:      0.299, val_acc: 0.887


epoch:   9, train_loss:      0.216, train_acc: 0.919||   val_loss:      0.276, val_acc: 0.895


epoch:  10, train_loss:      0.203, train_acc: 0.924||   val_loss:      0.256, val_acc: 0.908


epoch:  11, train_loss:      0.194, train_acc: 0.928||   val_loss:      0.245, val_acc: 0.913


epoch:  12, train_loss:      0.181, train_acc: 0.932||   val_loss:      0.247, val_acc: 0.910


epoch:  13, train_loss:      0.172, train_acc: 0.937||   val_loss:      0.215, val_acc: 0.921


epoch:  14, train_loss:      0.162, train_acc: 0.940||   val_loss:      0.215, val_acc: 0.924


epoch:  15, train_loss:      0.151, train_acc: 0.944||   val_loss:      0.207, val_acc: 0.926


epoch:  16, train_loss:      0.142, train_acc: 0.947||   val_loss:      0.206, val_acc: 0.927


epoch:  17, train_loss:      0.136, train_acc: 0.950||   val_loss:      0.201, val_acc: 0.930


epoch:  18, train_loss:      0.130, train_acc: 0.952||   val_loss:      0.202, val_acc: 0.929


epoch:  19, train_loss:      0.130, train_acc: 0.952||   val_loss:      0.201, val_acc: 0.928


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
learning_rate,▁▁▂▂▃▄▅▆▆▇███████▇▇▇▇▆▆▆▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁
train_acc,▁▂▅▆▇▇▇▇▇▇▇▇▇███████████████████████████
train_loss,█▇▄▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss,█▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
epoch,19.0
learning_rate,0.0
train_acc,0.928
train_loss,0.12964
valid_loss,0.20133
