In [10]:
from gensn.distributions import TrainableDistributionAdapter, Joint
from gensn.variational import ELBOMarginal
from gensn.parameters import TransformedParameter, PositiveDiagonal, Covariance

In [11]:
import torch
from torch import nn
from torch.optim import Adam
from torch.distributions import Normal, MultivariateNormal

## Example usage: learn Normal distribution

### Learn only the mean

In [12]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               loc=nn.Parameter(torch.Tensor([0.0])), 
                               scale=torch.Tensor([2.0]))

In [13]:
normal.state_dict()

OrderedDict([('loc', tensor([0.])), ('scale', tensor([2.]))])

In [14]:
list(normal.parameters())

[Parameter containing:
 tensor([0.], requires_grad=True)]

In [15]:
# setup target normal distribution to learn
target = Normal(torch.Tensor([5.0]), torch.Tensor([2.0]))

In [16]:
optim = Adam(normal.parameters(), lr=0.5)

In [17]:
for i in range(1000):
    optim.zero_grad()
    targets = target.sample((100,))
    nlogp = -normal(targets).mean()
    mean = normal.loc.detach()
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, mean={mean}')
    nlogp.backward()
    optim.step()

Neg logP: 2.0490012168884277, mean=tensor([4.9324])
Neg logP: 2.1136226654052734, mean=tensor([4.9247])
Neg logP: 2.0624942779541016, mean=tensor([5.0990])
Neg logP: 2.133841037750244, mean=tensor([5.0030])
Neg logP: 2.133676290512085, mean=tensor([4.9113])
Neg logP: 2.116976737976074, mean=tensor([5.1844])
Neg logP: 2.1371238231658936, mean=tensor([4.8856])
Neg logP: 2.0871145725250244, mean=tensor([5.2740])
Neg logP: 2.124019145965576, mean=tensor([5.1323])
Neg logP: 2.1224982738494873, mean=tensor([4.9836])


### Learn both mean and stdev

In [18]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               loc=nn.Parameter(torch.Tensor([0.0])), 
                               scale=TransformedParameter(torch.Tensor([1.0]), lambda x: x**2))

In [19]:
# setup target normal distribution to learn
target = Normal(torch.Tensor([5]), torch.Tensor([8]))

In [20]:
optim = Adam(normal.parameters(), lr=1.0)

In [21]:
for i in range(1000):
    optim.zero_grad()
    targets = target.sample((100,))
    nlogp = -normal(targets).mean()
    mean = normal.loc.detach()
    std = normal.scale.value.detach()
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, mean={mean}, std={std}')
    nlogp.backward()
    optim.step()

Neg logP: 4.549954891204834, mean=tensor([6.2335]), std=tensor([36.8581])
Neg logP: 3.7240371704101562, mean=tensor([5.9152]), std=tensor([14.1016])
Neg logP: 3.4181244373321533, mean=tensor([5.0096]), std=tensor([7.9622])
Neg logP: 3.5762038230895996, mean=tensor([5.0104]), std=tensor([7.8075])
Neg logP: 3.583933115005493, mean=tensor([4.8745]), std=tensor([7.9401])
Neg logP: 3.420821189880371, mean=tensor([5.1534]), std=tensor([8.1514])
Neg logP: 3.516340970993042, mean=tensor([5.0027]), std=tensor([8.2877])
Neg logP: 3.6439175605773926, mean=tensor([5.0950]), std=tensor([8.2710])
Neg logP: 3.349782943725586, mean=tensor([5.1073]), std=tensor([8.0211])
Neg logP: 3.498356819152832, mean=tensor([5.1375]), std=tensor([7.9661])


### Learn both mean and stdev (same as above, but specified positionally)

In [22]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               nn.Parameter(torch.Tensor([0.0])), 
                               TransformedParameter(torch.Tensor([1.0]), lambda x: x**2))

In [23]:
# setup target normal distribution to learn
target = Normal(torch.Tensor([5]), torch.Tensor([8]))

In [24]:
optim = Adam(normal.parameters(), lr=1.0)

In [25]:
for i in range(1000):
    optim.zero_grad()
    targets = target.sample((100,))
    nlogp = -normal(targets).mean()
    mean = normal._arg0.detach()
    std = normal._arg1.value.detach()
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, mean={mean}, std={std}')
    nlogp.backward()
    optim.step()

