In [18]:
import torch
import torch.nn.functional as F
from torch.distributions import biject_to, transform_to
import math
import pyro.distributions as dist
from pyro.distributions.lkj import LKJCorrCholesky, corr_cholesky_constraint, _signed_stick_breaking_tril

torch.set_default_tensor_type(torch.FloatTensor)

In [68]:
dimension = 4
d = LKJCorrCholesky(dimension=dimension, concentration=1)
sample = d.sample()

In [78]:
d = LKJCorrCholesky(dimension=dimension, concentration=1)
sample_shape = torch.Size([100])
sample = d.sample(sample_shape)
log_prob = d.log_prob(sample)

# Now we will compute Jacobian of the transform from cholesky to correlation
tril_index = sample.new_ones(sample.shape).tril(diagonal=-1) > 0.5
sample_tril = sample[tril_index].clone().requires_grad_()
sample_cloned = sample.new_zeros(sample.shape)
sample_cloned_tmp = sample_cloned.clone()
sample_cloned_tmp[tril_index] = sample_tril.reshape(-1)
sample_cloned_diag = (1 - sample_cloned_tmp.pow(2).sum(-1)).sqrt()
sample_cloned[tril_index] = sample_tril
sample_cloned.view(sample_shape + (dimension * dimension,))[..., ::dimension + 1] = sample_cloned_diag
y = sample_cloned.matmul(sample_cloned.transpose(-2, -1))
corr_tril = y[tril_index]

In [77]:
sample_cloned_diag

tensor([[1.0000, 0.7084, 0.8603, 0.7924],
        [1.0000, 0.9820, 0.7884, 0.4685],
        [1.0000, 0.8950, 0.9178, 0.6933],
        [1.0000, 0.9013, 0.8255, 0.6771],
        [1.0000, 0.8084, 0.8430, 0.5885],
        [1.0000, 0.5765, 0.5880, 0.5614],
        [1.0000, 0.8189, 0.9792, 0.5258],
        [1.0000, 0.9999, 0.8501, 0.7135],
        [1.0000, 0.8813, 0.9052, 0.6342],
        [1.0000, 0.7850, 0.7748, 0.5940],
        [1.0000, 0.9863, 0.9019, 0.3314],
        [1.0000, 0.6839, 0.7195, 0.4769],
        [1.0000, 0.9318, 0.7580, 0.7520],
        [1.0000, 0.9817, 0.4821, 0.7459],
        [1.0000, 0.9834, 0.7387, 0.8711],
        [1.0000, 0.9955, 0.8699, 0.4756],
        [1.0000, 0.9231, 0.8307, 0.3116],
        [1.0000, 0.8228, 0.9406, 0.6265],
        [1.0000, 0.9969, 0.7438, 0.4063],
        [1.0000, 0.9426, 0.7667, 0.4488],
        [1.0000, 0.8246, 0.8403, 0.4185],
        [1.0000, 0.9820, 0.9155, 0.3787],
        [1.0000, 0.5453, 0.9216, 0.5922],
        [1.0000, 0.9275, 0.6920, 0

In [76]:
sample_cloned.shape

torch.Size([100, 4, 4])

In [71]:
sample_cloned.shape

torch.Size([100, 4, 4])

In [65]:
torch.tensor(-2.4594)

tensor(-2.4594)

In [35]:
corr_tril = y[tril_index]

(tensor(-1.6547), tensor(-0.0584))

In [15]:
d.log_prob(sample)

tensor(-1.8193)

In [16]:
_autograd_log_det(corr_tril, sample_tril)

tensor(-0.2230)

In [17]:
-3 * math.log(2)

-2.0794415416798357

In [47]:
y = x.matmul(x.t())

In [49]:
_autograd_log_det(y[tril_index], x[tril_index])

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

In [None]:
sample_tril = sample[tril_index]

In [10]:
def _autograd_log_det(ys, x):
    # computes log_abs_det_jacobian of y w.r.t. x
    return torch.stack([torch.autograd.grad(y, (x,), retain_graph=True)[0]
                        for y in ys]).slogdet()[1]

In [20]:
dimension = 4
concentration = torch.tensor(1.)

In [21]:
d = LKJCorrCholesky(dimension, concentration, sample_method="cvine")
sample = d.sample()

# Start with the lower triangular part of a sample, then we will transform it back to a
# partial correlation; compute its log_prob and Jacobian of the transfrom.
tril_index = sample.new_ones(dimension, dimension).tril(diagonal=-1) > 0.5
sample_tril = sample[tril_index].clone().requires_grad_()
sample_cloned = sample.new_ones(dimension, dimension).tril(diagonal=-1)
sample_cloned_tmp = sample_cloned.clone()
sample_cloned_tmp[tril_index] = sample_tril
sample_cloned_diag = (1 - sample_cloned_tmp.pow(2).sum(-1)).sqrt()
sample_cloned[tril_index] = sample_tril
sample_cloned.view(-1)[::dimension + 1] = sample_cloned_diag

In [22]:
sample

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.0528, 0.9986, 0.0000, 0.0000],
        [0.4672, 0.0638, 0.8819, 0.0000],
        [0.5898, 0.2748, 0.0927, 0.7537]])

In [23]:
sample_cloned

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.0528, 0.9986, 0.0000, 0.0000],
        [0.4672, 0.0638, 0.8819, 0.0000],
        [0.5898, 0.2748, 0.0927, 0.7537]], grad_fn=<CopySlices>)

In [24]:
sample

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.0528, 0.9986, 0.0000, 0.0000],
        [0.4672, 0.0638, 0.8819, 0.0000],
        [0.5898, 0.2748, 0.0927, 0.7537]])

In [25]:
partial_corr = transform_to(corr_cholesky_constraint).inv(sample_cloned).tanh()
beta_sample = (partial_corr + 1) / 2  # inverse affine transform
partial_corr_log_prob = d._beta_dist.log_prob(beta_sample).sum(-1)
partial_corr_log_prob

tensor(0.9588, grad_fn=<SumBackward2>)

In [26]:
_autograd_log_det(beta_sample, sample_tril)

