In [1]:
from gensn.distributions import TrainableDistributionAdapter, Joint
from gensn.variational import ELBOMarginal
from gensn.parameters import TransformedParameter, PositiveDiagonal, Covariance

In [3]:
import torch
from torch import nn
from torch.optim import Adam, SGD
from torch.distributions import Normal, MultivariateNormal

## Example usage: learn Normal distribution

### Learn only the mean

In [4]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               loc=nn.Parameter(torch.Tensor([0.0])), 
                               scale=torch.Tensor([2.0]))

In [5]:
normal.state_dict()

OrderedDict([('loc', tensor([0.])), ('scale', tensor([2.]))])

In [12]:
list(normal.parameters())

[Parameter containing:
 tensor([0.], requires_grad=True)]

In [13]:
# setup target normal distribution to learn
target = Normal(torch.Tensor([5.0]), torch.Tensor([2.0]))

In [15]:
optim = Adam(normal.parameters(), lr=0.5)

In [16]:
for i in range(1000):
    optim.zero_grad()
    targets = target.sample((100,))
    nlogp = -normal(targets).mean()
    mean = normal.loc.detach()
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, mean={mean}')
    nlogp.backward()
    optim.step()

Neg logP: 2.1155459880828857, mean=tensor([4.9734])
Neg logP: 2.1493823528289795, mean=tensor([4.9213])
Neg logP: 2.211949348449707, mean=tensor([4.9919])
Neg logP: 2.0131587982177734, mean=tensor([5.1923])
Neg logP: 2.2272799015045166, mean=tensor([4.9601])
Neg logP: 2.091372013092041, mean=tensor([5.1384])
Neg logP: 2.0637428760528564, mean=tensor([5.2518])
Neg logP: 2.1895248889923096, mean=tensor([5.1309])
Neg logP: 2.1547136306762695, mean=tensor([4.9220])
Neg logP: 2.1652190685272217, mean=tensor([5.2024])


### Learn both mean and stdev

In [17]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               loc=nn.Parameter(torch.Tensor([0.0])), 
                               scale=TransformedParameter(torch.Tensor([1.0]), lambda x: x**2))

In [18]:
# setup target normal distribution to learn
target = Normal(torch.Tensor([5]), torch.Tensor([8]))

In [19]:
optim = Adam(normal.parameters(), lr=1.0)

In [20]:
for i in range(1000):
    optim.zero_grad()
    targets = target.sample((100,))
    nlogp = -normal(targets).mean()
    mean = normal.loc.detach()
    std = normal.scale.value.detach()
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, mean={mean}, std={std}')
    nlogp.backward()
    optim.step()

Neg logP: 4.5553154945373535, mean=tensor([6.3387]), std=tensor([37.1205])
Neg logP: 3.6742889881134033, mean=tensor([5.9076]), std=tensor([13.9000])
Neg logP: 3.489955186843872, mean=tensor([5.0773]), std=tensor([8.1060])
Neg logP: 3.5183615684509277, mean=tensor([4.9740]), std=tensor([7.9830])
Neg logP: 3.4774651527404785, mean=tensor([4.9874]), std=tensor([7.9501])
Neg logP: 3.4990129470825195, mean=tensor([5.2858]), std=tensor([7.5867])
Neg logP: 3.4986166954040527, mean=tensor([4.8456]), std=tensor([7.8560])
Neg logP: 3.4958086013793945, mean=tensor([4.7768]), std=tensor([7.9615])
Neg logP: 3.541909694671631, mean=tensor([4.8205]), std=tensor([7.5591])
Neg logP: 3.453545570373535, mean=tensor([4.9117]), std=tensor([7.9764])


### Learn both mean and stdev (same as above, but specified positionally)

In [21]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               nn.Parameter(torch.Tensor([0.0])), 
                               TransformedParameter(torch.Tensor([1.0]), lambda x: x**2))

In [22]:
# setup target normal distribution to learn
target = Normal(torch.Tensor([5]), torch.Tensor([8]))

In [23]:
optim = Adam(normal.parameters(), lr=1.0)

In [24]:
for i in range(1000):
    optim.zero_grad()
    targets = target.sample((100,))
    nlogp = -normal(targets).mean()
    mean = normal._arg0.detach()
    std = normal._arg1.value.detach()
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, mean={mean}, std={std}')
    nlogp.backward()
    optim.step()

