In [2]:
import torch
from torch import nn
from torch.distributions import Normal, MultivariateNormal
from torch.optim import SGD, Adam

In [3]:
#TODO: overall, the handling of returned tuples for the samples is very loose
# This stems primarily from the rather inconsitent interface/handling of random variables
# of the distribution (can be a singleton/non-tuple OR tuple) vs random variables
# you are conditioned on (generally expected to be a tuple)

In [None]:
def register_to_module(module, field, value):
    if isinstance(value, torch.Tensor) and not isinstance(value, nn.Parameter):
        # register as buffer
        module.register_buffer(field, value)
    else:
        setattr(module, field, value)

: 

In [None]:
class TrainableDistributionAdapter(nn.Module):
    
    n_rvs = 1
    
    def __init__(self, distribution_class, *dist_args, _parameters=None, **dist_kwargs):
        super().__init__()
        self.distrbrituion_class = distribution_class
        self.param_counts = len(dist_args)
        self.param_keys = list(dist_kwargs.keys())
        
        for pos, val in enumerate(dist_args):
            #setattr(self, f'_arg{pos}', val)
            register_to_module(self, f'_arg{pos}', val)
        for key, val in dist_kwargs.items():
            register_to_module(self, key, val)
            
        # The alternative interface to allow for a single module to 
        # flexibly output multiple parameters for the distribution
        # at the moment, the module is expected to output a dictionary
        # of parameters
        if _parameters is not None:
            self.parameter_generator = _parameters
            
    
    def distribution(self, cond=None):
        cond = turn_to_tuple(cond)
        
        # a helper function to visit the target field
        # and invoke the field with cond if it is a nn.Module.
        # Otherwise, simply return the field content
        def parse_attr(field, cond=None):
            attr = getattr(self, field)
            if isinstance(attr, nn.Module):
                attr = attr(*cond)
            return attr
        
        dist_args = tuple(parse_attr(f'_arg{pos}', cond=cond) for pos in range(self.param_counts))
        dist_kwargs = {k: parse_attr(k, cond=cond) for k in self.param_keys}

        # TODO: consider flipping the order of this with
        # init specified parameters
        if hasattr(self, 'parameter_generator'):
            dist_args, dist_kwargs = make_args(self.parameter_generator(*cond), *dist_args, **dist_kwargs)
            
        return self.distrbrituion_class(*dist_args, **dist_kwargs)
    
    def log_prob(self, *obs, cond=None):
        return self.distribution(cond=cond).log_prob(*obs)
    
    def forward(self, *obs, cond=None):
        return self.log_prob(*obs, cond=cond)
    
    def sample(self, sample_shape=torch.Size([]), cond=None):
        return self.distribution(cond=cond).sample(sample_shape=sample_shape)
    
    def rsample(self, sample_shape=torch.Size([]), cond=None):
        return self.distribution(cond=cond).rsample(sample_shape=sample_shape)

: 

y = f(x)

p(y) =  det dx/dy p(x)

p(y) = 1 / f'(x) p(x)

p(f(x)) * f'(x) = p(x)

In [None]:
# come up with a better name
# here we assume that conditioning of the joint distribution can occur
# by conditioning the prior
class Joint(nn.Module):
    def __init__(self, prior, conditional):
        super().__init__()
        self.prior = prior
        self.conditional = conditional
        self.split = self.prior.n_rvs
        self.n_rvs = self.prior.n_rvs + self.conditional.n_rvs
        
    def log_prob(self, *obs, cond=None):
        # TODO: maybe just use self.prior.n_rvs
        x, y = obs[:self.split], obs[self.split:]
        return self.prior(*x, cond=cond) + self.conditional(*y, cond=x)
    
    def forward(self, *obs, cond=None):
        return self.log_prob(*obs, cond=cond)
    
    def sample(self, sample_shape=torch.Size([]), cond=None):
        x_samples = self.prior.sample(sample_shape=sample_shape, cond=cond)
        y_samples = self.conditional.sample(cond=x_samples)
        return turn_to_tuple(x_samples) + turn_to_tuple(y_samples)
    
    def rsample(self, sample_shape=torch.Size([]), cond=None):
        x_samples = self.prior.rsample(sample_shape=sample_shape, cond=cond)
        y_samples = self.conditional.rsample(cond=x_samples)
        return turn_to_tuple(x_samples) + turn_to_tuple(y_samples)
    

