In [49]:
import torch
from torch import nn
from torch.distributions import Normal, MultivariateNormal
from torch.optim import SGD, Adam

In [113]:
#TODO: overall, the handling of returned tuples for the samples is very loose
# This stems primarily from the rather inconsitent interface/handling of random variables
# of the distribution (can be a singleton/non-tuple OR tuple) vs random variables
# you are conditioned on (generally expected to be a tuple)

In [655]:
def register_to_module(module, field, value):
    if isinstance(value, torch.Tensor) and not isinstance(value, nn.Parameter):
        # register as buffer
        module.register_buffer(field, value)
    else:
        setattr(module, field, value)

In [656]:
class TrainableDistributionAdapter(nn.Module):
    
    n_rvs = 1
    
    def __init__(self, distribution_class, *dist_args, _parameters=None, **dist_kwargs):
        super().__init__()
        self.distrbrituion_class = distribution_class
        self.param_counts = len(dist_args)
        self.param_keys = list(dist_kwargs.keys())
        
        for pos, val in enumerate(dist_args):
            #setattr(self, f'_arg{pos}', val)
            register_to_module(self, f'_arg{pos}', val)
        for key, val in dist_kwargs.items():
            register_to_module(self, key, val)
            
        # The alternative interface to allow for a single module to 
        # flexibly output multiple parameters for the distribution
        # at the moment, the module is expected to output a dictionary
        # of parameters
        if _parameters is not None:
            self.parameter_generator = _parameters
            
    
    def distribution(self, cond=None):
        cond = turn_to_tuple(cond)
        
        # a helper function to visit the target field
        # and invoke the field with cond if it is a nn.Module.
        # Otherwise, simply return the field content
        def parse_attr(field, cond=None):
            attr = getattr(self, field)
            if isinstance(attr, nn.Module):
                attr = attr(*cond)
            return attr
        
        dist_args = tuple(parse_attr(f'_arg{pos}', cond=cond) for pos in range(self.param_counts))
        dist_kwargs = {k: parse_attr(k, cond=cond) for k in self.param_keys}

        # TODO: consider flipping the order of this with
        # init specified parameters
        if hasattr(self, 'parameter_generator'):
            dist_args, dist_kwargs = make_args(self.parameter_generator(*cond), *dist_args, **dist_kwargs)
            
        return self.distrbrituion_class(*dist_args, **dist_kwargs)
    
    def log_prob(self, *obs, cond=None):
        return self.distribution(cond=cond).log_prob(*obs)
    
    def forward(self, *obs, cond=None):
        return self.log_prob(*obs, cond=cond)
    
    def sample(self, sample_shape=torch.Size([]), cond=None):
        return self.distribution(cond=cond).sample(sample_shape=sample_shape)
    
    def rsample(self, sample_shape=torch.Size([]), cond=None):
        return self.distribution(cond=cond).rsample(sample_shape=sample_shape)

y = f(x)

p(y) =  det dx/dy p(x)

p(y) = 1 / f'(x) p(x)

p(f(x)) * f'(x) = p(x)

In [657]:
# come up with a better name
# here we assume that conditioning of the joint distribution can occur
# by conditioning the prior
class Joint(nn.Module):
    def __init__(self, prior, conditional):
        super().__init__()
        self.prior = prior
        self.conditional = conditional
        self.split = self.prior.n_rvs
        self.n_rvs = self.prior.n_rvs + self.conditional.n_rvs
        
    def log_prob(self, *obs, cond=None):
        # TODO: maybe just use self.prior.n_rvs
        x, y = obs[:self.split], obs[self.split:]
        return self.prior(*x, cond=cond) + self.conditional(*y, cond=x)
    
    def forward(self, *obs, cond=None):
        return self.log_prob(*obs, cond=cond)
    
    def sample(self, sample_shape=torch.Size([]), cond=None):
        x_samples = self.prior.sample(sample_shape=sample_shape, cond=cond)
        y_samples = self.conditional.sample(cond=x_samples)
        return turn_to_tuple(x_samples) + turn_to_tuple(y_samples)
    
    def rsample(self, sample_shape=torch.Size([]), cond=None):
        x_samples = self.prior.rsample(sample_shape=sample_shape, cond=cond)
        y_samples = self.conditional.rsample(cond=x_samples)
        return turn_to_tuple(x_samples) + turn_to_tuple(y_samples)
    