Neg logP: 4.551199436187744, mean=tensor([6.3952]), std=tensor([36.8402])
Neg logP: 3.6840832233428955, mean=tensor([5.9138]), std=tensor([13.3099])
Neg logP: 3.5708117485046387, mean=tensor([4.9872]), std=tensor([8.0710])
Neg logP: 3.582395076751709, mean=tensor([5.1409]), std=tensor([8.0214])
Neg logP: 3.56783390045166, mean=tensor([5.0813]), std=tensor([7.7231])
Neg logP: 3.5683016777038574, mean=tensor([5.0931]), std=tensor([8.2432])
Neg logP: 3.538731575012207, mean=tensor([4.8726]), std=tensor([8.0664])
Neg logP: 3.4939463138580322, mean=tensor([5.1158]), std=tensor([7.8221])
Neg logP: 3.5021615028381348, mean=tensor([5.0405]), std=tensor([7.8277])
Neg logP: 3.493650436401367, mean=tensor([5.1432]), std=tensor([8.1597])


# Conditional case

## Simple conditioning

Now let us learn more complex relationship $p(z|x)$. Specifically, let $p(z|x) = \mathcal{N}(f(x), \sigma^2)$.

For simplicity, we'll assume a simple linear mapping for $f(x)$

Prepare data generator

In [25]:
def get_batch(batch_size):
    x = torch.rand((batch_size, 1))
    mu = -5 * x + 9
    y = Normal(mu, scale=7).sample((batch_size,))
    return x, y

Set up the conditional network:

In [26]:
# make the mean & std both learnable
# std (scale) is set to be abs of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               loc=nn.Linear(1, 1), 
                               scale=TransformedParameter(torch.Tensor([1.0]), torch.abs))

In [27]:
optim = Adam(normal.parameters(), lr=1.0)

In [28]:
for i in range(1000):
    optim.zero_grad()
    x, y = get_batch(100)
    nlogp = -normal(y, cond=x).mean()
    params = {k:v.detach() for k,v in normal.state_dict().items()}
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, params={params}')
    nlogp.backward()
    optim.step()

Neg logP: 3.4037044048309326, params={'loc.weight': tensor([[0.5491]]), 'loc.bias': tensor([5.8328]), 'scale.parameter': tensor([7.8711])}
Neg logP: 3.381765604019165, params={'loc.weight': tensor([[-1.3447]]), 'loc.bias': tensor([6.9216]), 'scale.parameter': tensor([7.5908])}
Neg logP: 3.3738582134246826, params={'loc.weight': tensor([[-3.0239]]), 'loc.bias': tensor([7.8740]), 'scale.parameter': tensor([7.3113])}
Neg logP: 3.354703187942505, params={'loc.weight': tensor([[-4.1132]]), 'loc.bias': tensor([8.4775]), 'scale.parameter': tensor([7.1333])}
Neg logP: 3.3553593158721924, params={'loc.weight': tensor([[-4.6604]]), 'loc.bias': tensor([8.8195]), 'scale.parameter': tensor([7.0414])}
Neg logP: 3.3557655811309814, params={'loc.weight': tensor([[-4.8893]]), 'loc.bias': tensor([8.9620]), 'scale.parameter': tensor([7.0145])}
Neg logP: 3.3629565238952637, params={'loc.weight': tensor([[-4.9671]]), 'loc.bias': tensor([8.9659]), 'scale.parameter': tensor([7.0023])}
Neg logP: 3.37320590019

## One network with multiple outputs (returning dict):

In [29]:
class NormalParams(nn.Module):
    def __init__(self):
        super().__init__()
        self.core = nn.Linear(1, 2)
        
    def forward(self, x):
        vals = self.core(x)
        #loc, scale = vals.split((1,1), dim=1)
        return dict(loc=vals[:,0:1], scale=(vals[:, 1:])**2)
        #return dict(loc=loc, scale=scale**2)

In [30]:
def get_batch(batch_size):
    x = torch.rand((batch_size, 1))
    mu = -5 * x + 9
    scale = (3 * x + 1)
    model = Normal(mu, scale=scale)
    y = model.sample((batch_size,))
    return x, y, model.log_prob(y)

