In [None]:
from gensn.distributions import TrainableDistributionAdapter, Joint
from gensn.variational import ELBOMarginal
from gensn.parameters import TransformedParameter, PositiveDiagonal, Covariance

In [None]:
import torch
from torch import nn
from torch.optim import Adam
from torch.distributions import Normal, MultivariateNormal

## Example usage: learn Normal distribution

### Learn only the mean

In [None]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               loc=nn.Parameter(torch.Tensor([0.0])), 
                               scale=torch.Tensor([2.0]))

In [None]:
normal.state_dict()

In [None]:
list(normal.parameters())

In [None]:
# setup target normal distribution to learn
target = Normal(torch.Tensor([5.0]), torch.Tensor([2.0]))

In [None]:
optim = Adam(normal.parameters(), lr=0.5)

In [None]:
for i in range(1000):
    optim.zero_grad()
    targets = target.sample((100,))
    nlogp = -normal(targets).mean()
    mean = normal.loc.detach()
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, mean={mean}')
    nlogp.backward()
    optim.step()

### Learn both mean and stdev

In [None]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               loc=nn.Parameter(torch.Tensor([0.0])), 
                               scale=TransformedParameter(torch.Tensor([1.0]), lambda x: x**2))

In [None]:
# setup target normal distribution to learn
target = Normal(torch.Tensor([5]), torch.Tensor([8]))

In [None]:
optim = Adam(normal.parameters(), lr=1.0)

In [None]:
for i in range(1000):
    optim.zero_grad()
    targets = target.sample((100,))
    nlogp = -normal(targets).mean()
    mean = normal.loc.detach()
    std = normal.scale.value.detach()
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, mean={mean}, std={std}')
    nlogp.backward()
    optim.step()

### Learn both mean and stdev (same as above, but specified positionally)

In [None]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               nn.Parameter(torch.Tensor([0.0])), 
                               TransformedParameter(torch.Tensor([1.0]), lambda x: x**2))

In [None]:
# setup target normal distribution to learn
target = Normal(torch.Tensor([5]), torch.Tensor([8]))

In [None]:
optim = Adam(normal.parameters(), lr=1.0)

In [None]:
for i in range(1000):
    optim.zero_grad()
    targets = target.sample((100,))
    nlogp = -normal(targets).mean()
    mean = normal._arg0.detach()
    std = normal._arg1.value.detach()
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, mean={mean}, std={std}')
    nlogp.backward()
    optim.step()

# Conditional case

## Simple conditioning

Now let us learn more complex relationship $p(z|x)$. Specifically, let $p(z|x) = \mathcal{N}(f(x), \sigma^2)$.

For simplicity, we'll assume a simple linear mapping for $f(x)$

Prepare data generator

In [None]:
def get_batch(batch_size):
    x = torch.rand((batch_size, 1))
    mu = -5 * x + 9
    y = Normal(mu, scale=7).sample((batch_size,))
    return x, y

Set up the conditional network:

In [None]:
# make the mean & std both learnable
# std (scale) is set to be abs of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               loc=nn.Linear(1, 1), 
                               scale=TransformedParameter(torch.Tensor([1.0]), torch.abs))

In [None]:
optim = Adam(normal.parameters(), lr=1.0)

In [None]:
for i in range(1000):
    optim.zero_grad()
    x, y = get_batch(100)
    nlogp = -normal(y, cond=x).mean()
    params = {k:v.detach() for k,v in normal.state_dict().items()}
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, params={params}')
    nlogp.backward()
    optim.step()

## One network with multiple outputs (returning dict):

In [None]:
class NormalParams(nn.Module):
    def __init__(self):
        super().__init__()
        self.core = nn.Linear(1, 2)
        
    def forward(self, x):
        vals = self.core(x)
        #loc, scale = vals.split((1,1), dim=1)
        return dict(loc=vals[:,0:1], scale=(vals[:, 1:])**2)
        #return dict(loc=loc, scale=scale**2)

