In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.nn import Parameter
from torch.autograd import Variable

from math import pi, log
from BBSVI import SVI

# Gaussian Mixture Example

Assume we have mixture of some Gaussians and we want to distinguish them. Here we do it using Black Box SVI.

In [2]:
# generating data

std = 1
mu = np.array([-5, 5])
num_components = 2

np.random.seed(42)
num_samples = 1000

components = np.random.choice(num_components, size=num_samples, p=np.array([1/2, 1/2]))
data = torch.Tensor(np.random.normal(mu[components], std))

Below we define two classes for prior and variational distributions:

In [3]:
class GaussianMixture:
    def __init__(self, num_components, log_std=0):
        
        self.log_std = nn.Parameter(torch.Tensor([log_std]))
        self.means = nn.Parameter(1 * torch.randn(num_components, requires_grad=True))
        self.weights = nn.Parameter(torch.ones(num_components, requires_grad=True) / num_components)
        self.weights_distr = torch.distributions.Categorical(probs=self.weights)
        self.weights_distr.probs = self.weights
        self.components_distrs = [torch.distributions.Normal(self.means[i], 
                                                             torch.exp(self.log_std)) for i in range(num_components)]
        self.parameters = [self.log_std, self.means]

    def log_likelihood_global(self, beta):
        return torch.zeros(1, requires_grad=True)
    
    def log_likelihood_local(self, z, beta):
        return self.weights_distr.log_prob(z)
    
    def log_likelihood_joint(self, x, z, beta):
        mix_term = self.weights_distr.log_prob(z)
        normal_term = self.components_distrs[z].log_prob(x)
        return mix_term + normal_term

In [4]:
class VariationalDistribution:
    def __init__(self, num_components, data_size):
        self.num_components = num_components
        self.probs = [torch.nn.Parameter(torch.ones(num_components) / num_components) for _ in range(data_size)]
        self.distrs = [torch.distributions.Categorical(probs=self.probs[i]) for i in range(data_size)]
        self.global_parameters = []
        self.local_parameters = self.probs
        
    def sample_global(self, num_samples = 1):
        return torch.zeros(1)
    
    def sample_local(self, beta, idx):
        return self.distrs[idx].sample()
    
    def log_likelihood_global(self, beta):
        return torch.zeros(1, requires_grad=True)
    
    def log_likelihood_local(self, z, idx):
        return self.distrs[idx].log_prob(z)

Define prior, variational distribution and optimizer

In [5]:
prior = GaussianMixture(num_components)
var = VariationalDistribution(num_components, num_samples)
opt = torch.optim.Adam([{'params': var.local_parameters},
                        {'params': var.global_parameters},
                        {'params': prior.parameters}], lr=1e-3)

Do the inference

In [6]:
svi = SVI(data, prior, var, opt)

Results:

In [7]:
predicted_mu = prior.means.data.numpy()
print('Initial means:   \t %.2f \t %.2f' % (predicted_mu[0], predicted_mu[1]))
svi.make_inference(num_steps=50, shuffle=False, loss='bb2', print_progress=True)

predicted_mu = prior.means.data.numpy()
print('Predicted means: \t %.2f \t %.2f' % (predicted_mu[0], predicted_mu[1]))
print('Real means:      \t %.2f \t %.2f' % (mu[0], mu[1]))

predicted_components = np.array([torch.max(var.probs[i], dim=-1)[1].data.numpy() for i in range(num_samples)])
accuracy = np.sum(predicted_components == components) / len(predicted_components)
accuracy = max(accuracy, 1 - accuracy)
print('Mixture components detecting accuracy: %.2f %%' % (accuracy * 100))

Initial means:   	 -1.01 	 0.55
..................................................
Predicted means: 	 -5.94 	 4.63
Real means:      	 -5.00 	 5.00
Mixture components detecting accuracy: 100.00 %


**NOTE:** SVI strongly depends on the initialization configuration. If you fail to reproduce good result, try to reinitialize initial parameters