Neg logP: 4.557194709777832, mean=tensor([6.3509]), std=tensor([37.3025])
Neg logP: 3.725822687149048, mean=tensor([5.8476]), std=tensor([14.1823])
Neg logP: 3.412493944168091, mean=tensor([5.0580]), std=tensor([8.1510])
Neg logP: 3.450681447982788, mean=tensor([4.7119]), std=tensor([7.9666])
Neg logP: 3.5525712966918945, mean=tensor([4.7050]), std=tensor([8.1479])
Neg logP: 3.6055593490600586, mean=tensor([5.1155]), std=tensor([8.0523])
Neg logP: 3.575485944747925, mean=tensor([4.8290]), std=tensor([7.9291])
Neg logP: 3.5787813663482666, mean=tensor([4.8853]), std=tensor([7.9831])
Neg logP: 3.444620370864868, mean=tensor([4.5961]), std=tensor([7.7306])
Neg logP: 3.4562206268310547, mean=tensor([5.1633]), std=tensor([7.7799])


# Conditional case

## Simple conditioning

Now let us learn more complex relationship $p(z|x)$. Specifically, let $p(z|x) = \mathcal{N}(f(x), \sigma^2)$.

For simplicity, we'll assume a simple linear mapping for $f(x)$

Prepare data generator

In [26]:
def get_batch(batch_size):
    x = torch.rand((batch_size, 1))
    mu = -5 * x + 9
    y = Normal(mu, scale=7).sample((batch_size,))
    return x, y

Set up the conditional network:

In [27]:
# make the mean & std both learnable
# std (scale) is set to be abs of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               loc=nn.Linear(1, 1), 
                               scale=TransformedParameter(torch.Tensor([1.0]), torch.abs))

In [28]:
optim = Adam(normal.parameters(), lr=1.0)

In [29]:
for i in range(1000):
    optim.zero_grad()
    x, y = get_batch(100)
    nlogp = -normal(y, cond=x).mean()
    params = {k:v.detach() for k,v in normal.state_dict().items()}
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, params={params}')
    nlogp.backward()
    optim.step()

Neg logP: 3.406355142593384, params={'loc.weight': tensor([[0.6536]]), 'loc.bias': tensor([5.7188]), 'scale.parameter': tensor([7.9252])}
Neg logP: 3.37754225730896, params={'loc.weight': tensor([[-1.3806]]), 'loc.bias': tensor([6.9814]), 'scale.parameter': tensor([7.6032])}
Neg logP: 3.378467559814453, params={'loc.weight': tensor([[-3.2067]]), 'loc.bias': tensor([7.9680]), 'scale.parameter': tensor([7.2977])}
Neg logP: 3.3775687217712402, params={'loc.weight': tensor([[-4.3148]]), 'loc.bias': tensor([8.5797]), 'scale.parameter': tensor([7.1092])}
Neg logP: 3.3665573596954346, params={'loc.weight': tensor([[-4.7426]]), 'loc.bias': tensor([8.8968]), 'scale.parameter': tensor([7.0250])}
Neg logP: 3.3628203868865967, params={'loc.weight': tensor([[-4.9615]]), 'loc.bias': tensor([8.9708]), 'scale.parameter': tensor([7.0068])}
Neg logP: 3.3707962036132812, params={'loc.weight': tensor([[-4.9755]]), 'loc.bias': tensor([9.0007]), 'scale.parameter': tensor([7.0055])}
Neg logP: 3.3539960384368

## One network with multiple outputs (returning dict):

In [30]:
class NormalParams(nn.Module):
    def __init__(self):
        super().__init__()
        self.core = nn.Linear(1, 2)
        
    def forward(self, x):
        vals = self.core(x)
        #loc, scale = vals.split((1,1), dim=1)
        return dict(loc=vals[:,0:1], scale=(vals[:, 1:])**2)
        #return dict(loc=loc, scale=scale**2)

In [31]:
def get_batch(batch_size):
    x = torch.rand((batch_size, 1))
    mu = -5 * x + 9
    scale = (3 * x + 1)
    model = Normal(mu, scale=scale)
    y = model.sample((batch_size,))
    return x, y, model.log_prob(y)

