In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.nn import Parameter
from torch.autograd import Variable

from math import pi, log
# from BBSVI import SVI

In [102]:
import torch
import numpy as np


class SVI():
    '''Class for black box stochastic variational inference
        https://arxiv.org/abs/1401.0118
    
    '''
    
    def __init__(self, data, prior_distr, var_distr, opt, scheduler=None):
        '''Initialization
        
        Args:
            data: oserved data
            prior_distr: class for prior probabilistic model
                requred methods: log_likelihood_global(beta)
                                 log_likelihood_local(z, beta)
                                 log_likelihood_joint(x, z, beta)
            var_distr: class for variational distribution
                required methods: log_likelihood_global(beta)
                                  log_likelihood_local(z, idx)
                                  sample_global()
                                  sample_local(idx)
            opt: optimizer
            scheduler: scheduler for an optimizer
        
        '''
        
        self.data = data
        self.prior_distr = prior_distr
        self.var_distr = var_distr
        self.opt = opt
        self.scheduler = scheduler      

    def bb1_loss_(self, num_samples, batch_indices):
        '''Computing loss of BB SVI 1
        
        Args:
            num_samples: number of samples used for approximation
            batch_indices: indices of batch
        
        Returns:
            loss: Black Box loss function
        
        '''
        
        global_loss = torch.zeros(1, requires_grad=True)
        local_loss = torch.zeros(1, requires_grad=True)
        
        for i in range(num_samples):
            
            beta = self.var_distr.sample_global()
            global_const_term = torch.zeros(1, requires_grad=False)
            
            for idx in batch_indices:
                x = self.data[idx]
                z = self.var_distr.sample_local(beta, idx)
                
                local_const_term = self.prior_distr.log_likelihood_local(z, beta) + \
                                   self.prior_distr.log_likelihood_joint(x, z, beta) - \
                                   self.var_distr.log_likelihood_local(z, idx)
                
                local_var_term = self.var_distr.log_likelihood_local(z, idx)
                
                local_loss = local_loss + local_var_term * local_const_term.data
                
                global_const_term += self.prior_distr.log_likelihood_local(z, beta) + \
                                     self.prior_distr.log_likelihood_joint(x, z, beta)
            
            global_const_term *= self.data.shape[0] / batch_indices.shape[0]
            global_const_term += self.prior_distr.log_likelihood_global(beta) - \
                                 self.var_distr.log_likelihood_global(beta)
            
            global_var_term = self.var_distr.log_likelihood_global(beta)
            global_loss = global_loss + global_var_term * global_const_term.data
                    
        loss = -(global_loss + local_loss) / num_samples
        
        return loss
    
    def bb2_loss_(self, num_samples, batch_indices):
        '''Computing loss of BB SVI 2, which has lower variance compare to BB SVI 1
        
        Args:
            num_samples: number of samples used for approximation
            batch_indices: indices of batch
        
        Returns:
            loss: Black Box loss function
        
        '''
        
        loss_global = torch.zeros(1, requires_grad=True)
        loss_local = torch.zeros(1, requires_grad=True)
        
        global_samples = [self.var_distr.sample_global() for _ in range(num_samples)]
        a = torch.autograd.grad(self.var_distr.log_likelihood_global(global_samples[0]), 
                                self.var_distr.parameters, 
                                allow_unused=True)
        a = [torch.zeros(1) if x is None else x for x in a]
        
        b = torch.autograd.grad(self.var_distr.log_likelihood_global(global_samples[1]), 
                                self.var_distr.parameters, 
                                allow_unused=True)
        b = [torch.zeros(1) if x is None else x for x in b]
        c = tuple(map(operator.add, a, b))
        print(a)
        print(b)
        print(c)
        raise Exception
        for idx in batch_indices:
            pass

    def make_inference(self, num_steps=100, num_samples=10, batch_size=10, shuffle=False, print_progress=True):
        '''Making SVI
        
        Args:
            num_steps: int, maximum number of epoches
            tol: required tolerance
            num_samples: int, number of samples used for ELBO approximation
            batch_size: int, size of one batch
            shuffle: boolean, if batch is shuffled every epoch or not
            print_progress: boolean, if True then progrss bar is printed
            
        '''
        
        
        for step in range(num_steps):
            
            if shuffle:
                indices = np.random.choice(self.data.shape[0], self.data.shape[0], False)
            else:
                indices = np.arange(self.data.shape[0])
                
            indices = np.split(indices, np.arange(batch_size, self.data.shape[0], batch_size))
                
            for batch_indices in indices:
                self.opt.zero_grad()
                loss = self.bb1_loss_(num_samples, batch_indices)
                loss.backward()
                self.opt.step()

            if print_progress:
                if (int(25 * step / num_steps) != int(25 * (step - 1) / num_steps)):
                    print('.', end='')
        
        if print_progress:
            print()