: 

In [None]:
# conceptual template
class InvertibleTransform(nn.Module):
    def forward(self, x, logL=0):
        return y, logL + log_det_f_prime
    
    def reverse(self, y, logL=0):
        return x, logL - log_det_f_prime

: 

In [None]:
class SequentialTransform(nn.Module):
    def __init__(self, *transforms):
        self.transforms = nn.ModuleList(transforms)
        
    # TODO: check the content of nn.Module __call__ and create something similar for reverse
    def forward(self, x, logL=0, cond=None):
        for t in self.transforms:
            x, logL = t(x, logL, cond=cond)
        return x, logL
    
    def reverse(self, y, logL=0, cond=None):
        for t in self.transforms[::-1]:
            y, logL = t(y, logL, cond=cond)
        return y, logL

: 

In [None]:
class FlowDistribution(nn.Module):
    def __init__(self, base_distribution, transform):
        super().__init__()
        self.base_distribution = base_distribution
        self.transform = transform
        
    def forward(self, *obs, cond=None):
        return self.log_prob(*obs, cond=cond)
        
    def log_prob(self, *obs, cond=None):
        x, logL = self.transform.reverse(*obs, cond=cond)
        return self.base_distribution.log_prob(*turn_to_tuple(x), cond=cond) + logL
    
    def sample(self, sample_shape=torch.Size([]), cond=None):
        samples = self.base_distribution.sample(sample_shape=sample_shape, cond=cond)
        y, _ = self.transform(samples, cond=cond)
        return y
    
    def rsample(self, sample_shape=torch.Size([]), cond=None):
        samples = self.base_distribution.rsample(sample_shape=sample_shape, cond=cond)
        y, _ = self.transform(samples, cond=cond)
        return y

: 

In [None]:
def ELBO_joint(joint, posterior, *obs, n_samples=1):
    # Joint = p(z, x), Posterior = p(z|x)
    z_samples = posterior.rsample((n_samples,), cond=obs)
    # take care of case where KL is known for the posterior
    elbo = -posterior(*turn_to_tuple(z_samples), cond=obs)
    elbo += joint(*turn_to_tuple(z_samples), *obs)
    return elbo

def ELBO_parts(prior, conditional, posterior, *obs, n_samples=1):
    # create a joint
    joint = Joint(prior, conditional)
    return ELBO_joint(joint, posterior, *obs, n_samples=n_samples)


: 

In [None]:
class ELBOMarginal(nn.Module):
    def __init__(self, joint, posterior, n_samples=1):
        super().__init__()
        self.joint = joint
        self.posterior = posterior
        # infer how many variables are in observations
        self.n_rvs = joint.n_rvs - posterior.n_rvs
        self.n_samples = n_samples
        
    def forward(self, *obs, cond=None):
        return self.elbo(*obs, cond=cond)
        
    def elbo(self, *obs, cond=None):
        # TODO: deal with conditioning correctly
        return ELBO_joint(self.joint, self.posterior, *obs, n_samples=self.n_samples)
    
    def log_prob(self, *obs):
        # TODO: let this be implemented as an "approximation" with ELBO
        # but with ample warnings
        pass
    
    def sample(self, sample_shape=torch.Size([]), cond=None):
        samples = self.joint.sample(sample_shape=sample_shape, cond=cond)
        return samples[-self.n_rvs:]
    
    def rsample(self, sample_shape=torch.Size([]), cond=None):
        samples = self.joint.rsample(sample_shape=sample_shape, cond=cond)
        return samples[-self.n_rvs:]


