In [13]:
import torch
from torch import nn
from torch.nn import functional as F
from time import perf_counter
import einops

from typing import Literal
from copy import deepcopy

In [52]:
class L2CELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.ce = nn.CrossEntropyLoss(reduction="mean", ignore_index=-1)
        self.l2 = nn.MSELoss(reduction="mean")

    def forward(self, output: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        labels_argmax = labels.argmax(dim=-1, keepdim=True)
        labels_onehot = F.one_hot(labels_argmax.squeeze(), num_classes=labels.shape[-1]).float().squeeze()

        # TODO: reduction="none", then manually reduce?
        loss = self.ce(output, labels_argmax.squeeze()) * self.l2(output, labels) / self.l2(output, labels_onehot)
        return loss

In [58]:
def test_l2ce_loss():
    b, s, d = 16, 16, 512
    loss_fn = L2CELoss()

    x = torch.randn(b, s, d)
    y = F.one_hot(torch.randint(0, d, (b, s)), num_classes=d).float()
    num_iter = 50

    losses = []
    for alpha in range(num_iter):
        alpha = alpha / num_iter
        target = torch.lerp(x, y, alpha)
        losses.append(loss_fn(x.flatten(0, 1), target.flatten(0, 1)).item())
    print(x.flatten(0, 1).shape)
    assert losses == sorted(losses)
    assert losses[0] == 0.0

    # Check out that this is equivalent to the CE loss if we give a 1hot target
    ce = nn.CrossEntropyLoss(reduction="mean", ignore_index=-1)

    squashed_y = y.view(-1, d).argmax(dim=-1).squeeze(())
    ce_loss = ce(x.view(-1, d), squashed_y)
    l2ce_loss = loss_fn(x.view(-1, d), y.view(-1, d))
    assert torch.allclose(l2ce_loss, ce_loss)


test_l2ce_loss()


torch.Size([256, 512])


In [124]:
b, s, d_init = 16, 16, 256
l2 = nn.MSELoss()
ce = nn.CrossEntropyLoss()

for d_prod in range(1, 10):
    d = d_prod * d_init

    t0 = perf_counter()
    x = torch.randn(b, s, d)
    target = torch.randint(0, d, (b, s))
    target_onehot = F.one_hot(target, num_classes=d).float()
    noisy_target_onehot = target_onehot + torch.randn_like(target_onehot) * 0.1

    model = nn.Linear(d, d)
    optimizer = torch.optim.Adam(model.parameters())
    optimizer.zero_grad()
    loss_ce = ce(model(x).view(-1, d), target.view(-1))
    loss_l2 = l2(model(x), target_onehot)
    loss_l2_noisy = l2(model(x), noisy_target_onehot)
    loss = loss_ce * loss_l2_noisy / loss_l2
    loss.backward()
    optimizer.step()
    print(f"{d=}, {loss=}, {perf_counter() - t0:.3f}s")


d=256, loss=tensor(5.8629, grad_fn=<DivBackward0>), 0.060s
d=512, loss=tensor(6.6182, grad_fn=<DivBackward0>), 0.094s
d=768, loss=tensor(6.9890, grad_fn=<DivBackward0>), 0.097s
d=1024, loss=tensor(7.3310, grad_fn=<DivBackward0>), 0.151s
d=1280, loss=tensor(7.5418, grad_fn=<DivBackward0>), 0.139s
d=1536, loss=tensor(7.7343, grad_fn=<DivBackward0>), 0.208s
d=1792, loss=tensor(7.8542, grad_fn=<DivBackward0>), 0.259s
d=2048, loss=tensor(8.0103, grad_fn=<DivBackward0>), 0.638s
d=2304, loss=tensor(8.1578, grad_fn=<DivBackward0>), 0.433s