In [658]:
# conceptual template
class InvertibleTransform(nn.Module):
    def forward(self, x, logL=0):
        return y, logL + log_det_f_prime
    
    def reverse(self, y, logL=0):
        return x, logL - log_det_f_prime

In [659]:
class SequentialTransform(nn.Module):
    def __init__(self, *transforms):
        self.transforms = nn.ModuleList(transforms)
        
    # TODO: check the content of nn.Module __call__ and create something similar for reverse
    def forward(self, x, logL=0, cond=None):
        for t in self.transforms:
            x, logL = t(x, logL, cond=cond)
        return x, logL
    
    def reverse(self, y, logL=0, cond=None):
        for t in self.transforms[::-1]:
            y, logL = t(y, logL, cond=cond)
        return y, logL

In [660]:
class FlowDistribution(nn.Module):
    def __init__(self, base_distribution, transform):
        super().__init__()
        self.base_distribution = base_distribution
        self.transform = transform
        
    def forward(self, *obs, cond=None):
        return self.log_prob(*obs, cond=cond)
        
    def log_prob(self, *obs, cond=None):
        x, logL = self.transform.reverse(*obs, cond=cond)
        return self.base_distribution.log_prob(*turn_to_tuple(x), cond=cond) + logL
    
    def sample(self, sample_shape=torch.Size([]), cond=None):
        samples = self.base_distribution.sample(sample_shape=sample_shape, cond=cond)
        y, _ = self.transform(samples, cond=cond)
        return y
    
    def rsample(self, sample_shape=torch.Size([]), cond=None):
        samples = self.base_distribution.rsample(sample_shape=sample_shape, cond=cond)
        y, _ = self.transform(samples, cond=cond)
        return y

In [661]:
def ELBO_joint(joint, posterior, *obs, n_samples=1):
    # Joint = p(z, x), Posterior = p(z|x)
    z_samples = posterior.rsample((n_samples,), cond=obs)
    # take care of case where KL is known for the posterior
    elbo = -posterior(*turn_to_tuple(z_samples), cond=obs)
    elbo += joint(*turn_to_tuple(z_samples), *obs)
    return elbo

def ELBO_parts(prior, conditional, posterior, *obs, n_samples=1):
    # create a joint
    joint = Joint(prior, conditional)
    return ELBO_joint(joint, posterior, *obs, n_samples=n_samples)


In [662]:
class ELBOMarginal(nn.Module):
    def __init__(self, joint, posterior, n_samples=1):
        super().__init__()
        self.joint = joint
        self.posterior = posterior
        # infer how many variables are in observations
        self.n_rvs = joint.n_rvs - posterior.n_rvs
        self.n_samples = n_samples
        
    def forward(self, *obs, cond=None):
        return self.elbo(*obs, cond=cond)
        
    def elbo(self, *obs, cond=None):
        # TODO: deal with conditioning correctly
        return ELBO_joint(self.joint, self.posterior, *obs, n_samples=self.n_samples)
    
    def log_prob(self, *obs):
        # TODO: let this be implemented as an "approximation" with ELBO
        # but with ample warnings
        pass
    
    def sample(self, sample_shape=torch.Size([]), cond=None):
        samples = self.joint.sample(sample_shape=sample_shape, cond=cond)
        return samples[-self.n_rvs:]
    
    def rsample(self, sample_shape=torch.Size([]), cond=None):
        samples = self.joint.rsample(sample_shape=sample_shape, cond=cond)
        return samples[-self.n_rvs:]


In [663]:
#TODO Consider implementing SurVAE
class SurVAE(nn.Module):
    pass

In [664]:
def turn_to_tuple(x):
    """
    Given a value x, turn into a consistent tuple
    * if x is None, return an empty tumple ()
    * if x is a non-tuple value, return as a single-element tuple (x,)
    * if x is already a tuple
    """
    if x is None:
        return ()
    return x if isinstance(x, tuple) else (x,)

