From 13ddee141f631de427efc7908a83323a79df9208 Mon Sep 17 00:00:00 2001 From: lambday Date: Thu, 14 Jul 2016 02:27:18 +0100 Subject: [PATCH] updated multikernel mmd2 to have the same api as single kernel --- .../MultiKernelQuadraticTimeMMD.cpp | 10 +++++----- .../MultiKernelQuadraticTimeMMD.h | 9 ++++----- .../QuadraticTimeMMD_unittest.cc | 19 +++++++++---------- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/src/shogun/statistical_testing/MultiKernelQuadraticTimeMMD.cpp b/src/shogun/statistical_testing/MultiKernelQuadraticTimeMMD.cpp index 46ebc3d9dd7..25f800d883a 100644 --- a/src/shogun/statistical_testing/MultiKernelQuadraticTimeMMD.cpp +++ b/src/shogun/statistical_testing/MultiKernelQuadraticTimeMMD.cpp @@ -110,20 +110,20 @@ void CMultiKernelQuadraticTimeMMD::cleanup() self->m_dtype=D_UNKNOWN; } -SGVector CMultiKernelQuadraticTimeMMD::statistic() +SGVector CMultiKernelQuadraticTimeMMD::compute_statistic() { ASSERT(self->m_owner); return statistic(self->m_kernel_mgr); } -SGVector CMultiKernelQuadraticTimeMMD::variance_h0() +SGVector CMultiKernelQuadraticTimeMMD::compute_variance_h0() { ASSERT(self->m_owner); SG_NOTIMPLEMENTED; return SGVector(); } -SGVector CMultiKernelQuadraticTimeMMD::variance_h1() +SGVector CMultiKernelQuadraticTimeMMD::compute_variance_h1() { ASSERT(self->m_owner); SG_NOTIMPLEMENTED; @@ -136,7 +136,7 @@ SGMatrix CMultiKernelQuadraticTimeMMD::sample_null() return sample_null(self->m_kernel_mgr); } -SGVector CMultiKernelQuadraticTimeMMD::p_values() +SGVector CMultiKernelQuadraticTimeMMD::compute_p_value() { ASSERT(self->m_owner); return p_values(self->m_kernel_mgr); @@ -144,7 +144,7 @@ SGVector CMultiKernelQuadraticTimeMMD::p_values() SGVector CMultiKernelQuadraticTimeMMD::perform_test(float64_t alpha) { - SGVector pvalues=p_values(); + SGVector pvalues=compute_p_value(); SGVector rejections(pvalues.size()); for (auto i=0; i statistic(); - SGVector variance_h0(); - SGVector variance_h1(); - + SGVector compute_statistic(); + SGVector compute_variance_h0(); + SGVector compute_variance_h1(); SGMatrix sample_null(); - SGVector p_values(); + SGVector compute_p_value(); SGVector perform_test(float64_t alpha); virtual const char* get_name() const; diff --git a/tests/unit/statistical_testing/QuadraticTimeMMD_unittest.cc b/tests/unit/statistical_testing/QuadraticTimeMMD_unittest.cc index b37c8523df8..376931547e0 100644 --- a/tests/unit/statistical_testing/QuadraticTimeMMD_unittest.cc +++ b/tests/unit/statistical_testing/QuadraticTimeMMD_unittest.cc @@ -496,7 +496,7 @@ TEST(QuadraticTimeMMD, precomputed_vs_nonprecomputed) EXPECT_NEAR(result_1[i], result_2[i], 1E-6); } -TEST(QuadraticTimeMMD, compute_multiple) +TEST(QuadraticTimeMMD, multikernel_compute_statistic) { const index_t m=20; const index_t n=20; @@ -518,15 +518,15 @@ TEST(QuadraticTimeMMD, compute_multiple) float64_t tau=pow(2, sigma); mmd->multikernel()->add_kernel(new CGaussianKernel(10, tau)); } - SGVector mmd_multiple=mmd->multikernel()->statistic(); + SGVector mmd_multiple=mmd->multikernel()->compute_statistic(); + mmd->multikernel()->cleanup(); SGVector mmd_single(num_kernels); for (auto i=0, sigma=-5; i(feat_p, feat_q); float64_t tau=pow(2, sigma); - mmd2->set_kernel(new CGaussianKernel(10, tau)); - mmd_single[i]=mmd2->compute_statistic(); + mmd->set_kernel(new CGaussianKernel(10, tau)); + mmd_single[i]=mmd->compute_statistic(); } ASSERT_EQ(mmd_multiple.size(), mmd_single.size()); @@ -534,7 +534,7 @@ TEST(QuadraticTimeMMD, compute_multiple) EXPECT_NEAR(mmd_multiple[i], mmd_single[i], 1E-4); } -TEST(QuadraticTimeMMD, perform_test_multiple) +TEST(QuadraticTimeMMD, multikernel_perform_test) { const index_t m=8; const index_t n=12; @@ -561,16 +561,15 @@ TEST(QuadraticTimeMMD, perform_test_multiple) } sg_rand->set_seed(12345); SGVector rejections_multiple=mmd->multikernel()->perform_test(alpha); + mmd->multikernel()->cleanup(); SGVector rejections_single(num_kernels); for (auto i=0, sigma=-5; i(feat_p, feat_q); - mmd2->set_num_null_samples(num_null_samples); float64_t tau=pow(2, sigma); - mmd2->set_kernel(new CGaussianKernel(cache_size, tau)); + mmd->set_kernel(new CGaussianKernel(cache_size, tau)); sg_rand->set_seed(12345); - rejections_single[i]=mmd2->perform_test(alpha); + rejections_single[i]=mmd->perform_test(alpha); } ASSERT_EQ(rejections_multiple.size(), rejections_single.size());