In [1]:
import numpy as np
import torch
from torch import nn

Shiesh and SITA functions

In [2]:
def Shiesh(x):
    return torch.arcsinh(np.exp(1)*torch.sinh(x))
    
def DShiesh(x):
    den = 1+(np.exp(1)*torch.sinh(x))**2
    jac = np.exp(1)*torch.cosh(x)/(den**0.5)
    return jac

def activation(x):
    asd = torch.abs(x) > 30
    int_val = torch.zeros_like(x)
    jac = torch.ones_like(x)
    int_val[asd] = x[asd] + torch.sign(x[asd])
    int_val[~asd] = Shiesh(x[~asd])
    jac[~asd] = DShiesh(x[~asd])
    return int_val.to(torch.float32), jac

def sita(K, scores):
    id_tensor = torch.eye(K).to(scores.dtype)
    A = scores.masked_fill(id_tensor == 1, 0)
    B = scores.masked_fill(id_tensor == 0, -1e8)
    A += torch.nn.Softplus()(B)
    A = torch.tril(A)
    return A


$\text{profiti}^{-1}$

In [3]:
def ProFITi_inv(x, A, theta, phi):
    nlloss = nn.GaussianNLLLoss(full=True, reduction='none')
    
    LJD = torch.tensor([0.]) # initiating log det jacobian

    # applying sita
    x = torch.matmul(A, x[:,None])[:,0]
    ljd = torch.log(torch.diagonal(A, dim1 = -2, dim2 = -1)).sum()
    LJD += ljd

    # applying elementwise transformation
    x = x*theta + phi
    LJD += torch.log(theta).sum()

    # applying shiesh
    x, ljd = activation(x)
    LJD += torch.log(ljd).sum()

    gnll = nlloss(torch.zeros_like(x), x, torch.ones_like(x)) # computing diagonal gaussian nll
    nll = gnll.sum() - LJD # computing joint nll
    density = torch.exp(-nll) # computing likelihood
    return density

Verifying if joint densities integrate to 1.

$x \in \mathbb{R}^2$ is uniformly sampled from (-50,50) ($x$ indicates observation space).

We map $x$ to gaussian via flows.

In [4]:
cum_density = 0
scores = torch.randn(2,2)
A = sita(2, scores).to(torch.float32) # Triangular Attention 
theta = torch.rand(2) # while theta = exp(tanh()), I just took small positive value for experiment
phi = torch.randn(2) # it is the offset
# now integrating from -50 to 50 
for i in np.linspace(-50,50,500):
    for j in np.linspace(-50,50,500):
        x = torch.tensor([i,j]).to(torch.float32)
        cum_density += 0.2*0.2*ProFITi_inv(x, A, theta, phi)
print(cum_density)

tensor([0.9959])