In [665]:
from warnings import warn

In [666]:
def make_args(x, *args, **kwargs):
    if isinstance(x, dict): # TODO: consider making it a Collection.Mapping
        kwargs.update(x)
    elif isinstance(x, tuple):
        args = x + args
    else:
        args = (x,) + args
        
    return args, kwargs

In [667]:
class TransformedParameter(nn.Module):
    def __init__(self, tensor, transform_fn=None):
        super().__init__()
        self.parameter = nn.Parameter(tensor)
        if transform_fn is None:
            transform_fn = lambda x: x
        self.transform_fn = transform_fn
        
    @property
    def value(self):
        return self()
        
    def forward(self, *args):
        return self.transform_fn(self.parameter)
    

In [668]:
class Covariance(nn.Module):
    def __init__(self, n_dims, rank=None, eps=1e-16):
        super().__init__()
        if rank is None:
            rank = n_dims
        self.n_dims = n_dims
        self.rank = rank
        self.eps = eps
        self.A = nn.Parameter(torch.randn(n_dims, rank))
        
    def forward(self, *args):
        return self.A @ self.A.T + torch.eye(self.n_dims) * self.eps
    
    @property
    def value(self):
        return self()

In [669]:
# TODO: generalize this so that positiveness can arise from other functions
class PositiveDiagonal(nn.Module):
    def __init__(self, n_dims, eps=1e-16):
        super().__init__()
        self.n_dims = n_dims
        self.eps = eps
        self.D = nn.Parameter(torch.randn(n_dims))
        
    def forward(self, *args):
        return torch.diag(self.D**2 + self.eps)
    
    @property
    def value(self):
        return self()

In [670]:
class ProbabilisticSIModel(nn.Module):
    # TODO: for this to work, the TrainableDistributionAdapter handling of
    # _parameters must be expanded. Namely, it needs to be able to accept:
    # * positional arguments (to be implemented)
    # * dict (already implemented)
    # * dict with positional arguments (to be implemented)
    # 
    # To also make this generically useful, it would be helpful to
    # allow for output conversion function to be supplied. This function then
    # should transform outputs of the SI model into format appropriate
    # to serve as _parameters for the TrainableDistributionAdapter
    # It's important that such transformation does NOT warp the output
    # Doing so will distort the probability density!
    def __init__(self, si_model, distribution_class, *dist_args, **dist_kwargs):
        super().__init__()
        sielf.si_model = si_model
        self.distribution_class = distribution_class
        self.trainable_distribution = TrainableDistributionAdapter(distribution_class, *dist_args,
                                                                   _parameters=si_model, **dist_kwargs)
        
    def log_prob(self, *obs, cond=None):
        return self.trainable_distribution.log_prob(*obs, cond=cond)
    
    def forward(self, *obs, cond=None):
        return self.log_prob(*obs, cond=cond)
    
    def sample(self, sample_shape=torch.Size([]), cond=None):
        return self.trainable_distribution.sample(sample_shape=sample_shape, cond=cond)

    def rsample(self, sample_shape=torch.Size([]), cond=None):
        return self.trainable_distribution.rsample(sample_shape=sample_shape, cond=cond)

## Example usage: learn Normal distribution

### Learn only the mean

In [671]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               loc=nn.Parameter(torch.Tensor([0.0])), 
                               scale=torch.Tensor([2.0]))

In [680]:
normal.state_dict()

OrderedDict([('loc', tensor([4.7673])), ('scale', tensor([2.]))])

In [681]:
list(normal.parameters())

[Parameter containing:
 tensor([4.7673], requires_grad=True)]

In [682]:
# setup target normal distribution to learn
target = Normal(torch.Tensor([5.0]), torch.Tensor([2.0]))

In [683]:
optim = Adam(normal.parameters(), lr=0.5)

In [684]:
for i in range(1000):
    optim.zero_grad()
    targets = target.sample((100,))
    nlogp = -normal(targets).mean()
    mean = normal.loc.detach()
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, mean={mean}')
    nlogp.backward()
    optim.step()