In [31]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, _parameters=NormalParams())

In [32]:
optim = Adam(normal.parameters(), lr=1)

In [33]:
for i in range(1000):
    optim.zero_grad()
    x, y, true_logp = get_batch(100)
    nlogp = -normal(y, cond=x).mean()
    params = {k:v.detach() for k,v in normal.state_dict().items()}
    if i % 100 == 0:
        print(f'Neg logP: {nlogp:0.3f} (gt={-true_logp.mean():0.3f}), params={params}')
    nlogp.backward()
    optim.step()

Neg logP: 879.588 (gt=2.240), params={'parameter_generator.core.weight': tensor([[-0.3289],
        [-0.2346]]), 'parameter_generator.core.bias': tensor([ 0.6627, -0.3280])}
Neg logP: 5.413 (gt=2.316), params={'parameter_generator.core.weight': tensor([[ 5.6894],
        [-6.2071]]), 'parameter_generator.core.bias': tensor([ 6.6866, -6.3188])}
Neg logP: 5.295 (gt=2.197), params={'parameter_generator.core.weight': tensor([[ 5.6862],
        [-6.1472]]), 'parameter_generator.core.bias': tensor([ 6.6860, -6.2889])}
Neg logP: 5.308 (gt=2.240), params={'parameter_generator.core.weight': tensor([[ 5.6816],
        [-6.0674]]), 'parameter_generator.core.bias': tensor([ 6.6849, -6.2484])}
Neg logP: 5.273 (gt=2.231), params={'parameter_generator.core.weight': tensor([[ 5.6757],
        [-5.9698]]), 'parameter_generator.core.bias': tensor([ 6.6836, -6.1987])}
Neg logP: 5.364 (gt=2.331), params={'parameter_generator.core.weight': tensor([[ 5.6686],
        [-5.8548]]), 'parameter_generator.core.b

## One network with multiple outputs (returning positionally):

In [34]:
class NormalParams(nn.Module):
    def __init__(self):
        super().__init__()
        self.core = nn.Linear(1, 2)
        
    def forward(self, x):
        vals = self.core(x)
        #loc, scale = vals.split((1,1), dim=1)
        return vals[:,0:1], (vals[:, 1:])**2
        #return dict(loc=loc, scale=scale**2)

In [35]:
def get_batch(batch_size):
    x = torch.rand((batch_size, 1))
    mu = -5 * x + 9
    scale = (3 * x + 1)
    model = Normal(mu, scale=scale)
    y = model.sample((batch_size,))
    return x, y, model.log_prob(y)

In [36]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, _parameters=NormalParams())

In [37]:
optim = Adam(normal.parameters(), lr=1)

In [38]:
for i in range(1000):
    optim.zero_grad()
    x, y, true_logp = get_batch(100)
    nlogp = -normal(y, cond=x).mean()
    params = {k:v.detach() for k,v in normal.state_dict().items()}
    if i % 100 == 0:
        print(f'Neg logP: {nlogp:0.3f} (gt={-true_logp.mean():0.3f}), params={params}')
    nlogp.backward()
    optim.step()

Neg logP: 3234.762 (gt=2.269), params={'parameter_generator.core.weight': tensor([[-0.6911],
        [-0.8402]]), 'parameter_generator.core.bias': tensor([0.3904, 0.9635])}
Neg logP: 5.410 (gt=2.277), params={'parameter_generator.core.weight': tensor([[5.3150],
        [5.1641]]), 'parameter_generator.core.bias': tensor([6.3980, 6.9668])}
Neg logP: 5.463 (gt=2.332), params={'parameter_generator.core.weight': tensor([[5.3151],
        [5.1626]]), 'parameter_generator.core.bias': tensor([6.3982, 6.9634])}
Neg logP: 5.388 (gt=2.249), params={'parameter_generator.core.weight': tensor([[5.3148],
        [5.1603]]), 'parameter_generator.core.bias': tensor([6.3979, 6.9586])}
Neg logP: 5.361 (gt=2.222), params={'parameter_generator.core.weight': tensor([[5.3144],
        [5.1575]]), 'parameter_generator.core.bias': tensor([6.3976, 6.9528])}