: 

In [None]:
#TODO Consider implementing SurVAE
class SurVAE(nn.Module):
    pass

: 

In [None]:
def turn_to_tuple(x):
    """
    Given a value x, turn into a consistent tuple
    * if x is None, return an empty tumple ()
    * if x is a non-tuple value, return as a single-element tuple (x,)
    * if x is already a tuple
    """
    if x is None:
        return ()
    return x if isinstance(x, tuple) else (x,)

: 

In [None]:
from warnings import warn

: 

In [None]:
def make_args(x, *args, **kwargs):
    if isinstance(x, dict): # TODO: consider making it a Collection.Mapping
        kwargs.update(x)
    elif isinstance(x, tuple):
        args = x + args
    else:
        args = (x,) + args
        
    return args, kwargs

: 

In [None]:
class TransformedParameter(nn.Module):
    def __init__(self, tensor, transform_fn=None):
        super().__init__()
        self.parameter = nn.Parameter(tensor)
        if transform_fn is None:
            transform_fn = lambda x: x
        self.transform_fn = transform_fn
        
    @property
    def value(self):
        return self()
        
    def forward(self, *args):
        return self.transform_fn(self.parameter)
    

: 

In [None]:
class Covariance(nn.Module):
    def __init__(self, n_dims, rank=None, eps=1e-16):
        super().__init__()
        if rank is None:
            rank = n_dims
        self.n_dims = n_dims
        self.rank = rank
        self.eps = eps
        self.A = nn.Parameter(torch.randn(n_dims, rank))
        
    def forward(self, *args):
        return self.A @ self.A.T + torch.eye(self.n_dims) * self.eps
    
    @property
    def value(self):
        return self()

: 

In [None]:
# TODO: generalize this so that positiveness can arise from other functions
class PositiveDiagonal(nn.Module):
    def __init__(self, n_dims, eps=1e-16):
        super().__init__()
        self.n_dims = n_dims
        self.eps = eps
        self.D = nn.Parameter(torch.randn(n_dims))
        
    def forward(self, *args):
        return torch.diag(self.D**2 + self.eps)
    
    @property
    def value(self):
        return self()

: 

In [None]:
class ProbabilisticSIModel(nn.Module):
    # TODO: for this to work, the TrainableDistributionAdapter handling of
    # _parameters must be expanded. Namely, it needs to be able to accept:
    # * positional arguments (to be implemented)
    # * dict (already implemented)
    # * dict with positional arguments (to be implemented)
    # 
    # To also make this generically useful, it would be helpful to
    # allow for output conversion function to be supplied. This function then
    # should transform outputs of the SI model into format appropriate
    # to serve as _parameters for the TrainableDistributionAdapter
    # It's important that such transformation does NOT warp the output
    # Doing so will distort the probability density!
    def __init__(self, si_model, distribution_class, *dist_args, **dist_kwargs):
        super().__init__()
        sielf.si_model = si_model
        self.distribution_class = distribution_class
        self.trainable_distribution = TrainableDistributionAdapter(distribution_class, *dist_args,
                                                                   _parameters=si_model, **dist_kwargs)
        
    def log_prob(self, *obs, cond=None):
        return self.trainable_distribution.log_prob(*obs, cond=cond)
    
    def forward(self, *obs, cond=None):
        return self.log_prob(*obs, cond=cond)
    
    def sample(self, sample_shape=torch.Size([]), cond=None):
        return self.trainable_distribution.sample(sample_shape=sample_shape, cond=cond)

    def rsample(self, sample_shape=torch.Size([]), cond=None):
        return self.trainable_distribution.rsample(sample_shape=sample_shape, cond=cond)

: 

## Example usage: learn Normal distribution

### Learn only the mean

In [None]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               loc=nn.Parameter(torch.Tensor([0.0])), 
                               scale=torch.Tensor([2.0]))

: 

In [None]:
normal.state_dict()

: 

In [None]:
list(normal.parameters())