Neg logP: 2.102768659591675, mean=tensor([5.0545])
Neg logP: 2.129380226135254, mean=tensor([5.4058])
Neg logP: 2.053159475326538, mean=tensor([4.9696])
Neg logP: 2.0966243743896484, mean=tensor([5.2391])
Neg logP: 2.118515968322754, mean=tensor([5.1150])
Neg logP: 2.1634061336517334, mean=tensor([4.9011])
Neg logP: 2.0691187381744385, mean=tensor([5.4793])
Neg logP: 2.030548572540283, mean=tensor([5.2033])
Neg logP: 2.1638779640197754, mean=tensor([4.7802])
Neg logP: 2.0723798274993896, mean=tensor([5.0352])


### Learn both mean and stdev

In [685]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               loc=nn.Parameter(torch.Tensor([0.0])), 
                               scale=TransformedParameter(torch.Tensor([1.0]), lambda x: x**2))

In [686]:
# setup target normal distribution to learn
target = Normal(torch.Tensor([5]), torch.Tensor([8]))

In [687]:
optim = Adam(normal.parameters(), lr=1.0)

In [688]:
for i in range(1000):
    optim.zero_grad()
    targets = target.sample((100,))
    nlogp = -normal(targets).mean()
    mean = normal.loc.detach()
    std = normal.scale.value.detach()
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, mean={mean}, std={std}')
    nlogp.backward()
    optim.step()

Neg logP: 4.520414352416992, mean=tensor([6.4031]), std=tensor([35.6068])
Neg logP: 3.544624090194702, mean=tensor([5.6331]), std=tensor([11.0698])
Neg logP: 3.5978991985321045, mean=tensor([5.1668]), std=tensor([7.9170])
Neg logP: 3.4680280685424805, mean=tensor([5.0932]), std=tensor([7.9356])
Neg logP: 3.6048929691314697, mean=tensor([4.9296]), std=tensor([8.2085])
Neg logP: 3.5616402626037598, mean=tensor([5.2902]), std=tensor([8.0100])
Neg logP: 3.5436019897460938, mean=tensor([5.1700]), std=tensor([8.1878])
Neg logP: 3.4725141525268555, mean=tensor([5.0611]), std=tensor([8.1045])
Neg logP: 3.451047897338867, mean=tensor([5.1721]), std=tensor([8.2135])
Neg logP: 3.495211601257324, mean=tensor([5.1571]), std=tensor([7.8703])


### Learn both mean and stdev (same as above, but specified positionally)

In [689]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               nn.Parameter(torch.Tensor([0.0])), 
                               TransformedParameter(torch.Tensor([1.0]), lambda x: x**2))

In [690]:
# setup target normal distribution to learn
target = Normal(torch.Tensor([5]), torch.Tensor([8]))

In [691]:
optim = Adam(normal.parameters(), lr=1.0)

In [692]:
for i in range(1000):
    optim.zero_grad()
    targets = target.sample((100,))
    nlogp = -normal(targets).mean()
    mean = normal._arg0.detach()
    std = normal._arg1.value.detach()
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, mean={mean}, std={std}')
    nlogp.backward()
    optim.step()

Neg logP: 4.446357250213623, mean=tensor([6.3663]), std=tensor([33.3451])
Neg logP: 3.469963312149048, mean=tensor([5.4749]), std=tensor([8.1169])
Neg logP: 3.5838122367858887, mean=tensor([4.9612]), std=tensor([8.1047])
Neg logP: 3.582165002822876, mean=tensor([5.0933]), std=tensor([8.0675])
Neg logP: 3.4827752113342285, mean=tensor([4.7392]), std=tensor([7.8857])
Neg logP: 3.4364471435546875, mean=tensor([4.9407]), std=tensor([7.9483])
Neg logP: 3.536595106124878, mean=tensor([4.9697]), std=tensor([7.9989])
Neg logP: 3.4074501991271973, mean=tensor([5.1103]), std=tensor([7.8625])
Neg logP: 3.4930973052978516, mean=tensor([4.9504]), std=tensor([7.6086])
Neg logP: 3.5675346851348877, mean=tensor([5.1123]), std=tensor([8.0731])