tensor(-3.5467)

In [27]:
d.log_prob(sample)

tensor(-3.1572)

In [28]:
target_log_prob = partial_corr_log_prob + _autograd_log_det(beta_sample, sample_tril)
target_log_prob

tensor(-2.5879, grad_fn=<AddBackward0>)

In [30]:
sample[1, 1].log()

tensor(-0.0014)

In [41]:
value = sample
order_offset = torch.arange(4 - dimension, 2.1)
order = 2 * concentration.unsqueeze(-1) - order_offset

# Compute unnormalized log_prob:
cholesky_logprob = (order * value.diagonal(dim1=-2, dim2=-1)[..., 1:].log()).sum(-1)

# Compute normalization constant (on the first proof of page 1999 of [1])
denominator_concentration = concentration + (dimension - 1) / 2.
denominator = torch.lgamma(denominator_concentration) * (dimension - 1)
numerator = torch.mvlgamma(denominator_concentration - 0.5, dimension - 1)
# pi_constant in [1] is D * (D - 1) / 4
# pi_constant in torch.mvlgamma is (D - 1) * (D - 2) / 4
# hence, we need to add a pi_constant = (D - 1) * (1 - D/4)
pi_constant = (dimension - 1) / 2. * math.log(math.pi)
normalization_constant = pi_constant + numerator - denominator

In [42]:
normalization_constant

tensor(2.4594)

In [39]:
numerator = (torch.lgamma(concentration + 1) + torch.lgamma(concentration + 0.5) +
             torch.lgamma(concentration) - torch.lgamma(concentration + 1.5) -
             torch.lgamma(concentration + 1.5) - torch.lgamma(concentration + 1.5) +
             math.log(math.pi) * 3 / 2 + math.log(math.pi) * 1 / 2 + math.log(math.pi) * 2 / 2)
numerator

tensor(2.4594)

In [None]:
torch.lgamma(concentration + 0.5) + torch.lgamma(concentration)

#### end

In [15]:
sample_new

tensor([[ 0.0000,  1.0000,  1.0000,  1.0000,  1.0000],
        [ 0.0276,  0.0000,  1.0000,  1.0000,  1.0000],
        [ 0.0949, -0.4510,  0.0000,  1.0000,  1.0000],
        [ 0.0035,  0.1989, -0.6447,  0.0000,  1.0000],
        [ 0.6909, -0.0108, -0.3325, -0.0022,  0.0000]], grad_fn=<CopySlices>)

In [10]:
sample_tril

tensor([ 0.0276,  0.0949, -0.4510,  0.0035,  0.1989, -0.6447,  0.6909, -0.0108,
        -0.3325, -0.0022], requires_grad=True)

In [11]:
corr

tensor([[ 4.0000,  3.0000,  1.5490,  0.5542, -0.3455],
        [ 3.0000,  3.0008,  2.0026,  0.3554, -0.3156],
        [ 1.5490,  2.0026,  2.2124,  0.9106,  0.0682],
        [ 0.5542,  0.3554,  0.9106,  1.4552,  0.2146],
        [-0.3455, -0.3156,  0.0682,  0.2146,  0.5880]], grad_fn=<MmBackward>)

In [7]:
target_log_prob

tensor(-inf)

In [2]:
d = LKJCorrCholesky(80, torch.rand(1), sample_method="onion")

In [10]:
x = torch.rand(3, 10, 10)

In [14]:
x.new_ones(3, 3)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

In [13]:
torch.diagonal(x)

tensor([[0.4958, 0.5597, 0.2474],
        [0.3369, 0.6387, 0.4935],
        [0.4567, 0.1401, 0.6528],
        [0.4648, 0.3759, 0.9469],
        [0.1449, 0.3248, 0.0386],
        [0.1499, 0.7719, 0.9055],
        [0.2691, 0.9263, 0.3220],
        [0.8208, 0.8923, 0.4726],
        [0.9984, 0.2575, 0.1213],
        [0.5694, 0.6186, 0.5278]])

In [6]:
d.sample(sample_shape=torch.Size([3]))