In [None]:
def get_batch(batch_size):
    x = torch.rand((batch_size, 1))
    mu = -5 * x + 9
    scale = (3 * x + 1)
    model = Normal(mu, scale=scale)
    y = model.sample((batch_size,))
    return x, y, model.log_prob(y)

In [None]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, _parameters=NormalParams())

In [None]:
optim = Adam(normal.parameters(), lr=1)

In [None]:
for i in range(1000):
    optim.zero_grad()
    x, y, true_logp = get_batch(100)
    nlogp = -normal(y, cond=x).mean()
    params = {k:v.detach() for k,v in normal.state_dict().items()}
    if i % 100 == 0:
        print(f'Neg logP: {nlogp:0.3f} (gt={-true_logp.mean():0.3f}), params={params}')
    nlogp.backward()
    optim.step()

## One network with multiple outputs (returning positionally):

In [None]:
class NormalParams(nn.Module):
    def __init__(self):
        super().__init__()
        self.core = nn.Linear(1, 2)
        
    def forward(self, x):
        vals = self.core(x)
        #loc, scale = vals.split((1,1), dim=1)
        return vals[:,0:1], (vals[:, 1:])**2
        #return dict(loc=loc, scale=scale**2)

In [None]:
def get_batch(batch_size):
    x = torch.rand((batch_size, 1))
    mu = -5 * x + 9
    scale = (3 * x + 1)
    model = Normal(mu, scale=scale)
    y = model.sample((batch_size,))
    return x, y, model.log_prob(y)

In [None]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, _parameters=NormalParams())

In [None]:
optim = Adam(normal.parameters(), lr=1)

In [None]:
for i in range(1000):
    optim.zero_grad()
    x, y, true_logp = get_batch(100)
    nlogp = -normal(y, cond=x).mean()
    params = {k:v.detach() for k,v in normal.state_dict().items()}
    if i % 100 == 0:
        print(f'Neg logP: {nlogp:0.3f} (gt={-true_logp.mean():0.3f}), params={params}')
    nlogp.backward()
    optim.step()

# Test sampling and joint

In [None]:
prior = TrainableDistributionAdapter(Normal, loc=nn.Parameter(torch.Tensor([5])),
                                     scale=torch.Tensor([2]))

linear = nn.Linear(1, 1)
linear.weight.data = torch.Tensor([[-2]])
linear.bias.data = torch.Tensor([6])
conditional = TrainableDistributionAdapter(Normal, loc=linear, scale=torch.Tensor([1]))

# create a joint distribution out of prior and conditional
joint = Joint(prior, conditional)

In [None]:
x, y = joint.sample((10000,))

Should be 5, 2

In [None]:
x.mean(), x.std()

Should be -4, 4

In [None]:
y.mean(), y.std()

# Simulating a multi-dimensional joint distribution

### Define the ground-truth generative model

In [None]:
n_latents = 5
gt_prior = TrainableDistributionAdapter(MultivariateNormal, 
                                        torch.ones([n_latents]),
                                        torch.eye(n_latents))

n_obs = 10
features = torch.randn([n_latents, n_obs])

feature_map = nn.Linear(n_latents, n_obs)
feature_map.weight.data = features.T
feature_map.bias.data.zero_()

gt_conditional = TrainableDistributionAdapter(MultivariateNormal,
                                              feature_map,
                                              torch.eye(n_obs))

gt_joint = Joint(gt_prior, gt_conditional)

### Now prepare a trainable model

class Covariance(nn.Module):
    def __init__(self, n_dim):
        

In [None]:
n_latents = 5
model_prior = TrainableDistributionAdapter(MultivariateNormal, 
                                        nn.Parameter(torch.zeros([n_latents])),
                                        PositiveDiagonal(n_latents))

n_obs = 10
features = torch.randn([n_latents, n_obs])

feature_map = nn.Linear(n_latents, n_obs)
#feature_map.weight.data = features.T
#feature_map.bias.data.zero_()

model_conditional = TrainableDistributionAdapter(MultivariateNormal,
                                              feature_map,
                                              torch.eye(n_obs))

