From 0b519ee536de8c2ba7a0a0d62fd6e3f75e075a70 Mon Sep 17 00:00:00 2001 From: Soumyajit De Date: Sun, 4 Feb 2018 16:42:39 +0530 Subject: [PATCH] bugfix for train-test ratio in statistical testing (#4134) --- .../internals/DataManager.cpp | 51 ++++++++++++++----- .../internals/mmd/ComputeMMD.h | 11 +++- .../KernelSelection_unittest.cc | 35 ++++++++++++- 3 files changed, 83 insertions(+), 14 deletions(-) diff --git a/src/shogun/statistical_testing/internals/DataManager.cpp b/src/shogun/statistical_testing/internals/DataManager.cpp index abb12f577e8..ebcb40c7294 100644 --- a/src/shogun/statistical_testing/internals/DataManager.cpp +++ b/src/shogun/statistical_testing/internals/DataManager.cpp @@ -222,24 +222,36 @@ const bool DataManager::is_blockwise() const void DataManager::set_train_test_mode(bool on) { - train_test_mode=on; - if (!train_test_mode) + if (!on) { train_mode=default_train_mode; train_test_ratio=default_train_test_ratio; cross_validation_mode=default_cross_validation_mode; + + set_train_mode(train_mode); + set_train_test_ratio(train_test_ratio); + + train_test_mode = on; + REQUIRE(fetchers.size() > 0, "Features are not set!\n"); + typedef std::unique_ptr fetcher_type; + std::for_each( + fetchers.begin(), fetchers.end(), [this](fetcher_type& f) { + f->set_train_test_mode(train_test_mode); + }); } - REQUIRE(fetchers.size()>0, "Features are not set!"); - typedef std::unique_ptr fetcher_type; - std::for_each(fetchers.begin(), fetchers.end(), [this, on](fetcher_type& f) + else { - f->set_train_test_mode(on); - if (on) - { - f->set_train_mode(train_mode); - f->set_train_test_ratio(train_test_ratio); - } - }); + train_test_mode = on; + REQUIRE(fetchers.size() > 0, "Features are not set!\n"); + typedef std::unique_ptr fetcher_type; + std::for_each( + fetchers.begin(), fetchers.end(), [this](fetcher_type& f) { + f->set_train_test_mode(train_test_mode); + }); + + set_train_mode(train_mode); + set_train_test_ratio(train_test_ratio); + } } bool DataManager::is_train_test_mode() const @@ -250,7 +262,14 @@ bool DataManager::is_train_test_mode() const void DataManager::set_train_mode(bool on) { if (train_test_mode) + { train_mode=on; + REQUIRE(fetchers.size() > 0, "Features are not set!\n"); + typedef std::unique_ptr fetcher_type; + std::for_each( + fetchers.begin(), fetchers.end(), + [this](fetcher_type& f) { f->set_train_mode(train_mode); }); + } else { SG_SERROR("Train mode cannot be used without turning on Train/Test mode first!" @@ -282,7 +301,15 @@ bool DataManager::is_cross_validation_mode() const void DataManager::set_train_test_ratio(float64_t ratio) { if (train_test_mode) + { train_test_ratio=ratio; + REQUIRE(fetchers.size() > 0, "Features are not set!\n"); + typedef std::unique_ptr fetcher_type; + std::for_each( + fetchers.begin(), fetchers.end(), [this](fetcher_type& f) { + f->set_train_test_ratio(train_test_ratio); + }); + } else { SG_SERROR("Train-test ratio cannot be set without turning on Train/Test mode first!" diff --git a/src/shogun/statistical_testing/internals/mmd/ComputeMMD.h b/src/shogun/statistical_testing/internals/mmd/ComputeMMD.h index bb20e88e436..fd31dea512f 100644 --- a/src/shogun/statistical_testing/internals/mmd/ComputeMMD.h +++ b/src/shogun/statistical_testing/internals/mmd/ComputeMMD.h @@ -87,7 +87,16 @@ struct ComputeMMD { ASSERT(m_n_x>0 && m_n_y>0); const index_t size=m_n_x+m_n_y; - ASSERT(kernel_matrix.num_rows==size && kernel_matrix.num_cols==size); + REQUIRE( + kernel_matrix.num_rows == size, + "Number of rows from kernel matrix (%d) did not match the total " + "number of samples from both distribution (%d)\n", + kernel_matrix.num_rows, size); + REQUIRE( + kernel_matrix.num_cols == size, + "Number of cols from kernel matrix (%d) did not match the total " + "number of samples from both distribution (%d)\n", + kernel_matrix.num_cols, size); typedef Eigen::Matrix MatrixXt; typedef Eigen::Block > BlockXt; diff --git a/tests/unit/statistical_testing/KernelSelection_unittest.cc b/tests/unit/statistical_testing/KernelSelection_unittest.cc index 21a97f07994..3439a75b982 100644 --- a/tests/unit/statistical_testing/KernelSelection_unittest.cc +++ b/tests/unit/statistical_testing/KernelSelection_unittest.cc @@ -108,6 +108,39 @@ TEST(KernelSelectionMaxMMD, quadratic_time_single_kernel_dense) EXPECT_NEAR(selected_kernel->get_width(), 0.25, 1E-10); } +TEST( + KernelSelectionMaxMMD, + quadratic_time_single_kernel_dense_unequal_train_test_ratio) +{ + const index_t m = 10; + const index_t n = 20; + const index_t dim = 1; + const float64_t difference = 0.5; + const index_t num_kernels = 10; + const float64_t train_test_ratio = 4; + + sg_rand->set_seed(12345); + + auto gen_p = some(0, dim, 0); + auto gen_q = some(difference, dim, 0); + + auto feats_p = gen_p->get_streamed_features(m); + auto feats_q = gen_q->get_streamed_features(n); + + auto mmd = some(feats_p, feats_q); + mmd->set_statistic_type(ST_BIASED_FULL); + for (auto i = 0, sigma = -5; i < num_kernels; ++i, sigma += 1) + { + float64_t tau = pow(2, sigma); + mmd->add_kernel(new CGaussianKernel(10, tau)); + } + mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_MMD); + + mmd->set_train_test_mode(true); + mmd->set_train_test_ratio(train_test_ratio); + EXPECT_NO_THROW(mmd->select_kernel()); +} + #ifdef USE_GPL_SHOGUN TEST(KernelSelectionMaxMMD, linear_time_weighted_kernel_streaming) { @@ -289,7 +322,7 @@ TEST(KernelSelectionMaxCrossValidation, quadratic_time_single_kernel_dense) mmd->set_train_test_mode(false); auto selected_kernel=static_cast(mmd->get_kernel()); - EXPECT_NEAR(selected_kernel->get_width(), 0.25, 1E-10); + EXPECT_NEAR(selected_kernel->get_width(), 0.03125, 1E-10); } TEST(KernelSelectionMaxCrossValidation, linear_time_single_kernel_dense)