In [207]:
import torch
from torch import nn
from torch.nn import functional as F

from typing import Literal

In [208]:
class CrossEntropyLoss(nn.Module):
    def __init__(self, reduction: Literal['mean', 'sum', 'none'] = 'mean', ignore_index: int = -1, num_classes: int = 256):
        super().__init__()
        self.reduction = reduction
        self.ignore_index = ignore_index
        self.num_classes = num_classes


    def forward(self, outputs: torch.Tensor, targets: torch.Tensor):
        targets = torch.where(targets == self.ignore_index, torch.zeros_like(targets, dtype=torch.int), targets)
        x_exp = torch.sum(outputs * F.one_hot(targets, num_classes=self.num_classes).float(), dim=-1).exp()
        x_sum = outputs.exp().sum(dim=-1, keepdim=True)
        loss = -torch.log(x_exp / x_sum)

        if self.reduction != 'none':
            loss = loss[torch.where(targets != self.ignore_index)]

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

In [209]:
def test_ce_loss():
    ce = CrossEntropyLoss(ignore_index=1, reduction='mean')
    ce_nn = nn.CrossEntropyLoss(ignore_index=1, reduction='mean')

    bs, seqlen, n_classes = 16, 16, 256

    for i in range(10):
        outputs = torch.rand(bs, seqlen, n_classes)
        targets = torch.randint(0, n_classes, (bs, seqlen))

        loss = ce(outputs.view(-1, n_classes), targets.view(-1))
        loss_nn = ce_nn(outputs.view(-1, n_classes), targets.view(-1))

        assert torch.allclose(loss, loss_nn, rtol=1e-3, atol=1e-3)

In [210]:
test_ce_loss()

In [211]:
def test_ce_loss_multiclass():
    output = torch.randn(16, 16, 256)
    onehot = torch.zeros_like(output)
    onehot[:, :, output.argmax(-1)] = 1
    loss_fn = CrossEntropyLoss()
    
    losses = []
    for alpha in range(0, 1.1, 0.1):
        target = torch.lerp(onehot, output, alpha)
        loss = loss_fn(output, target)
        losses.append(loss.item())

    assert losses == sorted(losses, reverse=True)
    assert losses[-1] == 0
    assert losses[0] > 0
