## This notebook demonstrates the use of the factorized log prob method in gensn

In [1]:
import torch
from torch import nn
import torch.distributions as D

import matplotlib.pyplot as plt
import seaborn as sns

from gensn.distributions import TrainableDistributionAdapter, Joint
from gensn.parameters import TransformedParameter
from gensn.variational import ELBOMarginal, VariationalDequantizedDistribution
from gensn.flow import FlowDistribution
import gensn.transforms.invertible as T
import gensn.distributions as G
import torch.distributions as D

from gensn.utils import squeeze_tuple, turn_to_tuple


seed = 100
torch.manual_seed(seed);

## Factorized log prob for independent distributions

### Create a 2d independent distribution

In [2]:
n_dims = 2
loc = torch.zeros(n_dims)
scale = torch.ones(n_dims)
dist = G.IndependentNormal(loc, scale)

### Call log prob and factorized log prob

In [3]:
x = torch.zeros(n_dims)
lp = dist.log_prob(x)
print(lp, lp.shape)

tensor(-1.8379) torch.Size([])


In [4]:
flp = dist.factorized_log_prob(x)
print(flp, flp.shape)

tensor([-0.9189, -0.9189]) torch.Size([2])


In [5]:
assert torch.isclose(flp.sum(), lp)

## Factorized log prob for flow distributions

### First create factorized transformations

In [6]:
affine = T.IndependentAffine(n_dims)

In [7]:
# prepare an initializer
init_std = 0.1
def init_module(module):
    if isinstance(module, T.IndependentAffine):
        module.weight.data.normal_(mean=1.0, std=init_std)
        module.bias.data.normal_(std=init_std * 0.1)

affine.apply(init_module);

### Call forward and factorized forward

In [8]:
y, log_det = affine(x)
print(y, y.shape)
print(log_det, log_det.shape)

tensor([-0.0039,  0.0024], grad_fn=<AddBackward0>) torch.Size([2])
tensor(0.0064, grad_fn=<SumBackward1>) torch.Size([])


In [9]:
fy, flog_det = affine.factorized_forward(x)
print(fy, fy.shape)
print(flog_det, flog_det.shape)

tensor([-0.0039,  0.0024], grad_fn=<AddBackward0>) torch.Size([2])
tensor([ 0.0354, -0.0290], grad_fn=<MulBackward0>) torch.Size([2])


In [10]:
assert torch.isclose(flog_det.sum(), log_det)

### Create a inverse transform

In [11]:
inv_affine = T.InverseTransform(T.IndependentAffine(n_dims))

inv_affine.apply(init_module);

In [12]:
y, log_det = inv_affine(x)
print(y, y.shape)
print(log_det, log_det.shape)

tensor([0.0037, 0.0113], grad_fn=<DivBackward0>) torch.Size([2])
tensor(0.4120, grad_fn=<SumBackward1>) torch.Size([])


In [13]:
fy, flog_det = inv_affine.factorized_forward(x)
print(fy, fy.shape)
print(flog_det, flog_det.shape)

tensor([0.0037, 0.0113], grad_fn=<DivBackward0>) torch.Size([2])
tensor([0.1489, 0.2631], grad_fn=<NegBackward0>) torch.Size([2])


In [14]:
assert torch.isclose(flog_det.sum(), log_det)

### Create a sequential transform

In [15]:
transform_sequence = [
    T.InverseTransform(T.Softplus()),
    T.IndependentAffine(n_dims),
    T.ELU(),
    T.IndependentAffine(n_dims),
    T.ELU(),
    T.IndependentAffine(n_dims),
]

sequential = T.SequentialTransform(*transform_sequence)

sequential.apply(init_module);

In [16]:
y, log_det = sequential(x)
print(y, y.shape)
print(log_det, log_det.shape)

tensor([-0.5234, -0.6372], grad_fn=<AddBackward0>) torch.Size([2])
tensor(-14.7422, grad_fn=<AddBackward0>) torch.Size([])


In [17]:
fy, flog_det = sequential.factorized_forward(x)
print(fy, fy.shape)
print(flog_det, flog_det.shape)

tensor([-0.5234, -0.6372], grad_fn=<AddBackward0>) torch.Size([2])
tensor([-16.1689,   1.4267], grad_fn=<AddBackward0>) torch.Size([2])


In [18]:
assert torch.isclose(flog_det.sum(), log_det)

### Create a flow distribution

In [19]:
flow_base_dist = G.IndependentNormal(loc, scale)
flow_dist = FlowDistribution(flow_base_dist, sequential)

In [20]:
lp = flow_dist.log_prob(x)
print(lp, lp.shape)

tensor(-16.9201, grad_fn=<AddBackward0>) torch.Size([])


In [21]:
flp = flow_dist.factorized_log_prob(x)
print(flp, flp.shape)

tensor([-17.2248,   0.3048], grad_fn=<AddBackward0>) torch.Size([2])


In [22]:
assert torch.isclose(flp.sum(), lp)

## Factorized elbo for variational dequantized distributions with normalizing flows

In [23]:
prior_base_dist = G.IndependentNormal(
    loc=torch.zeros(n_dims), scale=torch.ones(n_dims)
)
prior_dist = FlowDistribution(prior_base_dist, sequential)

dequant_base_dist = G.IndependentLaplace(
    loc=torch.zeros(n_dims),
    scale=torch.ones(n_dims),
)
dequant_dist = FlowDistribution(dequant_base_dist, sequential)

vdd = VariationalDequantizedDistribution(
    prior=prior_dist,
    dequantizer=dequant_dist,
)

In [24]:
elbo = vdd(x, n_samples=10_000_000)
print(elbo, elbo.shape)

tensor(-0.1279, grad_fn=<MeanBackward1>) torch.Size([])


In [25]:
felbo = vdd.factorized_elbo(x, n_samples=10_000_000)
print(flp, flp.shape)

tensor([-17.2248,   0.3048], grad_fn=<AddBackward0>) torch.Size([2])


In [26]:
felbo.sum() - elbo

tensor(-0.0010, grad_fn=<SubBackward0>)

In [27]:
assert torch.isclose(felbo.sum(), elbo)

AssertionError: 

In [28]:
iw_bound = vdd.iw_bound(x, n_samples=10_000_000)
print(iw_bound, iw_bound.shape)

tensor(0.1083, grad_fn=<SubBackward0>) torch.Size([])


In [29]:
fiw_bound = vdd.factorized_iw_bound(x, n_samples=10_000_000)
print(fiw_bound, fiw_bound.shape)

tensor([0.0466, 0.0617], grad_fn=<SubBackward0>) torch.Size([2])


In [30]:
fiw_bound.sum() - iw_bound

tensor(-2.6703e-05, grad_fn=<SubBackward0>)

In [31]:
assert torch.isclose(fiw_bound.sum(), iw_bound)

AssertionError: 