In [1]:
import numpy as np
import riemann

from gmm import GMM, GMMParameters

# RJMCMC Example with univariate mixture of Gaussians

In [2]:
gmm = GMM(weights=np.array([0.4,0.6]), means=np.array([[0.], [0.9]]), covariances=np.array([[[0.01]], [[0.01]]]))

In [3]:
import matplotlib.pyplot as plt

query_x = np.linspace(-1,2,100)[:,None]
pdf_query_x = gmm.density(query_x)

In [4]:
n_obs = 100
obs = gmm.sample(n_obs)

plt.plot(query_x, pdf_query_x, 'k--', label="true density")
plt.plot(obs, np.zeros(len(obs)), 'k+', ms=10, label="observations")
plt.hist(obs, density=True, histtype='step', label="empirical density")
plt.legend()

<matplotlib.legend.Legend at 0x7f66b076f550>

## Model

In [5]:
from riemann.proposals import rj_proposals, rj_generic

import importlib
importlib.reload(rj_proposals)

<module 'riemann.proposals.rj_proposals' from '/home/rafael/Projects/riemann/riemann/proposals/rj_proposals.py'>

In [6]:
true_cov_factors = np.linalg.cholesky(gmm.covariances)
true_params = GMMParameters.serialise(gmm.weights, gmm.means, true_cov_factors)
true_k = len(gmm.weights)
true_state = rj_proposals.RJState(true_params, true_k)

In [7]:
max_n_comp = 3
dim = int(len(true_params)/len(gmm.weights))
g_means = [np.zeros(dim*n) for n in range(1,max_n_comp+1)]
g_covs = [0.1*np.eye(dim*n) for n in range(1,max_n_comp+1)]
generic_map = rj_generic.GenericMapping(g_means, g_covs)

In [8]:
jump_prop = rj_proposals.JumpProposal(generic_map, rj_generic.GenericMatchingProp(dim))

In [9]:
no_jump_ratio = 1.1
move_prop = rj_generic.MoveProp(max_n_comp, no_jump_ratio)

In [10]:
rj_prop = rj_proposals.RJProposal(move_prop, jump_prop, [rj_generic.RandomWalk(0.01)] * max_n_comp)

In [11]:
n_samples = 100
sample = rj_proposals.RJState(np.random.rand(3), 1)
for i in range(n_samples):
    sample, _ = rj_prop.propose(sample)
    print(sample.param)

[ 0.49345285  0.46130768  0.40900528  0.16651237  0.09508007 -0.03085757
  0.00769313  0.01068609 -0.14811529]
[0.49345285 0.46130768 0.40900528]
[ 0.49345285  0.46130768  0.40900528 -0.01687694  0.02691142  0.0798387
 -0.05336457 -0.00140953 -0.19004706]
[0.49345285 0.46130768 0.40900528]
[0.49716035 0.45629537 0.40626787]
[ 0.49716035  0.45629537  0.40626787 -0.11078803  0.13381682 -0.2193061
  0.1175269  -0.05923439 -0.06100514]
[0.49716035 0.45629537 0.40626787]
[ 0.49716035  0.45629537  0.40626787  0.10064401  0.04011073 -0.02345776
 -0.04081855 -0.14160846 -0.12780408]
[ 0.49716035  0.45629537  0.40626787  0.10064401  0.04011073 -0.02345776]
[0.49716035 0.45629537 0.40626787]
[0.49908897 0.45085139 0.41405085]
[0.49482248 0.45943968 0.41420973]
[ 0.49482248  0.45943968  0.41420973  0.08028451 -0.2225823  -0.2698027
  0.00310035  0.14788107  0.04271556]
[0.49482248 0.45943968 0.41420973]
[0.48533394 0.44981459 0.41876148]
[ 0.48533394  0.44981459  0.41876148  0.07780021  0.0225315