In [32]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, _parameters=NormalParams())

In [33]:
optim = Adam(normal.parameters(), lr=1)

In [34]:
for i in range(1000):
    optim.zero_grad()
    x, y, true_logp = get_batch(100)
    nlogp = -normal(y, cond=x).mean()
    params = {k:v.detach() for k,v in normal.state_dict().items()}
    if i % 100 == 0:
        print(f'Neg logP: {nlogp:0.3f} (gt={-true_logp.mean():0.3f}), params={params}')
    nlogp.backward()
    optim.step()

Neg logP: 652.502 (gt=2.303), params={'parameter_generator.core.weight': tensor([[-0.3438],
        [-0.8978]]), 'parameter_generator.core.bias': tensor([ 0.9711, -0.2252])}
Neg logP: 5.496 (gt=2.332), params={'parameter_generator.core.weight': tensor([[ 5.6848],
        [-6.8267]]), 'parameter_generator.core.bias': tensor([ 7.0001, -6.2144])}
Neg logP: 5.469 (gt=2.330), params={'parameter_generator.core.weight': tensor([[ 5.6762],
        [-6.6893]]), 'parameter_generator.core.bias': tensor([ 6.9992, -6.1804])}
Neg logP: 5.387 (gt=2.290), params={'parameter_generator.core.weight': tensor([[ 5.6639],
        [-6.5045]]), 'parameter_generator.core.bias': tensor([ 6.9974, -6.1344])}
Neg logP: 5.334 (gt=2.267), params={'parameter_generator.core.weight': tensor([[ 5.6478],
        [-6.2751]]), 'parameter_generator.core.bias': tensor([ 6.9950, -6.0778])}
Neg logP: 5.280 (gt=2.261), params={'parameter_generator.core.weight': tensor([[ 5.6273],
        [-6.0016]]), 'parameter_generator.core.b

## One network with multiple outputs (returning positionally):

In [35]:
class NormalParams(nn.Module):
    def __init__(self):
        super().__init__()
        self.core = nn.Linear(1, 2)
        
    def forward(self, x):
        vals = self.core(x)
        #loc, scale = vals.split((1,1), dim=1)
        return vals[:,0:1], (vals[:, 1:])**2
        #return dict(loc=loc, scale=scale**2)

In [36]:
def get_batch(batch_size):
    x = torch.rand((batch_size, 1))
    mu = -5 * x + 9
    scale = (3 * x + 1)
    model = Normal(mu, scale=scale)
    y = model.sample((batch_size,))
    return x, y, model.log_prob(y)

In [37]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, _parameters=NormalParams())

In [38]:
optim = Adam(normal.parameters(), lr=1)

In [39]:
for i in range(1000):
    optim.zero_grad()
    x, y, true_logp = get_batch(100)
    nlogp = -normal(y, cond=x).mean()
    params = {k:v.detach() for k,v in normal.state_dict().items()}
    if i % 100 == 0:
        print(f'Neg logP: {nlogp:0.3f} (gt={-true_logp.mean():0.3f}), params={params}')
    nlogp.backward()
    optim.step()

Neg logP: 3160.783 (gt=2.258), params={'parameter_generator.core.weight': tensor([[0.7723],
        [0.6880]]), 'parameter_generator.core.bias': tensor([-0.7084, -0.8293])}
Neg logP: 5.407 (gt=2.285), params={'parameter_generator.core.weight': tensor([[ 6.7786],
        [-5.3161]]), 'parameter_generator.core.bias': tensor([ 5.3003, -6.8324])}
Neg logP: 5.445 (gt=2.324), params={'parameter_generator.core.weight': tensor([[ 6.7787],
        [-5.3145]]), 'parameter_generator.core.bias': tensor([ 5.3005, -6.8287])}
Neg logP: 5.390 (gt=2.273), params={'parameter_generator.core.weight': tensor([[ 6.7784],
        [-5.3119]]), 'parameter_generator.core.bias': tensor([ 5.3005, -6.8234])}
Neg logP: 5.346 (gt=2.231), params={'parameter_generator.core.weight': tensor([[ 6.7780],
        [-5.3087]]), 'parameter_generator.core.bias': tensor([ 5.3005, -6.8169])}
