In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.nn import Parameter
from torch.autograd import Variable

from math import pi, log

In [2]:
class SVI():
    '''Class for stochastic variational inference
    
    '''
    
    def __init__(self, joint_distr, var_distr, opt):
        '''Initialization
        
        Args:
            joint_distr: requred methods: joint_logpdf(x, z)
            var_distr: required methods: logpdf(z), sample()
            opt: optimizer
        
        '''
        
        self.joint_distr = joint_distr
        self.var_distr = var_distr
        self.opt = opt
    
    def mc_elbo_(self, num_samples, batch):
        '''Computing naive Monte Carlo approximation of ELBO over batch
        
        Args:
            num_samples: number of samples used for approximation
            batch: batch of observed variables
        
        Returns:
            mc_elbo: naive Monte Carlo approximation of ELBO over batch
        
        '''
        
        mc_elbo = Variable(torch.zeros(1))
        
        for i in range(num_samples):
            z = self.var_distr.sample()
            mc_elbo += self.joint_distr.joint_logpdf(batch, z) - self.var_distr.logpdf(z)
                
        mc_elbo /= num_samples
        
        return mc_elbo
    
    def make_inference(self, num_steps, batch_iterator, num_samples=10, print_progress=False):
        '''Making SVI
        
        Args:
            num_steps: number of epoches
            batch_iterator: iterator over batches
            num_samples: number of samples used for ELBO approximation
            print_progress: boolean, if True progress bar is printed
            
        '''
        
        for step in range(num_steps):
            batch_elbo = []
            for batch in batch_iterator:
                loss = self.mc_elbo_(num_samples, batch)
                batch_elbo.append(float(loss))
                (-loss).backward()
                opt.step()
            opt.zero_grad()
            
            if print_progress:
                if (int(25 * step / num_steps) != int(25 * (step - 1) / num_steps)):
                    print('.', end='')
                    

**Пример:**

Мы воспроизводим простейший пример вариационного вывода, приведенный в [документации Pyro](http://pyro.ai/examples/svi_part_i.html#A-simple-example)

Есть монетка, вероятность выпадения единицы равна $p$. Дан (неупорядоченный) набор результатов подбрасывания монетки. Хотим по нему оценить $p$. Сопряженное к распределению Бернулли -- бета-распределение, поэтому мы будем считать, что априорное распределение на $p$ -- бета-распределение с параметрами (10, 10).

В качестве априорного распределения на $p$ мы берем бета-распределение с параметрами 10 и 10. В таком случае апостериорное распределение тоже будет бета-распределением, но с другими параметрами. Потому в качестве вариационного распределения мы также берем бета-распределение.

**Пример не работает на стабильной версии PyTorch, т.к. на момент ее релиза еще не был реализован проброс градиента через гамма-функцию**

In [3]:
class JointDistribution():
    
    def __init__(self, prior):
        self.parameters = prior.parameters
        self.prior = prior
    
    def cond_logpdf(self, x, z):
        return torch.sum(torch.log(z * x + (1 - z) * (1 - x)))
    
    def joint_logpdf(self, x, z):
        return self.prior.logpdf(z) + self.cond_logpdf(x, z)
   


In [4]:
class Beta():
    
    def __init__(self, alpha=1., beta=1.):
        self.log_alpha = Parameter(torch.log(torch.FloatTensor([alpha])))
        self.log_beta = Parameter(torch.log(torch.FloatTensor([beta])))
        self.parameters = [self.log_alpha, self.log_beta]
        
    def logpdf(self, x):
        eps = 1e-8
        logbeta = torch.lgamma(torch.exp(torch.clamp(self.log_alpha, -10, 10))) + \
                  torch.lgamma(torch.exp(torch.clamp(self.log_beta, -10, 10))) - \
                  torch.lgamma(torch.exp(torch.clamp(self.log_alpha, -10, 10)) + \
                               torch.exp(torch.clamp(self.log_beta, -10, 10)))
        y = torch.clamp(x, eps, 1 - eps)
        alpha_term = (torch.exp(torch.clamp(self.log_alpha, -10, 10)) - 1) * torch.log(y)
        beta_term = (torch.exp(torch.clamp(self.log_beta, -10, 10)) - 1) * torch.log(1 - y)
        return alpha_term + beta_term - logbeta
    
    def sample(self):
        return Variable(torch.FloatTensor(np.random.beta(torch.exp(torch.clamp(self.log_alpha, -10, 10)).data.numpy(),
                                                         torch.exp(torch.clamp(self.log_beta, -10, 10)).data.numpy())),
                        requires_grad=False)


In [5]:
var = Beta(15, 15)
prior = Beta(10, 10)
joint = JointDistribution(prior)

In [6]:
data = [[1.]] * 6 + [[0.]] * 4
data = Variable(torch.FloatTensor(data))
data = data.view(1, 10)

In [7]:
opt = torch.optim.Adam(var.parameters, lr=5e-3)

In [8]:
svi = SVI(joint, var, opt)

In [9]:
svi.make_inference(4000, data, 25, print_progress=True)

........................

In [10]:
# grab the learned variational parameters
alpha_q = torch.exp(var.log_alpha).data.numpy()[0]
beta_q = torch.exp(var.log_beta).data.numpy()[0]

# here we use some facts about the beta distribution
# compute the inferred mean of the coin's fairness
inferred_mean = alpha_q / (alpha_q + beta_q)
# compute inferred standard deviation
factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
inferred_std = inferred_mean * np.sqrt(factor)

print("\nbased on the data and our prior belief, the fairness " +
      "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std))


based on the data and our prior belief, the fairness of the coin is 0.576 +- 0.091


In [11]:
var = Beta(15, 15)
prior = Beta(10, 10)
joint = JointDistribution(prior)

In [12]:
data = [[0.]] * 9 + [[1.]] * 1
data = Variable(torch.FloatTensor(data))
data = data.view(1, 10)

In [13]:
opt = torch.optim.Adam(var.parameters, lr=1e-2)

In [14]:
svi = SVI(joint, var, opt)

In [15]:
svi.make_inference(2000, data, 10, print_progress=True)

........................

In [16]:
# grab the learned variational parameters
alpha_q = torch.exp(var.log_alpha).data.numpy()[0]
beta_q = torch.exp(var.log_beta).data.numpy()[0]

# here we use some facts about the beta distribution
# compute the inferred mean of the coin's fairness
inferred_mean = alpha_q / (alpha_q + beta_q)
# compute inferred standard deviation
factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
inferred_std = inferred_mean * np.sqrt(factor)

print("\nbased on the data and our prior belief, the fairness " +
      "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std))


based on the data and our prior belief, the fairness of the coin is 0.321 +- 0.081
