From 2e36040b4502e6979efef6c137e2fe3599d24410 Mon Sep 17 00:00:00 2001 From: lambday Date: Thu, 7 Jul 2016 18:30:26 +0100 Subject: [PATCH] fixed multi-kernel perform test unit-test --- .../internals/KernelManager.cpp | 9 +++++++-- .../QuadraticTimeMMD_unittest.cc | 14 +++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/shogun/statistical_testing/internals/KernelManager.cpp b/src/shogun/statistical_testing/internals/KernelManager.cpp index 3106272ebfa..5ad91103181 100644 --- a/src/shogun/statistical_testing/internals/KernelManager.cpp +++ b/src/shogun/statistical_testing/internals/KernelManager.cpp @@ -137,6 +137,7 @@ void KernelManager::restore_kernel_at(size_t i) bool KernelManager::same_distance_type() const { + ASSERT(num_kernels()>0); bool same=false; EDistanceType distance_type=D_UNKNOWN; for (size_t i=0; i(kernel_at(i)); if (shift_invariant_kernel!=nullptr) { + auto current_distance_type=shift_invariant_kernel->get_distance_type(); if (distance_type==D_UNKNOWN) - distance_type=shift_invariant_kernel->get_distance_type(); - else if (distance_type==shift_invariant_kernel->get_distance_type()) + { + distance_type=current_distance_type; + same=true; + } + else if (distance_type==current_distance_type) same=true; else { diff --git a/tests/unit/statistical_testing/QuadraticTimeMMD_unittest.cc b/tests/unit/statistical_testing/QuadraticTimeMMD_unittest.cc index 36b6092cfaa..4259d45b23e 100644 --- a/tests/unit/statistical_testing/QuadraticTimeMMD_unittest.cc +++ b/tests/unit/statistical_testing/QuadraticTimeMMD_unittest.cc @@ -535,11 +535,13 @@ TEST(QuadraticTimeMMD, compute_multiple) TEST(QuadraticTimeMMD, perform_test_multiple) { - const index_t m=20; - const index_t n=20; + const index_t m=8; + const index_t n=12; const index_t dim=1; const index_t num_kernels=10; const float64_t alpha=0.05; + const index_t num_null_samples=200; + const index_t cache_size=10; float64_t difference=0.5; @@ -550,10 +552,11 @@ TEST(QuadraticTimeMMD, perform_test_multiple) CFeatures* feat_q=gen_q->get_streamed_features(n); auto mmd=some(feat_p, feat_q); + mmd->set_num_null_samples(num_null_samples); for (auto i=0, sigma=-5; imultikernel()->add_kernel(new CGaussianKernel(10, tau)); + mmd->multikernel()->add_kernel(new CGaussianKernel(cache_size, tau)); } sg_rand->set_seed(12345); SGVector rejections_multiple=mmd->multikernel()->perform_test(alpha); @@ -561,10 +564,11 @@ TEST(QuadraticTimeMMD, perform_test_multiple) SGVector rejections_single(num_kernels); for (auto i=0, sigma=-5; iset_seed(12345); auto mmd2=some(feat_p, feat_q); + mmd2->set_num_null_samples(num_null_samples); float64_t tau=pow(2, sigma); - mmd2->set_kernel(new CGaussianKernel(10, tau)); + mmd2->set_kernel(new CGaussianKernel(cache_size, tau)); + sg_rand->set_seed(12345); rejections_single[i]=mmd2->perform_test(alpha); }