Neg logP: 5.417 (gt=2.289), params={'parameter_generator.core.weight': tensor([[5.3139],
        [5.1542]]), 'parameter_generator.core.bias': tensor([6.397

# Test sampling and joint

In [39]:
prior = TrainableDistributionAdapter(Normal, loc=nn.Parameter(torch.Tensor([5])),
                                     scale=torch.Tensor([2]))

linear = nn.Linear(1, 1)
linear.weight.data = torch.Tensor([[-2]])
linear.bias.data = torch.Tensor([6])
conditional = TrainableDistributionAdapter(Normal, loc=linear, scale=torch.Tensor([1]))

# create a joint distribution out of prior and conditional
joint = Joint(prior, conditional)

In [40]:
x, y = joint.sample((10000,))

Should be 5, 2

In [41]:
x.mean(), x.std()

(tensor(5.0032), tensor(2.0049))

Should be -4, 4

In [42]:
y.mean(), y.std()

(tensor(-4.0130), tensor(4.1541))

# Simulating a multi-dimensional joint distribution

### Define the ground-truth generative model

In [47]:
n_latents = 5
gt_prior = TrainableDistributionAdapter(MultivariateNormal, 
                                        torch.ones([n_latents]),
                                        torch.eye(n_latents))

n_obs = 10
features = torch.randn([n_latents, n_obs])

feature_map = nn.Linear(n_latents, n_obs)
feature_map.weight.data = features.T
feature_map.bias.data.zero_()

gt_conditional = TrainableDistributionAdapter(MultivariateNormal,
                                              feature_map,
                                              torch.eye(n_obs))

gt_joint = Joint(gt_prior, gt_conditional)

### Now prepare a trainable model

class Covariance(nn.Module):
    def __init__(self, n_dim):
        

In [48]:
n_latents = 5
model_prior = TrainableDistributionAdapter(MultivariateNormal, 
                                        nn.Parameter(torch.zeros([n_latents])),
                                        PositiveDiagonal(n_latents))

n_obs = 10
features = torch.randn([n_latents, n_obs])

feature_map = nn.Linear(n_latents, n_obs)
#feature_map.weight.data = features.T
#feature_map.bias.data.zero_()

model_conditional = TrainableDistributionAdapter(MultivariateNormal,
                                              feature_map,
                                              torch.eye(n_obs))

model_joint = Joint(model_prior, model_conditional)

### Go ahead and train the joint model

In [49]:
optim = Adam(model_joint.parameters(), lr=0.5)

In [50]:
for i in range(1000):
    optim.zero_grad()
    samples = gt_joint.sample((100,))
    gt_logl = gt_joint.log_prob(*samples).mean()
    model_logl = model_joint.log_prob(*samples).mean()
    if i % 100 == 0:
        print(f'Model logp: {model_logl:.3f} / GT logp: {gt_logl:.3f}')
    (-model_logl).backward()
    optim.step()

Model logp: -115.531 / GT logp: -21.522
Model logp: -21.543 / GT logp: -21.270
Model logp: -21.440 / GT logp: -21.152
Model logp: -21.756 / GT logp: -21.236
Model logp: -21.606 / GT logp: -21.260
Model logp: -21.964 / GT logp: -21.495
Model logp: -21.295 / GT logp: -21.023
Model logp: -21.824 / GT logp: -21.071
Model logp: -21.861 / GT logp: -21.264
Model logp: -21.820 / GT logp: -21.440


Checking the learned parameters

In [51]:
model_joint.prior._arg1()

tensor([[0.9212, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.2304, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.5738, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 1.0520, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.0407]], grad_fn=<DiagBackward0>)

In [52]:
model_joint.prior.state_dict()

OrderedDict([('_arg0', tensor([0.8332, 0.8988, 0.9068, 0.8510, 0.7687])),
             ('_arg1.D',
              tensor([-0.9598, -1.1092, -1.2545,  1.0257, -1.0201]))])

In [53]:
model_joint.conditional.state_dict()

OrderedDict([('_arg1',
              tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                      [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
                      [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
                      [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
                      [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
                      [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
                      [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])),
             ('_arg0.weight',
              tensor([[-0.9050,  1.5812,  0.4306,  0.5078, -1.2745],
                      [ 0.0943,  1.1739, -1.0866, -0.3643,  2.3895],
                      [ 0.0312, -1.0371,  0.4130,  0.8667, -0.7674],
                      [ 0.3733,  0.1724,  1.1796,  0.7658,  1.0307],
                      [ 0.84

## Compute the posterior distriubtion via ELBO

First, we'll train the posterior for the ground-truth model, using ELBO

In [54]:
# prepare a posterior distribution
features = torch.randn([n_obs, n_latents])

linear_map = nn.Linear(n_obs, n_latents)

posterior = TrainableDistributionAdapter(MultivariateNormal,
                                              linear_map,
                                              Covariance(n_latents, rank=5, eps=1e-6))


In [55]:
elbo_x = ELBOMarginal(gt_joint, posterior)

# only training the posterior
optim = Adam(posterior.parameters(), lr=1e-2)

In [56]:
z_sample, x_sample = gt_joint.sample((1000,))

for i in range(1000):
    optim.zero_grad()
    elbo = elbo_x(x_sample).mean()
    if i % 100 == 0:
        print(f'Model elbo: {elbo:.3f}')
    (-elbo).backward()
    optim.step()

Model elbo: -726.289
Model elbo: -42.695
Model elbo: -24.010
Model elbo: -20.559
Model elbo: -20.051
Model elbo: -19.918
Model elbo: -19.868
Model elbo: -19.838
Model elbo: -19.816
Model elbo: -19.827


### Evaluate the posterior
Here we will evaluate how good the posterior is (roughly) by sampling $\hat{z}$ from the posterior $p(z|x)$ and evaluate the expected $\log p(\hat{z}, x)$.
If $\hat{z}$ approximates the true distribution over $z$, then $\log p(\hat{z}, x)$ will closely apprximate the expected $\log p(z, x)$, which is negative of entropy.

In [57]:
z_sample, x_sample = gt_joint.sample((1000,))

In [58]:
# ground-truth negative entropy
gt_joint.log_prob(z_sample, x_sample).mean()

tensor(-21.3407, grad_fn=<MeanBackward0>)

In [59]:
# sample from the trained posterior
z_hat = posterior.sample(cond=x_sample)

# 
# negative entropy of the approximation
gt_joint.log_prob(z_hat, x_sample).mean()

tensor(-21.3697, grad_fn=<MeanBackward0>)

## Compute the posterior distriubtion via direct fit to samples

Now, we'll train the posterior for the ground-truth model by training directly on the samples

In [60]:
# prepare a posterior distribution
features = torch.randn([n_obs, n_latents])

linear_map = nn.Linear(n_obs, n_latents)

posterior = TrainableDistributionAdapter(MultivariateNormal,
                                              linear_map,
                                              Covariance(n_latents, rank=5, eps=1e-6))


In [61]:
# only training the posterior
optim = Adam(posterior.parameters(), lr=1e-2)

In [62]:

for i in range(1000):
    optim.zero_grad()
    z_sample, x_sample = gt_joint.sample((100,))
    logp = posterior.log_prob(z_sample, cond=x_sample).mean()
    if i % 100 == 0:
        print(f'LogP: {logp:.3f}')
    (-logp).backward()
    optim.step()

LogP: -142.656
LogP: -7.445
LogP: -6.255
LogP: -4.870
LogP: -4.178
LogP: -3.914
LogP: -3.563
LogP: -3.272
LogP: -3.157
LogP: -2.958


### Evaluate the posterior
Here we will evaluate how good the posterior is (roughly) by sampling $\hat{z}$ from the posterior $p(z|x)$ and evaluate the expected $\log p(\hat{z}, x)$.
If $\hat{z}$ approximates the true distribution over $z$, then $\log p(\hat{z}, x)$ will closely apprximate the expected $\log p(z, x)$, which is negative of entropy.

In [63]:
z_sample, x_sample = gt_joint.sample((1000,))

In [64]:
# ground-truth negative entropy
gt_joint.log_prob(z_sample, x_sample).mean()

tensor(-21.2337, grad_fn=<MeanBackward0>)

In [65]:
# sample from the trained posterior
z_hat = posterior.sample(cond=x_sample)

# 
# negative entropy of the approximation
gt_joint.log_prob(z_hat, x_sample).mean()

tensor(-37.6097, grad_fn=<MeanBackward0>)