: 

In [None]:
# setup target normal distribution to learn
target = Normal(torch.Tensor([5.0]), torch.Tensor([2.0]))

: 

In [None]:
optim = Adam(normal.parameters(), lr=0.5)

: 

In [None]:
for i in range(1000):
    optim.zero_grad()
    targets = target.sample((100,))
    nlogp = -normal(targets).mean()
    mean = normal.loc.detach()
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, mean={mean}')
    nlogp.backward()
    optim.step()

: 

### Learn both mean and stdev

In [None]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               loc=nn.Parameter(torch.Tensor([0.0])), 
                               scale=TransformedParameter(torch.Tensor([1.0]), lambda x: x**2))

: 

In [None]:
# setup target normal distribution to learn
target = Normal(torch.Tensor([5]), torch.Tensor([8]))

: 

In [None]:
optim = Adam(normal.parameters(), lr=1.0)

: 

In [None]:
for i in range(1000):
    optim.zero_grad()
    targets = target.sample((100,))
    nlogp = -normal(targets).mean()
    mean = normal.loc.detach()
    std = normal.scale.value.detach()
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, mean={mean}, std={std}')
    nlogp.backward()
    optim.step()

: 

### Learn both mean and stdev (same as above, but specified positionally)

In [None]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               nn.Parameter(torch.Tensor([0.0])), 
                               TransformedParameter(torch.Tensor([1.0]), lambda x: x**2))

: 

In [None]:
# setup target normal distribution to learn
target = Normal(torch.Tensor([5]), torch.Tensor([8]))

: 

In [None]:
optim = Adam(normal.parameters(), lr=1.0)

: 

In [None]:
for i in range(1000):
    optim.zero_grad()
    targets = target.sample((100,))
    nlogp = -normal(targets).mean()
    mean = normal._arg0.detach()
    std = normal._arg1.value.detach()
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, mean={mean}, std={std}')
    nlogp.backward()
    optim.step()

: 

# Conditional case

## Simple conditioning

Now let us learn more complex relationship $p(z|x)$. Specifically, let $p(z|x) = \mathcal{N}(f(x), \sigma^2)$.

For simplicity, we'll assume a simple linear mapping for $f(x)$

Prepare data generator

In [None]:
def get_batch(batch_size):
    x = torch.rand((batch_size, 1))
    mu = -5 * x + 9
    y = Normal(mu, scale=7).sample((batch_size,))
    return x, y

: 

Set up the conditional network:

In [None]:
# make the mean & std both learnable
# std (scale) is set to be abs of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               loc=nn.Linear(1, 1), 
                               scale=TransformedParameter(torch.Tensor([1.0]), torch.abs))

: 

In [None]:
optim = Adam(normal.parameters(), lr=1.0)

: 

In [None]:
for i in range(1000):
    optim.zero_grad()
    x, y = get_batch(100)
    nlogp = -normal(y, cond=x).mean()
    params = {k:v.detach() for k,v in normal.state_dict().items()}
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, params={params}')
    nlogp.backward()
    optim.step()

: 

## One network with multiple outputs (returning dict):

In [None]:
class NormalParams(nn.Module):
    def __init__(self):
        super().__init__()
        self.core = nn.Linear(1, 2)
        
    def forward(self, x):
        vals = self.core(x)
        #loc, scale = vals.split((1,1), dim=1)
        return dict(loc=vals[:,0:1], scale=(vals[:, 1:])**2)
        #return dict(loc=loc, scale=scale**2)

: 

In [None]:
def get_batch(batch_size):
    x = torch.rand((batch_size, 1))
    mu = -5 * x + 9
    scale = (3 * x + 1)
    model = Normal(mu, scale=scale)
    y = model.sample((batch_size,))
    return x, y, model.log_prob(y)

: 

In [None]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, _parameters=NormalParams())

: 

In [None]:
optim = Adam(normal.parameters(), lr=1)

: 

