diff --git a/src/shogun/statistical_testing/internals/DataFetcher.cpp b/src/shogun/statistical_testing/internals/DataFetcher.cpp index b60c14555bd..f870ab8bd30 100644 --- a/src/shogun/statistical_testing/internals/DataFetcher.cpp +++ b/src/shogun/statistical_testing/internals/DataFetcher.cpp @@ -50,6 +50,7 @@ const char* DataFetcher::get_name() const void DataFetcher::set_train_test_ratio(float64_t train_test_ratio) { m_num_samples=m_train_test_details.get_total_num_samples(); + REQUIRE(m_num_samples>0, "Number of samples is not set!\n"); index_t num_training_samples=m_num_samples*train_test_ratio/(train_test_ratio+1); m_train_test_details.set_num_training_samples(num_training_samples); SG_SINFO("Must set the train/test mode by calling set_train_mode(True/False)!\n"); @@ -68,10 +69,12 @@ void DataFetcher::set_train_mode(bool train_mode) m_num_samples=m_train_test_details.get_num_training_samples(); if (m_num_samples==0) SG_SERROR("The number of training samples is 0! Please set a valid train-test ratio\n"); + SG_SINFO("Using %d number of samples for training!\n", m_num_samples); } else { m_num_samples=m_train_test_details.get_num_test_samples(); + SG_SINFO("Using %d number of samples for testing!\n", m_num_samples); start_index=m_train_test_details.get_num_training_samples(); if (start_index==0) { diff --git a/src/shogun/statistical_testing/internals/DataFetcher.h b/src/shogun/statistical_testing/internals/DataFetcher.h index 40c675552b5..4e8c9849d80 100644 --- a/src/shogun/statistical_testing/internals/DataFetcher.h +++ b/src/shogun/statistical_testing/internals/DataFetcher.h @@ -52,9 +52,9 @@ class DataFetcher public: DataFetcher(CFeatures* samples); virtual ~DataFetcher(); - void set_train_test_ratio(float64_t train_test_ratio); + virtual void set_train_test_ratio(float64_t train_test_ratio); float64_t get_train_test_ratio() const; - void set_train_mode(bool train_mode); + virtual void set_train_mode(bool train_mode); void set_xvalidation_mode(bool xvalidation_mode); index_t get_num_folds() const; void use_fold(index_t idx); diff --git a/src/shogun/statistical_testing/internals/StreamingDataFetcher.cpp b/src/shogun/statistical_testing/internals/StreamingDataFetcher.cpp index a811120cb91..37457b94c8a 100644 --- a/src/shogun/statistical_testing/internals/StreamingDataFetcher.cpp +++ b/src/shogun/statistical_testing/internals/StreamingDataFetcher.cpp @@ -48,14 +48,38 @@ const char* StreamingDataFetcher::get_name() const void StreamingDataFetcher::set_num_samples(index_t num_samples) { m_num_samples=num_samples; + m_train_test_details.set_total_num_samples(m_num_samples); +} + +void StreamingDataFetcher::set_train_test_ratio(float64_t train_test_ratio) +{ + if (m_train_test_details.get_total_num_samples()==0) + m_train_test_details.set_total_num_samples(m_num_samples); + DataFetcher::set_train_test_ratio(train_test_ratio); +} + +void StreamingDataFetcher::set_train_mode(bool train_mode) +{ + if (train_mode) + { + m_num_samples=m_train_test_details.get_num_training_samples(); + if (m_num_samples==0) + SG_SERROR("The number of training samples is 0! Please set a valid train-test ratio\n"); + SG_SINFO("Using %d number of samples for training!\n", m_num_samples); + } + else + { + m_num_samples=m_train_test_details.get_num_test_samples(); + SG_SINFO("Using %d number of samples for testing!\n", m_num_samples); + } } void StreamingDataFetcher::start() { REQUIRE(m_num_samples>0, "Number of samples is not set! It is MANDATORY for streaming features!\n"); - if (m_block_details.m_blocksize==0) + if (m_block_details.m_full_data || m_block_details.m_blocksize>m_num_samples) { - SG_SINFO("Block details not set! Fetching entire data (%d samples)!\n", m_num_samples); + SG_SINFO("Fetching entire data (%d samples)!\n", m_num_samples); m_block_details.with_blocksize(m_num_samples); } m_block_details.m_total_num_blocks=m_num_samples/m_block_details.m_blocksize; @@ -64,7 +88,6 @@ void StreamingDataFetcher::start() { m_samples->start_parser(); parser_running=true; - // TODO check if resetting the stream is required } } diff --git a/src/shogun/statistical_testing/internals/StreamingDataFetcher.h b/src/shogun/statistical_testing/internals/StreamingDataFetcher.h index b5fcf7fdfce..91941b0f2ef 100644 --- a/src/shogun/statistical_testing/internals/StreamingDataFetcher.h +++ b/src/shogun/statistical_testing/internals/StreamingDataFetcher.h @@ -39,6 +39,8 @@ class StreamingDataFetcher : public DataFetcher public: StreamingDataFetcher(CStreamingFeatures* samples); virtual ~StreamingDataFetcher() override; + virtual void set_train_test_ratio(float64_t train_test_ratio) override; + virtual void set_train_mode(bool train_mode) override; virtual void start() override; virtual CFeatures* next() override; virtual void reset() override; diff --git a/tests/unit/statistical_testing/internals/DataManager_unittest.cc b/tests/unit/statistical_testing/internals/DataManager_unittest.cc index 0cf7b1af4e9..31920f531f0 100644 --- a/tests/unit/statistical_testing/internals/DataManager_unittest.cc +++ b/tests/unit/statistical_testing/internals/DataManager_unittest.cc @@ -674,3 +674,204 @@ TEST(DataManager, train_data_two_distributions_normal_feats_blockwise) ASSERT_TRUE(total==num_vec); mgr.end(); } + +TEST(DataManager, train_data_two_distributions_streaming_feats) +{ + const index_t dim=3; + const index_t num_vec=8; + const index_t num_distributions=2; + const index_t train_test_ratio=3; + + SGMatrix data_p(dim, num_vec); + std::iota(data_p.matrix, data_p.matrix+dim*num_vec, 0); + + SGMatrix data_q(dim, num_vec); + std::iota(data_q.matrix, data_q.matrix+dim*num_vec, dim*num_vec); + + using feat_type=CDenseFeatures; + auto feats_p=new feat_type(data_p); + auto feats_q=new feat_type(data_q); + + DataManager mgr(num_distributions); + mgr.samples_at(0)=new CStreamingDenseFeatures(feats_p); + mgr.samples_at(1)=new CStreamingDenseFeatures(feats_q); + mgr.num_samples_at(0)=num_vec; + mgr.num_samples_at(1)=num_vec; + + // training data + mgr.set_train_test_ratio(train_test_ratio); + mgr.set_train_mode(true); + mgr.start(); + + auto next_burst=mgr.next(); + ASSERT_TRUE(!next_burst.empty()); + ASSERT_TRUE(next_burst.num_blocks()==1); + + auto tmp_p=dynamic_cast(next_burst[0][0].get()); + auto tmp_q=dynamic_cast(next_burst[1][0].get()); + + ASSERT_TRUE(tmp_p!=nullptr); + ASSERT_TRUE(tmp_q!=nullptr); + ASSERT_TRUE(tmp_p->get_num_vectors()==num_vec*train_test_ratio/(train_test_ratio+1)); + ASSERT_TRUE(tmp_q->get_num_vectors()==num_vec*train_test_ratio/(train_test_ratio+1)); + + next_burst=mgr.next(); + ASSERT_TRUE(next_burst.empty()); + mgr.end(); + + // test data + mgr.set_train_mode(false); + mgr.start(); + + next_burst=mgr.next(); + ASSERT_TRUE(!next_burst.empty()); + ASSERT_TRUE(next_burst.num_blocks()==1); + + tmp_p=dynamic_cast(next_burst[0][0].get()); + tmp_q=dynamic_cast(next_burst[1][0].get()); + + ASSERT_TRUE(tmp_p!=nullptr); + ASSERT_TRUE(tmp_q!=nullptr); + ASSERT_TRUE(tmp_p->get_num_vectors()==num_vec/(train_test_ratio+1)); + ASSERT_TRUE(tmp_q->get_num_vectors()==num_vec/(train_test_ratio+1)); + + next_burst=mgr.next(); + ASSERT_TRUE(next_burst.empty()); + mgr.end(); + + // full data + mgr.set_train_test_ratio(0); + mgr.set_train_mode(false); + mgr.reset(); + mgr.start(); + + next_burst=mgr.next(); + ASSERT_TRUE(!next_burst.empty()); + ASSERT_TRUE(next_burst.num_blocks()==1); + + tmp_p=dynamic_cast(next_burst[0][0].get()); + tmp_q=dynamic_cast(next_burst[1][0].get()); + + ASSERT_TRUE(tmp_p!=nullptr); + ASSERT_TRUE(tmp_q!=nullptr); + ASSERT_TRUE(tmp_p->get_num_vectors()==num_vec); + ASSERT_TRUE(tmp_q->get_num_vectors()==num_vec); + + next_burst=mgr.next(); + ASSERT_TRUE(next_burst.empty()); + mgr.end(); +} + +TEST(DataManager, train_data_two_distributions_streaming_feats_blockwise) +{ + const index_t dim=3; + const index_t num_vec=8; + const index_t blocksize=2; + const index_t num_blocks_per_burst=2; + const index_t num_distributions=2; + const index_t train_test_ratio=3; + + SGMatrix data_p(dim, num_vec); + std::iota(data_p.matrix, data_p.matrix+dim*num_vec, 0); + + SGMatrix data_q(dim, num_vec); + std::iota(data_q.matrix, data_q.matrix+dim*num_vec, dim*num_vec); + + using feat_type=CDenseFeatures; + auto feats_p=new feat_type(data_p); + auto feats_q=new feat_type(data_q); + + DataManager mgr(num_distributions); + mgr.samples_at(0)=new CStreamingDenseFeatures(feats_p); + mgr.samples_at(1)=new CStreamingDenseFeatures(feats_q); + mgr.num_samples_at(0)=num_vec; + mgr.num_samples_at(1)=num_vec; + + mgr.set_blocksize(blocksize); + mgr.set_num_blocks_per_burst(num_blocks_per_burst); + + mgr.set_train_test_ratio(train_test_ratio); + + // train data + mgr.set_train_mode(true); + mgr.start(); + + auto next_burst=mgr.next(); + ASSERT_TRUE(!next_burst.empty()); + + auto total=0; + + while (!next_burst.empty()) + { + ASSERT_TRUE(next_burst.num_blocks()==num_blocks_per_burst); + for (auto i=0; i(next_burst[0][i].get()); + auto tmp_q=dynamic_cast(next_burst[1][i].get()); + ASSERT_TRUE(tmp_p!=nullptr); + ASSERT_TRUE(tmp_q!=nullptr); + ASSERT_TRUE(tmp_p->get_num_vectors()==blocksize/2); + ASSERT_TRUE(tmp_q->get_num_vectors()==blocksize/2); + total+=tmp_p->get_num_vectors(); + } + next_burst=mgr.next(); + } + ASSERT_TRUE(total==num_vec*train_test_ratio/(train_test_ratio+1)); + mgr.end(); + + // test data + mgr.set_train_mode(false); + mgr.start(); + + next_burst=mgr.next(); + ASSERT_TRUE(!next_burst.empty()); + + total=0; + + while (!next_burst.empty()) + { + ASSERT_TRUE(next_burst.num_blocks()==num_blocks_per_burst); + for (auto i=0; i(next_burst[0][i].get()); + auto tmp_q=dynamic_cast(next_burst[1][i].get()); + ASSERT_TRUE(tmp_p!=nullptr); + ASSERT_TRUE(tmp_q!=nullptr); + ASSERT_TRUE(tmp_p->get_num_vectors()==blocksize/2); + ASSERT_TRUE(tmp_q->get_num_vectors()==blocksize/2); + total+=tmp_p->get_num_vectors(); + } + next_burst=mgr.next(); + } + ASSERT_TRUE(total==num_vec/(train_test_ratio+1)); + mgr.end(); + + // full data + mgr.set_train_test_ratio(0); + mgr.set_train_mode(false); + mgr.reset(); + mgr.start(); + + next_burst=mgr.next(); + ASSERT_TRUE(!next_burst.empty()); + + total=0; + + while (!next_burst.empty()) + { + ASSERT_TRUE(next_burst.num_blocks()==num_blocks_per_burst); + for (auto i=0; i(next_burst[0][i].get()); + auto tmp_q=dynamic_cast(next_burst[1][i].get()); + ASSERT_TRUE(tmp_p!=nullptr); + ASSERT_TRUE(tmp_q!=nullptr); + ASSERT_TRUE(tmp_p->get_num_vectors()==blocksize/2); + ASSERT_TRUE(tmp_q->get_num_vectors()==blocksize/2); + total+=tmp_p->get_num_vectors(); + } + next_burst=mgr.next(); + } + ASSERT_TRUE(total==num_vec); + mgr.end(); +}