In [None]:
import variational_bayes as vb
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook
import sklearn.metrics
%matplotlib inline

In [None]:
np.random.normal(1)
num_nodes = 200
num_groups = 5

z = np.random.choice(num_groups, num_nodes)

proba = np.random.uniform(0, .1, (num_groups, num_groups))
i = np.arange(num_groups)
proba[i, i] = np.random.uniform(.2, 1, num_groups)
# proba = np.eye(num_groups) * .8 + .1


_proba = proba[z[None, :], z[:, None]]
adjacency = np.random.uniform(0, 1, _proba.shape) < _proba
onehot = np.zeros((num_nodes, num_groups))
onehot[np.arange(num_nodes), z] = 1

In [None]:
class StochasticBlockModel(vb.Model):
    def __init__(self, adjacency, num_groups):
        num_nodes, _ = adjacency.shape
        q_z = vb.CategoricalDistribution(np.random.dirichlet(1000 * np.ones(num_groups), num_nodes))
        q_proba = vb.BetaDistribution(10 + np.random.exponential(1e-3, (num_groups, num_groups)),
                                      10 + np.random.exponential(1e-3, (num_groups, num_groups)))
        
        likelihoods = [
            vb.InteractingMixtureLikelihood(q_z, vb.BernoulliLikelihood, 
                                            x=adjacency[..., None, None], proba=q_proba),
            vb.CategoricalLikelihood(q_z, np.ones(num_groups) / num_groups),
            vb.BetaLikelihood(q_proba, 1, 1),
        ]
        
        super(StochasticBlockModel, self).__init__({
            'z': q_z, 
            'proba': q_proba
        }, likelihoods)
        
    def update_factor(self, factor):
        if isinstance(factor, str):
            factor = self._factors[factor]
        if factor is self._factors.get('z'):
            # Collect the natural parameters except for the interacting mixture
            mixture = self._likelihoods[0]
            natural_parameters = self.aggregate_natural_parameters(factor, [mixture])
            natural_parameters = mixture.indicator_natural_parameters(natural_parameters, **mixture.parameters)
            factor.update_from_natural_parameters(natural_parameters)
        else:
            super(StochasticBlockModel, self).update_factor(factor)

In [None]:
ensemble = vb.ModelEnsemble(StochasticBlockModel, (adjacency, num_groups))
ensemble.update(20, None, tqdm_notebook, convergence_predicate=vb.ConvergencePredicate(1e-3, 10))

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True)
ax1.imshow(ensemble.best_model['proba'].mean)
ax2.imshow(ensemble.best_model['z'].mean[np.argsort(z)], aspect='auto')
sklearn.metrics.adjusted_rand_score(z, np.argmax(ensemble.best_model['z'].mean, axis=1))