In [None]:
for i in range(1000):
    optim.zero_grad()
    x, y, true_logp = get_batch(100)
    nlogp = -normal(y, cond=x).mean()
    params = {k:v.detach() for k,v in normal.state_dict().items()}
    if i % 100 == 0:
        print(f'Neg logP: {nlogp:0.3f} (gt={-true_logp.mean():0.3f}), params={params}')
    nlogp.backward()
    optim.step()

: 

## One network with multiple outputs (returning positionally):

In [None]:
class NormalParams(nn.Module):
    def __init__(self):
        super().__init__()
        self.core = nn.Linear(1, 2)
        
    def forward(self, x):
        vals = self.core(x)
        #loc, scale = vals.split((1,1), dim=1)
        return vals[:,0:1], (vals[:, 1:])**2
        #return dict(loc=loc, scale=scale**2)

: 

In [None]:
def get_batch(batch_size):
    x = torch.rand((batch_size, 1))
    mu = -5 * x + 9
    scale = (3 * x + 1)
    model = Normal(mu, scale=scale)
    y = model.sample((batch_size,))
    return x, y, model.log_prob(y)

: 

In [None]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, _parameters=NormalParams())

: 

In [None]:
optim = Adam(normal.parameters(), lr=1)

: 

In [None]:
for i in range(1000):
    optim.zero_grad()
    x, y, true_logp = get_batch(100)
    nlogp = -normal(y, cond=x).mean()
    params = {k:v.detach() for k,v in normal.state_dict().items()}
    if i % 100 == 0:
        print(f'Neg logP: {nlogp:0.3f} (gt={-true_logp.mean():0.3f}), params={params}')
    nlogp.backward()
    optim.step()

: 

# Test sampling and joint

In [None]:
prior = TrainableDistributionAdapter(Normal, loc=nn.Parameter(torch.Tensor([5])),
                                     scale=torch.Tensor([2]))

linear = nn.Linear(1, 1)
linear.weight.data = torch.Tensor([[-2]])
linear.bias.data = torch.Tensor([6])
conditional = TrainableDistributionAdapter(Normal, loc=linear, scale=torch.Tensor([1]))

# create a joint distribution out of prior and conditional
joint = Joint(prior, conditional)

: 

In [None]:
x, y = joint.sample((10000,))

: 

Should be 5, 2

In [None]:
x.mean(), x.std()

: 

Should be -4, 4

In [None]:
y.mean(), y.std()

: 

# Simulating a multi-dimensional joint distribution

### Define the ground-truth generative model

In [None]:
n_latents = 5
gt_prior = TrainableDistributionAdapter(MultivariateNormal, 
                                        torch.ones([n_latents]),
                                        torch.eye(n_latents))

n_obs = 10
features = torch.randn([n_latents, n_obs])

feature_map = nn.Linear(n_latents, n_obs)
feature_map.weight.data = features.T
feature_map.bias.data.zero_()

gt_conditional = TrainableDistributionAdapter(MultivariateNormal,
                                              feature_map,
                                              torch.eye(n_obs))

gt_joint = Joint(gt_prior, gt_conditional)

: 

### Now prepare a trainable model

class Covariance(nn.Module):
    def __init__(self, n_dim):
        

In [None]:
n_latents = 5
model_prior = TrainableDistributionAdapter(MultivariateNormal, 
                                        nn.Parameter(torch.zeros([n_latents])),
                                        PositiveDiagonal(n_latents))

n_obs = 10
features = torch.randn([n_latents, n_obs])

feature_map = nn.Linear(n_latents, n_obs)
#feature_map.weight.data = features.T
#feature_map.bias.data.zero_()

model_conditional = TrainableDistributionAdapter(MultivariateNormal,
                                              feature_map,
                                              torch.eye(n_obs))

model_joint = Joint(model_prior, model_conditional)

: 

### Go ahead and train the joint model

In [None]:
optim = Adam(model_joint.parameters(), lr=0.5)

: 

