In [2]:
import torch
import numpy as np

In [18]:
def compute_gaussian_values(
    means2D: torch.Tensor,    # (N, 2)
    covs2D: torch.Tensor,     # (N, 2, 2)
    pixels: torch.Tensor      # (H, W, 2)
) -> torch.Tensor:           # (N, H, W)
    N = means2D.shape[0]
    H, W = pixels.shape[:2]
    
    # Compute offset from mean (N, H, W, 2)
    dx = pixels.unsqueeze(0) - means2D.reshape(N, 1, 1, 2)  # (N, H, W, 2)
    
    # Add small epsilon to diagonal for numerical stability
    eps = 1e-4
    covs2D = covs2D + eps * torch.eye(2, device=covs2D.device).unsqueeze(0)
    
    # Compute determinant for normalization
    det = torch.det(covs2D)  # (N,)
    inv_covs2D = torch.inverse(covs2D)  # (N, 2, 2)
    
    # Reshape dx for einsum: (N, H, W, 2) -> (N, H*W, 2)
    dx_reshaped = dx.reshape(N, H * W, 2)  # (N, H*W, 2)
    
    # Compute exponent term using einsum
    # Equation: 'nij,njk,nlk->nil'
    # n: batch dimension (N)
    # i: H*W dimension (flattened spatial dimensions)
    # j, k: 2D covariance dimensions (2)
    # l: H*W dimension (flattened spatial dimensions)
    exponent = -0.5 * torch.einsum('nijk,nkl,nijl->nij', dx, inv_covs2D, dx)
    
    # Reshape exponent back to (N, H, W)
    exponent = exponent.reshape(N, H, W)  # (N, H, W)
    
    # Compute gaussian values
    gaussian = torch.exp(exponent) / (2 * np.pi * torch.sqrt(det)).unsqueeze(-1).unsqueeze(-1)  # (N, H, W)
    
    return gaussian

In [19]:
# N=13, H=7, W=11
means2D = torch.rand([13,2])
covs2D = torch.rand([13, 2,2])
pixels = torch.rand([7,11,2])

In [7]:
N = means2D.shape[0]
H, W = pixels.shape[:2]
print(N,H,W)

13 7 11


In [8]:
# Compute offset from mean (N, H, W, 2)
dx = pixels.unsqueeze(0) - means2D.reshape(N, 1, 1, 2)  # (N, H, W, 2)
print(dx.shape)

torch.Size([13, 7, 11, 2])


In [9]:
# Add small epsilon to diagonal for numerical stability
eps = 1e-4
covs2D = covs2D + eps * torch.eye(2, device=covs2D.device).unsqueeze(0)

# Compute determinant for normalization
det = torch.det(covs2D)  # (N,)
inv_covs2D = torch.inverse(covs2D)  # (N, 2, 2)

# Reshape dx for einsum: (N, H, W, 2) -> (N, H*W, 2)
dx_reshaped = dx.reshape(N, H * W, 2)  # (N, H*W, 2)


In [12]:
dx_reshaped.shape
det.shape

torch.Size([13])

In [10]:

# Compute exponent term using einsum
# Equation: 'nij,njk,nlk->nil'
# n: batch dimension (N)
# i: H*W dimension (flattened spatial dimensions)
# j, k: 2D covariance dimensions (2)
# l: H*W dimension (flattened spatial dimensions)
exponent = -0.5 * torch.einsum('nij,njk,nlk->nil', dx_reshaped, inv_covs2D, dx_reshaped)  # (N, H*W)
print(exponent.shape)

torch.Size([13, 77, 77])


In [None]:

# Reshape exponent back to (N, H, W)
exponent = exponent.reshape(N, H, W)  # (N, H, W)

# Compute gaussian values
gaussian = torch.exp(exponent) / (2 * np.pi * torch.sqrt(det)).unsqueeze(-1).unsqueeze(-1)  # (N, H, W)


In [21]:
dx = compute_gaussian_values(means2D,covs2D,pixels)
dx.shape

torch.Size([13, 7, 11])

In [16]:
EEE = torch.rand([3,2,2])
EEE = EEE + eps * torch.eye(2, device=EEE.device).unsqueeze(0)
print(torch.det(torch.inverse(EEE)))
print(torch.det(EEE))


tensor([10.6378, 18.2494, -4.0876])
tensor([ 0.0940,  0.0548, -0.2446])
