Skip to content

Commit

Permalink
Improve GMM by registering m_components within the new parameter
Browse files Browse the repository at this point in the history
framework.

* Update also the associated observe() directives;
* Improve GMM observable meta example;
  • Loading branch information
geektoni committed May 30, 2019
1 parent 4bbf17d commit 13287b8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
17 changes: 11 additions & 6 deletions examples/meta/src/observers/gaussian_mixture_models.sg
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,22 @@ gmm.subscribe(obs)
#![train_sample]
gmm.put("features", features_train)
gmm.train_em()
RealVector output = gmm.sample()
#![train_sample]

#![training_smem]
obs.clear()
gmm.train_smem()
#![training_smem]

#![extract_params]
#![extract_coeff]
int num_observations = obs.get_int("num_observations")
ObservedValue obs_val = obs.get_observation(num_observations-1)
RealVector coeff = obs_val.get_real_vector("coefficients")
real log_likelihood = obs_val.get_real("log_likelihood")
RealVector coef = gmm.get_real_vector("coefficients")
#![extract_params]
ObservedValue obs_val = obs.get_observation(num_observations-2)
RealVector observed_coeff = obs_val.get_real_vector("coefficients")
RealVector coeff = gmm.get_real_vector("coefficients")
#![extract_coeff]

#![extract_likelihood]
ObservedValue obs_val_2 = obs.get_observation(num_observations-1)
real log_likelihood = obs_val_2.get_real("log_likelihood")
#![extract_likelihood]
7 changes: 3 additions & 4 deletions src/shogun/clustering/GMM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ float64_t CGMM::train_em(float64_t min_cov, int32_t max_iter, float64_t min_chan

this->observe<float64_t>(iter, "log_likelihood", "Log Likelihood", log_likelihood_cur);
this->observe<SGVector<float64_t>>(iter, "coefficients", "Mixture Coefficients", alpha);
this->observe<std::vector<CGaussian*>>(iter, "gaussians", "Gaussians Components", m_components);
this->observe<std::vector<CGaussian*>>(iter, "components");

iter++;
}
Expand Down Expand Up @@ -342,7 +342,7 @@ float64_t CGMM::train_smem(int32_t max_iter, int32_t max_cand, float64_t min_cov

this->observe<float64_t>(iter, "log_likelihood", "Log Likelihood", cur_likelihood);
this->observe<SGVector<float64_t>>(iter, "coefficients");
this->observe<std::vector<CGaussian*>>(iter, "gaussians", "Gaussians Components",m_components);
this->observe<std::vector<CGaussian*>>(iter, "components");

iter++;
pb.print_progress();
Expand Down Expand Up @@ -832,8 +832,7 @@ SGVector<float64_t> CGMM::cluster(SGVector<float64_t> point)

void CGMM::register_params()
{
//TODO serialization broken
//m_parameters->add((SGVector<CSGObject*>*) &m_components, "m_components", "Mixture components");
this->watch_param("components", &m_components, AnyParameterProperties("Mixture components"));
SG_ADD(
&m_coefficients, "coefficients", "Mixture coefficients.");
}

0 comments on commit 13287b8

Please sign in to comment.