Neg logP: 5.402 (gt=2.274), params={'parameter_generator.core.weight': tensor([[ 6.7776],
        [-5.3051]]), 'parameter_generator.core.bi

# Test sampling and joint

In [40]:
prior = TrainableDistributionAdapter(Normal, loc=nn.Parameter(torch.Tensor([5])),
                                     scale=torch.Tensor([2]))

linear = nn.Linear(1, 1)
linear.weight.data = torch.Tensor([[-2]])
linear.bias.data = torch.Tensor([6])
conditional = TrainableDistributionAdapter(Normal, loc=linear, scale=torch.Tensor([1]))

# create a joint distribution out of prior and conditional
joint = Joint(prior, conditional)

In [41]:
x, y = joint.sample((10000,))

Should be 5, 2

In [42]:
x.mean(), x.std()

(tensor(4.9957), tensor(1.9761))

Should be -4, 4

In [43]:
y.mean(), y.std()

(tensor(-3.9767), tensor(4.0760))

# Simulating a multi-dimensional joint distribution

### Define the ground-truth generative model

In [44]:
n_latents = 5
gt_prior = TrainableDistributionAdapter(MultivariateNormal, 
                                        torch.ones([n_latents]),
                                        torch.eye(n_latents))

n_obs = 10
features = torch.randn([n_latents, n_obs])

feature_map = nn.Linear(n_latents, n_obs)
feature_map.weight.data = features.T
feature_map.bias.data.zero_()

gt_conditional = TrainableDistributionAdapter(MultivariateNormal,
                                              feature_map,
                                              torch.eye(n_obs))

gt_joint = Joint(gt_prior, gt_conditional)

### Now prepare a trainable model

class Covariance(nn.Module):
    def __init__(self, n_dim):
        

In [45]:
n_latents = 5
model_prior = TrainableDistributionAdapter(MultivariateNormal, 
                                        nn.Parameter(torch.zeros([n_latents])),
                                        PositiveDiagonal(n_latents))

n_obs = 10
features = torch.randn([n_latents, n_obs])

feature_map = nn.Linear(n_latents, n_obs)
#feature_map.weight.data = features.T
#feature_map.bias.data.zero_()

model_conditional = TrainableDistributionAdapter(MultivariateNormal,
                                              feature_map,
                                              torch.eye(n_obs))

model_joint = Joint(model_prior, model_conditional)

### Go ahead and train the joint model

In [46]:
optim = Adam(model_joint.parameters(), lr=0.5)

In [47]:
for i in range(1000):
    optim.zero_grad()
    samples = gt_joint.sample((100,))
    gt_logl = gt_joint.log_prob(*samples).mean()
    model_logl = model_joint.log_prob(*samples).mean()
    if i % 100 == 0:
        print(f'Model logp: {model_logl:.3f} / GT logp: {gt_logl:.3f}')
    (-model_logl).backward()
    optim.step()

Model logp: -204.999 / GT logp: -21.587
Model logp: -23.608 / GT logp: -21.388
Model logp: -22.945 / GT logp: -21.200
Model logp: -22.840 / GT logp: -21.416
Model logp: -22.374 / GT logp: -21.205
Model logp: -21.950 / GT logp: -20.999
Model logp: -22.480 / GT logp: -21.573
Model logp: -22.451 / GT logp: -21.343
Model logp: -21.975 / GT logp: -21.006
Model logp: -21.781 / GT logp: -21.408


Checking the learned parameters

In [48]:
model_joint.prior._arg1()

tensor([[0.8752, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.2800, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0109, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.6556, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.0013]], grad_fn=<DiagBackward0>)

In [49]:
model_joint.prior.state_dict()

OrderedDict([('_arg0', tensor([0.8684, 0.7502, 1.0640, 1.0752, 0.9583])),
             ('_arg1.D',
              tensor([ 0.9355,  1.1314, -1.0055, -0.8097, -1.0006]))])

In [50]:
model_joint.conditional.state_dict()

OrderedDict([('_arg1',
              tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                      [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
                      [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
                      [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
                      [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
                      [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
                      [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])),
             ('_arg0.weight',
              tensor([[-1.3747, -0.5574,  0.0672, -0.4904,  0.2244],
                      [-0.6747, -0.0906, -0.9939, -0.3891,  0.8754],
                      [ 0.6459,  1.1428,  0.3236, -0.0247, -0.9512],
                      [ 0.0231, -0.2420,  1.3344,  0.3168, -0.4036],
                      [ 1.00

## Compute the posterior distriubtion via ELBO

First, we'll train the posterior for the ground-truth model, using ELBO

In [51]:
# prepare a posterior distribution
features = torch.randn([n_obs, n_latents])

linear_map = nn.Linear(n_obs, n_latents)

posterior = TrainableDistributionAdapter(MultivariateNormal,
                                              linear_map,
                                              Covariance(n_latents, rank=5, eps=1e-6))


In [55]:
elbo_x = ELBOMarginal(gt_joint, posterior)

# only training the posterior
optim = Adam(posterior.parameters(), lr=1e-2)

In [56]:
z_sample, x_sample = gt_joint.sample((1000,))

for i in range(1000):
    optim.zero_grad()
    elbo = elbo_x(x_sample).mean()
    if i % 100 == 0:
        print(f'Model elbo: {elbo:.3f}')
    (-elbo).backward()
    optim.step()

Model elbo: -726.289
Model elbo: -42.695
Model elbo: -24.010
Model elbo: -20.559
Model elbo: -20.051
Model elbo: -19.918
Model elbo: -19.868
Model elbo: -19.838
Model elbo: -19.816
Model elbo: -19.827


### Evaluate the posterior
Here we will evaluate how good the posterior is (roughly) by sampling $\hat{z}$ from the posterior $p(z|x)$ and evaluate the expected $\log p(\hat{z}, x)$.
If $\hat{z}$ approximates the true distribution over $z$, then $\log p(\hat{z}, x)$ will closely apprximate the expected $\log p(z, x)$, which is negative of entropy.

In [57]:
z_sample, x_sample = gt_joint.sample((1000,))

In [58]:
# ground-truth negative entropy
gt_joint.log_prob(z_sample, x_sample).mean()

tensor(-21.3407, grad_fn=<MeanBackward0>)

In [59]:
# sample from the trained posterior
z_hat = posterior.sample(cond=x_sample)

# 
# negative entropy of the approximation
gt_joint.log_prob(z_hat, x_sample).mean()

tensor(-21.3697, grad_fn=<MeanBackward0>)

## Compute the posterior distriubtion via direct fit to samples

Now, we'll train the posterior for the ground-truth model by training directly on the samples

In [60]:
# prepare a posterior distribution
features = torch.randn([n_obs, n_latents])

linear_map = nn.Linear(n_obs, n_latents)

posterior = TrainableDistributionAdapter(MultivariateNormal,
                                              linear_map,
                                              Covariance(n_latents, rank=5, eps=1e-6))


In [61]:
# only training the posterior
optim = Adam(posterior.parameters(), lr=1e-2)

In [62]:

for i in range(1000):
    optim.zero_grad()
    z_sample, x_sample = gt_joint.sample((100,))
    logp = posterior.log_prob(z_sample, cond=x_sample).mean()
    if i % 100 == 0:
        print(f'LogP: {logp:.3f}')
    (-logp).backward()
    optim.step()

LogP: -142.656
LogP: -7.445
LogP: -6.255
LogP: -4.870
LogP: -4.178
LogP: -3.914
LogP: -3.563
LogP: -3.272
LogP: -3.157
LogP: -2.958


### Evaluate the posterior
Here we will evaluate how good the posterior is (roughly) by sampling $\hat{z}$ from the posterior $p(z|x)$ and evaluate the expected $\log p(\hat{z}, x)$.
If $\hat{z}$ approximates the true distribution over $z$, then $\log p(\hat{z}, x)$ will closely apprximate the expected $\log p(z, x)$, which is negative of entropy.

In [63]:
z_sample, x_sample = gt_joint.sample((1000,))

In [64]:
# ground-truth negative entropy
gt_joint.log_prob(z_sample, x_sample).mean()

tensor(-21.2337, grad_fn=<MeanBackward0>)

In [65]:
# sample from the trained posterior
z_hat = posterior.sample(cond=x_sample)

# 
# negative entropy of the approximation
gt_joint.log_prob(z_hat, x_sample).mean()

tensor(-37.6097, grad_fn=<MeanBackward0>)