diff --git a/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.cpp b/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.cpp index 8a787eb37cf..7ad4efa9e11 100644 --- a/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.cpp +++ b/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.cpp @@ -232,16 +232,20 @@ void MultiKernelPermutationTestCrossValidation::operator()(const KernelManager& // transpose the null_samples matrix for faster access MatrixXd transposed_null_samples=null_samples_map.transpose(); +// cout << transposed_null_samples << endl; #pragma omp parallel for for (size_t k=0; k null_samples_k(transposed_null_samples.col(k).data(), num_null_samples, false); std::sort(null_samples_k.data(), null_samples_k.data()+null_samples_k.size()); +// null_samples_k.display_vector("null_samples_k"); SG_SDEBUG("statistic=%f\n", statistic[k]); float64_t idx=null_samples_k.find_position_to_insert(statistic[k]); + SG_SDEBUG("index=%f\n", idx); auto p_value=1.0-idx/num_null_samples; - SG_SDEBUG("p-value=%f, rejected=%d\n", p_value, p_value measures); void set_distance(CCustomDistance* distance); private: diff --git a/src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.cpp b/src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.cpp index e3ef7df58b9..e52f3806926 100644 --- a/src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.cpp +++ b/src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.cpp @@ -258,7 +258,7 @@ void PermutationTestCrossValidation::set_num_folds(index_t nf) num_folds=nf; } -void PermutationTestCrossValidation::set_alpha(index_t alp) +void PermutationTestCrossValidation::set_alpha(float64_t alp) { alpha=alp; } diff --git a/src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.h b/src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.h index d211b718f49..27c46f669fa 100644 --- a/src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.h +++ b/src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.h @@ -56,7 +56,7 @@ class PermutationTestCrossValidation template void operator()(const SGMatrix& km, index_t k); void set_num_runs(index_t nr); void set_num_folds(index_t nf); - void set_alpha(index_t alp); + void set_alpha(float64_t alp); void set_measure_matrix(SGMatrix measures); private: struct terms_t; diff --git a/tests/unit/statistical_testing/KernelSelection_unittest.cc b/tests/unit/statistical_testing/KernelSelection_unittest.cc index 7f8121390d9..c5399e6166f 100644 --- a/tests/unit/statistical_testing/KernelSelection_unittest.cc +++ b/tests/unit/statistical_testing/KernelSelection_unittest.cc @@ -238,7 +238,7 @@ TEST(KernelSelectionMaxCrossValidation, quadratic_time_single_kernel_dense) auto mmd=some(feats_p, feats_q); mmd->set_statistic_type(ST_BIASED_FULL); mmd->set_null_approximation_method(NAM_PERMUTATION); - mmd->set_num_null_samples(1); + mmd->set_num_null_samples(10); // mmd->io->set_loglevel(MSG_DEBUG); for (auto i=0, sigma=-5; i