# Conditional case

## Simple conditioning

Now let us learn more complex relationship $p(z|x)$. Specifically, let $p(z|x) = \mathcal{N}(f(x), \sigma^2)$.

For simplicity, we'll assume a simple linear mapping for $f(x)$

Prepare data generator

In [693]:
def get_batch(batch_size):
    x = torch.rand((batch_size, 1))
    mu = -5 * x + 9
    y = Normal(mu, scale=7).sample((batch_size,))
    return x, y

Set up the conditional network:

In [694]:
# make the mean & std both learnable
# std (scale) is set to be abs of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, 
                               loc=nn.Linear(1, 1), 
                               scale=TransformedParameter(torch.Tensor([1.0]), torch.abs))

In [695]:
optim = Adam(normal.parameters(), lr=1.0)

In [696]:
for i in range(1000):
    optim.zero_grad()
    x, y = get_batch(100)
    nlogp = -normal(y, cond=x).mean()
    params = {k:v.detach() for k,v in normal.state_dict().items()}
    if (i+1) % 100 == 0:
        print(f'Neg logP: {nlogp}, params={params}')
    nlogp.backward()
    optim.step()

Neg logP: 3.4149293899536133, params={'loc.weight': tensor([[1.2547]]), 'loc.bias': tensor([5.4446]), 'scale.parameter': tensor([7.9149])}
Neg logP: 3.379789352416992, params={'loc.weight': tensor([[-0.7773]]), 'loc.bias': tensor([6.6418]), 'scale.parameter': tensor([7.6350])}
Neg logP: 3.3700051307678223, params={'loc.weight': tensor([[-2.6216]]), 'loc.bias': tensor([7.6955]), 'scale.parameter': tensor([7.3514])}
Neg logP: 3.3667819499969482, params={'loc.weight': tensor([[-3.8872]]), 'loc.bias': tensor([8.3710]), 'scale.parameter': tensor([7.1556])}
Neg logP: 3.3841280937194824, params={'loc.weight': tensor([[-4.6002]]), 'loc.bias': tensor([8.7309]), 'scale.parameter': tensor([7.0508])}
Neg logP: 3.3676373958587646, params={'loc.weight': tensor([[-4.9019]]), 'loc.bias': tensor([8.9178]), 'scale.parameter': tensor([7.0084])}
Neg logP: 3.3664793968200684, params={'loc.weight': tensor([[-4.9266]]), 'loc.bias': tensor([8.9745]), 'scale.parameter': tensor([7.0037])}
Neg logP: 3.3688745498

## One network with multiple outputs (returning dict):

In [697]:
class NormalParams(nn.Module):
    def __init__(self):
        super().__init__()
        self.core = nn.Linear(1, 2)
        
    def forward(self, x):
        vals = self.core(x)
        #loc, scale = vals.split((1,1), dim=1)
        return dict(loc=vals[:,0:1], scale=(vals[:, 1:])**2)
        #return dict(loc=loc, scale=scale**2)

In [698]:
def get_batch(batch_size):
    x = torch.rand((batch_size, 1))
    mu = -5 * x + 9
    scale = (3 * x + 1)
    model = Normal(mu, scale=scale)
    y = model.sample((batch_size,))
    return x, y, model.log_prob(y)

In [699]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, _parameters=NormalParams())

In [700]:
optim = Adam(normal.parameters(), lr=1)

In [702]:
for i in range(1000):
    optim.zero_grad()
    x, y, true_logp = get_batch(100)
    nlogp = -normal(y, cond=x).mean()
    params = {k:v.detach() for k,v in normal.state_dict().items()}
    if i % 100 == 0:
        print(f'Neg logP: {nlogp:0.3f} (gt={-true_logp.mean():0.3f}), params={params}')
    nlogp.backward()
    optim.step()

Neg logP: 2.214 (gt=2.213), params={'parameter_generator.core.weight': tensor([[-5.0274],
        [-0.9990]]), 'parameter_generator.core.bias': tensor([ 9.0434, -1.0585])}
