In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.nn import Parameter
from torch.autograd import Variable

from math import pi, log
from BBSVI import SVI

In [2]:
mu = 2.5
np.random.seed(42)
num_samples = 1000
data = torch.Tensor(np.random.normal(mu, 1, size=num_samples))

In [3]:
class Prior:
    def __init__(self):
        self.prior = torch.distributions.Normal(torch.zeros(1), torch.ones(1))
        
    def log_likelihood_global(self, beta):
        return self.prior.log_prob(beta)
    
    def log_likelihood_joint(self, x, z, beta):
        cond = torch.distributions.Normal(beta, torch.ones(1))
        return cond.log_prob(x)


In [4]:
class VariationalDistribution:
    def __init__(self):
        self.mu = nn.Parameter(torch.zeros(1), requires_grad=True)
        self.sigma = nn.Parameter(torch.ones(1), requires_grad=True)
        self.distr = torch.distributions.Normal(self.mu, self.sigma)
        self.parameters = [self.mu, self.sigma]
        
    def sample_global(self):
        return self.distr.rsample()
    
    def sample_local(self, beta, idx):
        return None
    
    def entropy(self):
        return self.distr.entropy()

In [5]:
analytical_mu = torch.sum(data) / (1 + num_samples)
analytical_sigma = np.sqrt(1 / (1 + num_samples))

In [6]:
prior = Prior()
var = VariationalDistribution()
opt = torch.optim.Adam(var.parameters, lr=1e-3)
svi = SVI(data, prior, var, opt)

In [7]:
num_steps = 50
discounter_schedule = torch.Tensor(np.linspace(0, 1, num_steps))
print('Prior params:                \t mu=%.2f \t sigma=%.2f' % (var.mu, var.sigma))
svi.make_inference(num_steps=num_steps, loss='entropy', print_progress=True, discounter_schedule=discounter_schedule)
print('VI Posterior params:         \t mu=%.2f \t sigma=%.2f' % (var.mu, var.sigma))
print('Analytical Posterior params: \t mu=%.2f \t sigma=%.2f' % (analytical_mu, analytical_sigma))

Prior params:                	 mu=0.00 	 sigma=1.00
..................................................
VI Posterior params:         	 mu=2.50 	 sigma=0.02
Analytical Posterior params: 	 mu=2.52 	 sigma=0.03
