# ViT model

> Putting together patch embeddings and transformer encoder

In [None]:
#| default_exp utils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch
from torch import nn
import torch.functional as F
from torchvision import datasets
import numpy as np

import yaml
from fastcore.basics import Path

In [None]:
from torchmetrics import Metric

In [None]:
#| export 

class LossAccumulator(Metric):
    """
    A PyTorch metric for accumulating loss values over multiple mini-batches.

    This class inherits from the `Metric` class provided by the `torchmetrics` package.
    It takes a loss function during initialization and uses it to calculate the loss for each batch during the update step.
    The final loss value is calculated by averaging the total loss over the total number of elements.
    """
    def __init__(self, loss_func=nn.CrossEntropyLoss()):
        super().__init__()
        self.loss_func = loss_func
        self.add_state("loss", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, acts, target): 
        assert acts.shape[0] == target.shape[0]
        acts, target = acts.to(self.device), target.to(self.device)
        self.loss += self.loss_func(acts, target)
        self.total += target.numel()

    def compute(self):
        return self.loss / self.total

In [None]:
lossaccum = LossAccumulator()

In [None]:
tmp = []

for i in range(10):
    acts = nn.Softmax(dim=-1)(torch.rand(5, 2))
    target = 1 * (acts[:,1]>0.5)
    loss = lossaccum(acts, target)
    # print(acts.dtype, target.dtype, loss)
    print(loss)
    tmp.append(loss.item())

print(f"avg loss: {np.mean(tmp)}")

tensor(0.1240)
tensor(0.1212)
tensor(0.1225)
tensor(0.1247)
tensor(0.1132)
tensor(0.1177)
tensor(0.1301)
tensor(0.1309)
tensor(0.1192)
tensor(0.1303)
avg loss: 0.12338263019919396


In [None]:
lossaccum.compute() # reset after end of epoch

tensor(0.1234)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()