From f28eacdc71b4e84145356c63b6eadac75531d524 Mon Sep 17 00:00:00 2001 From: lambday Date: Wed, 29 Jun 2016 18:53:50 +0100 Subject: [PATCH] added permutation test for cross-validation for precomputed kernels --- src/shogun/statistical_testing/MMD.cpp | 4 +- src/shogun/statistical_testing/MMD.h | 8 +- .../mmd/PermutationTestCrossValidation.cpp | 278 ++++++++++++++++++ .../mmd/PermutationTestCrossValidation.h | 81 +++++ .../KernelSelectionStrategy.cpp | 38 +-- .../kernelselection/KernelSelectionStrategy.h | 3 +- .../internals/MaxCrossValidation.cpp | 102 +++++-- .../internals/MaxCrossValidation.h | 5 +- .../KernelSelection_unittest.cc | 30 +- 9 files changed, 476 insertions(+), 73 deletions(-) create mode 100644 src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.cpp create mode 100644 src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.h diff --git a/src/shogun/statistical_testing/MMD.cpp b/src/shogun/statistical_testing/MMD.cpp index 6c12a9662be..3727cf8176c 100644 --- a/src/shogun/statistical_testing/MMD.cpp +++ b/src/shogun/statistical_testing/MMD.cpp @@ -418,9 +418,9 @@ void CMMD::set_kernel_selection_strategy(EKernelSelectionMethod method, bool wei self->strategy=strategy; } -void CMMD::set_kernel_selection_strategy(EKernelSelectionMethod method, index_t num_runs, float64_t alpha) +void CMMD::set_kernel_selection_strategy(EKernelSelectionMethod method, index_t num_runs, index_t num_folds, float64_t alpha) { - auto strategy=std::shared_ptr(new CKernelSelectionStrategy(method, num_runs, alpha)); + auto strategy=std::shared_ptr(new CKernelSelectionStrategy(method, num_runs, num_folds, alpha)); const auto& kernel_mgr=self->strategy->get_kernel_mgr(); for (size_t i=0; iadd_kernel(kernel_mgr.kernel_at(i)); diff --git a/src/shogun/statistical_testing/MMD.h b/src/shogun/statistical_testing/MMD.h index 3d8ed0b9167..02b2bf081f6 100644 --- a/src/shogun/statistical_testing/MMD.h +++ b/src/shogun/statistical_testing/MMD.h @@ -55,20 +55,20 @@ class WeightedMaxTestPower; } -enum EStatisticType +enum EStatisticType : uint32_t { ST_UNBIASED_FULL, ST_UNBIASED_INCOMPLETE, ST_BIASED_FULL }; -enum EVarianceEstimationMethod +enum EVarianceEstimationMethod : uint32_t { VEM_DIRECT, VEM_PERMUTATION }; -enum ENullApproximationMethod +enum ENullApproximationMethod : uint32_t { NAM_PERMUTATION, NAM_MMD1_GAUSSIAN, @@ -89,7 +89,7 @@ class CMMD : public CTwoSampleTest void set_kernel_selection_strategy(EKernelSelectionMethod method); void set_kernel_selection_strategy(EKernelSelectionMethod method, bool weighted); - void set_kernel_selection_strategy(EKernelSelectionMethod method, index_t num_runs, float64_t alpha); + void set_kernel_selection_strategy(EKernelSelectionMethod method, index_t num_runs, index_t num_folds, float64_t alpha); CKernelSelectionStrategy* get_kernel_selection_strategy() const; void add_kernel(CKernel *kernel); diff --git a/src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.cpp b/src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.cpp new file mode 100644 index 00000000000..59c95fc29a7 --- /dev/null +++ b/src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.cpp @@ -0,0 +1,278 @@ +/* + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (w) 2016 Soumyajit De + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// TODO remove +#include +#include + +using Eigen::Matrix; +using Eigen::Dynamic; +using Eigen::Map; +using std::cout; +using std::endl; +// TODO remove + +namespace shogun +{ + +namespace internal +{ + +namespace mmd +{ + +struct PermutationTestCrossValidation::terms_t +{ + std::array term{}; + std::array diag{}; +}; + +PermutationTestCrossValidation::PermutationTestCrossValidation(index_t nx, index_t ny, index_t nns, EStatisticType type) +: n_x(nx), n_y(ny), num_null_samples(nns), stype(type) +{ + SG_SDEBUG("number of samples are %d and %d!\n", n_x, n_y); + SG_SDEBUG("Number of null samples is %d!\n", num_null_samples); + +} + +PermutationTestCrossValidation::~PermutationTestCrossValidation() +{ +} + +template +void PermutationTestCrossValidation::add_term(terms_t& terms, T val, index_t i, index_t j) +{ + if (i=n_x && j>=n_x && i<=j) + { + SG_SDEBUG("Adding Kernel(%d,%d)=%f to term_1!\n", i, j, val); + terms.term[1]+=val; + if (i==j) + terms.diag[1]+=val; + } + else if (i>=n_x && j +void PermutationTestCrossValidation::operator()(const SGMatrix& km, index_t k) +{ + SG_SDEBUG("Entering!\n"); + REQUIRE(rejections.num_rows==num_runs*num_folds, + "Number of rows in the measure matrix (was %d), has to be >= %d*%d = %d!\n", + rejections.num_rows, num_runs, num_folds, num_runs*num_folds); + REQUIRE(k>=0 && k MatrixXt; + Map map(km.data(), km.num_rows, km.num_cols); + cout << map << endl; + + SGVector dummy_labels_x(n_x); + SGVector dummy_labels_y(n_y); + auto kfold_x=some(new CBinaryLabels(dummy_labels_x), num_folds); + auto kfold_y=some(new CBinaryLabels(dummy_labels_y), num_folds); + + for (auto i=0; ibuild_subsets(); + kfold_y->build_subsets(); + for (auto j=0; j x_inds=kfold_x->generate_subset_inverse(j); + SGVector y_inds=kfold_y->generate_subset_inverse(j); +// x_inds.display_vector("x_inds"); +// y_inds.display_vector("y_inds"); + std::for_each(y_inds.data(), y_inds.data()+y_inds.size(), [this](index_t& val) { val += n_x; }); + SGVector xy_inds(x_inds.size()+y_inds.size()); + std::copy(x_inds.data(), x_inds.data()+x_inds.size(), xy_inds.data()); + std::copy(y_inds.data(), y_inds.data()+y_inds.size(), xy_inds.data()+x_inds.size()); +// xy_inds.display_vector("xy_inds"); + + // compute statistic + SGVector inverted_inds(n_x+n_y); + std::fill(inverted_inds.data(), inverted_inds.data()+n_x+n_y, -1); + for (int idx=0; idx null_samples(num_null_samples); +#pragma omp parallel for + for (auto n=0; n(); + stack->add_subset(xy_inds); + + SGVector permutation_inds(xy_inds.size()); + std::iota(permutation_inds.data(), permutation_inds.data()+permutation_inds.size(), 0); + CMath::permute(permutation_inds); + stack->add_subset(permutation_inds); + + SGVector inds=stack->get_last_subset()->get_subset_idx(); +// inds.display_vector("inds (after permutation)"); + + SGVector inverted_permutation_inds(n_x+n_y); + std::fill(inverted_permutation_inds.data(), inverted_permutation_inds.data()+n_x+n_y, -1); + for (int idx=0; idx measures) +{ + rejections=measures; +} + +template void PermutationTestCrossValidation::operator()(const SGMatrix& km, index_t k); +template void PermutationTestCrossValidation::add_term(terms_t& terms, float64_t val, index_t i, index_t j); + +} + +} + +} diff --git a/src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.h b/src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.h new file mode 100644 index 00000000000..d211b718f49 --- /dev/null +++ b/src/shogun/statistical_testing/internals/mmd/PermutationTestCrossValidation.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (w) 2016 Soumyajit De + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. + */ + +#ifndef PERMUTATION_TEST_CROSS_VALIDATION +#define PERMUTATION_TEST_CROSS_VALIDATION + +#include +#include +#include +#include + +namespace shogun +{ + +namespace internal +{ + +namespace mmd +{ + +/** + * @brief class that runs cross-validation test for MMD for a single kernel. + */ +class PermutationTestCrossValidation +{ +public: + PermutationTestCrossValidation(index_t nx, index_t ny, index_t null_samples, EStatisticType type); + ~PermutationTestCrossValidation(); + template void operator()(const SGMatrix& km, index_t k); + void set_num_runs(index_t nr); + void set_num_folds(index_t nf); + void set_alpha(index_t alp); + void set_measure_matrix(SGMatrix measures); +private: + struct terms_t; + template void add_term(terms_t&, T kernel, index_t i, index_t j); + float64_t compute_mmd(terms_t&); + const index_t n_x; + const index_t n_y; + const index_t num_null_samples; + const EStatisticType stype; + index_t num_runs; + index_t num_folds; + float64_t alpha; + SGMatrix rejections; +}; + +} + +} + +} + +#endif // PERMUTATION_TEST_CROSS_VALIDATION diff --git a/src/shogun/statistical_testing/kernelselection/KernelSelectionStrategy.cpp b/src/shogun/statistical_testing/kernelselection/KernelSelectionStrategy.cpp index 9c9d01f9457..fcb31f89e3f 100644 --- a/src/shogun/statistical_testing/kernelselection/KernelSelectionStrategy.cpp +++ b/src/shogun/statistical_testing/kernelselection/KernelSelectionStrategy.cpp @@ -57,6 +57,7 @@ struct CKernelSelectionStrategy::Self EKernelSelectionMethod method; bool weighted; index_t num_runs; + index_t num_folds; float64_t alpha; void init_policy(CMMD* estimator); @@ -64,16 +65,18 @@ struct CKernelSelectionStrategy::Self const static EKernelSelectionMethod default_method; const static bool default_weighted; const static index_t default_num_runs; + const static index_t default_num_folds; const static float64_t default_alpha; }; const EKernelSelectionMethod CKernelSelectionStrategy::Self::default_method=KSM_AUTO; const bool CKernelSelectionStrategy::Self::default_weighted=false; const index_t CKernelSelectionStrategy::Self::default_num_runs=10; +const index_t CKernelSelectionStrategy::Self::default_num_folds=3; const float64_t CKernelSelectionStrategy::Self::default_alpha=0.05; CKernelSelectionStrategy::Self::Self() : policy(nullptr), method(default_method), - weighted(default_weighted), num_runs(default_num_runs), alpha(default_alpha) + weighted(default_weighted), num_runs(default_num_runs), num_folds(default_num_folds), alpha(default_alpha) { } @@ -91,7 +94,7 @@ void CKernelSelectionStrategy::Self::init_policy(CMMD* estimator) { REQUIRE(!weighted, "Weighted kernel selection is not possible with MAXIMIZE_CROSS_VALIDATION!\n"); policy=std::unique_ptr(new MaxCrossValidation(kernel_mgr, estimator, - num_runs, alpha)); + num_runs, num_folds, alpha)); } break; case KSM_MAXIMIZE_MMD: @@ -140,35 +143,16 @@ CKernelSelectionStrategy::CKernelSelectionStrategy(EKernelSelectionMethod method self->weighted=weighted; } -CKernelSelectionStrategy::CKernelSelectionStrategy(EKernelSelectionMethod method, index_t num_runs, float64_t alpha) +CKernelSelectionStrategy::CKernelSelectionStrategy(EKernelSelectionMethod method, index_t num_runs, + index_t num_folds, float64_t alpha) { init(); self->method=method; self->num_runs=num_runs; + self->num_folds=num_folds; self->alpha=alpha; } -//CKernelSelectionStrategy::CKernelSelectionStrategy(const CKernelSelectionStrategy& other) -//{ -// init(); -// self->method=other.self->method; -// self->num_runs=other.self->num_runs; -// self->alpha=other.self->alpha; -// for (size_t i=0; ikernel_mgr.num_kernels(); ++i) -// self->kernel_mgr.push_back(other.self->kernel_mgr.kernel_at(i)); -//} -// -//CKernelSelectionStrategy& CKernelSelectionStrategy::operator=(const CKernelSelectionStrategy& other) -//{ -// init(); -// self->method=other.self->method; -// self->num_runs=other.self->num_runs; -// self->alpha=other.self->alpha; -// for (size_t i=0; ikernel_mgr.num_kernels(); ++i) -// self->kernel_mgr.push_back(other.self->kernel_mgr.kernel_at(i)); -// return *this; -//} - void CKernelSelectionStrategy::init() { self=std::unique_ptr(new Self()); @@ -191,6 +175,12 @@ CKernelSelectionStrategy& CKernelSelectionStrategy::use_num_runs(index_t num_run return *this; } +CKernelSelectionStrategy& CKernelSelectionStrategy::use_num_folds(index_t num_folds) +{ + self->num_folds=num_folds; + return *this; +} + CKernelSelectionStrategy& CKernelSelectionStrategy::use_alpha(float64_t alpha) { self->alpha=alpha; diff --git a/src/shogun/statistical_testing/kernelselection/KernelSelectionStrategy.h b/src/shogun/statistical_testing/kernelselection/KernelSelectionStrategy.h index ef470ac65ba..3de5b38ad96 100644 --- a/src/shogun/statistical_testing/kernelselection/KernelSelectionStrategy.h +++ b/src/shogun/statistical_testing/kernelselection/KernelSelectionStrategy.h @@ -66,13 +66,14 @@ class CKernelSelectionStrategy : public CSGObject CKernelSelectionStrategy(); explicit CKernelSelectionStrategy(EKernelSelectionMethod method); CKernelSelectionStrategy(EKernelSelectionMethod method, bool weighted); - CKernelSelectionStrategy(EKernelSelectionMethod method, index_t num_runs, float64_t alpha); + CKernelSelectionStrategy(EKernelSelectionMethod method, index_t num_runs, index_t num_folds, float64_t alpha); CKernelSelectionStrategy(const CKernelSelectionStrategy& other)=delete; CKernelSelectionStrategy& operator=(const CKernelSelectionStrategy& other)=delete; virtual ~CKernelSelectionStrategy(); CKernelSelectionStrategy& use_method(EKernelSelectionMethod method); CKernelSelectionStrategy& use_num_runs(index_t num_runs); + CKernelSelectionStrategy& use_num_folds(index_t num_folds); CKernelSelectionStrategy& use_alpha(float64_t alpha); CKernelSelectionStrategy& use_weighted(bool weighted); diff --git a/src/shogun/statistical_testing/kernelselection/internals/MaxCrossValidation.cpp b/src/shogun/statistical_testing/kernelselection/internals/MaxCrossValidation.cpp index 91a0018edc1..b299ef0c4e3 100644 --- a/src/shogun/statistical_testing/kernelselection/internals/MaxCrossValidation.cpp +++ b/src/shogun/statistical_testing/kernelselection/internals/MaxCrossValidation.cpp @@ -34,18 +34,24 @@ #include #include #include +#include #include #include +#include +#include +#include #include using namespace shogun; using namespace internal; +using namespace mmd; -MaxCrossValidation::MaxCrossValidation(KernelManager& km, CMMD* est, const index_t& M, const float64_t& alp) -: KernelSelection(km, est), num_run(M), alpha(alp) +MaxCrossValidation::MaxCrossValidation(KernelManager& km, CMMD* est, const index_t& M, const index_t& K, const float64_t& alp) +: KernelSelection(km, est), num_runs(M), num_folds(K), alpha(alp) { - REQUIRE(num_run>0, "Number of runs is %d!\n", num_run); - REQUIRE(alpha>=0.0 && alpha<=1.0, "Threshold is %f!\n", alpha); + REQUIRE(num_runs>0, "Number of runs (%d) must be positive!\n", num_runs); + REQUIRE(num_folds>0, "Number of folds (%d) must be positive!\n", num_folds); + REQUIRE(alpha>=0.0 && alpha<=1.0, "Threshold (%f) has to be in [0, 1]!\n", alpha); } MaxCrossValidation::~MaxCrossValidation() @@ -65,11 +71,8 @@ SGMatrix MaxCrossValidation::get_measure_matrix() void MaxCrossValidation::init_measures() { const index_t num_kernels=kernel_mgr.num_kernels(); - auto& data_mgr=estimator->get_data_mgr(); - const index_t N=data_mgr.get_num_folds(); - REQUIRE(N!=0, "Number of folds is not set!\n"); - if (rejections.num_rows!=N*num_run || rejections.num_cols!=num_kernels) - rejections=SGMatrix(N*num_run, num_kernels); + if (rejections.num_rows!=num_folds*num_runs || rejections.num_cols!=num_kernels) + rejections=SGMatrix(num_folds*num_runs, num_kernels); std::fill(rejections.data(), rejections.data()+rejections.size(), 0); if (measures.size()!=num_kernels) measures=SGVector(num_kernels); @@ -79,33 +82,76 @@ void MaxCrossValidation::init_measures() void MaxCrossValidation::compute_measures() { auto& data_mgr=estimator->get_data_mgr(); - data_mgr.set_cross_validation_mode(true); - - const index_t N=data_mgr.get_num_folds(); - SG_SINFO("Performing %d fold cross-validattion!\n", N); + SG_SINFO("Performing %d fold cross-validattion!\n", num_folds); const size_t num_kernels=kernel_mgr.num_kernels(); - auto existing_kernel=estimator->get_kernel(); - for (auto i=0; i(estimator); + if (quadratic_time_mmd) + { +// if (kernel_mgr.same_distance_type()) +// { +// // compute distance on estimator and set the distance to the kernels +// MultiKernelPermutationTestCrossValidation compute(num_runs, num_folds, rejections); +// compute(kernel_mgr); +// } +// else + { + data_mgr.start(); + auto samples=data_mgr.next(); + if (!samples.empty()) + { + CFeatures *samples_p=samples[0][0].get(); + CFeatures *samples_q=samples[1][0].get(); + auto samples_p_and_q=FeaturesUtil::create_merged_copy(samples_p, samples_q); + SG_REF(samples_p_and_q); + samples.clear(); + + auto Nx=estimator->get_num_samples_p(); + auto Ny=estimator->get_num_samples_q(); + auto num_null_samples=estimator->get_num_null_samples(); + auto stype=estimator->get_statistic_type(); + + PermutationTestCrossValidation compute(Nx, Ny, num_null_samples, stype); + compute.set_num_runs(num_runs); + compute.set_num_folds(num_folds); + compute.set_alpha(alpha); + compute.set_measure_matrix(rejections); + + for (size_t k=0; kinit(samples_p_and_q, samples_p_and_q); + compute(kernel->get_kernel_matrix(), k); + kernel->remove_lhs_and_rhs(); + } + SG_UNREF(samples_p_and_q); + } + else + SG_SERROR("Could not fetch samples!\n"); + data_mgr.end(); + } + } + else // TODO put check, this one assumes infinite data { - data_mgr.shuffle_features(); - for (auto j=0; jget_kernel(); + for (auto i=0; iset_kernel(kernel); - auto statistic=estimator->compute_statistic(); - rejections(i*N+j, k)=estimator->compute_p_value(statistic)cleanup(); + SG_SDEBUG("Running fold %d\n", j); + for (size_t k=0; kset_kernel(kernel); + auto statistic=estimator->compute_statistic(); + rejections(i*num_folds+j, k)=estimator->compute_p_value(statistic)cleanup(); + } } } - data_mgr.unshuffle_features(); + estimator->set_kernel(existing_kernel); } - data_mgr.set_cross_validation_mode(false); - estimator->set_kernel(existing_kernel); for (auto j=0; j rejections; SGVector measures; diff --git a/tests/unit/statistical_testing/KernelSelection_unittest.cc b/tests/unit/statistical_testing/KernelSelection_unittest.cc index 2360b7fdefe..dadea1bda5b 100644 --- a/tests/unit/statistical_testing/KernelSelection_unittest.cc +++ b/tests/unit/statistical_testing/KernelSelection_unittest.cc @@ -218,13 +218,14 @@ TEST(KernelSelectionMaxTestPower, linear_time_weighted_kernel_streaming) TEST(KernelSelectionMaxCrossValidation, quadratic_time_single_kernel_dense) { - const index_t m=5; - const index_t n=10; + const index_t m=8; + const index_t n=12; const index_t dim=1; - const float64_t difference=0.5; - const index_t num_kernels=10; + const float64_t difference=1.0; + const index_t num_kernels=2; const index_t num_runs=1; - const index_t num_folds=5; + const index_t num_folds=3; + const float64_t train_test_ratio=3; const float64_t alpha=0.05; sg_rand->set_seed(12345); @@ -236,16 +237,20 @@ TEST(KernelSelectionMaxCrossValidation, quadratic_time_single_kernel_dense) auto mmd=some(feats_p, feats_q); mmd->set_statistic_type(ST_BIASED_FULL); + mmd->set_null_approximation_method(NAM_PERMUTATION); + mmd->set_num_null_samples(1); + mmd->io->set_loglevel(MSG_DEBUG); for (auto i=0, sigma=-5; iadd_kernel(new CGaussianKernel(10, tau)); } - mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_CROSS_VALIDATION, num_runs, alpha); + mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_CROSS_VALIDATION, num_runs, num_folds, alpha); mmd->set_train_test_mode(true); - mmd->set_train_test_ratio(num_folds-1); + mmd->set_train_test_ratio(train_test_ratio); mmd->select_kernel(); + mmd->get_kernel_selection_strategy()->get_measure_matrix().display_matrix(); mmd->set_train_test_mode(false); auto selected_kernel=static_cast(mmd->get_kernel()); @@ -254,13 +259,14 @@ TEST(KernelSelectionMaxCrossValidation, quadratic_time_single_kernel_dense) TEST(KernelSelectionMaxCrossValidation, linear_time_single_kernel_dense) { - const index_t m=5; - const index_t n=10; + const index_t m=8; + const index_t n=12; const index_t dim=1; const float64_t difference=0.5; const index_t num_kernels=10; const index_t num_runs=1; - const index_t num_folds=5; + const index_t num_folds=3; + const float64_t train_test_ratio=3; const float64_t alpha=0.05; sg_rand->set_seed(12345); @@ -277,10 +283,10 @@ TEST(KernelSelectionMaxCrossValidation, linear_time_single_kernel_dense) float64_t tau=pow(2, sigma); mmd->add_kernel(new CGaussianKernel(10, tau)); } - mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_CROSS_VALIDATION, num_runs, alpha); + mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_CROSS_VALIDATION, num_runs, num_folds, alpha); mmd->set_train_test_mode(true); - mmd->set_train_test_ratio(num_folds-1); + mmd->set_train_test_ratio(train_test_ratio); mmd->select_kernel(); mmd->set_train_test_mode(false);