RuntimeError: expand(torch.FloatTensor{[3, 1, 80]}, size=[3, 80]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (3)

In [7]:
x.diagonal

> [0;32m/home/fehiepsi/pyro/pyro/distributions/lkj.py[0m(311)[0;36m_rsample_onion[0;34m()[0m
[0;32m    308 [0;31m        [0mtril_index[0m [0;34m=[0m [0mcholesky[0m[0;34m.[0m[0mnew_ones[0m[0;34m([0m[0mcholesky[0m[0;34m.[0m[0mshape[0m[0;34m)[0m[0;34m.[0m[0mtril[0m[0;34m([0m[0mdiagonal[0m[0;34m=[0m[0;34m-[0m[0;36m1[0m[0;34m)[0m [0;34m>[0m [0;36m0.5[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    309 [0;31m        [0mcholesky[0m[0;34m[[0m[0mtril_index[0m[0;34m][0m [0;34m=[0m [0mw[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    310 [0;31m        [0mcholesky_diag[0m [0;34m=[0m [0;34m([0m[0;36m1[0m [0;34m-[0m [0mcholesky[0m[0;34m.[0m[0mpow[0m[0;34m([0m[0;36m2[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m.[0m[0msqrt[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 311 [0;31m        [0mcholesky[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-

ipdb>  cholesky_diag.shape


torch.Size([3, 1, 80])


ipdb>  cholesky.shape


torch.Size([3, 1, 80, 80])


ipdb>  exit


In [5]:
d.sample().shape

torch.Size([3, 80, 80])

In [5]:
%debug

> [0;32m/home/fehiepsi/pyro/pyro/distributions/lkj.py[0m(299)[0;36m_rsample_onion[0;34m()[0m
[0;32m    297 [0;31m        [0mnormal_sample[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mnormal[0m[0;34m([0m[0mloc[0m[0;34m,[0m [0mscale[0m[0;34m)[0m[0;34m.[0m[0mtril[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    298 [0;31m        [0msphere_uniform_sample[0m [0;34m=[0m [0mnormal_sample[0m [0;34m/[0m [0mnormal_sample[0m[0;34m.[0m[0mnorm[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;34m-[0m[0;36m1[0m[0;34m,[0m [0mkeepdim[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 299 [0;31m        [0mtril_index[0m [0;34m=[0m [0mscale[0m[0;34m.[0m[0mtril[0m[0;34m([0m[0;34m)[0m [0;34m>[0m [0;36m0.5[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    300 [0;31m        [0mw[0m [0;34m=[0m [0mbeta_sample[0m[0;34m.[0m[0msqrt[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0;34m

ipdb>  beta_sample.shape


torch.Size([3, 9480])


ipdb>  sphere_uniform_sample.shape


torch.Size([3, 79, 79])


ipdb>  79 * 79


6241


ipdb>  80 * 79


6320


ipdb>  normal_sample.shape


torch.Size([3, 79, 79])


ipdb>  loc.shape


torch.Size([3, 79, 79])


ipdb>  self._beta_dist.batch_shape


torch.Size([3, 9480])


ipdb>  exit


In [11]:
d = LKJCorrCholesky(80, 1, sample_method="cvine")

In [3]:
%%time
x = d.sample(torch.Size([5000]))
y = x.matmul(x.transpose(-1, -2))

CPU times: user 21 s, sys: 1.31 s, total: 22.3 s
Wall time: 8.32 s


In [14]:
x.shape

torch.Size([60000, 10, 10])

In [5]:
x = torch.rand(10, 10).tril(diagonal=-1)

In [6]:
x

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1770, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.9371, 0.0278, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.9500, 0.8768, 0.1039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.4127, 0.4784, 0.0989, 0.2805, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.6206, 0.4922, 0.1069, 0.4032, 0.1080, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1665, 0.3607, 0.4326, 0.2511, 0.0284, 0.7881, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.8304, 0.9402, 0.2872, 0.5601, 0.7791, 0.9055, 0.9199, 0.0000, 0.0000,
         0.0000],
        [0.8698, 0.4121, 0.9474, 0.3823, 0.1347, 0.5006, 0.9804, 0.2945, 0.0000,
         0.0000],
        [0.8448, 0.8238, 0.7491, 0.0591, 0.5597, 0.2987, 0.7390, 0.1199, 0.5501,
         0.0000]])

In [4]:
%debug

> [0;32m/home/fehiepsi/pyro/pyro/distributions/lkj.py[0m(52)[0;36m_signed_stick_breaking_tril[0;34m()[0m
[0;32m     50 [0;31m[0;34m[0m[0m
[0m[0;32m     51 [0;31m    [0;31m# transform t to tril matrix with identity diagonal[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 52 [0;31m    [0mr[0m [0;34m=[0m [0mt[0m[0;34m.[0m[0mnew_ones[0m[0;34m([0m[0mt[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;34m:[0m[0;34m-[0m[0;36m1[0m[0;34m][0m [0;34m+[0m [0;34m([0m[0mD[0m[0;34m,[0m [0mD[0m[0;34m)[0m[0;34m)[0m[0;34m.[0m[0mtril[0m[0;34m([0m[0mdiagonal[0m[0;34m=[0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     53 [0;31m    [0mtril_index[0m [0;34m=[0m [0mr[0m [0;34m>[0m [0;36m0.5[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     54 [0;31m    [0mr[0m[0;34m[[0m[0mtril_index[0m[0;34m][0m [0;34m=[0m [0mt[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[

ipdb>  t.shape


torch.Size([100000, 45])


ipdb>  t.device


device(type='cuda', index=0)


ipdb>  D


10


ipdb>  t.new_ones(1)


tensor([1.])


ipdb>  t.new_ones(t.shape[:-1])


tensor([1., 1., 1.,  ..., 1., 1., 1.])


ipdb>  t.new_ones(t.shape[:-1] + (D, D))


tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        ...,

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1., 

ipdb>  t.new_ones(t.shape[:-1] + (D, D)).tril(diagonal=-1)


*** RuntimeError: CUDA error: invalid configuration argument


ipdb>  t.new_ones(t.shape[:-1] + (D, D)).shape


torch.Size([100000, 10, 10])


ipdb>  exit()


In [9]:
t = x.matmul(x.transpose(-2, -1))

In [33]:
t.std(dim=0)

tensor([[0.0000e+00, 3.1614e-01, 3.1634e-01, 3.1656e-01, 3.1553e-01, 3.1650e-01,
         3.1666e-01, 3.1687e-01, 3.1708e-01, 3.1792e-01],
        [3.1614e-01, 9.1151e-17, 3.1555e-01, 3.1533e-01, 3.1601e-01, 3.1601e-01,
         3.1606e-01, 3.1626e-01, 3.1506e-01, 3.1666e-01],
        [3.1634e-01, 3.1555e-01, 1.4001e-16, 3.1707e-01, 3.1603e-01, 3.1666e-01,
         3.1648e-01, 3.1639e-01, 3.1561e-01, 3.1669e-01],
        [3.1656e-01, 3.1533e-01, 3.1707e-01, 1.0386e-16, 3.1578e-01, 3.1637e-01,
         3.1560e-01, 3.1628e-01, 3.1621e-01, 3.1538e-01],
        [3.1553e-01, 3.1601e-01, 3.1603e-01, 3.1578e-01, 1.1097e-16, 3.1534e-01,
         3.1585e-01, 3.1727e-01, 3.1732e-01, 3.1606e-01],
        [3.1650e-01, 3.1601e-01, 3.1666e-01, 3.1637e-01, 3.1534e-01, 1.1836e-16,
         3.1649e-01, 3.1514e-01, 3.1581e-01, 3.1619e-01],
        [3.1666e-01, 3.1606e-01, 3.1648e-01, 3.1560e-01, 3.1585e-01, 3.1649e-01,
         1.2413e-16, 3.1631e-01, 3.1643e-01, 3.1585e-01],
        [3.1687e-01, 3.1626

In [22]:
D = 4
x = torch.ones(D * (D-1) // 2, requires_grad=True)
y = x.new_zeros(D, D)
tril_index = y.new_ones(D, D).tril(diagonal=-1) > 0.5
y[tril_index] = x
y.view(-1)[::D+1] = y.sum(dim=-1)

In [24]:
y.sum().backward()

In [28]:
d.sample()

tensor([[ 1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.4663,  0.8846,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.6254,  0.3690,  0.6875,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.4112,  0.0647, -0.4068,  0.8131,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.3380, -0.0882, -0.0610, -0.7045,  0.6147,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.2838, -0.2251, -0.4623,  0.2745, -0.1423,  0.7480,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.0847,  0.1818,  0.2294, -0.4110,  0.3110, -0.3466,  0.7220,  0.0000,
          0.0000,  0.0000],
        [-0.3346,  0.3746, -0.3475,  0.3819, -0.4790,  0.0401, -0.4967,  0.0581,
          0.0000,  0.0000],
        [ 0.0861,  0.5125, -0.1190, -0.0087, -0.4350,  0.3188,  0.4801,  0.2605,
          0.3555,  0.0000],
        [ 0.0688, -

In [27]:
y

tensor([[0., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 2., 0.],
        [1., 1., 1., 3.]], grad_fn=<CopySlices>)

In [26]:
x.grad

tensor([2., 2., 2., 2., 2., 2.])

In [41]:
x = torch.rand(10).cuda()

In [44]:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)

In [45]:
x = torch.rand(10)

In [47]:
x.device

device(type='cuda', index=0)

In [43]:
x.new_ones(x.shape + (10,)).tril()

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

In [13]:
dist.Beta(4.5, 4.5).variance.sqrt() * 2

tensor(0.3162)

In [11]:
t.std(dim=0)

tensor([[0.0000e+00, 3.1614e-01, 3.1634e-01, 3.1656e-01, 3.1553e-01, 3.1650e-01,
         3.1666e-01, 3.1687e-01, 3.1708e-01, 3.1792e-01],
        [3.1614e-01, 9.1151e-17, 3.1555e-01, 3.1533e-01, 3.1601e-01, 3.1601e-01,
         3.1606e-01, 3.1626e-01, 3.1506e-01, 3.1666e-01],
        [3.1634e-01, 3.1555e-01, 1.4001e-16, 3.1707e-01, 3.1603e-01, 3.1666e-01,
         3.1648e-01, 3.1639e-01, 3.1561e-01, 3.1669e-01],
        [3.1656e-01, 3.1533e-01, 3.1707e-01, 1.0386e-16, 3.1578e-01, 3.1637e-01,
         3.1560e-01, 3.1628e-01, 3.1621e-01, 3.1538e-01],
        [3.1553e-01, 3.1601e-01, 3.1603e-01, 3.1578e-01, 1.1097e-16, 3.1534e-01,
         3.1585e-01, 3.1727e-01, 3.1732e-01, 3.1606e-01],
        [3.1650e-01, 3.1601e-01, 3.1666e-01, 3.1637e-01, 3.1534e-01, 1.1836e-16,
         3.1649e-01, 3.1514e-01, 3.1581e-01, 3.1619e-01],
        [3.1666e-01, 3.1606e-01, 3.1648e-01, 3.1560e-01, 3.1585e-01, 3.1649e-01,
         1.2413e-16, 3.1631e-01, 3.1643e-01, 3.1585e-01],
        [3.1687e-01, 3.1626

In [9]:
t.mean(dim=0)

tensor([[ 1.0000e+00, -1.5795e-11, -1.0260e-03, -4.4503e-04],
        [-1.5795e-11,  1.0000e+00, -5.3766e-04, -1.5094e-03],
        [-1.0260e-03, -5.3766e-04,  1.0000e+00,  2.4660e-03],
        [-4.4503e-04, -1.5094e-03,  2.4660e-03,  1.0000e+00]])

In [12]:
beta_sample = d._beta_dist.rsample()

In [13]:
beta_sample

tensor([2.2204e-16, 8.7688e-01, 1.6567e-01, 8.0272e-01, 9.2591e-01, 1.9920e-01])

In [14]:
d._beta_dist.concentration0

tensor([2.0000, 1.5000, 1.5000, 1.0000, 1.0000, 1.0000])

In [15]:
d._beta_dist.concentration1

tensor([0.0000, 0.5000, 0.5000, 1.0000, 1.0000, 1.0000])

In [None]:
loc = beta_sample.new_zeros(beta_sample.shape[:-1] + (D - 1, D - 1))
scale = loc.new_ones(loc.shape)
normal_sample = torch.normal(loc, scale).tril()
sphere_uniform_sample = normal_sample / normal_sample.norm(dim=-1, keepdim=True)
tril_index = scale.tril() > 0.5
w = beta_sample.sqrt().reshape(-1) * sphere_uniform_sample[tril_index]

# Note that w is the triangular part of a Cholesky factor of a correlation
# matrix (from the procedure in algorithm 3.2 of [1]).
# The diagonal entries of Cholesky factor is sqrt(1 - w^2). We can show it by linear
# algebra or by recalling that each row of Cholesky factor has unit Euclidean length.
cholesky = beta_sample.new_zeros(beta_sample.shape[:-1] + (D, D))
tril_index = cholesky.new_ones(cholesky.shape).tril(diagonal=-1) > 0.5
cholesky[tril_index] = w
cholesky_diag = (1 - cholesky.pow(2).sum(-1)).sqrt()
cholesky.view(-1, D * D)[..., ::D + 1] = cholesky_diag

In [9]:
torch.arange(end=10., step=0.5)

TypeError: arange() received an invalid combination of arguments - got (step=float, end=float, ), but expected one of:
 * (Number end, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool requires_grad)
 * (Number start, Number end, Number step, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool requires_grad)


In [4]:
%debug

> [0;32m/home/fehiepsi/pyro/pyro/distributions/lkj.py[0m(205)[0;36m__init__[0;34m()[0m
[0;32m    203 [0;31m            beta_concentration_offset = torch.arange((dimension - 1.5) / 2, step=0.5,
[0m[0;32m    204 [0;31m                                                     [0mdtype[0m[0;34m=[0m[0mconcentration[0m[0;34m.[0m[0mdtype[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 205 [0;31m                                                     device=concentration.device)
[0m[0;32m    206 [0;31m            [0mbeta_concentration[0m [0;34m=[0m [0mbeta_concentration_init[0m [0;34m-[0m [0mbeta_concentration_offset[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    207 [0;31m            [0;31m# expand to a matrix then takes the vector form of the lower triangular part[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  dimension


10


ipdb>  exit


In [3]:
f = LKJCorrCholesky(10, 1, sample_method="C-vine")

In [4]:
x = d(torch.Size([8000]))

tensor([[[ 1.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.3373,  0.9414,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.5406, -0.1343,  0.8305,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.1373,  0.0849,  0.2268,  ...,  0.6343,  0.0000,  0.0000],
         [-0.1566,  0.1786,  0.1847,  ...,  0.0589,  0.8067,  0.0000],
         [ 0.3151, -0.3075,  0.0927,  ..., -0.4077,  0.2942,  0.2706]],

        [[ 1.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.5596,  0.8287,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.3119,  0.1231,  0.9421,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.1565,  0.3144,  0.7781,  ...,  0.4808,  0.0000,  0.0000],
         [-0.2162, -0.1328,  0.4667,  ..., -0.4097,  0.5027,  0.0000],
         [ 0.2822,  0.2354,  0.0455,  ..., -0.0159,  0.1627,  0.3905]],

        [[ 1.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0121,  0.9999,  0.0000,  ...,  0

In [17]:
%%timeit
samples = d(torch.Size([8000]))

209 ms ± 1.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
%%timeit
samples = f(torch.Size([8000]))

264 ms ± 6.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
D = 10
y = torch.rand(D, D).tril()
y = y / y.norm(dim=-1, keepdim=True)
x = t.inv(y)

In [None]:
%timeit t.inv(y)

In [208]:
D = 80
x = torch.rand(5000, D * (D - 1) // 2)

In [222]:
torch.tensor(0.).cosh()

tensor(1.)

In [213]:
def f(x):
    # make sure that x.size(-1) = D * (D - 1) / 2 for some D
    D = round((1 + math.sqrt(1 + 8 * x.size(-1))) / 2)
    if D * (D - 1) != 2 * x.size(-1):
        raise ValueError("This transformation requires an input with last shape is "
                         "D*(D-1)/2 for some integer D.")

    # we interchange step 1 and step 2.a for a better performance
    eps = torch.finfo(x.dtype).eps
    t = x.tanh().clamp(min=(-1 + eps), max=(1 - eps))

    # transform to tril matrix with identity diagonal
    r = t.new_ones(t.shape[:-1] + (D, D)).tril(diagonal=-1)
    tril_index = r > 0.5
    r[tril_index] = t.reshape(-1)
    r.view(-1, D * D)[..., ::D + 1] = 1

    # apply stick-breaking on the squared values;
    # we omit the step of computing s = z * z_cumprod by using the fact:
    #     y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
    z = r ** 2
    z_cumprod = (1 - z).cumprod(-1)

    # to workaround the issue: NaN propagated through backward pass even when not accessed
    # at https://github.com/pytorch/pytorch/issues/15506,
    # here we only take sqrt at tril_index
    z_cumprod_sqrt = z_cumprod.new_zeros(z_cumprod.shape)
    z_cumprod_sqrt[tril_index] = z_cumprod[tril_index].sqrt()
    z_cumprod_sqrt_shifted = F.pad(z_cumprod_sqrt[..., :-1], pad=(1, 0), value=1)
    y = r * z_cumprod_sqrt_shifted
    return y

In [215]:
def g(x):
    # make sure that x.size(-1) = D * (D - 1) / 2 for some D
    D = round((1 + math.sqrt(1 + 8 * x.size(-1))) / 2)
    if D * (D - 1) != 2 * x.size(-1):
        raise ValueError("This transformation requires an input with last shape is "
                         "D*(D-1)/2 for some integer D.")

    # we interchange step 1 and step 2.a for a better performance
    eps = torch.finfo(x.dtype).eps
    t = x.tanh().clamp(min=(-1 + eps), max=(1 - eps))

    # transform to tril matrix with identity diagonal
    r = t.new_ones(t.shape[:-1] + (D, D)).tril(diagonal=-1)
    tril_index = r > 0.5
    r[tril_index] = t.reshape(-1)
    r.view(-1, D * D)[..., ::D + 1] = 1

    # apply stick-breaking on the squared values;
    # we omit the step of computing s = z * z_cumprod by using the fact:
    #     y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
    z = r ** 2
    z_cumprod = (1 - z).cumprod(-1)

    # to workaround the issue: NaN propagated through backward pass even when not accessed
    # at https://github.com/pytorch/pytorch/issues/15506,
    # here we only take sqrt at tril_index
    z_cumprod_sqrt = z_cumprod.new_zeros(z_cumprod.shape)
    z_cumprod_sqrt[tril_index] = z_cumprod[tril_index].sqrt()
    z_cumprod_sqrt_shifted = F.pad(z_cumprod_sqrt[..., :-1], pad=(1, 0), value=1)
    y = z_cumprod.new_zeros(z_cumprod.shape)
    y[tril_index] = t.reshape(-1) * z_cumprod_sqrt_shifted[tril_index]
    return y

In [214]:
%timeit f(x)

4.81 s ± 21.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [217]:
%timeit g(x)

7.89 s ± 160 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [26]:
%timeit (1 + x) / (1 - x)

12.8 µs ± 77.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [27]:
%timeit 1 - (1 - x).reciprocal()

12.1 µs ± 79.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [30]:
%timeit  2 / (1 - x) - 1

18 µs ± 51.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


%timeit t.log_abs_det_jacobian(x, y)

transform = _PartialCorrToCorrCholeskyTransform()
x_shape = (6,)
x = torch.empty(x_shape).uniform_(-1, 1).requires_grad_()

def _vector_to_l_cholesky(z):
    D = (1.0 + math.sqrt(1.0 + 8.0 * z.shape[-1]))/2.0
    if D % 1 != 0:
        raise ValueError("Correlation matrix transformation requires d choose 2 inputs")
    D = int(D)
    x = torch.zeros(list(z.shape[:-1]) + [D,D], device=z.device)

    x[..., 0,0] = 1
    x[..., 1:,0] = z[..., :(D-1)]
    i = D - 1
    last_squared_x = torch.zeros(list(z.shape[:-1]) + [D], device=z.device)
    for j in range(1, D):
        distance_to_copy = D - 1 - j
        last_squared_x = last_squared_x[..., 1:] + x[...,j:,(j-1)].clone()**2
        x[..., j, j] = (1 - last_squared_x[..., 0]).sqrt()
        x[..., (j+1):, j] = z[..., i:(i + distance_to_copy)] * (1 - last_squared_x[..., 1:]).sqrt()
        i += distance_to_copy
    return x

def _call(x):
    # make sure that x.size(-1) = D * (D - 1) / 2 for some D
    D = round((1 + math.sqrt(1 + 8 * x.size(-1))) / 2)
    if D * (D - 1) != 2 * x.size(-1):
        raise ValueError("This transformation requires an input with last shape is "
                         "D*(D-1)/2 for some integer D.")

    # transform to tril matrix
    r = x.new_ones(x.shape[:-1] + (D, D)).tril()
    tril_index = r.tril(diagonal=-1) > 0.5
    r[tril_index] = x.reshape(-1)

    # apply stick-breaking on the squared values
    z = r ** 2
    z_cumprod = (1 - z).cumprod(-1)
    # we omit the step computing s = z * z_cumprod by using the trick:
    #     y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)

    # workaround the issue: NaN propagated through backward pass even when not accessed
    # at https://github.com/pytorch/pytorch/issues/15506;
    # here we only take sqrt at tril_index
    z_cumprod_sqrt = z_cumprod.new_zeros(z_cumprod.shape)
    z_cumprod_sqrt[tril_index] = z_cumprod[tril_index].sqrt()        
    z_cumprod_sqrt_shifted = F.pad(z_cumprod_sqrt[..., :-1], pad=(1, 0), value=1)
    y = r * z_cumprod_sqrt_shifted
    return y

D = 100

%timeit torch.arange(beta_c_init - 0.5, beta_c_init - D / 2. + 0.1, 0.5).expand(D, D)

torch.arange(1., 6, 0.5)

y = torch.arange(1., 10)

%timeit 2 * y

%timeit torch.arange(2., 20, 2)

torch.finfo(torch.float).epsneg

torch.finfo(torch.float)

x.device

D = 3
x = torch.randn(D * (D-1) // 2)

import math

%timeit _call(x)

import torch

torch.set_printoptions(10)
torch.set_default_dtype(torch.double)

x = torch.cuda.FloatTensor([1])

x

torch.finfo(x.dtype).eps

torch.tanh(torch.tensor(-10.))

torch.sigmoid(torch.tensor(-40.))

torch.tensor(-40.).tanh()

torch.rand(10).expand(torch.Size([3, 3, -1])).shape

%timeit _vector_to_l_cholesky(x)

def _call(x):
    # transform to tril matrix
    r = _vector_to_lower_triangular_with_identity_diagonal(x)

    # apply stick-breaking on the squared values
    z = r ** 2
    z_cumprod = (1 - z).cumprod(-1)

    # we omit the step to compute s = z * z_cumprod by using the trick:
    #     y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
    y = r * F.pad(_masked_sqrt(z_cumprod[..., :-1]), pad=(1, 0), value=1)
    return y

tril_index = x.new_ones(x.shape).tril(diagonal=diagonal) > 0.5

torch.uint8

torch.full((4, 4), 0.5).cumsum(-1)

x = torch.rand(3, 3).cuda()

torch.tensor(3.).size(-1)

x.new_ones((3, 3), dtype=torch.uint8)

import math

def abc(x):
    # make sure that x.size(-1) = D * (D - 1) / 2 for some D
    D = (1 + math.sqrt(1 + 8 * x.size(-1))) / 2
    if D % 1 != 0:
        raise ValueError("This transformation requires an input with size D*(D-1)/2.")
    D = int(D)
    r = x.new_ones(x.shape[:-1] + (D, D)).tril()
    tril_index = r.tril(diagonal=-1) > 0.5
    r[tril_index] = x.reshape(-1)
    return r, tril_index


def defg(x, mask):
    # workaround the issue NaN propagated through backward pass even when not accessed
    # https://github.com/pytorch/pytorch/issues/15506
    y = x.new_zeros(x.shape)
    y[mask] = x[mask].sqrt()
    return y

D = 10
x = torch.rand(10, D * (D - 1) // 2)
value = torch.randn(10, D, D).tril()
value.diagonal(dim1=-2, dim2=-1).exp_()
y = value / value.norm(2, dim=-1, keepdim=True)

def f(x):
    # transform to tril matrix
    r, i = abc(x)

    # apply stick-breaking on the squared values
    z = r ** 2
    z_cumprod = (1 - z).cumprod(-1)

    # we omit the step computing s = z * z_cumprod by using the trick:
    #     y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
    y = r * defg(F.pad(z_cumprod[..., :-1], pad=(1, 0), value=1), i)
    return y


def f1(x):
    """
    Because domain and codomain are two spaces with different dimensions, determinant of
    Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the
    flatten lower triangular part of `y`.
    """
    # XXX the fastest way is to return log(y / x) as in StickBreakingTransform;
    # however doing so can suffer from the issue x = 0 or y = 0 (which is the same issue
    # with StickBreakingTransform's log_abs_det_jacobian);
    # fortunately, the probability of having x = 0 or y = 0 is zero, so we will hardly
    # suffer this issue in practice.
    # Here we will use a slower version but can solve that issue, using the fact that
    # y / x = sqrt(z_cumprod)  (modulo right shifted)
    r = _vector_to_lower_triangular_with_identity_diagonal(x)

    # apply stick-breaking on the squared values
    z = r ** 2
    z_cumprod = (1 - z).cumprod(-1)

    # we omit the step computing s = z * z_cumprod by using the trick:
    #     y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
    y = r * F.pad(_masked_sqrt(z_cumprod[..., :-1]), pad=(1, 0), value=1)
    return y

r = x.new_ones(x.shape[:-1] + (D, D)).tril()
tril_index = r.tril(diagonal=-1) > 0.5
r[tril_index] = x.reshape(-1)

# apply stick-breaking on the squared values
z = r ** 2
z_cumprod = (1 - z).cumprod(-1)
# we omit the step computing s = z * z_cumprod by using the trick:
#     y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)

# workaround the issue NaN propagated through backward pass even when not accessed
# https://github.com/pytorch/pytorch/issues/15506
# here we only take sqrt at mask index
z_cumprod_sqrt = z_cumprod.new_zeros(z_cumprod.shape)
z_cumprod_sqrt[tril_index] = z_cumprod[tril_index].sqrt()        
z_cumprod_sqrt_shifted = F.pad(z_cumprod_sqrt[..., :-1], pad=(1, 0), value=1)
y = r * z_cumprod_sqrt_shifted

x.shape

%debug

f1(x).shape

f(x).shape

%timeit f(x)

%timeit f1(x)

def _inverse(y):
    # inverse stick-breaking
    z_cumprod = 1 - y.pow(2).cumsum(-1)
    z_cumprod_shifted = F.pad(z_cumprod[..., :-1], pad=(1, 0), value=1)
    r = _lower_triangular_to_vector(y, diagonal=-1)\
        / _lower_triangular_to_vector(z_cumprod_shifted, diagonal=-1).sqrt()
    return r

def _inverse3(y):
    # inverse stick-breaking
    z_cumprod = 1 - y.pow(2).cumsum(-1)
    z_cumprod_shifted = F.pad(z_cumprod[..., :-1], pad=(1, 0), value=1)
    tril_index = y.new_ones(y.shape).tril(diagonal=-1) > 0.5
    r = y[tril_index] / z_cumprod_shifted[tril_index].sqrt()
    return r.reshape(y.shape[:-2] + (-1,))

def _inverse1(y):
    # inverse stick-breaking
    z_cumprod = 1 - y.pow(2).cumsum(-1)
    z_cumprod_shifted = F.pad(z_cumprod[..., :-1], pad=(1, 0), value=1)
    r = y / z_cumprod_shifted.sqrt()
    return _lower_triangular_to_vector(r, diagonal=-1)

def _inverse2(y):
    # inverse stick-breaking
    z_cumprod = 1 - y.pow(2).cumsum(-1)
    r = y / F.pad(_masked_sqrt(z_cumprod[..., :-1]), pad=(1, 0), value=1)
    return _lower_triangular_to_vector(r, diagonal=-1)

D = 100
value = torch.randn(D, D).tril()
value.diagonal(dim1=-2, dim2=-1).exp_()
y = value / value.norm(2, dim=-1, keepdim=True)

%timeit _inverse3(y)

%timeit _inverse(y)

%timeit _inverse1(y)

%timeit _inverse2(y)

tril_index = y.new_ones(y.shape).tril(diagonal=-1) > 0.5
y_tril_vector = y[tril_index]
y_tril_vector / x

def _autograd_log_det(ys, x):
    # computes log_abs_det_jacobian of y w.r.t. x
    return torch.stack([torch.autograd.grad(y, (x,), retain_graph=True)[0]
                        for y in ys]).slogdet()[1]

_autograd_log_det(y_tril_vector, x)

z_cumprod = 1 - y.pow(2).cumsum(-1)
z_cumprod

z_cumprod.sqrt()

y_tril_vector / x

(y_tril_vector / x).log().sum()

z = r ** 2

z_cumprod = (1 - z).cumprod(-1)

def _masked_sqrt(x):
    # hack around the issue NaN propagated through backward pass even when not accessed
    # https://github.com/pytorch/pytorch/issues/15506
    y = x.new_zeros(x.shape)
    mask = x > 0
    y[mask] = x[mask].sqrt()
    return y

t = F.pad(_masked_sqrt(z_cumprod), pad=(1, 0), value=1)
t

torch.autograd.grad(t[0, 0], (x,), retain_graph=True)

x = torch.zeros(1, requires_grad=True)
y = torch.nn.functional.pad(x.sqrt(), pad=(1, 0), value=1)
torch.autograd.grad(y[0], (x,))[0]

x = torch.ones(1, requires_grad=True)
#z = x.new_ones(2)
#z[1] = x.sqrt()
z = torch.cat((x.new_ones(1), x.sqrt()))
torch.autograd.grad(z[0], (x,))[0]

torch.autograd.grad(F.pad(z_cumprod[..., :-1].sqrt(), pad=(1, 0), value=1)[0, 0], (x,), retain_graph=True)

# apply stick-breaking on the squared values
z = r ** 2
z_cumprod = (1 - z).cumprod(-1)

# we omit the step to compute s = z * z_cumprod by using the trick:
#     y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
y = r * F.pad(z_cumprod[..., :-1].sqrt(), pad=(1, 0), value=1)
return y

x

y

y

def _vector_to_lower_triangular_with_identity_diagonal(x):
    # make sure that x.size(-1) = D * (D - 1) / 2 for some D
    D = (1 + math.sqrt(1 + 8 * x.size(-1))) / 2
    if D % 1 != 0:
        raise ValueError("This transformation requires an input with size D*(D-1)/2.")
    D = int(D)
    r = x.new_ones(x.shape[:-1] + (D, D)).tril()
    tril_index = r.tril(diagonal=-1) > 0.5
    r[tril_index] = x
    return r

def _autograd_log_det(ys, x):
    # computes log_abs_det_jacobian of y w.r.t. x
    return torch.stack([torch.autograd.grad(y, (x,), retain_graph=True)[0]
                        for y in ys]).slogdet()[1].log()

torch.autograd.grad(y[0, 0], (x,), retain_graph=True)

y_tril_vector

torch.stack([torch.autograd.grad(y, (x,), retain_graph=True)[0]
                        for y in y_tril_vector])

_autograd_log_det(y_tril_vector, x)

def _call(x):
    # transform to tril matrix
    r = _vector_to_lower_triangular_with_identity_diagonal(x)

    # apply stick-breaking on the squared values
    z = r ** 2
    z_cumprod = (1 - z).cumprod(-1).tril()

    # we omit the step to compute s = z * z_cumprod by using the trick:
    #     y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
    y = r * F.pad(z_cumprod[..., :-1].sqrt(), pad=(1, 0), value=1)
    return y

F.pad(z_cumprod.sqrt(), pad=(1, 0), value=1)

_call(torch.rand(6))

_vector_to_lower_triangular(torch.rand(6))

torch.rand(3, 3).tril(diagonal=-1)

D = 100
x = torch.rand(10, D, D).tril()
x_norm = x.norm(dim=-1)
y = x / x_norm.unsqueeze(-1)

%timeit f1(y)

%timeit f2(y)

f1(y)

f2(y)

x = torch.rand(10)

%timeit f1(x)

%timeit f2(x)

transform = _PartialCorrToCorrCholeskyTransform()

x = torch.randn((), requires_grad=True)
print(x)
y = x.tanh()
torch.autograd.grad(y.sum(), (x,))[0].log()

(1 - y**2).log()

(-2) * y.cosh().log()

x_shape = (6,)
x = torch.rand(x_shape, requires_grad=True) * 2 - 1
y = transform(x)
x

def _autograd_log_det(y, x):
    # computes log_abs_det_jacobian of y w.r.t. x
    triu_index = y.new_ones(y.shape).triu(diagonal=1) > 0.5
    y_tril_vector = y.t()[triu_index]
    return torch.stack([torch.autograd.grad(y, (x,), retain_graph=True)[0] for y in y_tril_vector]).det().abs().log()

with torch.autograd.set_detect_anomaly(True):
    D = 4
    y = x.new_zeros(x.shape[:-1] + (D, D))
    y[..., 0, 0] = 1
    y[..., 1:, 0] = x[..., :(D - 1)]
    pos_x = D - 1
    past_y_squared_sum = None
    print(y)
    # FIX ME: find a vectorized way to compute y instead of loop
    for j in range(1, D):
        if j == 1:
            past_y_squared_sum = y[..., j:, (j - 1)].pow(2)
        else:
            past_y_squared_sum = past_y_squared_sum[..., 1:] + y[..., j:, (j - 1)].pow(2)
        print(past_y_squared_sum)
        y[..., j, j] = (1 - past_y_squared_sum[..., :1]).sqrt()
        print(j, j)
        print(_autograd_log_det(y, x))
        new_pos_x = pos_x + D - 1 - j
        y[..., (j + 1):, j] = x[..., pos_x:new_pos_x] * (1 - past_y_squared_sum[..., 1:]).sqrt()
        print(range(j+1, D), j)
        pos_x = new_pos_x

x

triu_index = y.new_ones(y.shape).triu(diagonal=1) > 0.5
y_tril_vector = y.t()[triu_index]
_autograd_log_det(y_tril_vector, x)

x1 = y.new_ones(y.shape)
triu_index = x1.triu(diagonal=1) > 0.5
x1[..., :, 0] = y[..., :, 0]
x1[..., :, 1:] = y[..., :, 1:] / (1 - y.pow(2).cumsum(-1)[..., :, :-1]).sqrt()


x1.transpose(-1, -2)[triu_index]

torch.autograd.grad(y[1,0], (x,))

z = transform.inv(y)

y

transform.domain.check(z)

x = y.new_ones(y.shape)
print(x)
triu_index = x.triu(diagonal=1) > 0.5
x[..., :, 0] = y[..., :, 0]
print(x)
x[..., :, 1:] = y[..., :, 1:] / (1 - y.pow(2).cumsum(-1)[..., :, :-1]).sqrt()
print(x)
# we transpose and take upper triangular indices to arrange the result vector
# by (x21, x31, x41,..., x32, x42,...) instead of (x21, x31, x32, x41, x42,...)
z = x.transpose(-1, -2)[triu_index]

x

z

# test codomain
assert_tensors_equal(transform.codomain.check(y), torch.ones(x.shape[:-1]))

# test inv
z = transform.inv(y)
assert_tensors_equal(x, z)

torch.arange(10).reshape(2, 5).cumsum(-1)

torch.rand(3, 3).tril(diagonal=-1)

x = torch.rand(3, 3)

(1 - x.pow(2).cumsum(-1)).tril(diagonal=-1)

D = 4
x = torch.rand(D, D)

def f1(x):
    y = x.new_zeros(x.shape)
    y[..., 0, :] = x[..., 0, :]
    y[..., 1:, :] = x[..., 1:, :] / (1 - x.pow(2).cumsum(-1)[..., :-1, :]).sqrt()
    triu_index = x.new_ones(D, D).triu(diagonal=1) > 0.5
    return y.t()[triu_index]

f1(x)


def _inverse(x):
    if (x.shape[0] != x.shape[1]):
        raise ValueError("A matrix that isn't square can't be a Cholesky factor of a correlation matrix")
    D = x.shape[0]

    z_stack = [
        x[1:, 0]
    ]
    current_x = z_stack[0]
    last_squared_x = None
    for j in range(1, D):
        if last_squared_x is None:
            last_squared_x = current_x**2
        else:
            last_squared_x += current_x[1:]**2
    current_x = x[j:, j]
    z_stack.append(current_x / (1 - last_squared_x).sqrt())
    z = torch.cat(z_stack)
    return z

_inverse(x)

%timeit f1(x)

x.pow(2)

f1(x)

x.triu(diagonal=-1)

x

x.pow(2).cumsum(-1)[..., :-1]

x[..., 1:]

x[..., :-1]

y[:, 0] = x[..., 0]

y[]