model_joint = Joint(model_prior, model_conditional)

### Go ahead and train the joint model

In [None]:
optim = Adam(model_joint.parameters(), lr=0.5)

In [None]:
for i in range(1000):
    optim.zero_grad()
    samples = gt_joint.sample((100,))
    gt_logl = gt_joint.log_prob(*samples).mean()
    model_logl = model_joint.log_prob(*samples).mean()
    if i % 100 == 0:
        print(f'Model logp: {model_logl:.3f} / GT logp: {gt_logl:.3f}')
    (-model_logl).backward()
    optim.step()

Checking the learned parameters

In [None]:
model_joint.prior._arg1()

In [None]:
model_joint.prior.state_dict()

In [None]:
model_joint.conditional.state_dict()

## Compute the posterior distriubtion via ELBO

First, we'll train the posterior for the ground-truth model, using ELBO

In [None]:
# prepare a posterior distribution
features = torch.randn([n_obs, n_latents])

linear_map = nn.Linear(n_obs, n_latents)

posterior = TrainableDistributionAdapter(MultivariateNormal,
                                              linear_map,
                                              Covariance(n_latents, rank=5, eps=1e-6))


In [None]:
elbo_x = ELBOMarginal(gt_joint, posterior)

# only training the posterior
optim = Adam(posterior.parameters(), lr=1e-2)

In [None]:
z_sample, x_sample = gt_joint.sample((1000,))

for i in range(1000):
    optim.zero_grad()
    elbo = elbo_x(x_sample).mean()
    if i % 100 == 0:
        print(f'Model elbo: {elbo:.3f}')
    (-elbo).backward()
    optim.step()

### Evaluate the posterior
Here we will evaluate how good the posterior is (roughly) by sampling $\hat{z}$ from the posterior $p(z|x)$ and evaluate the expected $\log p(\hat{z}, x)$.
If $\hat{z}$ approximates the true distribution over $z$, then $\log p(\hat{z}, x)$ will closely apprximate the expected $\log p(z, x)$, which is negative of entropy.

In [None]:
z_sample, x_sample = gt_joint.sample((1000,))

In [None]:
# ground-truth negative entropy
gt_joint.log_prob(z_sample, x_sample).mean()

In [None]:
# sample from the trained posterior
z_hat = posterior.sample(cond=x_sample)

# 
# negative entropy of the approximation
gt_joint.log_prob(z_hat, x_sample).mean()

## Compute the posterior distriubtion via direct fit to samples

Now, we'll train the posterior for the ground-truth model by training directly on the samples

In [None]:
# prepare a posterior distribution
features = torch.randn([n_obs, n_latents])

linear_map = nn.Linear(n_obs, n_latents)

posterior = TrainableDistributionAdapter(MultivariateNormal,
                                              linear_map,
                                              Covariance(n_latents, rank=5, eps=1e-6))


In [None]:
# only training the posterior
optim = Adam(posterior.parameters(), lr=1e-2)

In [None]:

for i in range(1000):
    optim.zero_grad()
    z_sample, x_sample = gt_joint.sample((100,))
    logp = posterior.log_prob(z_sample, cond=x_sample).mean()
    if i % 100 == 0:
        print(f'LogP: {logp:.3f}')
    (-logp).backward()
    optim.step()

### Evaluate the posterior
Here we will evaluate how good the posterior is (roughly) by sampling $\hat{z}$ from the posterior $p(z|x)$ and evaluate the expected $\log p(\hat{z}, x)$.
If $\hat{z}$ approximates the true distribution over $z$, then $\log p(\hat{z}, x)$ will closely apprximate the expected $\log p(z, x)$, which is negative of entropy.

In [None]:
z_sample, x_sample = gt_joint.sample((1000,))

In [None]:
# ground-truth negative entropy
gt_joint.log_prob(z_sample, x_sample).mean()

In [None]:
# sample from the trained posterior
z_hat = posterior.sample(cond=x_sample)

# 
# negative entropy of the approximation
gt_joint.log_prob(z_hat, x_sample).mean()