Neg logP: 2.227 (gt=2.225), params={'parameter_generator.core.weight': tensor([[-4.9925],
        [-0.9860]]), 'parameter_generator.core.bias': tensor([ 8.9737, -1.0683])}
Neg logP: 2.324 (gt=2.324), params={'parameter_generator.core.weight': tensor([[-5.0019],
        [-1.0087]]), 'parameter_generator.core.bias': tensor([ 9.0111, -1.0616])}
Neg logP: 2.262 (gt=2.260), params={'parameter_generator.core.weight': tensor([[-4.9992],
        [-0.9894]]), 'parameter_generator.core.bias': tensor([ 9.0099, -1.0683])}
Neg logP: 2.291 (gt=2.290), params={'parameter_generator.core.weight': tensor([[-5.0182],
        [-1.0027]]), 'parameter_generator.core.bias': tensor([ 8.9757, -1.0533])}
Neg logP: 2.226 (gt=2.224), params={'parameter_generator.core.weight': tensor([[-4.9768],
        [-1.0186]]), 'parameter_generator.core.bia

## One network with multiple outputs (returning positionally):

In [703]:
class NormalParams(nn.Module):
    def __init__(self):
        super().__init__()
        self.core = nn.Linear(1, 2)
        
    def forward(self, x):
        vals = self.core(x)
        #loc, scale = vals.split((1,1), dim=1)
        return vals[:,0:1], (vals[:, 1:])**2
        #return dict(loc=loc, scale=scale**2)

In [704]:
def get_batch(batch_size):
    x = torch.rand((batch_size, 1))
    mu = -5 * x + 9
    scale = (3 * x + 1)
    model = Normal(mu, scale=scale)
    y = model.sample((batch_size,))
    return x, y, model.log_prob(y)

In [705]:
# make the mean & std both learnable
# std (scale) is set to be square of a parameter, thus ensuring positive value
normal = TrainableDistributionAdapter(Normal, _parameters=NormalParams())

In [706]:
optim = Adam(normal.parameters(), lr=1)

In [707]:
for i in range(1000):
    optim.zero_grad()
    x, y, true_logp = get_batch(100)
    nlogp = -normal(y, cond=x).mean()
    params = {k:v.detach() for k,v in normal.state_dict().items()}
    if i % 100 == 0:
        print(f'Neg logP: {nlogp:0.3f} (gt={-true_logp.mean():0.3f}), params={params}')
    nlogp.backward()
    optim.step()

Neg logP: 44.716 (gt=2.301), params={'parameter_generator.core.weight': tensor([[-0.5920],
        [ 0.8770]]), 'parameter_generator.core.bias': tensor([0.3602, 0.6186])}
Neg logP: 5.115 (gt=2.217), params={'parameter_generator.core.weight': tensor([[5.5234],
        [4.9245]]), 'parameter_generator.core.bias': tensor([6.5077, 5.9420])}
Neg logP: 3.884 (gt=2.244), params={'parameter_generator.core.weight': tensor([[5.0220],
        [0.1505]]), 'parameter_generator.core.bias': tensor([6.3939, 4.2745])}
Neg logP: 2.596 (gt=2.285), params={'parameter_generator.core.weight': tensor([[-1.9150],
        [-0.6146]]), 'parameter_generator.core.bias': tensor([6.8716, 2.2467])}
Neg logP: 2.265 (gt=2.263), params={'parameter_generator.core.weight': tensor([[-5.0224],
        [ 1.0395]]), 'parameter_generator.core.bias': tensor([8.9794, 1.0392])}
Neg logP: 2.344 (gt=2.343), params={'parameter_generator.core.weight': tensor([[-5.0002],
        [ 1.0129]]), 'parameter_generator.core.bias': tensor([8

# Test sampling and joint

In [708]:
prior = TrainableDistributionAdapter(Normal, loc=nn.Parameter(torch.Tensor([5])),
                                     scale=torch.Tensor([2]))

linear = nn.Linear(1, 1)
linear.weight.data = torch.Tensor([[-2]])
linear.bias.data = torch.Tensor([6])
conditional = TrainableDistributionAdapter(Normal, loc=linear, scale=torch.Tensor([1]))

# create a joint distribution out of prior and conditional
joint = Joint(prior, conditional)

In [709]:
x, y = joint.sample((10000,))

Should be 5, 2

In [710]:
x.mean(), x.std()

(tensor(4.9870), tensor(2.0324))

Should be -4, 4

In [711]:
y.mean(), y.std()

(tensor(-3.9677), tensor(4.1786))

# Simulating a multi-dimensional joint distribution

### Define the ground-truth generative model

In [712]:
n_latents = 5
gt_prior = TrainableDistributionAdapter(MultivariateNormal, 
                                        torch.ones([n_latents]),
                                        torch.eye(n_latents))

n_obs = 10
features = torch.randn([n_latents, n_obs])

feature_map = nn.Linear(n_latents, n_obs)
feature_map.weight.data = features.T
feature_map.bias.data.zero_()

gt_conditional = TrainableDistributionAdapter(MultivariateNormal,
                                              feature_map,
                                              torch.eye(n_obs))

gt_joint = Joint(gt_prior, gt_conditional)

### Now prepare a trainable model

class Covariance(nn.Module):
    def __init__(self, n_dim):
        

In [713]:
n_latents = 5
model_prior = TrainableDistributionAdapter(MultivariateNormal, 
                                        nn.Parameter(torch.zeros([n_latents])),
                                        PositiveDiagonal(n_latents))

n_obs = 10
features = torch.randn([n_latents, n_obs])

feature_map = nn.Linear(n_latents, n_obs)
#feature_map.weight.data = features.T
#feature_map.bias.data.zero_()

model_conditional = TrainableDistributionAdapter(MultivariateNormal,
                                              feature_map,
                                              torch.eye(n_obs))

model_joint = Joint(model_prior, model_conditional)

### Go ahead and train the joint model

In [714]:
optim = Adam(model_joint.parameters(), lr=0.5)

In [716]:
for i in range(1000):
    optim.zero_grad()
    samples = gt_joint.sample((100,))
    gt_logl = gt_joint.log_prob(*samples).mean()
    model_logl = model_joint.log_prob(*samples).mean()
    if i % 100 == 0:
        print(f'Model logp: {model_logl:.3f} / GT logp: {gt_logl:.3f}')
    (-model_logl).backward()
    optim.step()

Model logp: -22.338 / GT logp: -20.972
Model logp: -22.691 / GT logp: -21.712
Model logp: -22.532 / GT logp: -21.455
Model logp: -22.455 / GT logp: -21.301
Model logp: -22.525 / GT logp: -21.508
Model logp: -22.324 / GT logp: -21.521
Model logp: -22.413 / GT logp: -21.539
Model logp: -21.608 / GT logp: -20.863
Model logp: -22.002 / GT logp: -21.358
Model logp: -22.456 / GT logp: -21.556


Checking the learned parameters

In [717]:
model_joint.prior._arg1()

tensor([[1.2909, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.7904, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.3943, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.9610, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.7481]], grad_fn=<DiagBackward0>)

In [718]:
model_joint.prior.state_dict()

OrderedDict([('_arg0', tensor([0.8104, 1.0067, 0.9980, 0.9421, 0.9557])),
             ('_arg1.D',
              tensor([ 1.1362,  0.8890,  1.1808,  0.9803, -0.8649]))])

In [719]:
model_joint.conditional.state_dict()

OrderedDict([('_arg1',
              tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                      [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
                      [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
                      [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
                      [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
                      [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
                      [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
                      [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])),
             ('_arg0.weight',
              tensor([[ 0.5283, -0.6021,  1.1344, -0.6800, -0.1336],
                      [ 1.0862, -0.6126,  0.3790, -1.7255,  0.7884],
                      [-0.8417,  0.0308, -0.9640,  1.5167, -1.3290],
                      [-0.9635,  1.1864, -0.8965,  0.8684, -1.0919],
                      [ 0.78

## Compute the posterior distriubtion via ELBO

First, we'll train the posterior for the ground-truth model, using ELBO

In [720]:
# prepare a posterior distribution
features = torch.randn([n_obs, n_latents])

linear_map = nn.Linear(n_obs, n_latents)

posterior = TrainableDistributionAdapter(MultivariateNormal,
                                              linear_map,
                                              Covariance(n_latents, rank=5, eps=1e-6))


In [721]:
elbo_x = ELBOMarginal(gt_joint, posterior)

# only training the posterior
optim = Adam(posterior.parameters(), lr=1e-2)

In [722]:
z_sample, x_sample = gt_joint.sample((1000,))

for i in range(1000):
    optim.zero_grad()
    elbo = elbo_x(x_sample).mean()
    if i % 100 == 0:
        print(f'Model elbo: {elbo:.3f}')
    (-elbo).backward()
    optim.step()

Model elbo: -267.320
Model elbo: -42.153
Model elbo: -22.208
Model elbo: -19.550
Model elbo: -19.352
Model elbo: -19.348
Model elbo: -19.341
Model elbo: -19.347
Model elbo: -19.351
Model elbo: -19.351


### Evaluate the posterior
Here we will evaluate how good the posterior is (roughly) by sampling $\hat{z}$ from the posterior $p(z|x)$ and evaluate the expected $\log p(\hat{z}, x)$.
If $\hat{z}$ approximates the true distribution over $z$, then $\log p(\hat{z}, x)$ will closely apprximate the expected $\log p(z, x)$, which is negative of entropy.

In [723]:
z_sample, x_sample = gt_joint.sample((1000,))

In [724]:
# ground-truth negative entropy
gt_joint.log_prob(z_sample, x_sample).mean()

tensor(-21.1714, grad_fn=<MeanBackward0>)

In [725]:
# sample from the trained posterior
z_hat = posterior.sample(cond=x_sample)

# 
# negative entropy of the approximation
gt_joint.log_prob(z_hat, x_sample).mean()

tensor(-21.2454, grad_fn=<MeanBackward0>)

## Compute the posterior distriubtion via direct fit to samples

Now, we'll train the posterior for the ground-truth model by training directly on the samples

In [732]:
# prepare a posterior distribution
features = torch.randn([n_obs, n_latents])

linear_map = nn.Linear(n_obs, n_latents)

posterior = TrainableDistributionAdapter(MultivariateNormal,
                                              linear_map,
                                              Covariance(n_latents, rank=5, eps=1e-6))


In [733]:
# only training the posterior
optim = Adam(posterior.parameters(), lr=1e-2)

In [734]:

for i in range(1000):
    optim.zero_grad()
    z_sample, x_sample = gt_joint.sample((100,))
    logp = posterior.log_prob(z_sample, cond=x_sample).mean()
    if i % 100 == 0:
        print(f'LogP: {logp:.3f}')
    (-logp).backward()
    optim.step()

LogP: -34.669
LogP: -5.655
LogP: -4.524
LogP: -3.645
LogP: -2.407
LogP: -2.074
LogP: -2.334
LogP: -1.957
LogP: -1.962
LogP: -2.228


### Evaluate the posterior
Here we will evaluate how good the posterior is (roughly) by sampling $\hat{z}$ from the posterior $p(z|x)$ and evaluate the expected $\log p(\hat{z}, x)$.
If $\hat{z}$ approximates the true distribution over $z$, then $\log p(\hat{z}, x)$ will closely apprximate the expected $\log p(z, x)$, which is negative of entropy.

In [735]:
z_sample, x_sample = gt_joint.sample((1000,))

In [736]:
# ground-truth negative entropy
gt_joint.log_prob(z_sample, x_sample).mean()

tensor(-21.3203, grad_fn=<MeanBackward0>)

In [737]:
# sample from the trained posterior
z_hat = posterior.sample(cond=x_sample)

# 
# negative entropy of the approximation
gt_joint.log_prob(z_hat, x_sample).mean()

tensor(-21.5950, grad_fn=<MeanBackward0>)