In [None]:
for i in range(1000):
    optim.zero_grad()
    samples = gt_joint.sample((100,))
    gt_logl = gt_joint.log_prob(*samples).mean()
    model_logl = model_joint.log_prob(*samples).mean()
    if i % 100 == 0:
        print(f'Model logp: {model_logl:.3f} / GT logp: {gt_logl:.3f}')
    (-model_logl).backward()
    optim.step()

: 

Checking the learned parameters

In [None]:
model_joint.prior._arg1()

: 

In [None]:
model_joint.prior.state_dict()

: 

In [None]:
model_joint.conditional.state_dict()

: 

## Compute the posterior distriubtion via ELBO

First, we'll train the posterior for the ground-truth model, using ELBO

In [None]:
# prepare a posterior distribution
features = torch.randn([n_obs, n_latents])

linear_map = nn.Linear(n_obs, n_latents)

posterior = TrainableDistributionAdapter(MultivariateNormal,
                                              linear_map,
                                              Covariance(n_latents, rank=5, eps=1e-6))


: 

In [None]:
elbo_x = ELBOMarginal(gt_joint, posterior)

# only training the posterior
optim = Adam(posterior.parameters(), lr=1e-2)

: 

In [None]:
z_sample, x_sample = gt_joint.sample((1000,))

for i in range(1000):
    optim.zero_grad()
    elbo = elbo_x(x_sample).mean()
    if i % 100 == 0:
        print(f'Model elbo: {elbo:.3f}')
    (-elbo).backward()
    optim.step()

: 

### Evaluate the posterior
Here we will evaluate how good the posterior is (roughly) by sampling $\hat{z}$ from the posterior $p(z|x)$ and evaluate the expected $\log p(\hat{z}, x)$.
If $\hat{z}$ approximates the true distribution over $z$, then $\log p(\hat{z}, x)$ will closely apprximate the expected $\log p(z, x)$, which is negative of entropy.

In [None]:
z_sample, x_sample = gt_joint.sample((1000,))

: 

In [None]:
# ground-truth negative entropy
gt_joint.log_prob(z_sample, x_sample).mean()

: 

In [None]:
# sample from the trained posterior
z_hat = posterior.sample(cond=x_sample)

# 
# negative entropy of the approximation
gt_joint.log_prob(z_hat, x_sample).mean()

: 

## Compute the posterior distriubtion via direct fit to samples

Now, we'll train the posterior for the ground-truth model by training directly on the samples

In [None]:
# prepare a posterior distribution
features = torch.randn([n_obs, n_latents])

linear_map = nn.Linear(n_obs, n_latents)

posterior = TrainableDistributionAdapter(MultivariateNormal,
                                              linear_map,
                                              Covariance(n_latents, rank=5, eps=1e-6))


: 

In [None]:
# only training the posterior
optim = Adam(posterior.parameters(), lr=1e-2)

: 

In [None]:

for i in range(1000):
    optim.zero_grad()
    z_sample, x_sample = gt_joint.sample((100,))
    logp = posterior.log_prob(z_sample, cond=x_sample).mean()
    if i % 100 == 0:
        print(f'LogP: {logp:.3f}')
    (-logp).backward()
    optim.step()

: 

### Evaluate the posterior
Here we will evaluate how good the posterior is (roughly) by sampling $\hat{z}$ from the posterior $p(z|x)$ and evaluate the expected $\log p(\hat{z}, x)$.
If $\hat{z}$ approximates the true distribution over $z$, then $\log p(\hat{z}, x)$ will closely apprximate the expected $\log p(z, x)$, which is negative of entropy.

In [None]:
z_sample, x_sample = gt_joint.sample((1000,))

: 

In [None]:
# ground-truth negative entropy
gt_joint.log_prob(z_sample, x_sample).mean()

: 

In [None]:
# sample from the trained posterior
z_hat = posterior.sample(cond=x_sample)

# 
# negative entropy of the approximation
gt_joint.log_prob(z_hat, x_sample).mean()

: 

: 