From 10ed7983c0cbfc887d8ca4969ce9bf3457bb5c19 Mon Sep 17 00:00:00 2001 From: lambday Date: Thu, 2 Jun 2016 02:14:48 +0100 Subject: [PATCH] refactored train/test subsetting and hypothesis test framework --- src/shogun/statistical_testing/BTestMMD.cpp | 20 +- .../statistical_testing/HypothesisTest.cpp | 2 +- .../statistical_testing/IndependenceTest.cpp | 35 ++- .../statistical_testing/IndependenceTest.h | 16 +- .../KernelSelectionStrategy.cpp | 2 +- .../KernelSelectionStrategy.h | 2 +- .../statistical_testing/LinearTimeMMD.cpp | 30 +- src/shogun/statistical_testing/MMD.cpp | 94 +++---- src/shogun/statistical_testing/MMD.h | 2 +- .../OneDistributionTest.cpp | 19 +- .../statistical_testing/OneDistributionTest.h | 8 +- .../statistical_testing/QuadraticTimeMMD.cpp | 71 +++-- .../TwoDistributionTest.cpp | 19 +- .../statistical_testing/TwoDistributionTest.h | 12 +- .../statistical_testing/TwoSampleTest.cpp | 38 ++- .../statistical_testing/TwoSampleTest.h | 18 +- .../internals/DataFetcher.cpp | 266 +++++++++++------- .../internals/DataFetcher.h | 46 ++- .../internals/DataManager.cpp | 174 ++++++++---- .../internals/DataManager.h | 93 ++---- .../internals/MaxXValidation.cpp | 22 +- .../internals/MedianHeuristic.cpp | 4 +- .../internals/StreamingDataFetcher.cpp | 51 ++-- .../internals/StreamingDataFetcher.h | 26 +- .../internals/TrainTestDetails.cpp | 78 ----- .../internals/TrainTestDetails.h | 73 ----- .../KernelSelectionMaxMMD_unittest.cc | 2 + .../KernelSelectionMaxPower_unittest.cc | 2 + .../KernelSelectionMaxXValidation_unittest.cc | 9 +- ...KernelSelectionMedianHeuristic_unittest.cc | 5 +- .../internals/DataManager_unittest.cc | 33 ++- 31 files changed, 655 insertions(+), 617 deletions(-) delete mode 100644 src/shogun/statistical_testing/internals/TrainTestDetails.cpp delete mode 100644 src/shogun/statistical_testing/internals/TrainTestDetails.h diff --git a/src/shogun/statistical_testing/BTestMMD.cpp b/src/shogun/statistical_testing/BTestMMD.cpp index e3c758f037e..cacc801c5db 100644 --- a/src/shogun/statistical_testing/BTestMMD.cpp +++ b/src/shogun/statistical_testing/BTestMMD.cpp @@ -37,12 +37,12 @@ CBTestMMD::~CBTestMMD() void CBTestMMD::set_blocksize(index_t blocksize) { - get_data_manager().set_blocksize(blocksize); + get_data_mgr().set_blocksize(blocksize); } void CBTestMMD::set_num_blocks_per_burst(index_t num_blocks_per_burst) { - get_data_manager().set_num_blocks_per_burst(num_blocks_per_burst); + get_data_mgr().set_num_blocks_per_burst(num_blocks_per_burst); } const std::function)> CBTestMMD::get_direct_estimation_method() const @@ -52,19 +52,19 @@ const std::function)> CBTestMMD::get_direct_estima const float64_t CBTestMMD::normalize_statistic(float64_t statistic) const { - const DataManager& dm=get_data_manager(); - const index_t Nx=dm.num_samples_at(0); - const index_t Ny=dm.num_samples_at(1); - const index_t Bx=dm.blocksize_at(0); - const index_t By=dm.blocksize_at(1); + const DataManager& data_mgr=get_data_mgr(); + const index_t Nx=data_mgr.num_samples_at(0); + const index_t Ny=data_mgr.num_samples_at(1); + const index_t Bx=data_mgr.blocksize_at(0); + const index_t By=data_mgr.blocksize_at(1); return Nx*Ny*statistic*CMath::sqrt((Bx+By)/float64_t(Nx+Ny))/(Nx+Ny); } const float64_t CBTestMMD::normalize_variance(float64_t variance) const { - const DataManager& dm=get_data_manager(); - const index_t Bx=dm.blocksize_at(0); - const index_t By=dm.blocksize_at(1); + const DataManager& data_mgr=get_data_mgr(); + const index_t Bx=data_mgr.blocksize_at(0); + const index_t By=data_mgr.blocksize_at(1); return variance*CMath::sq(Bx*By/float64_t(Bx+By)); } diff --git a/src/shogun/statistical_testing/HypothesisTest.cpp b/src/shogun/statistical_testing/HypothesisTest.cpp index f813f2d6f45..2804d61d552 100644 --- a/src/shogun/statistical_testing/HypothesisTest.cpp +++ b/src/shogun/statistical_testing/HypothesisTest.cpp @@ -64,7 +64,7 @@ void CHypothesisTest::set_train_test_mode(bool on) void CHypothesisTest::set_train_test_ratio(float64_t ratio) { - self->data_mgr.train_test_ratio(ratio); + self->data_mgr.set_train_test_ratio(ratio); } float64_t CHypothesisTest::compute_p_value(float64_t statistic) diff --git a/src/shogun/statistical_testing/IndependenceTest.cpp b/src/shogun/statistical_testing/IndependenceTest.cpp index a6fa3d2c5fa..77b8013a401 100644 --- a/src/shogun/statistical_testing/IndependenceTest.cpp +++ b/src/shogun/statistical_testing/IndependenceTest.cpp @@ -24,8 +24,19 @@ using namespace shogun; using namespace internal; -CIndependenceTest::CIndependenceTest() : CTwoDistributionTest(IndependenceTest::num_kernels) +struct CIndependenceTest::Self { + Self(index_t num_kernels); + KernelManager kernel_mgr; +}; + +CIndependenceTest::Self::Self(index_t num_kernels) : kernel_mgr(num_kernels) +{ +} + +CIndependenceTest::CIndependenceTest() : CTwoDistributionTest() +{ + self=std::unique_ptr(new Self(IndependenceTest::num_kernels)); } CIndependenceTest::~CIndependenceTest() @@ -34,29 +45,35 @@ CIndependenceTest::~CIndependenceTest() void CIndependenceTest::set_kernel_p(CKernel* kernel_p) { - auto& km = get_kernel_manager(); - km.kernel_at(0) = kernel_p; + self->kernel_mgr.kernel_at(0)=kernel_p; } CKernel* CIndependenceTest::get_kernel_p() const { - const auto& km = get_kernel_manager(); - return km.kernel_at(0); + return self->kernel_mgr.kernel_at(0); } void CIndependenceTest::set_kernel_q(CKernel* kernel_q) { - auto& km = get_kernel_manager(); - km.kernel_at(1) = kernel_q; + self->kernel_mgr.kernel_at(1)=kernel_q; } CKernel* CIndependenceTest::get_kernel_q() const { - const auto& km = get_kernel_manager(); - return km.kernel_at(1); + return self->kernel_mgr.kernel_at(1); } const char* CIndependenceTest::get_name() const { return "IndependenceTest"; } + +KernelManager& CIndependenceTest::get_kernel_mgr() +{ + return self->kernel_mgr; +} + +const KernelManager& CIndependenceTest::get_kernel_mgr() const +{ + return self->kernel_mgr; +} diff --git a/src/shogun/statistical_testing/IndependenceTest.h b/src/shogun/statistical_testing/IndependenceTest.h index 23a776dbd27..f1cdbd41134 100644 --- a/src/shogun/statistical_testing/IndependenceTest.h +++ b/src/shogun/statistical_testing/IndependenceTest.h @@ -19,6 +19,7 @@ #ifndef INDEPENDENCE_TEST_H_ #define INDEPENDENCE_TEST_H_ +#include #include namespace shogun @@ -26,6 +27,11 @@ namespace shogun class CKernel; +namespace internal +{ + class KernelManager; +} + class CIndependenceTest : public CTwoDistributionTest { public: @@ -38,10 +44,16 @@ class CIndependenceTest : public CTwoDistributionTest void set_kernel_q(CKernel* kernel_q); CKernel* get_kernel_q() const; - virtual float64_t compute_statistic() = 0; - virtual SGVector sample_null() = 0; + virtual float64_t compute_statistic()=0; + virtual SGVector sample_null()=0; virtual const char* get_name() const; +protected: + internal::KernelManager& get_kernel_mgr(); + const internal::KernelManager& get_kernel_mgr() const; +private: + struct Self; + std::unique_ptr self; }; } diff --git a/src/shogun/statistical_testing/KernelSelectionStrategy.cpp b/src/shogun/statistical_testing/KernelSelectionStrategy.cpp index 7ddbbb1993a..96c13d1df51 100644 --- a/src/shogun/statistical_testing/KernelSelectionStrategy.cpp +++ b/src/shogun/statistical_testing/KernelSelectionStrategy.cpp @@ -222,7 +222,7 @@ const char* CKernelSelectionStrategy::get_name() const return "KernelSelectionStrategy"; } -const KernelManager& CKernelSelectionStrategy::get_kernel_manager() const +const KernelManager& CKernelSelectionStrategy::get_kernel_mgr() const { return self->kernel_mgr; } diff --git a/src/shogun/statistical_testing/KernelSelectionStrategy.h b/src/shogun/statistical_testing/KernelSelectionStrategy.h index b2f032d2d9d..eac801046a0 100644 --- a/src/shogun/statistical_testing/KernelSelectionStrategy.h +++ b/src/shogun/statistical_testing/KernelSelectionStrategy.h @@ -85,7 +85,7 @@ class CKernelSelectionStrategy : public CSGObject struct Self; std::unique_ptr self; void init(); - const internal::KernelManager& get_kernel_manager() const; + const internal::KernelManager& get_kernel_mgr() const; }; } diff --git a/src/shogun/statistical_testing/LinearTimeMMD.cpp b/src/shogun/statistical_testing/LinearTimeMMD.cpp index 3c6a6bbe651..927dcf5961e 100644 --- a/src/shogun/statistical_testing/LinearTimeMMD.cpp +++ b/src/shogun/statistical_testing/LinearTimeMMD.cpp @@ -44,12 +44,12 @@ CLinearTimeMMD::~CLinearTimeMMD() void CLinearTimeMMD::set_num_blocks_per_burst(index_t num_blocks_per_burst) { - auto& dm=get_data_manager(); - auto min_blocksize=dm.get_min_blocksize(); + auto& data_mgr=get_data_mgr(); + auto min_blocksize=data_mgr.get_min_blocksize(); if (min_blocksize==2) { // only possible when number of samples from both the distributions are the same - auto N=dm.num_samples_at(0); + auto N=data_mgr.num_samples_at(0); for (auto i=2; i)> CLinearTimeMMD::get_direct_estimation_method() const @@ -71,17 +71,17 @@ const std::function)> CLinearTimeMMD::get_direct_e const float64_t CLinearTimeMMD::normalize_statistic(float64_t statistic) const { - const DataManager& dm = get_data_manager(); - const index_t Nx = dm.num_samples_at(0); - const index_t Ny = dm.num_samples_at(1); + const DataManager& data_mgr = get_data_mgr(); + const index_t Nx = data_mgr.num_samples_at(0); + const index_t Ny = data_mgr.num_samples_at(1); return CMath::sqrt(Nx * Ny / float64_t(Nx + Ny)) * statistic; } const float64_t CLinearTimeMMD::normalize_variance(float64_t variance) const { - const DataManager& dm = get_data_manager(); - const index_t Bx = dm.blocksize_at(0); - const index_t By = dm.blocksize_at(1); + const DataManager& data_mgr = get_data_mgr(); + const index_t Bx = data_mgr.blocksize_at(0); + const index_t By = data_mgr.blocksize_at(1); const index_t B = Bx + By; if (get_statistic_type() == ST_UNBIASED_INCOMPLETE) { @@ -92,9 +92,9 @@ const float64_t CLinearTimeMMD::normalize_variance(float64_t variance) const const float64_t CLinearTimeMMD::gaussian_variance(float64_t variance) const { - const DataManager& dm = get_data_manager(); - const index_t Bx = dm.blocksize_at(0); - const index_t By = dm.blocksize_at(1); + const DataManager& data_mgr = get_data_mgr(); + const index_t Bx = data_mgr.blocksize_at(0); + const index_t By = data_mgr.blocksize_at(1); const index_t B = Bx + By; if (get_statistic_type() == ST_UNBIASED_INCOMPLETE) { diff --git a/src/shogun/statistical_testing/MMD.cpp b/src/shogun/statistical_testing/MMD.cpp index c250ce39372..b656473fa98 100644 --- a/src/shogun/statistical_testing/MMD.cpp +++ b/src/shogun/statistical_testing/MMD.cpp @@ -115,9 +115,9 @@ void CMMD::Self::create_computation_jobs() void CMMD::Self::create_statistic_job() { - const DataManager& dm=owner.get_data_manager(); - auto Bx=dm.blocksize_at(0); - auto By=dm.blocksize_at(1); + const DataManager& data_mgr=owner.get_data_mgr(); + auto Bx=data_mgr.blocksize_at(0); + auto By=data_mgr.blocksize_at(1); switch (statistic_type) { case ST_UNBIASED_FULL: @@ -193,8 +193,8 @@ void CMMD::Self::compute_jobs(ComputationManager& cm) const std::pair CMMD::Self::compute_statistic_variance() { - const KernelManager& km=owner.get_kernel_manager(); - auto kernel=km.kernel_at(0); + const KernelManager& kernel_mgr=owner.get_kernel_mgr(); + auto kernel=kernel_mgr.kernel_at(0); REQUIRE(kernel != nullptr, "Kernel is not set!\n"); float64_t statistic=0; @@ -203,7 +203,7 @@ std::pair CMMD::Self::compute_statistic_variance() index_t statistic_term_counter=1; index_t variance_term_counter=1; - DataManager& dm=owner.get_data_manager(); + DataManager& data_mgr=owner.get_data_mgr(); ComputationManager cm; create_computation_jobs(); @@ -212,8 +212,8 @@ std::pair CMMD::Self::compute_statistic_variance() std::vector blocks; - dm.start(); - auto next_burst=dm.next(); + data_mgr.start(); + auto next_burst=data_mgr.next(); while (!next_burst.empty()) { merge_samples(next_burst, blocks); @@ -250,10 +250,10 @@ std::pair CMMD::Self::compute_statistic_variance() variance_term_counter++; } } - next_burst=dm.next(); + next_burst=data_mgr.next(); } - dm.end(); + data_mgr.end(); cm.done(); // normalize statistic and variance @@ -284,13 +284,13 @@ std::pair, SGMatrix > CMMD::Self::compute_statist SGMatrix term_counters_Q(num_kernels, num_kernels); std::fill(term_counters_Q.data(), term_counters_Q.data()+term_counters_Q.size(), 1); - DataManager& dm=owner.get_data_manager(); + DataManager& data_mgr=owner.get_data_mgr(); ComputationManager cm; create_computation_jobs(); cm.enqueue_job(statistic_job); - dm.start(); - auto next_burst=dm.next(); + data_mgr.start(); + auto next_burst=data_mgr.next(); std::vector blocks; std::vector > mmds(num_kernels); while (!next_burst.empty()) @@ -327,11 +327,11 @@ std::pair, SGMatrix > CMMD::Self::compute_statist Q(j, i)=Q(i, j); } } - next_burst=dm.next(); + next_burst=data_mgr.next(); } mmds.clear(); - dm.end(); + data_mgr.end(); cm.done(); std::for_each(statistic.data(), statistic.data()+statistic.size(), [this](float64_t val) @@ -343,8 +343,8 @@ std::pair, SGMatrix > CMMD::Self::compute_statist SGVector CMMD::Self::sample_null() { - const KernelManager& km=owner.get_kernel_manager(); - auto kernel=km.kernel_at(0); + const KernelManager& kernel_mgr=owner.get_kernel_mgr(); + auto kernel=kernel_mgr.kernel_at(0); REQUIRE(kernel != nullptr, "Kernel is not set!\n"); SGVector statistic(num_null_samples); @@ -353,7 +353,7 @@ SGVector CMMD::Self::sample_null() std::fill(statistic.vector, statistic.vector+statistic.vlen, 0); std::fill(term_counters.data(), term_counters.data()+term_counters.size(), 1); - DataManager& dm=owner.get_data_manager(); + DataManager& data_mgr=owner.get_data_mgr(); ComputationManager cm; create_statistic_job(); @@ -361,8 +361,8 @@ SGVector CMMD::Self::sample_null() std::vector blocks; - dm.start(); - auto next_burst=dm.next(); + data_mgr.start(); + auto next_burst=data_mgr.next(); while (!next_burst.empty()) { @@ -381,10 +381,10 @@ SGVector CMMD::Self::sample_null() term_counters[j]++; } } - next_burst=dm.next(); + next_burst=data_mgr.next(); } - dm.end(); + data_mgr.end(); cm.done(); // normalize statistic @@ -399,19 +399,17 @@ SGVector CMMD::Self::sample_null() CCustomDistance* CMMD::Self::compute_distance() { auto distance=new CCustomDistance(); - DataManager& dm=owner.get_data_manager(); + DataManager& data_mgr=owner.get_data_mgr(); - bool blockwise=dm.is_blockwise(); - dm.set_blockwise(false); + bool blockwise=data_mgr.is_blockwise(); + data_mgr.set_blockwise(false); // using data manager next() API in order to make it work with // streaming samples as well. - dm.start(); - auto samples=dm.next(); + data_mgr.start(); + auto samples=data_mgr.next(); if (!samples.empty()) { - dm.end(); - // use 0th block from each distribution (since there is only one block // for quadratic time MMD CFeatures *samples_p=samples[0][0].get(); @@ -438,14 +436,11 @@ CCustomDistance* CMMD::Self::compute_distance() } } else - { - dm.end(); SG_SERROR("Could not fetch samples!\n"); - } - dm.set_blockwise(blockwise); + data_mgr.end(); + data_mgr.set_blockwise(blockwise); - SG_REF(distance); return distance; } @@ -464,9 +459,9 @@ CMMD::~CMMD() void CMMD::set_kernel_selection_strategy(CKernelSelectionStrategy* strategy) { SG_REF(strategy); - const auto& km=self->strategy->get_kernel_manager(); - for (size_t i=0; iadd_kernel(km.kernel_at(i)); + const auto& kernel_mgr=self->strategy->get_kernel_mgr(); + for (size_t i=0; iadd_kernel(kernel_mgr.kernel_at(i)); self->strategy=std::shared_ptr(strategy, [](CKernelSelectionStrategy* ptr) { SG_UNREF(ptr); }); } @@ -475,18 +470,17 @@ void CMMD::add_kernel(CKernel* kernel) self->strategy->add_kernel(kernel); } -void CMMD::select_kernel(float64_t ratio) +void CMMD::select_kernel() { SG_DEBUG("Entering!\n"); - auto& dm=get_data_manager(); - dm.set_train_test_ratio(ratio); - dm.set_train_mode(true); + auto& data_mgr=get_data_mgr(); + data_mgr.set_train_mode(true); - auto& km=get_kernel_manager(); - km.kernel_at(0)=self->strategy->select_kernel(this); - km.restore_kernel_at(0); + auto& kernel_mgr=get_kernel_mgr(); + kernel_mgr.kernel_at(0)=self->strategy->select_kernel(this); + kernel_mgr.restore_kernel_at(0); - dm.set_train_mode(false); + data_mgr.set_train_mode(false); SG_DEBUG("Leaving!\n"); } @@ -507,9 +501,9 @@ float64_t CMMD::compute_variance() void CMMD::set_train_test_ratio(float64_t ratio) { - auto& dm=get_data_manager(); - dm.set_train_test_ratio(ratio); - dm.reset(); + auto& data_mgr=get_data_mgr(); + data_mgr.set_train_test_ratio(ratio); + data_mgr.reset(); } std::pair CMMD::compute_statistic_variance() @@ -549,8 +543,8 @@ bool CMMD::use_gpu() const void CMMD::cleanup() { - for (size_t i=0; i #include -#include namespace shogun { @@ -29,7 +27,7 @@ namespace shogun class COneDistributionTest : public CHypothesisTest { public: - COneDistributionTest(index_t num_kernels); + COneDistributionTest(); virtual ~COneDistributionTest(); void set_samples(CFeatures* samples); @@ -38,8 +36,8 @@ class COneDistributionTest : public CHypothesisTest void set_num_samples(index_t num_samples); index_t get_num_samples() const; - virtual float64_t compute_statistic() = 0; - virtual SGVector sample_null() = 0; + virtual float64_t compute_statistic()=0; + virtual SGVector sample_null()=0; virtual const char* get_name() const; }; diff --git a/src/shogun/statistical_testing/QuadraticTimeMMD.cpp b/src/shogun/statistical_testing/QuadraticTimeMMD.cpp index ec7203a0fef..dd099608400 100644 --- a/src/shogun/statistical_testing/QuadraticTimeMMD.cpp +++ b/src/shogun/statistical_testing/QuadraticTimeMMD.cpp @@ -94,8 +94,8 @@ void CQuadraticTimeMMD::Self::create_computation_jobs() void CQuadraticTimeMMD::Self::create_statistic_job() { SG_SDEBUG("Entering\n"); - const DataManager& dm=owner.get_data_manager(); - auto Nx=dm.num_samples_at(0); + const DataManager& data_mgr=owner.get_data_mgr(); + auto Nx=data_mgr.num_samples_at(0); switch (owner.get_statistic_type()) { case ST_UNBIASED_FULL: @@ -141,18 +141,17 @@ void CQuadraticTimeMMD::Self::compute_jobs(ComputationManager& cm) const void CQuadraticTimeMMD::Self::init_kernel() { SG_SDEBUG("Entering\n"); - const KernelManager& km=owner.get_kernel_manager(); - auto kernel=km.kernel_at(0); + const KernelManager& kernel_mgr=owner.get_kernel_mgr(); + auto kernel=kernel_mgr.kernel_at(0); REQUIRE(kernel!=nullptr, "Kernel is not set!\n"); if (!is_kernel_initialized && !(kernel->get_kernel_type()==K_CUSTOM)) { - DataManager& dm=owner.get_data_manager(); - dm.start(); - auto samples=dm.next(); + DataManager& data_mgr=owner.get_data_mgr(); + data_mgr.start(); + auto samples=data_mgr.next(); if (!samples.empty()) { - dm.end(); CFeatures *samples_p=samples[0][0].get(); CFeatures *samples_q=samples[1][0].get(); auto samples_p_and_q=FeaturesUtil::create_merged_copy(samples_p, samples_q); @@ -162,10 +161,8 @@ void CQuadraticTimeMMD::Self::init_kernel() SG_SDEBUG("Kernel is initialized with joint features of %d total samples!\n", samples_p_and_q->get_num_vectors()); } else - { - dm.end(); SG_SERROR("Could not fetch samples!\n"); - } + data_mgr.end(); } SG_SDEBUG("Leaving\n"); } @@ -173,8 +170,8 @@ void CQuadraticTimeMMD::Self::init_kernel() SGMatrix CQuadraticTimeMMD::Self::get_kernel_matrix() { SG_SDEBUG("Entering\n"); - const KernelManager& km=owner.get_kernel_manager(); - auto kernel=km.kernel_at(0); + const KernelManager& kernel_mgr=owner.get_kernel_mgr(); + auto kernel=kernel_mgr.kernel_at(0); REQUIRE(kernel!=nullptr, "Kernel is not set!\n"); SGMatrix kernel_matrix; @@ -189,14 +186,14 @@ SGMatrix CQuadraticTimeMMD::Self::get_kernel_matrix() init_kernel(); try { - owner.get_kernel_manager().precompute_kernel_at(0); + owner.get_kernel_mgr().precompute_kernel_at(0); } catch (ShogunException e) { SG_SERROR("%s, Data is too large! Computing kernel matrix was not possible!\n", e.get_exception_string()); } kernel->remove_lhs_and_rhs(); - auto precomputed_kernel=dynamic_cast(km.kernel_at(0)); + auto precomputed_kernel=dynamic_cast(kernel_mgr.kernel_at(0)); ASSERT(precomputed_kernel!=nullptr); kernel_matrix=precomputed_kernel->get_float32_kernel_matrix(); } @@ -238,9 +235,9 @@ SGVector CQuadraticTimeMMD::Self::sample_null() { SG_SDEBUG("Entering\n"); - const DataManager& dm=owner.get_data_manager(); - auto Nx=dm.num_samples_at(0); - auto Ny=dm.num_samples_at(1); + const DataManager& data_mgr=owner.get_data_mgr(); + auto Nx=data_mgr.num_samples_at(0); + auto Ny=data_mgr.num_samples_at(1); WithinBlockPermutationBatch compute(Nx, Ny, owner.get_num_null_samples(), owner.get_statistic_type()); SGVector result; @@ -251,8 +248,8 @@ SGVector CQuadraticTimeMMD::Self::sample_null() } else { - const KernelManager& km=owner.get_kernel_manager(); - auto kernel=km.kernel_at(0); + const KernelManager& kernel_mgr=owner.get_kernel_mgr(); + auto kernel=kernel_mgr.kernel_at(0); REQUIRE(kernel!=nullptr, "Kernel is not set!\n"); if (kernel->get_kernel_type()==K_CUSTOM) { @@ -284,7 +281,7 @@ CQuadraticTimeMMD::CQuadraticTimeMMD(CFeatures* samples_from_p, CFeatures* sampl CQuadraticTimeMMD::~CQuadraticTimeMMD() { - get_kernel_manager().restore_kernel_at(0); + get_kernel_mgr().restore_kernel_at(0); } void CQuadraticTimeMMD::set_kernel(CKernel* kernel) @@ -300,9 +297,9 @@ const std::function)> CQuadraticTimeMMD::get_direc const float64_t CQuadraticTimeMMD::normalize_statistic(float64_t statistic) const { - const DataManager& dm=get_data_manager(); - const index_t Nx=dm.num_samples_at(0); - const index_t Ny=dm.num_samples_at(1); + const DataManager& data_mgr=get_data_mgr(); + const index_t Nx=data_mgr.num_samples_at(0); + const index_t Ny=data_mgr.num_samples_at(1); return Nx*Ny*statistic/(Nx+Ny); } @@ -388,9 +385,9 @@ SGVector CQuadraticTimeMMD::sample_null() SGVector CQuadraticTimeMMD::gamma_fit_null() { SG_DEBUG("Entering\n"); - DataManager& dm=get_data_manager(); - index_t m=dm.num_samples_at(0); - index_t n=dm.num_samples_at(1); + DataManager& data_mgr=get_data_mgr(); + index_t m=data_mgr.num_samples_at(0); + index_t n=data_mgr.num_samples_at(1); REQUIRE(m==n, "Number of samples from p (%d) and q (%d) must be equal.\n", n, m) @@ -460,9 +457,9 @@ SGVector CQuadraticTimeMMD::gamma_fit_null() SGVector CQuadraticTimeMMD::spectrum_sample_null() { SG_DEBUG("Entering\n"); - DataManager& dm=get_data_manager(); - index_t m=dm.num_samples_at(0); - index_t n=dm.num_samples_at(1); + DataManager& data_mgr=get_data_mgr(); + index_t m=data_mgr.num_samples_at(0); + index_t n=data_mgr.num_samples_at(1); if (self->num_eigenvalues > m+n - 1) { @@ -483,9 +480,9 @@ SGVector CQuadraticTimeMMD::spectrum_sample_null() /* imaginary matrix K=[K KL; KL' L] (MATLAB notation) * K is matrix for XX, L is matrix for YY, KL is XY, LK is YX * works since X and Y are concatenated here */ - SGMatrix precomputed_km=self->get_kernel_matrix(); - SGMatrix K(precomputed_km.num_rows, precomputed_km.num_cols); - std::copy(precomputed_km.matrix, precomputed_km.matrix+precomputed_km.num_rows*precomputed_km.num_cols, K.matrix); + SGMatrix precomputed_kernel_mgr=self->get_kernel_matrix(); + SGMatrix K(precomputed_kernel_mgr.num_rows, precomputed_kernel_mgr.num_cols); + std::copy(precomputed_kernel_mgr.matrix, precomputed_kernel_mgr.matrix+precomputed_kernel_mgr.num_rows*precomputed_kernel_mgr.num_cols, K.matrix); /* center matrix K=H*K*H */ K.center(); @@ -525,15 +522,15 @@ void CQuadraticTimeMMD::precompute_kernel_matrix(bool precompute) { if (self->precompute && !precompute) { - const KernelManager& km=get_kernel_manager(); - auto kernel=km.kernel_at(0); + const KernelManager& kernel_mgr=get_kernel_mgr(); + auto kernel=kernel_mgr.kernel_at(0); REQUIRE(kernel!=nullptr, "Kernel is not set!\n"); if (kernel->get_kernel_type()==K_CUSTOM) { SG_SINFO("Precomputed kernel matrix exists! Removing the existing matrix!\n"); - get_kernel_manager().restore_kernel_at(0); - kernel=km.kernel_at(0); + get_kernel_mgr().restore_kernel_at(0); + kernel=kernel_mgr.kernel_at(0); REQUIRE(kernel!=nullptr, "Kernel is not set!\n"); if (kernel->get_kernel_type()==K_CUSTOM) { diff --git a/src/shogun/statistical_testing/TwoDistributionTest.cpp b/src/shogun/statistical_testing/TwoDistributionTest.cpp index 7c02658dd5e..021fb6e02f5 100644 --- a/src/shogun/statistical_testing/TwoDistributionTest.cpp +++ b/src/shogun/statistical_testing/TwoDistributionTest.cpp @@ -23,8 +23,7 @@ using namespace shogun; using namespace internal; -CTwoDistributionTest::CTwoDistributionTest(index_t num_kernels) -: CHypothesisTest(TwoDistributionTest::num_feats, num_kernels) +CTwoDistributionTest::CTwoDistributionTest() : CHypothesisTest(TwoDistributionTest::num_feats) { } @@ -34,49 +33,49 @@ CTwoDistributionTest::~CTwoDistributionTest() void CTwoDistributionTest::set_p(CFeatures* samples_from_p) { - auto& dm=get_data_manager(); + auto& dm=get_data_mgr(); dm.samples_at(0)=samples_from_p; } CFeatures* CTwoDistributionTest::get_p() const { - const auto& dm=get_data_manager(); + const auto& dm=get_data_mgr(); return dm.samples_at(0); } void CTwoDistributionTest::set_q(CFeatures* samples_from_q) { - auto& dm=get_data_manager(); + auto& dm=get_data_mgr(); dm.samples_at(1)=samples_from_q; } CFeatures* CTwoDistributionTest::get_q() const { - const auto& dm=get_data_manager(); + const auto& dm=get_data_mgr(); return dm.samples_at(1); } void CTwoDistributionTest::set_num_samples_p(index_t num_samples_from_p) { - auto& dm=get_data_manager(); + auto& dm=get_data_mgr(); dm.num_samples_at(0)=num_samples_from_p; } const index_t CTwoDistributionTest::get_num_samples_p() const { - const auto& dm=get_data_manager(); + const auto& dm=get_data_mgr(); return dm.num_samples_at(0); } void CTwoDistributionTest::set_num_samples_q(index_t num_samples_from_q) { - auto& dm=get_data_manager(); + auto& dm=get_data_mgr(); dm.num_samples_at(1)=num_samples_from_q; } const index_t CTwoDistributionTest::get_num_samples_q() const { - const auto& dm=get_data_manager(); + const auto& dm=get_data_mgr(); return dm.num_samples_at(1); } diff --git a/src/shogun/statistical_testing/TwoDistributionTest.h b/src/shogun/statistical_testing/TwoDistributionTest.h index f343d32ae45..120410d2acf 100644 --- a/src/shogun/statistical_testing/TwoDistributionTest.h +++ b/src/shogun/statistical_testing/TwoDistributionTest.h @@ -29,25 +29,25 @@ namespace shogun class CTwoDistributionTest : public CHypothesisTest { public: - CTwoDistributionTest(index_t num_kernels); + CTwoDistributionTest(); virtual ~CTwoDistributionTest(); + void set_p(CFeatures* samples_from_p); CFeatures* get_p() const; + + void set_q(CFeatures* samples_from_q); CFeatures* get_q() const; void set_num_samples_p(index_t num_samples_from_p); - void set_num_samples_q(index_t num_samples_from_q); - const index_t get_num_samples_p() const; + + void set_num_samples_q(index_t num_samples_from_q); const index_t get_num_samples_q() const; virtual float64_t compute_statistic()=0; virtual SGVector sample_null()=0; virtual const char* get_name() const; -protected: - void set_p(CFeatures* samples_from_p); - void set_q(CFeatures* samples_from_q); }; } diff --git a/src/shogun/statistical_testing/TwoSampleTest.cpp b/src/shogun/statistical_testing/TwoSampleTest.cpp index a5d98f5a95d..6dc906aa826 100644 --- a/src/shogun/statistical_testing/TwoSampleTest.cpp +++ b/src/shogun/statistical_testing/TwoSampleTest.cpp @@ -24,8 +24,26 @@ using namespace shogun; using namespace internal; -CTwoSampleTest::CTwoSampleTest() : CTwoDistributionTest(TwoSampleTest::num_kernels) +struct CTwoSampleTest::Self { + Self(index_t num_kernels); + KernelManager kernel_mgr; +}; + +CTwoSampleTest::Self::Self(index_t num_kernels) : kernel_mgr(num_kernels) +{ +} + +CTwoSampleTest::CTwoSampleTest() : CTwoDistributionTest() +{ + self=std::unique_ptr(new Self(TwoSampleTest::num_kernels)); +} + +CTwoSampleTest::CTwoSampleTest(CFeatures* samples_from_p, CFeatures* samples_from_q) : CTwoDistributionTest() +{ + self=std::unique_ptr(new Self(TwoSampleTest::num_kernels)); + set_p(samples_from_p); + set_q(samples_from_q); } CTwoSampleTest::~CTwoSampleTest() @@ -34,18 +52,26 @@ CTwoSampleTest::~CTwoSampleTest() void CTwoSampleTest::set_kernel(CKernel* kernel) { - auto& km=get_kernel_manager(); - km.kernel_at(0)=kernel; - km.restore_kernel_at(0); + self->kernel_mgr.kernel_at(0)=kernel; + self->kernel_mgr.restore_kernel_at(0); } CKernel* CTwoSampleTest::get_kernel() const { - const auto& km=get_kernel_manager(); - return km.kernel_at(0); + return self->kernel_mgr.kernel_at(0); } const char* CTwoSampleTest::get_name() const { return "TwoSampleTest"; } + +KernelManager& CTwoSampleTest::get_kernel_mgr() +{ + return self->kernel_mgr; +} + +const KernelManager& CTwoSampleTest::get_kernel_mgr() const +{ + return self->kernel_mgr; +} diff --git a/src/shogun/statistical_testing/TwoSampleTest.h b/src/shogun/statistical_testing/TwoSampleTest.h index f8008ce72fd..476cc07ba8f 100644 --- a/src/shogun/statistical_testing/TwoSampleTest.h +++ b/src/shogun/statistical_testing/TwoSampleTest.h @@ -19,26 +19,40 @@ #ifndef TWO_SAMPLE_TEST_H_ #define TWO_SAMPLE_TEST_H_ +#include #include namespace shogun { class CKernel; +class CFeatures; + +namespace internal +{ + class KernelManager; +} class CTwoSampleTest : public CTwoDistributionTest { public: CTwoSampleTest(); + CTwoSampleTest(CFeatures* samples_from_p, CFeatures* samples_from_q); virtual ~CTwoSampleTest(); virtual void set_kernel(CKernel* kernel); CKernel* get_kernel() const; - virtual float64_t compute_statistic() = 0; - virtual SGVector sample_null() = 0; + virtual float64_t compute_statistic()=0; + virtual SGVector sample_null()=0; virtual const char* get_name() const; +protected: + internal::KernelManager& get_kernel_mgr(); + const internal::KernelManager& get_kernel_mgr() const; +private: + struct Self; + std::unique_ptr self; }; } diff --git a/src/shogun/statistical_testing/internals/DataFetcher.cpp b/src/shogun/statistical_testing/internals/DataFetcher.cpp index b770b9e7de1..92742ef28a3 100644 --- a/src/shogun/statistical_testing/internals/DataFetcher.cpp +++ b/src/shogun/statistical_testing/internals/DataFetcher.cpp @@ -24,158 +24,161 @@ using namespace shogun; using namespace internal; -DataFetcher::DataFetcher() : m_num_samples(0), m_samples(nullptr), - train_test_subset_used(false) +DataFetcher::DataFetcher() : m_num_samples(0), train_test_mode(false), + train_mode(false), m_samples(nullptr), features_shuffled(false) { } -DataFetcher::DataFetcher(CFeatures* samples) : m_samples(samples), - train_test_subset_used(false) +DataFetcher::DataFetcher(CFeatures* samples) : train_test_mode(false), + train_mode(false), m_samples(samples), features_shuffled(false) { REQUIRE(m_samples!=nullptr, "Samples cannot be null!\n"); SG_REF(m_samples); m_num_samples=m_samples->get_num_vectors(); - m_train_test_details.set_total_num_samples(m_num_samples); } DataFetcher::~DataFetcher() { - end(); SG_UNREF(m_samples); } -const char* DataFetcher::get_name() const +void DataFetcher::set_blockwise(bool blockwise) +{ + if (blockwise) + { + m_block_details=last_blockwise_details; + SG_SDEBUG("Restoring the blockwise details!\n"); + m_block_details.m_full_data=false; + } + else + { + last_blockwise_details=m_block_details; + SG_SDEBUG("Saving the blockwise details!\n"); + m_block_details=BlockwiseDetails(); + } +} + +void DataFetcher::set_train_test_mode(bool on) +{ + train_test_mode=on; +} + +bool DataFetcher::is_train_test_mode() const +{ + return train_test_mode; +} + +void DataFetcher::set_train_mode(bool on) +{ + train_mode=on; +} + +bool DataFetcher::is_train_mode() const { - return "DataFetcher"; + return train_mode; } -void DataFetcher::set_train_test_ratio(float64_t train_test_ratio) +void DataFetcher::set_train_test_ratio(float64_t ratio) { - m_num_samples=m_train_test_details.get_total_num_samples(); - REQUIRE(m_num_samples>0, "Number of samples is not set!\n"); - index_t num_training_samples=m_num_samples*train_test_ratio/(train_test_ratio+1); - m_train_test_details.set_num_training_samples(num_training_samples); - SG_SINFO("Must set the train/test mode by calling set_train_mode(True/False)!\n"); + train_test_ratio=ratio; } float64_t DataFetcher::get_train_test_ratio() const { - return float64_t(m_train_test_details.get_num_training_samples())/m_train_test_details.get_num_test_samples(); + return train_test_ratio; } -void DataFetcher::set_train_mode(bool train_mode) +void DataFetcher::shuffle_features() { - m_train_test_details.train_mode=train_mode; - // TODO put the following in another methods - index_t start_index=0; - if (m_train_test_details.train_mode) + REQUIRE(train_test_mode, "This method is allowed only when Train/Test method is active!\n"); + if (features_shuffled) { - m_num_samples=m_train_test_details.get_num_training_samples(); - if (m_num_samples==0) - SG_SERROR("The number of training samples is 0! Please set a valid train-test ratio\n"); - SG_SINFO("Using %d number of samples for training!\n", m_num_samples); + SG_SWARNING("Features are already shuffled! Call to shuffle_features() has no effect." + "If you want to reshuffle, please call unshuffle_features() first and then call this method!\n"); } else { - m_num_samples=m_train_test_details.get_num_test_samples(); - SG_SINFO("Using %d number of samples for testing!\n", m_num_samples); - start_index=m_train_test_details.get_num_training_samples(); - if (start_index==0) + const index_t size=m_samples->get_num_vectors(); + SG_SDEBUG("Current number of feature vectors = %d\n", size); + if (shuffle_subset.size()remove_subset(); - train_test_subset_used=false; - } - return; + SG_SDEBUG("Resizing the shuffle indices vector (from %d to %d)\n", shuffle_subset.size(), size); + shuffle_subset=SGVector(size); } - } - SGVector inds(m_num_samples); - std::iota(inds.data(), inds.data()+inds.size(), start_index); - if (train_test_subset_used) - m_samples->remove_subset(); - m_samples->add_subset(inds); - train_test_subset_used=true; -} + std::iota(shuffle_subset.data(), shuffle_subset.data()+shuffle_subset.size(), 0); + CMath::permute(shuffle_subset); +// shuffle_subset.display_vector("shuffle_subset"); -void DataFetcher::set_xvalidation_mode(bool xvalidation_mode) -{ -// using fetcher_type=std::unique_ptr; -// std::for_each(fetchers.begin(), fetchers.end(), [&train_mode](fetcher_type& f) -// { -// f->set_xvalidation_mode(xvalidation_mode); -// }); + SG_SDEBUG("Shuffling %d feature vectors\n", size); + m_samples->add_subset(shuffle_subset); + + features_shuffled=true; + } } -index_t DataFetcher::get_num_folds() const +void DataFetcher::unshuffle_features() { - return 1+ceil(get_train_test_ratio()); + REQUIRE(train_test_mode, "This method is allowed only when Train/Test method is active!\n"); + if (features_shuffled) + { + m_samples->remove_subset(); + features_shuffled=false; + } + else + { + SG_SWARNING("Features are NOT shuffled! Call to unshuffle_features() has no effect." + "If you want to reshuffle, please call shuffle_features() instead!\n"); + } } void DataFetcher::use_fold(index_t idx) { - auto num_folds=get_num_folds(); - REQUIRE(idx>=0, "The index (%d) has to be between 0 and %d, both inclusive!\n", idx, num_folds-1); - REQUIRE(idxremove_subset(); - - SGVector inds; - auto start_idx=idx*num_per_fold; - auto num_samples=0; - - if (m_train_test_details.train_mode) + allocate_active_subset(); + auto num_samples_per_fold=get_num_samples()/get_num_folds(); + auto start_idx=idx*num_samples_per_fold; + if (train_mode) { - num_samples=m_train_test_details.get_num_training_samples(); - inds=SGVector(num_samples); - std::iota(inds.data(), inds.data()+inds.size(), 0); - if (start_idx(num_samples); - std::iota(inds.data(), inds.data()+inds.size(), start_idx); - m_samples->add_subset(inds); - } - inds.display_vector("inds"); - m_samples->add_subset(inds); + std::iota(active_subset.data(), active_subset.data()+active_subset.size(), start_idx); +// active_subset.display_vector("active_subset"); } -void DataFetcher::set_blockwise(bool blockwise) +void DataFetcher::init_active_subset() { - if (blockwise) - { - m_block_details=last_blockwise_details; - SG_SDEBUG("Restoring the blockwise details!\n"); - m_block_details.m_full_data=false; - } - else - { - last_blockwise_details=m_block_details; - SG_SDEBUG("Saving the blockwise details!\n"); - m_block_details=BlockwiseDetails(); - } + allocate_active_subset(); + index_t start_index=0; + if (!train_mode) + start_index=m_samples->get_num_vectors()*train_test_ratio/(train_test_ratio+1); + std::iota(active_subset.data(), active_subset.data()+active_subset.size(), start_index); +// active_subset.display_vector("active_subset"); } void DataFetcher::start() { - REQUIRE(m_num_samples>0, "Number of samples is 0!\n"); - if (m_block_details.m_full_data || m_block_details.m_blocksize>m_num_samples) + REQUIRE(get_num_samples()>0, "Number of samples is 0!\n"); + if (train_test_mode) + { + m_samples->add_subset(active_subset); + SG_SDEBUG("Added active subset!\n"); + SG_SINFO("Currently active number of samples is %d\n", get_num_samples()); + } + + if (m_block_details.m_full_data || m_block_details.m_blocksize>get_num_samples()) { - SG_SINFO("Fetching entire data (%d samples)!\n", m_num_samples); - m_block_details.with_blocksize(m_num_samples); + SG_SINFO("Fetching entire data (%d samples)!\n", get_num_samples()); + m_block_details.with_blocksize(get_num_samples()); } - m_block_details.m_total_num_blocks=m_num_samples/m_block_details.m_blocksize; + m_block_details.m_total_num_blocks=get_num_samples()/m_block_details.m_blocksize; reset(); } @@ -184,16 +187,18 @@ CFeatures* DataFetcher::next() CFeatures* next_samples=nullptr; // figure out how many samples to fetch in this burst auto num_already_fetched=m_block_details.m_next_block_index*m_block_details.m_blocksize; - auto num_more_samples=m_num_samples-num_already_fetched; + auto num_more_samples=get_num_samples()-num_already_fetched; if (num_more_samples>0) { - auto num_samples_this_burst=std::min(m_block_details.m_max_num_samples_per_burst, num_more_samples); // create a shallow copy and add proper index subset next_samples=FeaturesUtil::create_shallow_copy(m_samples); - SGVector inds(num_samples_this_burst); - std::iota(inds.vector, inds.vector+inds.vlen, num_already_fetched); - next_samples->add_subset(inds); - + auto num_samples_this_burst=std::min(m_block_details.m_max_num_samples_per_burst, num_more_samples); + if (num_samples_this_burstget_num_vectors()) + { + SGVector inds(num_samples_this_burst); + std::iota(inds.vector, inds.vector+inds.vlen, num_already_fetched); + next_samples->add_subset(inds); + } m_block_details.m_next_block_index+=m_block_details.m_num_blocks_per_burst; } return next_samples; @@ -206,11 +211,39 @@ void DataFetcher::reset() void DataFetcher::end() { + if (train_test_mode) + { + m_samples->remove_subset(); + SG_SDEBUG("Removed active subset!\n"); + SG_SINFO("Currently active number of samples is %d\n", get_num_samples()); + } +} + +index_t DataFetcher::get_num_samples() const +{ + if (train_test_mode) + { + if (train_mode) + return m_num_samples*train_test_ratio/(train_test_ratio+1); + else + return m_num_samples/(train_test_ratio+1); + } + return m_samples->get_num_vectors(); +} + +index_t DataFetcher::get_num_folds() const +{ + return 1+ceil(get_train_test_ratio()); +} + +index_t DataFetcher::get_num_training_samples() const +{ + return get_num_samples()*get_train_test_ratio()/(get_train_test_ratio()+1); } -const index_t DataFetcher::get_num_samples() const +index_t DataFetcher::get_num_testing_samples() const { - return m_num_samples; + return get_num_samples()/(get_train_test_ratio()+1); } BlockwiseDetails& DataFetcher::fetch_blockwise() @@ -218,3 +251,26 @@ BlockwiseDetails& DataFetcher::fetch_blockwise() m_block_details.m_full_data=false; return m_block_details; } + +void DataFetcher::allocate_active_subset() +{ + REQUIRE(train_test_mode, "This method is allowed only when Train/Test method is active!\n"); + index_t num_active_samples=0; + if (train_mode) + { + num_active_samples=m_samples->get_num_vectors()*train_test_ratio/(train_test_ratio+1); + SG_SINFO("Using %d number of samples for this fold as training samples!\n", num_active_samples); + } + else + { + num_active_samples=m_samples->get_num_vectors()/(train_test_ratio+1); + SG_SINFO("Using %d number of samples for this fold as testing samples!\n", num_active_samples); + } + + ASSERT(num_active_samples>0); + if (active_subset.size()!=num_active_samples) + { + SG_SDEBUG("Resizing the active subset from %d to %d\n", active_subset.size(), num_active_samples); + active_subset=SGVector(num_active_samples); + } +} diff --git a/src/shogun/statistical_testing/internals/DataFetcher.h b/src/shogun/statistical_testing/internals/DataFetcher.h index 155b90db668..e0cd6b2dec3 100644 --- a/src/shogun/statistical_testing/internals/DataFetcher.h +++ b/src/shogun/statistical_testing/internals/DataFetcher.h @@ -30,8 +30,8 @@ #include #include +#include #include -#include #ifndef DATA_FETCHER_H__ #define DATA_FETCHER_H__ @@ -53,30 +53,54 @@ class DataFetcher public: DataFetcher(CFeatures* samples); virtual ~DataFetcher(); - virtual void set_train_test_ratio(float64_t train_test_ratio); - float64_t get_train_test_ratio() const; - virtual void set_train_mode(bool train_mode); - void set_xvalidation_mode(bool xvalidation_mode); - index_t get_num_folds() const; - void use_fold(index_t idx); + void set_blockwise(bool blockwise); + void set_train_test_mode(bool on); + bool is_train_test_mode() const; + + void set_train_mode(bool on); + bool is_train_mode() const; + + void set_train_test_ratio(float64_t ratio); + float64_t get_train_test_ratio() const; + + virtual void shuffle_features(); + virtual void unshuffle_features(); + + virtual void use_fold(index_t i); + virtual void init_active_subset(); + virtual void start(); virtual CFeatures* next(); virtual void reset(); virtual void end(); - const index_t get_num_samples() const; + + virtual index_t get_num_samples() const; + + index_t get_num_folds() const; + index_t get_num_training_samples() const; + index_t get_num_testing_samples() const; + BlockwiseDetails& fetch_blockwise(); - virtual const char* get_name() const; + virtual const char* get_name() const + { + return "DataFetcher"; + } protected: DataFetcher(); BlockwiseDetails m_block_details; - TrainTestDetails m_train_test_details; index_t m_num_samples; + bool train_test_mode; + bool train_mode; + float64_t train_test_ratio; private: CFeatures* m_samples; - bool train_test_subset_used; + SGVector shuffle_subset; + SGVector active_subset; + bool features_shuffled; BlockwiseDetails last_blockwise_details; + void allocate_active_subset(); }; } diff --git a/src/shogun/statistical_testing/internals/DataManager.cpp b/src/shogun/statistical_testing/internals/DataManager.cpp index b5a9acce2d8..9679bf8515a 100644 --- a/src/shogun/statistical_testing/internals/DataManager.cpp +++ b/src/shogun/statistical_testing/internals/DataManager.cpp @@ -41,15 +41,16 @@ using namespace shogun; using namespace internal; -// TODO add nullptr check before calling the methods on actual fetchers -// this would be where someone calls the other methofds before setiing the sameples - DataManager::DataManager(size_t num_distributions) { SG_SDEBUG("Data manager instance initialized with %d data sources!\n", num_distributions); fetchers.resize(num_distributions); std::fill(fetchers.begin(), fetchers.end(), nullptr); + train_test_mode=default_train_test_mode; + train_mode=default_train_mode; + train_test_ratio=default_train_test_ratio; + cross_validation_mode=default_cross_validation_mode; } DataManager::~DataManager() @@ -103,15 +104,15 @@ void DataManager::set_blocksize(index_t blocksize) "The blocksize has to be within [0, %d], given = %d!\n", n, blocksize); REQUIRE(n%blocksize==0, - "Total number of samples (%d) has to be divisble by the blocksize (%d)!\n", - n, blocksize); + "Total number of samples (%d) has to be divisble by the blocksize (%d)!\n", + n, blocksize); for (size_t i=0; im_num_samples; REQUIRE((blocksize*m)%n==0, - "Blocksize (%d) cannot be even distributed with a ratio of %f!\n", - blocksize, m/n); + "Blocksize (%d) cannot be even distributed with a ratio of %f!\n", + blocksize, m/n); fetchers[i]->fetch_blockwise().with_blocksize(blocksize*m/n); SG_SDEBUG("block[%d].size = ", i, blocksize*m/n); } @@ -122,8 +123,8 @@ void DataManager::set_num_blocks_per_burst(index_t num_blocks_per_burst) { SG_SDEBUG("Entering!\n"); REQUIRE(num_blocks_per_burst>0, - "Number of blocks per burst (%d) has to be greater than 0!\n", - num_blocks_per_burst); + "Number of blocks per burst (%d) has to be greater than 0!\n", + num_blocks_per_burst); index_t blocksize=0; typedef std::unique_ptr fetcher_type; @@ -184,7 +185,7 @@ const index_t DataManager::num_samples_at(size_t i) const "Value of i (%d) should be between 0 and %d, inclusive!", i, fetchers.size()-1); SG_SDEBUG("Leaving!\n"); - return fetchers[i]->m_num_samples; + return fetchers[i]->get_num_samples(); } const index_t DataManager::blocksize_at(size_t i) const @@ -197,75 +198,157 @@ const index_t DataManager::blocksize_at(size_t i) const return fetchers[i]->m_block_details.m_blocksize; } -const bool DataManager::is_blockwise() const +void DataManager::set_blockwise(bool blockwise) { SG_SDEBUG("Entering!\n"); - bool blockwise=true; for (size_t i=0; im_block_details.m_full_data; + fetchers[i]->set_blockwise(blockwise); SG_SDEBUG("Leaving!\n"); - return blockwise; } -void DataManager::set_blockwise(bool blockwise) +const bool DataManager::is_blockwise() const { SG_SDEBUG("Entering!\n"); + bool blockwise=true; for (size_t i=0; iset_blockwise(blockwise); + blockwise&=!fetchers[i]->m_block_details.m_full_data; SG_SDEBUG("Leaving!\n"); + return blockwise; } -void DataManager::set_train_test_ratio(float64_t train_test_ratio) +void DataManager::set_train_test_mode(bool on) { - SG_SDEBUG("Entering!\n"); + train_test_mode=on; + if (!train_test_mode) + { + train_mode=default_train_mode; + train_test_ratio=default_train_test_ratio; + cross_validation_mode=default_cross_validation_mode; + } + REQUIRE(fetchers.size()>0, "Features are not set!"); typedef std::unique_ptr fetcher_type; - std::for_each(fetchers.begin(), fetchers.end(), [&train_test_ratio](fetcher_type& f) + std::for_each(fetchers.begin(), fetchers.end(), [this, on](fetcher_type& f) { - f->set_train_test_ratio(train_test_ratio); + f->set_train_test_mode(on); + if (on) + { + f->set_train_mode(train_mode); + f->set_train_test_ratio(train_test_ratio); + } }); - SG_SDEBUG("Leaving!\n"); +} + +bool DataManager::is_train_test_mode() const +{ + return train_test_mode; +} + +void DataManager::set_train_mode(bool on) +{ + if (train_test_mode) + train_mode=on; + else + { + SG_SERROR("Train mode cannot be used without turning on Train/Test mode first!" + "Please call set_train_test_mode(True) before using this method!\n"); + } +} + +bool DataManager::is_train_mode() const +{ + return train_mode; +} + +void DataManager::set_cross_validation_mode(bool on) +{ + if (train_test_mode) + cross_validation_mode=on; + else + { + SG_SERROR("Cross-validation mode cannot be used without turning on Train/Test mode first!" + "Please call set_train_test_mode(True) before using this method!\n"); + } +} + +bool DataManager::is_cross_validation_mode() const +{ + return cross_validation_mode; +} + +void DataManager::set_train_test_ratio(float64_t ratio) +{ + if (train_test_mode) + train_test_ratio=ratio; + else + { + SG_SERROR("Train-test ratio cannot be set without turning on Train/Test mode first!" + "Please call set_train_test_mode(True) before using this method!\n"); + } } float64_t DataManager::get_train_test_ratio() const { - REQUIRE(fetchers[0]!=nullptr, "Please set the samples first!\n"); - return fetchers[0]->get_train_test_ratio(); + return train_test_ratio; } -void DataManager::set_train_mode(bool train_mode) +index_t DataManager::get_num_folds() const +{ + return ceil(get_train_test_ratio())+1; +} + +void DataManager::shuffle_features() { SG_SDEBUG("Entering!\n"); + REQUIRE(fetchers.size()>0, "Features are not set!"); typedef std::unique_ptr fetcher_type; - std::for_each(fetchers.begin(), fetchers.end(), [&train_mode](fetcher_type& f) - { - f->set_train_mode(train_mode); - }); + std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->shuffle_features(); }); SG_SDEBUG("Leaving!\n"); } -void DataManager::set_xvalidation_mode(bool xvalidation_mode) +void DataManager::unshuffle_features() { SG_SDEBUG("Entering!\n"); + REQUIRE(fetchers.size()>0, "Features are not set!"); typedef std::unique_ptr fetcher_type; - std::for_each(fetchers.begin(), fetchers.end(), [&xvalidation_mode](fetcher_type& f) - { - f->set_xvalidation_mode(xvalidation_mode); - }); + std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->unshuffle_features(); }); SG_SDEBUG("Leaving!\n"); } -index_t DataManager::get_num_folds() const +void DataManager::init_active_subset() { - REQUIRE(fetchers[0]!=nullptr, "Please set the samples first!\n"); - return fetchers[0]->get_num_folds(); + SG_SDEBUG("Entering!\n"); + + REQUIRE(train_test_mode, + "Train-test subset cannot be used without turning on Train/Test mode first!" + "Please call set_train_test_mode(True) before using this method!\n"); + REQUIRE(fetchers.size()>0, "Features are not set!"); + + typedef std::unique_ptr fetcher_type; + std::for_each(fetchers.begin(), fetchers.end(), [this](fetcher_type& f) + { + f->set_train_mode(train_mode); + f->set_train_test_ratio(train_test_ratio); + f->init_active_subset(); + }); + SG_SDEBUG("Leaving!\n"); } void DataManager::use_fold(index_t idx) { SG_SDEBUG("Entering!\n"); + + REQUIRE(train_test_mode, + "Fold subset cannot be used without turning on Train/Test mode first!" + "Please call set_train_test_mode(True) before using this method!\n"); + REQUIRE(fetchers.size()>0, "Features are not set!"); + REQUIRE(idx>=0, "Fold index has to be in [0, %d]!", get_num_folds()-1); + REQUIRE(idx fetcher_type; - std::for_each(fetchers.begin(), fetchers.end(), [&idx](fetcher_type& f) + std::for_each(fetchers.begin(), fetchers.end(), [this, idx](fetcher_type& f) { + f->set_train_mode(train_mode); + f->set_train_test_ratio(train_test_ratio); f->use_fold(idx); }); SG_SDEBUG("Leaving!\n"); @@ -274,6 +357,11 @@ void DataManager::use_fold(index_t idx) void DataManager::start() { SG_SDEBUG("Entering!\n"); + REQUIRE(fetchers.size()>0, "Features are not set!"); + + if (train_test_mode && !cross_validation_mode) + init_active_subset(); + typedef std::unique_ptr fetcher_type; std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->start(); }); SG_SDEBUG("Leaving!\n"); @@ -313,6 +401,7 @@ NextSamples DataManager::next() void DataManager::end() { SG_SDEBUG("Entering!\n"); + REQUIRE(fetchers.size()>0, "Features are not set!"); typedef std::unique_ptr fetcher_type; std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->end(); }); SG_SDEBUG("Leaving!\n"); @@ -321,17 +410,8 @@ void DataManager::end() void DataManager::reset() { SG_SDEBUG("Entering!\n"); + REQUIRE(fetchers.size()>0, "Features are not set!"); typedef std::unique_ptr fetcher_type; std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->reset(); }); SG_SDEBUG("Leaving!\n"); } - -void DataManager::set_train_test_mode(bool on) -{ - train_test_mode_on=on; -} - -void DataManager::set_train_test_ratio(float64_t ratio) -{ - train_test_ratio=ratio; -} diff --git a/src/shogun/statistical_testing/internals/DataManager.h b/src/shogun/statistical_testing/internals/DataManager.h index df08b359af1..f5368d7de46 100644 --- a/src/shogun/statistical_testing/internals/DataManager.h +++ b/src/shogun/statistical_testing/internals/DataManager.h @@ -176,19 +176,6 @@ class DataManager */ const index_t blocksize_at(size_t i) const; - /** - * @return True if block-wise fetching is on, False otherwise. - */ - const bool is_blockwise() const; - - /** - * Turns on blockwise fetching if True is passed. Turns off blockwise fetching if - * False is passed. The blockwise details are not destroyed when set to False, i.e. - * turning blockwise fetching back on again, we can get blocks as we would have got - * in the original setup. - */ - void set_blockwise(bool blockwise); - /** * @return Total number of samples that can be fetched from all the data sources. */ @@ -202,80 +189,44 @@ class DataManager */ index_t get_min_blocksize() const; - /** - * @param train_test_ratio The split ratio for train-test data. The default value is 0 - * which means that all of the data would be used for testing. - */ - void set_train_test_ratio(float64_t train_test_ratio); + void set_blockwise(bool blockwise); + const bool is_blockwise() const; - /** - * @return The split ratio for train-test data. The default value is 0, which means - * that all of the data would be used for testing. - */ - float64_t get_train_test_ratio() const; + void set_train_test_mode(bool on); + bool is_train_test_mode() const; - /** - * @param train_mode If set to true, then the training data would be returned by the data - * fetching API of this data manager. Otherwise, test data would be returend. - */ - void set_train_mode(bool train_mode); + void set_train_mode(bool on); + bool is_train_mode() const; - /** - * @param xvalidation_mode If set to true, then the data would be split in N fold (the value - * of N is determined from the train_test_ratio). - */ - void set_xvalidation_mode(bool xvalidation_mode); + void set_cross_validation_mode(bool on); + bool is_cross_validation_mode() const; + + void set_train_test_ratio(float64_t ratio); + float64_t get_train_test_ratio() const; - /** - * @return The number of folds that can be used based on the train-test ratio. Returns - * an integer if xvalidation mode is ON, 0 otherwise. - */ index_t get_num_folds() const; - /** - * Permutes the feature vectors. Useful for cross-validation set-up. Everytime - * TODO - * - void shuffle_features() - void unshuffle_features() - */ + void shuffle_features(); + void unshuffle_features(); - /** - * @param idx The index of the fold in X-validation scenario, has to be within the range of - * \f$[0, N)\f$, where N is the number of folds as returned by get_num_folds() method. - */ - void use_fold(index_t idx); + void use_fold(index_t i); + void init_active_subset(); - /** - * Call this method before fetching the data from the data manager - */ void start(); - - /** - * @return The next bunch of blocks fetched at any given burst. - */ NextSamples next(); - - /** - * call this method after fetching the data is done. - */ void end(); - - /** - * Resets the fetchers to the initial states. - */ void reset(); - - void set_train_test_mode(bool on); - void set_train_test_ratio(float64_t ratio); - - bool is_train_test_mode() const; - float64_t get_train_test_ratio() const; private: std::vector > fetchers; - bool train_test_mode; + + bool train_test_mode; // -> if ON, then train/test/fold subset is used (in start()) in end() method, we remove these subsets. + bool cross_validation_mode; // -> if ON, then shuffle subset is used, remove it after train_test mode in end() + bool train_mode; // -> if train/test mode ON or cross-validation mode on, this one is used. float64_t train_test_ratio; + constexpr static bool default_train_test_mode=false; + constexpr static bool default_train_mode=false; + constexpr static bool default_cross_validation_mode=false; constexpr static float64_t default_train_test_ratio=1.0; }; diff --git a/src/shogun/statistical_testing/internals/MaxXValidation.cpp b/src/shogun/statistical_testing/internals/MaxXValidation.cpp index 3c2a6244cc4..e0a635e5629 100644 --- a/src/shogun/statistical_testing/internals/MaxXValidation.cpp +++ b/src/shogun/statistical_testing/internals/MaxXValidation.cpp @@ -65,8 +65,8 @@ SGMatrix MaxXValidation::get_measure_matrix() void MaxXValidation::init_measures() { const index_t num_kernels=kernel_mgr.num_kernels(); - auto& dm=estimator->get_data_manager(); - const index_t N=dm.get_num_folds(); + auto& data_mgr=estimator->get_data_mgr(); + const index_t N=data_mgr.get_num_folds(); REQUIRE(N!=0, "Number of folds is not set!\n"); if (rejections.num_rows!=N*num_run || rejections.num_cols!=num_kernels) rejections=SGMatrix(N*num_run, num_kernels); @@ -78,31 +78,33 @@ void MaxXValidation::init_measures() void MaxXValidation::compute_measures() { - auto& dm=estimator->get_data_manager(); - dm.set_xvalidation_mode(true); + auto& data_mgr=estimator->get_data_mgr(); + data_mgr.set_cross_validation_mode(true); - const index_t N=dm.get_num_folds(); + const index_t N=data_mgr.get_num_folds(); SG_SINFO("Performing %d fold cross-validattion!\n", N); const size_t num_kernels=kernel_mgr.num_kernels(); auto existing_kernel=estimator->get_kernel(); for (auto i=0; iset_kernel(kernel); - rejections(i*N+j, k)=estimator->compute_p_value(estimator->compute_statistic())compute_statistic(); + rejections(i*N+j, k)=estimator->compute_p_value(statistic)cleanup(); } } - dm.unshuffle_features(); + data_mgr.unshuffle_features(); } - dm.set_xvalidation_mode(false); + data_mgr.set_cross_validation_mode(false); estimator->set_kernel(existing_kernel); for (auto j=0; jcompute_distance(); SG_REF(distance); n=distance->get_num_vec_lhs(); @@ -87,6 +88,7 @@ SGVector MedianHeuristic::get_measure_vector() SGMatrix MedianHeuristic::get_measure_matrix() { + REQUIRE(distance!=nullptr, "Distance is not initialized!\n"); return distance->get_distance_matrix(); } diff --git a/src/shogun/statistical_testing/internals/StreamingDataFetcher.cpp b/src/shogun/statistical_testing/internals/StreamingDataFetcher.cpp index 37457b94c8a..378b9c8c9e8 100644 --- a/src/shogun/statistical_testing/internals/StreamingDataFetcher.cpp +++ b/src/shogun/statistical_testing/internals/StreamingDataFetcher.cpp @@ -40,49 +40,48 @@ StreamingDataFetcher::~StreamingDataFetcher() end(); } -const char* StreamingDataFetcher::get_name() const +void StreamingDataFetcher::set_num_samples(index_t num_samples) { - return "StreamingDataFetcher"; + m_num_samples=num_samples; } -void StreamingDataFetcher::set_num_samples(index_t num_samples) +void StreamingDataFetcher::shuffle_features() { - m_num_samples=num_samples; - m_train_test_details.set_total_num_samples(m_num_samples); } -void StreamingDataFetcher::set_train_test_ratio(float64_t train_test_ratio) +void StreamingDataFetcher::unshuffle_features() { - if (m_train_test_details.get_total_num_samples()==0) - m_train_test_details.set_total_num_samples(m_num_samples); - DataFetcher::set_train_test_ratio(train_test_ratio); } -void StreamingDataFetcher::set_train_mode(bool train_mode) +void StreamingDataFetcher::use_fold(index_t i) { - if (train_mode) - { - m_num_samples=m_train_test_details.get_num_training_samples(); - if (m_num_samples==0) - SG_SERROR("The number of training samples is 0! Please set a valid train-test ratio\n"); - SG_SINFO("Using %d number of samples for training!\n", m_num_samples); - } - else +} + +void StreamingDataFetcher::init_active_subset() +{ +} + +index_t StreamingDataFetcher::get_num_samples() const +{ + if (train_test_mode) { - m_num_samples=m_train_test_details.get_num_test_samples(); - SG_SINFO("Using %d number of samples for testing!\n", m_num_samples); + if (train_mode) + return m_num_samples*train_test_ratio/(train_test_ratio+1); + else + return m_num_samples/(train_test_ratio+1); } + return m_num_samples; } void StreamingDataFetcher::start() { - REQUIRE(m_num_samples>0, "Number of samples is not set! It is MANDATORY for streaming features!\n"); - if (m_block_details.m_full_data || m_block_details.m_blocksize>m_num_samples) + REQUIRE(get_num_samples()>0, "Number of samples is not set! It is MANDATORY for streaming features!\n"); + if (m_block_details.m_full_data || m_block_details.m_blocksize>get_num_samples()) { - SG_SINFO("Fetching entire data (%d samples)!\n", m_num_samples); - m_block_details.with_blocksize(m_num_samples); + SG_SINFO("Fetching entire data (%d samples)!\n", get_num_samples()); + m_block_details.with_blocksize(get_num_samples()); } - m_block_details.m_total_num_blocks=m_num_samples/m_block_details.m_blocksize; + m_block_details.m_total_num_blocks=get_num_samples()/m_block_details.m_blocksize; m_block_details.m_next_block_index=0; if (!parser_running) { @@ -96,7 +95,7 @@ CFeatures* StreamingDataFetcher::next() CFeatures* next_samples=nullptr; // figure out how many samples to fetch in this burst auto num_already_fetched=m_block_details.m_next_block_index*m_block_details.m_blocksize; - auto num_more_samples=m_num_samples-num_already_fetched; + auto num_more_samples=get_num_samples()-num_already_fetched; if (num_more_samples>0) { auto num_samples_this_burst=std::min(m_block_details.m_max_num_samples_per_burst, num_more_samples); diff --git a/src/shogun/statistical_testing/internals/StreamingDataFetcher.h b/src/shogun/statistical_testing/internals/StreamingDataFetcher.h index 91941b0f2ef..b7f3453ab50 100644 --- a/src/shogun/statistical_testing/internals/StreamingDataFetcher.h +++ b/src/shogun/statistical_testing/internals/StreamingDataFetcher.h @@ -38,15 +38,25 @@ class StreamingDataFetcher : public DataFetcher friend class DataManager; public: StreamingDataFetcher(CStreamingFeatures* samples); - virtual ~StreamingDataFetcher() override; - virtual void set_train_test_ratio(float64_t train_test_ratio) override; - virtual void set_train_mode(bool train_mode) override; - virtual void start() override; - virtual CFeatures* next() override; - virtual void reset() override; - virtual void end() override; + virtual ~StreamingDataFetcher(); void set_num_samples(index_t num_samples); - virtual const char* get_name() const override; + + virtual void shuffle_features(); + virtual void unshuffle_features(); + + virtual void use_fold(index_t i); + virtual void init_active_subset(); + + virtual void start(); + virtual CFeatures* next(); + virtual void reset(); + virtual void end(); + + virtual index_t get_num_samples() const; + virtual const char* get_name() const + { + return "StreamingDataFetcher"; + } private: std::shared_ptr m_samples; bool parser_running; diff --git a/src/shogun/statistical_testing/internals/TrainTestDetails.cpp b/src/shogun/statistical_testing/internals/TrainTestDetails.cpp deleted file mode 100644 index 9df301f0786..00000000000 --- a/src/shogun/statistical_testing/internals/TrainTestDetails.cpp +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright (c) The Shogun Machine Learning Toolbox - * Written (w) 2016 Soumyajit De - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR - * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * The views and conclusions contained in the software and documentation are those - * of the authors and should not be interpreted as representing official policies, - * either expressed or implied, of the Shogun Development Team. - */ - -#include -#include - -using namespace shogun; -using namespace internal; - -TrainTestDetails::TrainTestDetails() : m_total_num_samples(0), m_num_training_samples(0) -{ -} - -void TrainTestDetails::set_total_num_samples(index_t total_num_samples) -{ - m_total_num_samples=total_num_samples; -} - -index_t TrainTestDetails::get_total_num_samples() const -{ - return m_total_num_samples; -} - -void TrainTestDetails::set_num_training_samples(index_t num_training_samples) -{ - REQUIRE(m_total_num_samples>=num_training_samples, - "Number of training samples cannot be greater than the total number of samples!\n"); - m_num_training_samples=num_training_samples; -} - -index_t TrainTestDetails::get_num_training_samples() const -{ - return m_num_training_samples; -} - -index_t TrainTestDetails::get_num_test_samples() const -{ - return m_total_num_samples-m_num_training_samples; -} -// -//bool TrainTestDetails::is_training_mode() const -//{ -//} -// -//void TrainTestDetails::set_train_mode(bool train_mode) -//{ -//} -// -//void TrainTestDetails::set_xvalidation_mode(bool xvalidation_mode) -//{ -//} diff --git a/src/shogun/statistical_testing/internals/TrainTestDetails.h b/src/shogun/statistical_testing/internals/TrainTestDetails.h deleted file mode 100644 index b4823c2125f..00000000000 --- a/src/shogun/statistical_testing/internals/TrainTestDetails.h +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright (c) The Shogun Machine Learning Toolbox - * Written (w) 2016 Soumyajit De - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR - * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * The views and conclusions contained in the software and documentation are those - * of the authors and should not be interpreted as representing official policies, - * either expressed or implied, of the Shogun Development Team. - */ - -#include - -#ifndef TRAIN_TEST_DETAILS_H__ -#define TRAIN_TEST_DETAILS_H__ - -namespace shogun -{ - -namespace internal -{ - -/** - * @brief Class that holds train-test details for the data-fetchers. - * There are one instance of this class per fetcher. - */ -class TrainTestDetails -{ - friend class DataFetcher; - friend class StreamingDataFetcher; - -public: - TrainTestDetails(); - - void set_total_num_samples(index_t total_num_sampels); - index_t get_total_num_samples() const; - - void set_num_training_samples(index_t num_training_samples); - index_t get_num_training_samples() const; - index_t get_num_test_samples() const; - -// bool is_training_mode() const; -// void set_train_mode(bool train_mode); -// void set_xvalidation_mode(bool xvalidation_mode); -private: - bool train_mode; - index_t m_total_num_samples; - index_t m_num_training_samples; -}; - -} - -} -#endif // TRAIN_TEST_DETAILS_H__ diff --git a/tests/unit/statistical_testing/KernelSelectionMaxMMD_unittest.cc b/tests/unit/statistical_testing/KernelSelectionMaxMMD_unittest.cc index 249a53adc8e..62c42b9ac35 100644 --- a/tests/unit/statistical_testing/KernelSelectionMaxMMD_unittest.cc +++ b/tests/unit/statistical_testing/KernelSelectionMaxMMD_unittest.cc @@ -74,6 +74,7 @@ TEST(KernelSelectionMaxMMD, single_kernel) } mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_MMD)); + mmd->set_train_test_mode(true); mmd->select_kernel(); auto selected_kernel=static_cast(mmd->get_kernel()); EXPECT_NEAR(selected_kernel->get_width(), 0.5, 1E-10); @@ -110,6 +111,7 @@ TEST(KernelSelectionMaxMMD, weighted_kernel) } mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_MMD, true)); + mmd->set_train_test_mode(true); mmd->select_kernel(); auto weighted_kernel=dynamic_cast(mmd->get_kernel()); ASSERT_TRUE(weighted_kernel!=nullptr); diff --git a/tests/unit/statistical_testing/KernelSelectionMaxPower_unittest.cc b/tests/unit/statistical_testing/KernelSelectionMaxPower_unittest.cc index a5dac71ef56..c7bcb7d4929 100644 --- a/tests/unit/statistical_testing/KernelSelectionMaxPower_unittest.cc +++ b/tests/unit/statistical_testing/KernelSelectionMaxPower_unittest.cc @@ -74,6 +74,7 @@ TEST(KernelSelectionMaxPower, single_kernel) } mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_POWER)); + mmd->set_train_test_mode(true); mmd->select_kernel(); auto selected_kernel=static_cast(mmd->get_kernel()); EXPECT_NEAR(selected_kernel->get_width(), 0.5, 1E-10); @@ -110,6 +111,7 @@ TEST(KernelSelectionMaxPower, weighted_kernel) } mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_POWER, true)); + mmd->set_train_test_mode(true); mmd->select_kernel(); auto weighted_kernel=dynamic_cast(mmd->get_kernel()); ASSERT_TRUE(weighted_kernel!=nullptr); diff --git a/tests/unit/statistical_testing/KernelSelectionMaxXValidation_unittest.cc b/tests/unit/statistical_testing/KernelSelectionMaxXValidation_unittest.cc index 78f1c1d1e17..429503fce72 100644 --- a/tests/unit/statistical_testing/KernelSelectionMaxXValidation_unittest.cc +++ b/tests/unit/statistical_testing/KernelSelectionMaxXValidation_unittest.cc @@ -51,9 +51,6 @@ TEST(KernelSelectionMaxXValidation, single_kernel) const float64_t difference=0.5; const index_t num_kernels=10; -// sg_io->set_loglevel(MSG_DEBUG); -// sg_io->set_location_info(MSG_FUNCTION); - // use fixed seed sg_rand->set_seed(12345); @@ -79,8 +76,10 @@ TEST(KernelSelectionMaxXValidation, single_kernel) mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice)); } - mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_XVALIDATION, 1, 0.05)); - mmd->select_kernel(4); + mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_XVALIDATION, 5, 0.05)); + mmd->set_train_test_mode(true); + mmd->set_train_test_ratio(4); + mmd->select_kernel(); auto selected_kernel=static_cast(mmd->get_kernel()); EXPECT_NEAR(selected_kernel->get_width(), 0.5, 1E-10); } diff --git a/tests/unit/statistical_testing/KernelSelectionMedianHeuristic_unittest.cc b/tests/unit/statistical_testing/KernelSelectionMedianHeuristic_unittest.cc index 0d3b2394459..7ec92a9de05 100644 --- a/tests/unit/statistical_testing/KernelSelectionMedianHeuristic_unittest.cc +++ b/tests/unit/statistical_testing/KernelSelectionMedianHeuristic_unittest.cc @@ -69,7 +69,9 @@ TEST(KernelSelectionMedianHeuristic, quadratic_time_mmd) mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice)); } - mmd->select_kernel(KSM_MEDIAN_HEURISTIC); + mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MEDIAN_HEURISTIC)); + mmd->set_train_test_mode(true); + mmd->select_kernel(); auto selected_kernel=static_cast(mmd->get_kernel()); EXPECT_NEAR(selected_kernel->get_width(), 1.62, 1E-10); } @@ -104,6 +106,7 @@ TEST(KernelSelectionMedianHeuristic, linear_time_mmd) } mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MEDIAN_HEURISTIC)); + mmd->set_train_test_mode(true); mmd->select_kernel(); auto selected_kernel=static_cast(mmd->get_kernel()); EXPECT_NEAR(selected_kernel->get_width(), 1.62, 1E-10); diff --git a/tests/unit/statistical_testing/internals/DataManager_unittest.cc b/tests/unit/statistical_testing/internals/DataManager_unittest.cc index f1d847b0c46..fab238474ae 100644 --- a/tests/unit/statistical_testing/internals/DataManager_unittest.cc +++ b/tests/unit/statistical_testing/internals/DataManager_unittest.cc @@ -482,7 +482,7 @@ TEST(DataManager, block_data_two_distributions_streaming_feats_different_blocksi ASSERT_TRUE(total_q==num_vec_q); } -TEST(DataManager, train_data_two_distributions_normal_feats) +TEST(DataManager, train_test_whole_dense) { const index_t dim=3; const index_t num_vec=8; @@ -503,8 +503,10 @@ TEST(DataManager, train_data_two_distributions_normal_feats) mgr.samples_at(0)=feats_p; mgr.samples_at(1)=feats_q; - // training data + mgr.set_train_test_mode(true); mgr.set_train_test_ratio(train_test_ratio); + + // training data mgr.set_train_mode(true); mgr.start(); @@ -545,8 +547,7 @@ TEST(DataManager, train_data_two_distributions_normal_feats) mgr.end(); // full data - mgr.set_train_test_ratio(0); - mgr.set_train_mode(false); + mgr.set_train_test_mode(false); mgr.start(); next_burst=mgr.next(); @@ -566,7 +567,7 @@ TEST(DataManager, train_data_two_distributions_normal_feats) mgr.end(); } -TEST(DataManager, train_data_two_distributions_normal_feats_blockwise) +TEST(DataManager, train_test_blockwise_dense) { const index_t dim=3; const index_t num_vec=8; @@ -591,6 +592,7 @@ TEST(DataManager, train_data_two_distributions_normal_feats_blockwise) mgr.set_blocksize(blocksize); mgr.set_num_blocks_per_burst(num_blocks_per_burst); + mgr.set_train_test_mode(true); mgr.set_train_test_ratio(train_test_ratio); // train data @@ -648,8 +650,7 @@ TEST(DataManager, train_data_two_distributions_normal_feats_blockwise) mgr.end(); // full data - mgr.set_train_test_ratio(0); - mgr.set_train_mode(false); + mgr.set_train_test_mode(false); mgr.start(); next_burst=mgr.next(); @@ -676,7 +677,7 @@ TEST(DataManager, train_data_two_distributions_normal_feats_blockwise) mgr.end(); } -TEST(DataManager, train_data_two_distributions_streaming_feats) +TEST(DataManager, train_test_whole_streaming) { const index_t dim=3; const index_t num_vec=8; @@ -692,8 +693,10 @@ TEST(DataManager, train_data_two_distributions_streaming_feats) typedef CDenseFeatures feat_type; - // training data + mgr.set_train_test_mode(true); mgr.set_train_test_ratio(train_test_ratio); + + // training data mgr.set_train_mode(true); mgr.start(); @@ -734,8 +737,7 @@ TEST(DataManager, train_data_two_distributions_streaming_feats) mgr.end(); // full data - mgr.set_train_test_ratio(0); - mgr.set_train_mode(false); + mgr.set_train_test_mode(false); mgr.reset(); mgr.start(); @@ -756,7 +758,7 @@ TEST(DataManager, train_data_two_distributions_streaming_feats) mgr.end(); } -TEST(DataManager, train_data_two_distributions_streaming_feats_blockwise) +TEST(DataManager, train_test_blockwise_streaming) { const index_t dim=3; const index_t num_vec=8; @@ -773,10 +775,12 @@ TEST(DataManager, train_data_two_distributions_streaming_feats_blockwise) mgr.num_samples_at(1)=num_vec; mgr.set_blocksize(blocksize); mgr.set_num_blocks_per_burst(num_blocks_per_burst); - mgr.set_train_test_ratio(train_test_ratio); typedef CDenseFeatures feat_type; + mgr.set_train_test_mode(true); + mgr.set_train_test_ratio(train_test_ratio); + // train data mgr.set_train_mode(true); mgr.start(); @@ -832,8 +836,7 @@ TEST(DataManager, train_data_two_distributions_streaming_feats_blockwise) mgr.end(); // full data - mgr.set_train_test_ratio(0); - mgr.set_train_mode(false); + mgr.set_train_test_mode(false); mgr.reset(); mgr.start();