diff --git a/src/shogun/statistical_testing/internals/BlockwiseDetails.cpp b/src/shogun/statistical_testing/internals/BlockwiseDetails.cpp index 5961aa6071f..e8ceb4519b9 100644 --- a/src/shogun/statistical_testing/internals/BlockwiseDetails.cpp +++ b/src/shogun/statistical_testing/internals/BlockwiseDetails.cpp @@ -34,7 +34,8 @@ using namespace shogun; using namespace internal; BlockwiseDetails::BlockwiseDetails() : m_blocksize(0), m_num_blocks_per_burst(1), - m_max_num_samples_per_burst(0), m_next_block_index(0), m_total_num_blocks(0) + m_max_num_samples_per_burst(0), m_next_block_index(0), m_total_num_blocks(0), + m_full_data(true) { } diff --git a/src/shogun/statistical_testing/internals/BlockwiseDetails.h b/src/shogun/statistical_testing/internals/BlockwiseDetails.h index 8f742028974..916fd53e578 100644 --- a/src/shogun/statistical_testing/internals/BlockwiseDetails.h +++ b/src/shogun/statistical_testing/internals/BlockwiseDetails.h @@ -86,6 +86,9 @@ class BlockwiseDetails /** Total number of blocks to be fetched. Set by data fetchers */ index_t m_total_num_blocks; + + /** Whether the block should consist of full data (i.e. no block at all) */ + bool m_full_data; }; } diff --git a/src/shogun/statistical_testing/internals/DataFetcher.cpp b/src/shogun/statistical_testing/internals/DataFetcher.cpp index a21caca8124..b60c14555bd 100644 --- a/src/shogun/statistical_testing/internals/DataFetcher.cpp +++ b/src/shogun/statistical_testing/internals/DataFetcher.cpp @@ -52,6 +52,7 @@ void DataFetcher::set_train_test_ratio(float64_t train_test_ratio) m_num_samples=m_train_test_details.get_total_num_samples(); index_t num_training_samples=m_num_samples*train_test_ratio/(train_test_ratio+1); m_train_test_details.set_num_training_samples(num_training_samples); + SG_SINFO("Must set the train/test mode by calling set_train_mode(True/False)!\n"); } float64_t DataFetcher::get_train_test_ratio() const @@ -85,9 +86,7 @@ void DataFetcher::set_train_mode(bool train_mode) SGVector inds(m_num_samples); std::iota(inds.data(), inds.data()+inds.size(), start_index); if (train_test_subset_used) - { m_samples->remove_subset(); - } m_samples->add_subset(inds); train_test_subset_used=true; } @@ -120,9 +119,9 @@ void DataFetcher::use_fold(index_t idx) void DataFetcher::start() { REQUIRE(m_num_samples>0, "Number of samples is 0!\n"); - if (m_block_details.m_blocksize==0 || m_block_details.m_blocksize>m_num_samples) + if (m_block_details.m_full_data || m_block_details.m_blocksize>m_num_samples) { - SG_SINFO("Block details invalid! Fetching entire data (%d samples)!\n", m_num_samples); + SG_SINFO("Fetching entire data (%d samples)!\n", m_num_samples); m_block_details.with_blocksize(m_num_samples); } m_block_details.m_total_num_blocks=m_num_samples/m_block_details.m_blocksize; @@ -165,5 +164,6 @@ const index_t DataFetcher::get_num_samples() const BlockwiseDetails& DataFetcher::fetch_blockwise() { + m_block_details.m_full_data=false; return m_block_details; } diff --git a/tests/unit/statistical_testing/internals/DataManager_unittest.cc b/tests/unit/statistical_testing/internals/DataManager_unittest.cc index da752a31510..0cf7b1af4e9 100644 --- a/tests/unit/statistical_testing/internals/DataManager_unittest.cc +++ b/tests/unit/statistical_testing/internals/DataManager_unittest.cc @@ -521,6 +521,7 @@ TEST(DataManager, train_data_two_distributions_normal_feats) next_burst=mgr.next(); ASSERT_TRUE(next_burst.empty()); + mgr.end(); // test data mgr.set_train_mode(false); @@ -540,6 +541,7 @@ TEST(DataManager, train_data_two_distributions_normal_feats) next_burst=mgr.next(); ASSERT_TRUE(next_burst.empty()); + mgr.end(); // full data mgr.set_train_test_ratio(0); @@ -560,5 +562,115 @@ TEST(DataManager, train_data_two_distributions_normal_feats) next_burst=mgr.next(); ASSERT_TRUE(next_burst.empty()); + mgr.end(); } +TEST(DataManager, train_data_two_distributions_normal_feats_blockwise) +{ + const index_t dim=3; + const index_t num_vec=8; + const index_t blocksize=2; + const index_t num_blocks_per_burst=2; + const index_t num_distributions=2; + const index_t train_test_ratio=3; + + SGMatrix data_p(dim, num_vec); + std::iota(data_p.matrix, data_p.matrix+dim*num_vec, 0); + + SGMatrix data_q(dim, num_vec); + std::iota(data_q.matrix, data_q.matrix+dim*num_vec, dim*num_vec); + + using feat_type=CDenseFeatures; + auto feats_p=new feat_type(data_p); + auto feats_q=new feat_type(data_q); + + DataManager mgr(num_distributions); + mgr.samples_at(0)=feats_p; + mgr.samples_at(1)=feats_q; + mgr.set_blocksize(blocksize); + mgr.set_num_blocks_per_burst(num_blocks_per_burst); + + mgr.set_train_test_ratio(train_test_ratio); + + // train data + mgr.set_train_mode(true); + mgr.start(); + + auto next_burst=mgr.next(); + ASSERT_TRUE(!next_burst.empty()); + + auto total=0; + + while (!next_burst.empty()) + { + ASSERT_TRUE(next_burst.num_blocks()==num_blocks_per_burst); + for (auto i=0; i(next_burst[0][i].get()); + auto tmp_q=dynamic_cast(next_burst[1][i].get()); + ASSERT_TRUE(tmp_p!=nullptr); + ASSERT_TRUE(tmp_q!=nullptr); + ASSERT_TRUE(tmp_p->get_num_vectors()==blocksize/2); + ASSERT_TRUE(tmp_q->get_num_vectors()==blocksize/2); + total+=tmp_p->get_num_vectors(); + } + next_burst=mgr.next(); + } + ASSERT_TRUE(total==num_vec*train_test_ratio/(train_test_ratio+1)); + mgr.end(); + + // test data + mgr.set_train_mode(false); + mgr.start(); + + next_burst=mgr.next(); + ASSERT_TRUE(!next_burst.empty()); + + total=0; + + while (!next_burst.empty()) + { + ASSERT_TRUE(next_burst.num_blocks()==num_blocks_per_burst); + for (auto i=0; i(next_burst[0][i].get()); + auto tmp_q=dynamic_cast(next_burst[1][i].get()); + ASSERT_TRUE(tmp_p!=nullptr); + ASSERT_TRUE(tmp_q!=nullptr); + ASSERT_TRUE(tmp_p->get_num_vectors()==blocksize/2); + ASSERT_TRUE(tmp_q->get_num_vectors()==blocksize/2); + total+=tmp_p->get_num_vectors(); + } + next_burst=mgr.next(); + } + ASSERT_TRUE(total==num_vec/(train_test_ratio+1)); + mgr.end(); + + // full data + mgr.set_train_test_ratio(0); + mgr.set_train_mode(false); + mgr.start(); + + next_burst=mgr.next(); + ASSERT_TRUE(!next_burst.empty()); + + total=0; + + while (!next_burst.empty()) + { + ASSERT_TRUE(next_burst.num_blocks()==num_blocks_per_burst); + for (auto i=0; i(next_burst[0][i].get()); + auto tmp_q=dynamic_cast(next_burst[1][i].get()); + ASSERT_TRUE(tmp_p!=nullptr); + ASSERT_TRUE(tmp_q!=nullptr); + ASSERT_TRUE(tmp_p->get_num_vectors()==blocksize/2); + ASSERT_TRUE(tmp_q->get_num_vectors()==blocksize/2); + total+=tmp_p->get_num_vectors(); + } + next_burst=mgr.next(); + } + ASSERT_TRUE(total==num_vec); + mgr.end(); +}