In [None]:
import torch
from torch import Tensor, jit

DEVICE = torch.device("cuda")

# Profiling P function implementations

In [None]:
@jit.script
def P1(x: Tensor, xhat: Tensor, q: float = 0.5) -> Tensor:
    r = torch.abs(x - xhat)
    return torch.where(x > xhat, q * r, (1 - q) * r)


@jit.script
def P2(x: Tensor, xhat: Tensor, q: float = 0.5) -> Tensor:
    errors = x - xhat
    return torch.max((q - 1) * errors, q * errors)


@jit.script
def P3(x: Tensor, xhat: Tensor, q: float = 0.5) -> Tensor:
    errors = x - xhat
    return torch.max((q - 1) * errors, q * errors)

In [None]:
import tsdm

In [None]:
tsdm.metrics.ND()(x, xhat)

In [None]:
%%timeit

x, xhat = torch.randn(2, 1_000_000, device=DEVICE)
P1(x, xhat)

In [None]:
%%timeit
x = torch.nn.Parameter(torch.randn(1_000_000, device=DEVICE))
xhat = torch.nn.Parameter(torch.randn(1_000_000, device=DEVICE))
result = torch.mean(P1(x, xhat))
result.backward()

In [None]:
%%timeit
x = torch.nn.Parameter(torch.randn(1_000_000, device=DEVICE))
xhat = torch.nn.Parameter(torch.randn(1_000_000, device=DEVICE))
result = torch.mean(P2(x, xhat))
result.backward()

In [None]:
def QL(x, xhat, p=0.5):
    return 2 * torch.sum(P(x, xhat, p)) / torch.sum(torch.abs(x))

In [None]:
x = torch.nn.Parameter(torch.randn(20, 30, device=DEVICE))
xhat = torch.nn.Parameter(torch.randn(20, 30, device=DEVICE))