In [62]:
import torch
import torch.nn.functional as F

# t = torch.Tensor([[1.9392, -1.9266, 0.9664], 
#             [0.0000, -1.9266, 0.9664], 
#             [0.0000, -0.0000, 0.9664]]) 

def covariance_formula(ret):
    idx = torch.tensor([[0, 1, 2, 1, 2, 2], [0, 1, 2, 0, 0, 1]], dtype=torch.int64)
    A = torch.zeros((3, 3))
    # ret = torch.rand((6,))# if False else torch.tensor([0, 0, 0, -torch.inf, -torch.inf, -torch.inf], dtype=torch.float32)
    # print(ret.dtype)
    A[idx[0], idx[1]] = ret.exp()
    S = A @ A.T

    loss = 0.5 * ((2 * ret).exp().sum() - 2 * ret[:3].sum() - 3)
    test_loss = 0.5 * (torch.trace(S) - torch.log(torch.linalg.det(S)) - 3)
    assert torch.isclose(loss, test_loss)
    return loss

def naive_batch_covariance_formula(ret):
    losses = []
    for i in range(len(ret)):
        loss = covariance_formula(ret[i])
        losses.append(loss)
    return torch.stack(losses, dim=0)

def batch_covariance_formula(input):
    idx = torch.tensor([[0, 1, 2, 1, 2, 2], [0, 1, 2, 0, 0, 1]], dtype=torch.int64)

    # TODO: add requires grad.
    A = torch.zeros((input.shape[0], 3, 3))
    # ret = torch.rand((6,))# if False else torch.tensor([0, 0, 0, -torch.inf, -torch.inf, -torch.inf], dtype=torch.float32)
    # print(ret.dtype)
    A[:, idx[0], idx[1]] = input.exp()
    S = A @ A.transpose(1, 2).contiguous()

    loss = 0.5 * ((2 * input).exp().sum(dim=1) - 2 * input[:, :3].sum(dim=1) - 3)
    test_loss = 0.5 * (S.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1) - torch.log(torch.linalg.det(S)) - 3)
    assert torch.allclose(loss, test_loss)
    return loss

batch_size = 32
input = torch.rand((batch_size, 6))
naive_output = naive_batch_covariance_formula(input)
output = batch_covariance_formula(input)
assert torch.allclose(output, naive_output)
naive_output, output

(tensor([ 9.6399,  6.5368,  7.5383,  8.0682,  7.2965,  7.9007,  9.1775,  3.8679,
          6.7374,  7.2718,  7.0281,  5.5908,  4.4065,  7.1589,  8.0531, 10.3503,
          8.8353,  9.0627,  5.4845,  6.3137,  5.6081,  9.2538,  7.7487,  6.9330,
          9.3055,  4.4201,  5.3666,  3.2983,  5.7141,  6.3828,  6.1005,  7.9595]),
 tensor([ 9.6399,  6.5368,  7.5383,  8.0682,  7.2965,  7.9007,  9.1775,  3.8679,
          6.7374,  7.2718,  7.0281,  5.5908,  4.4065,  7.1589,  8.0531, 10.3503,
          8.8353,  9.0627,  5.4845,  6.3137,  5.6081,  9.2538,  7.7487,  6.9330,
          9.3055,  4.4201,  5.3666,  3.2983,  5.7141,  6.3828,  6.1005,  7.9595]))