In [None]:
import variational_bayes as vb
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook
import sklearn.metrics
%matplotlib inline

In [None]:
np.random.seed(1)
num_nodes = 100
num_groups = 3

z = np.random.choice(num_groups, num_nodes)

proba = np.random.uniform(0, 1, (num_groups, num_groups))

_proba = proba[z[None, :], z[:, None]]
adjacency = np.random.uniform(0, 1, _proba.shape) < _proba
onehot = np.zeros((num_nodes, num_groups))
onehot[np.arange(num_nodes), z] = 1

In [None]:
def stochastic_block_model(adjacency, num_groups):
    num_nodes, _ = adjacency.shape
    q_z = vb.CategoricalDistribution(np.random.dirichlet(1000 * np.ones(num_groups), num_nodes))
    q_proba = vb.BetaDistribution(10 + np.random.exponential(1e-3, (num_groups, num_groups)),
                                  10 + np.random.exponential(1e-3, (num_groups, num_groups)))
    
    likelihoods = [
        vb.InteractingMixtureDistribution(q_z, vb.BernoulliDistribution(q_proba)).likelihood(adjacency[..., None, None]),
        vb.CategoricalDistribution(np.ones(num_groups) / num_groups).likelihood(q_z),
        vb.BetaDistribution(1, 1).likelihood(q_proba),
    ]
    
    return vb.InteractingMixtureModel({'z': q_z, 'proba': q_proba}, likelihoods)

In [None]:
ensemble = vb.ModelEnsemble(stochastic_block_model, (adjacency, num_groups))
ensemble.update(20, None, tqdm_notebook, convergence_predicate=vb.ConvergencePredicate(1e-3, 10))

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True)
ax1.imshow(ensemble.best_model['proba'].mean)
ax2.imshow(ensemble.best_model['z'].mean[np.argsort(z)], aspect='auto')
sklearn.metrics.adjusted_rand_score(z, np.argmax(ensemble.best_model['z'].mean, axis=1))

In [None]:
list_num_groups = [1, 2, 3, 4, 5, 6]
elbos = []
for _num_groups in tqdm_notebook(list_num_groups):
    ensemble = vb.ModelEnsemble(stochastic_block_model, (adjacency, _num_groups))
    ensemble.update(20, None, convergence_predicate=vb.ConvergencePredicate(1e-3, 10))
    elbos.append(ensemble.best_elbo)

In [None]:
plt.plot(list_num_groups, np.asarray(elbos) - np.max(elbos), marker='.')
plt.axvline(num_groups)