Skip to content

Commit

Permalink
fix null spectrum approximation formula in quadratic time MMD
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday committed Mar 18, 2014
1 parent 8963986 commit e8484ff
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
31 changes: 19 additions & 12 deletions src/shogun/statistics/QuadraticTimeMMD.cpp
Expand Up @@ -394,12 +394,14 @@ SGVector<float64_t> CQuadraticTimeMMD::sample_null_spectrum(index_t num_samples,
/* precomputing terms with rho_x and rho_y of equation 10 in [1]
* (see documentation) */
float64_t rho_x=float64_t(m)/(m+n);
float64_t rho_y=1.0-rho_x;
float64_t sqrt_rho_x=CMath::sqrt(rho_x);
float64_t sqrt_rho_y=CMath::sqrt(rho_y);
float64_t inv_rho_x_y=1.0/(rho_x*rho_y);
float64_t rho_y=1-rho_x;

SG_DEBUG("sqrt_rho_x=%f, sqrt_rho_y=%f\n", sqrt_rho_x, sqrt_rho_y);
/* instead of using two Gaussian rv's ~ N(0,1), we'll use just one rv
* ~ N(0, 1/rho_x+1/rho_y) (derived from eq 10 in [1]) */
float64_t std_dev=CMath::sqrt(1/rho_x+1/rho_y);
float64_t inv_rho_x_y=1/(rho_x*rho_y);

SG_DEBUG("Using Gaussian samples ~ N(0,%f)\n", std_dev*std_dev);

/* finally, sample from null distribution */
SGVector<float64_t> null_samples(num_samples);
Expand All @@ -408,13 +410,14 @@ SGVector<float64_t> CQuadraticTimeMMD::sample_null_spectrum(index_t num_samples,
null_samples[i]=0;
for (index_t j=0; j<num_eigenvalues; ++j)
{
/* compute the right hand multiple of eq. 10 in [1] */
float64_t a_j=CMath::randn_double();
float64_t b_j=CMath::randn_double();
/* compute the right hand multiple of eq. 10 in [1] using one RV.
* randn_double() gives a sample from N(0,1), we need samples
* from N(0,1/rho_x+1/rho_y) */
float64_t z_j=std_dev*CMath::randn_double();

SG_DEBUG("a_j=%f, b_j=%f\n", a_j, b_j);
SG_DEBUG("z_j=%f\n", z_j);

float64_t multiple=CMath::pow(a_j/sqrt_rho_x-b_j/sqrt_rho_y, 2);
float64_t multiple=CMath::pow(z_j, 2);

/* take largest EV, scale by 1/(m+n) on the fly and take abs value*/
float64_t eigenvalue_estimate=CMath::abs(1.0/(m+n)
Expand All @@ -423,13 +426,17 @@ SGVector<float64_t> CQuadraticTimeMMD::sample_null_spectrum(index_t num_samples,
if (m_statistic_type==UNBIASED)
multiple-=inv_rho_x_y;

SG_DEBUG("multiple=%f, eigenvalue=%f\n", multiple, eigenvalue_estimate);
SG_DEBUG("multiple=%f, eigenvalue=%f\n", multiple,
eigenvalue_estimate);

null_samples[i]+=eigenvalue_estimate*multiple;
}
null_samples[i]*=2;
}

/* when m=n, return m*MMD^2 instead */
if (m==n)
null_samples.scale(0.5);

return null_samples;
}
#endif // HAVE_EIGEN3
Expand Down
1 change: 1 addition & 0 deletions src/shogun/statistics/QuadraticTimeMMD.h
Expand Up @@ -233,6 +233,7 @@ class CQuadraticTimeMMD : public CKernelTwoSampleTest
*
* Note that (m+n)*Null-distribution is returned,
* which is fine since the statistic is also (m+n)*MMD:
* except when m and n are equal, then m*MMD^2 is returned
*
* Works well if the kernel matrix is NOT diagonal dominant.
* See Gretton, A., Fukumizu, K., & Harchaoui, Z. (2011).
Expand Down

0 comments on commit e8484ff

Please sign in to comment.