# Gaussian Mixture Example

Assume we have mixture of some Gaussians and we want to distinguish them. Here we do it using Black Box SVI.

In [103]:
# generating data

std = 1
mu = np.array([-5, 5])
num_components = 2

np.random.seed(42)
num_samples = 100

components = np.random.randint(num_components, size=num_samples)
data = torch.Tensor(np.random.normal(mu[components], std))

Below we define two classes for prior and variational distributions:

In [104]:
class GaussianMixture:
    def __init__(self, num_components, std=1.):
        self.std = std
        self.num_components = num_components
        self.prior = torch.distributions.Normal(0, std)
        
    def log_likelihood_global(self, beta):
        return torch.sum(self.prior.log_prob(beta))
    
    def log_likelihood_local(self, z, beta):
        return torch.ones(1) / self.num_components
    
    def log_likelihood_joint(self, x, z, beta):
        dist = torch.distributions.Normal(beta[z], 1)
        return dist.log_prob(x)

In [105]:
class VariationalDistribution:
    def __init__(self, num_components, data_size):
        self.num_components = num_components
        self.means = torch.nn.Parameter(torch.normal(torch.zeros(num_components), torch.ones(num_components)))
        self.log_std = torch.nn.Parameter(torch.zeros(num_components))
        self.probs = torch.nn.Parameter(torch.ones(data_size, num_components) / num_components)
        self.parameters = [self.means, self.log_std, self.probs]
        
    def sample_global(self):
        return torch.normal(self.means, torch.exp(self.log_std))
    
    def sample_local(self, beta, idx):
        probs = self.probs[idx].data.numpy()
        z = np.random.choice(self.num_components, p=probs / np.sum(probs))
        return torch.LongTensor([z])
    
    def log_likelihood_global(self, beta):
        return torch.sum((beta - self.means) ** 2 / (2 * torch.exp(2 * self.log_std)) - self.log_std - log(2 * pi) / 2)
    
    def log_likelihood_local(self, z, idx):
        return self.probs[idx, z] / torch.sum(self.probs[idx])

Define prior, variational distribution and optimizer

In [106]:
prior = GaussianMixture(num_components, std=10)
var = VariationalDistribution(num_components, num_samples)

opt = torch.optim.Adam(var.parameters, lr=1e-3)

Do the inference

In [107]:
svi = SVI(data, prior, var, opt)

Results:

In [108]:
predicted_mu = var.means.data.numpy()
print('Initial means:   \t %.2f \t %.2f' % (predicted_mu[0], predicted_mu[1]))
svi.make_inference(num_steps=1000, shuffle=False, print_progress=True)

predicted_mu = var.means.data.numpy()
print('Predicted means: \t %.2f \t %.2f' % (predicted_mu[0], predicted_mu[1]))
print('Real means:      \t %.2f \t %.2f' % (mu[0], mu[1]))

predicted_components = torch.max(var.probs, dim=1)[1].data.numpy()
accuracy = np.sum(predicted_components == components) / len(predicted_components)
accuracy = max(accuracy, 1 - accuracy)
print('Mixture components detecting accuracy: %.2f %%' % (accuracy * 100))

Initial means:   	 0.24 	 0.51
[tensor([ 2.4586, -0.0703]), tensor([-7.0448, -1.0049]), tensor([ 0.])]
[tensor([ 0.1047,  1.0934]), tensor([-1.0110, -2.1955]), tensor([ 0.])]
(tensor([ 2.5633,  1.0231]), tensor([-8.0557, -3.2005]), tensor([ 0.]))


Exception: 

In [101]:
var.probs

Parameter containing:
tensor([[ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000],
        [ 0.5000,  0.5000]

**NOTE:** SVI strongly depends on the initialization configuration. If you fail to reproduce good result, try to reinitialize initial parameters

In [14]:
a = torch.ones(1, requires_grad=True)

In [17]:
a.data.requires_grad

False

In [61]:
a = (0, 1, None) 
b = (2, 1, None)

In [62]:
[sum(x) for x in zip(a, b)]

TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'

In [44]:
import operator
tuple(map(operator.add, a, b))

TypeError: unsupported operand type(s) for +: 'NoneType' and 'NoneType'

In [63]:
torch.Tensor([None])

TypeError: must be real number, not NoneType

In [71]:
a

(0, 1, None)

In [None]:
a