Skip to content

Commit

Permalink
Added SMEM GMM examples
Browse files Browse the repository at this point in the history
  • Loading branch information
alesis committed Jul 24, 2011
1 parent 104f0f0 commit fc4baa4
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 0 deletions.
54 changes: 54 additions & 0 deletions examples/undocumented/python_modular/graphical/smem_1d_gmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from pylab import figure,show,connect,hist,plot,legend
from numpy import array, append, arange, empty
from shogun.Distribution import Gaussian, GMM
from shogun.Features import RealFeatures
import util

util.set_title('SMEM for 1d GMM example')

max_iter=100
max_cand=5
min_cov=1e-9
max_em_iter=1000
min_change=1e-9

real_gmm=GMM(3)

real_gmm.set_nth_mean(array([-2.0]), 0)
real_gmm.set_nth_mean(array([0.0]), 1)
real_gmm.set_nth_mean(array([2.0]), 2)

real_gmm.set_nth_cov(array([[0.3]]), 0)
real_gmm.set_nth_cov(array([[0.1]]), 1)
real_gmm.set_nth_cov(array([[0.2]]), 2)

real_gmm.set_coef(array([0.3, 0.5, 0.2]))

generated=array([real_gmm.sample()])
for i in range(199):
generated=append(generated, array([real_gmm.sample()]), axis=1)

feat_train=RealFeatures(generated)
est_smem_gmm=GMM(3)
est_smem_gmm.train(feat_train)
print est_smem_gmm.train_smem(max_iter, max_cand, min_cov, max_em_iter, min_change)
est_em_gmm=GMM(3)
est_em_gmm.train(feat_train)
print est_em_gmm.train_em(min_cov, max_em_iter, min_change)

min_gen=min(min(generated))
max_gen=max(max(generated))
plot_real=empty(0)
plot_est_smem=empty(0)
plot_est_em=empty(0)
for i in arange(min_gen, max_gen, 0.001):
plot_real=append(plot_real, array([real_gmm.cluster(array([i]))[3]]))
plot_est_smem=append(plot_est_smem, array([est_smem_gmm.cluster(array([i]))[3]]))
plot_est_em=append(plot_est_em, array([est_em_gmm.cluster(array([i]))[3]]))
real_plot=plot(arange(min_gen, max_gen, 0.001), plot_real, "b")
est_em_plot=plot(arange(min_gen, max_gen, 0.001), plot_est_em, "g")
est_smem_plot=plot(arange(min_gen, max_gen, 0.001), plot_est_smem, "r")
real_hist=hist(generated.transpose(), bins=50, normed=True, fc="gray")
legend(("Real GMM", "Estimated EM GMM", "Estimated SMEM GMM"))
connect('key_press_event', util.quit)
show()
67 changes: 67 additions & 0 deletions examples/undocumented/python_modular/graphical/smem_2d_gmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from pylab import figure,scatter,contour,show,legend,connect
from numpy import array, append, arange, reshape, empty
from shogun.Distribution import Gaussian, GMM
from shogun.Features import RealFeatures
import util

util.set_title('SMEM for 2d GMM example')

max_iter=100
max_cand=5
min_cov=1e-9
max_em_iter=1000
min_change=1e-9
cov_type=0

real_gmm=GMM(3)

real_gmm.set_nth_mean(array([2.0, 2.0]), 0)
real_gmm.set_nth_mean(array([-2.0, -2.0]), 1)
real_gmm.set_nth_mean(array([2.0, -2.0]), 2)

real_gmm.set_nth_cov(array([[1.0, 0.2],[0.2, 0.5]]), 0)
real_gmm.set_nth_cov(array([[0.2, 0.1],[0.1, 0.5]]), 1)
real_gmm.set_nth_cov(array([[0.3, -0.2],[-0.2, 0.8]]), 2)

real_gmm.set_coef(array([0.3, 0.4, 0.3]))

generated=array([real_gmm.sample()])
for i in range(199):
generated=append(generated, array([real_gmm.sample()]), axis=0)

generated=generated.transpose()
feat_train=RealFeatures(generated)
est_smem_gmm=GMM(3, cov_type)
est_smem_gmm.train(feat_train)
print est_smem_gmm.train_smem(max_iter, max_cand, min_cov, max_em_iter, min_change)
est_em_gmm=GMM(3, cov_type)
est_em_gmm.train(feat_train)
print est_em_gmm.train_em(min_cov, max_em_iter, min_change)

min_x_gen=min(min(generated[[0]]))-0.1
max_x_gen=max(max(generated[[0]]))+0.1
min_y_gen=min(min(generated[[1]]))-0.1
max_y_gen=max(max(generated[[1]]))+0.1

plot_real=empty(0)
plot_est_smem=empty(0)
plot_est_em=empty(0)

for i in arange(min_x_gen, max_x_gen, 0.05):
for j in arange(min_y_gen, max_y_gen, 0.05):
plot_real=append(plot_real, array([real_gmm.cluster(array([i, j]))[3]]))
plot_est_smem=append(plot_est_smem, array([est_smem_gmm.cluster(array([i, j]))[3]]))
plot_est_em=append(plot_est_em, array([est_em_gmm.cluster(array([i, j]))[3]]))

plot_real=reshape(plot_real, (arange(min_x_gen, max_x_gen, 0.05).shape[0], arange(min_y_gen, max_y_gen, 0.05).shape[0]))
plot_est_smem=reshape(plot_est_smem, (arange(min_x_gen, max_x_gen, 0.05).shape[0], arange(min_y_gen, max_y_gen, 0.05).shape[0]))
plot_est_em=reshape(plot_est_em, (arange(min_x_gen, max_x_gen, 0.05).shape[0], arange(min_y_gen, max_y_gen, 0.05).shape[0]))

real_plot=contour(arange(min_x_gen, max_x_gen, 0.05), arange(min_y_gen, max_y_gen, 0.05), plot_real.transpose(), colors="b")
est_smem_plot=contour(arange(min_x_gen, max_x_gen, 0.05), arange(min_y_gen, max_y_gen, 0.05), plot_est_smem.transpose(), 3, colors="r")
est_em_plot=contour(arange(min_x_gen, max_x_gen, 0.05), arange(min_y_gen, max_y_gen, 0.05), plot_est_em.transpose(), colors="g")
real_scatter=scatter(generated[[0]], generated[[1]], c="gray")
legend((real_plot.collections[0], est_em_plot.collections[0], est_smem_plot.collections[0]), ("Real GMM", "Estimated EM GMM", "Estimated SMEM GMM"))

connect('key_press_event', util.quit)
show()

0 comments on commit fc4baa4

Please sign in to comment.