Skip to content

Commit

Permalink
updated multikernel mmd2 to have the same api as single kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday committed Jul 14, 2016
1 parent f9fa858 commit 13ddee1
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 20 deletions.
10 changes: 5 additions & 5 deletions src/shogun/statistical_testing/MultiKernelQuadraticTimeMMD.cpp
Expand Up @@ -110,20 +110,20 @@ void CMultiKernelQuadraticTimeMMD::cleanup()
self->m_dtype=D_UNKNOWN;
}

SGVector<float64_t> CMultiKernelQuadraticTimeMMD::statistic()
SGVector<float64_t> CMultiKernelQuadraticTimeMMD::compute_statistic()
{
ASSERT(self->m_owner);
return statistic(self->m_kernel_mgr);
}

SGVector<float64_t> CMultiKernelQuadraticTimeMMD::variance_h0()
SGVector<float64_t> CMultiKernelQuadraticTimeMMD::compute_variance_h0()
{
ASSERT(self->m_owner);
SG_NOTIMPLEMENTED;
return SGVector<float64_t>();
}

SGVector<float64_t> CMultiKernelQuadraticTimeMMD::variance_h1()
SGVector<float64_t> CMultiKernelQuadraticTimeMMD::compute_variance_h1()
{
ASSERT(self->m_owner);
SG_NOTIMPLEMENTED;
Expand All @@ -136,15 +136,15 @@ SGMatrix<float32_t> CMultiKernelQuadraticTimeMMD::sample_null()
return sample_null(self->m_kernel_mgr);
}

SGVector<float64_t> CMultiKernelQuadraticTimeMMD::p_values()
SGVector<float64_t> CMultiKernelQuadraticTimeMMD::compute_p_value()
{
ASSERT(self->m_owner);
return p_values(self->m_kernel_mgr);
}

SGVector<bool> CMultiKernelQuadraticTimeMMD::perform_test(float64_t alpha)
{
SGVector<float64_t> pvalues=p_values();
SGVector<float64_t> pvalues=compute_p_value();
SGVector<bool> rejections(pvalues.size());
for (auto i=0; i<pvalues.size(); ++i)
{
Expand Down
9 changes: 4 additions & 5 deletions src/shogun/statistical_testing/MultiKernelQuadraticTimeMMD.h
Expand Up @@ -69,12 +69,11 @@ class CMultiKernelQuadraticTimeMMD : public CSGObject
void add_kernel(CShiftInvariantKernel *kernel);
void cleanup();

SGVector<float64_t> statistic();
SGVector<float64_t> variance_h0();
SGVector<float64_t> variance_h1();

SGVector<float64_t> compute_statistic();
SGVector<float64_t> compute_variance_h0();
SGVector<float64_t> compute_variance_h1();
SGMatrix<float32_t> sample_null();
SGVector<float64_t> p_values();
SGVector<float64_t> compute_p_value();
SGVector<bool> perform_test(float64_t alpha);

virtual const char* get_name() const;
Expand Down
19 changes: 9 additions & 10 deletions tests/unit/statistical_testing/QuadraticTimeMMD_unittest.cc
Expand Up @@ -496,7 +496,7 @@ TEST(QuadraticTimeMMD, precomputed_vs_nonprecomputed)
EXPECT_NEAR(result_1[i], result_2[i], 1E-6);
}

TEST(QuadraticTimeMMD, compute_multiple)
TEST(QuadraticTimeMMD, multikernel_compute_statistic)
{
const index_t m=20;
const index_t n=20;
Expand All @@ -518,23 +518,23 @@ TEST(QuadraticTimeMMD, compute_multiple)
float64_t tau=pow(2, sigma);
mmd->multikernel()->add_kernel(new CGaussianKernel(10, tau));
}
SGVector<float64_t> mmd_multiple=mmd->multikernel()->statistic();
SGVector<float64_t> mmd_multiple=mmd->multikernel()->compute_statistic();
mmd->multikernel()->cleanup();

SGVector<float64_t> mmd_single(num_kernels);
for (auto i=0, sigma=-5; i<num_kernels; ++i, sigma+=1)
{
auto mmd2=some<CQuadraticTimeMMD>(feat_p, feat_q);
float64_t tau=pow(2, sigma);
mmd2->set_kernel(new CGaussianKernel(10, tau));
mmd_single[i]=mmd2->compute_statistic();
mmd->set_kernel(new CGaussianKernel(10, tau));
mmd_single[i]=mmd->compute_statistic();
}

ASSERT_EQ(mmd_multiple.size(), mmd_single.size());
for (auto i=0; i<mmd_multiple.size(); ++i)
EXPECT_NEAR(mmd_multiple[i], mmd_single[i], 1E-4);
}

TEST(QuadraticTimeMMD, perform_test_multiple)
TEST(QuadraticTimeMMD, multikernel_perform_test)
{
const index_t m=8;
const index_t n=12;
Expand All @@ -561,16 +561,15 @@ TEST(QuadraticTimeMMD, perform_test_multiple)
}
sg_rand->set_seed(12345);
SGVector<bool> rejections_multiple=mmd->multikernel()->perform_test(alpha);
mmd->multikernel()->cleanup();

SGVector<bool> rejections_single(num_kernels);
for (auto i=0, sigma=-5; i<num_kernels; ++i, sigma+=1)
{
auto mmd2=some<CQuadraticTimeMMD>(feat_p, feat_q);
mmd2->set_num_null_samples(num_null_samples);
float64_t tau=pow(2, sigma);
mmd2->set_kernel(new CGaussianKernel(cache_size, tau));
mmd->set_kernel(new CGaussianKernel(cache_size, tau));
sg_rand->set_seed(12345);
rejections_single[i]=mmd2->perform_test(alpha);
rejections_single[i]=mmd->perform_test(alpha);
}

ASSERT_EQ(rejections_multiple.size(), rejections_single.size());
Expand Down

0 comments on commit 13ddee1

Please sign in to comment.