Skip to content

Commit

Permalink
fixed multi-kernel perform test unit-test
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday committed Jul 13, 2016
1 parent a9f0693 commit 2e36040
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
9 changes: 7 additions & 2 deletions src/shogun/statistical_testing/internals/KernelManager.cpp
Expand Up @@ -137,16 +137,21 @@ void KernelManager::restore_kernel_at(size_t i)

bool KernelManager::same_distance_type() const
{
ASSERT(num_kernels()>0);
bool same=false;
EDistanceType distance_type=D_UNKNOWN;
for (size_t i=0; i<num_kernels(); ++i)
{
CShiftInvariantKernel* shift_invariant_kernel=dynamic_cast<CShiftInvariantKernel*>(kernel_at(i));
if (shift_invariant_kernel!=nullptr)
{
auto current_distance_type=shift_invariant_kernel->get_distance_type();
if (distance_type==D_UNKNOWN)
distance_type=shift_invariant_kernel->get_distance_type();
else if (distance_type==shift_invariant_kernel->get_distance_type())
{
distance_type=current_distance_type;
same=true;
}
else if (distance_type==current_distance_type)
same=true;
else
{
Expand Down
14 changes: 9 additions & 5 deletions tests/unit/statistical_testing/QuadraticTimeMMD_unittest.cc
Expand Up @@ -535,11 +535,13 @@ TEST(QuadraticTimeMMD, compute_multiple)

TEST(QuadraticTimeMMD, perform_test_multiple)
{
const index_t m=20;
const index_t n=20;
const index_t m=8;
const index_t n=12;
const index_t dim=1;
const index_t num_kernels=10;
const float64_t alpha=0.05;
const index_t num_null_samples=200;
const index_t cache_size=10;

float64_t difference=0.5;

Expand All @@ -550,21 +552,23 @@ TEST(QuadraticTimeMMD, perform_test_multiple)
CFeatures* feat_q=gen_q->get_streamed_features(n);

auto mmd=some<CQuadraticTimeMMD>(feat_p, feat_q);
mmd->set_num_null_samples(num_null_samples);
for (auto i=0, sigma=-5; i<num_kernels; ++i, sigma+=1)
{
float64_t tau=pow(2, sigma);
mmd->multikernel()->add_kernel(new CGaussianKernel(10, tau));
mmd->multikernel()->add_kernel(new CGaussianKernel(cache_size, tau));
}
sg_rand->set_seed(12345);
SGVector<bool> rejections_multiple=mmd->multikernel()->perform_test(alpha);

SGVector<bool> rejections_single(num_kernels);
for (auto i=0, sigma=-5; i<num_kernels; ++i, sigma+=1)
{
sg_rand->set_seed(12345);
auto mmd2=some<CQuadraticTimeMMD>(feat_p, feat_q);
mmd2->set_num_null_samples(num_null_samples);
float64_t tau=pow(2, sigma);
mmd2->set_kernel(new CGaussianKernel(10, tau));
mmd2->set_kernel(new CGaussianKernel(cache_size, tau));
sg_rand->set_seed(12345);
rejections_single[i]=mmd2->perform_test(alpha);
}

Expand Down

0 comments on commit 2e36040

Please sign in to comment.