From a1bfed952f8e8e235258262fc2253654c9fe55f9 Mon Sep 17 00:00:00 2001 From: lambday Date: Thu, 12 May 2016 19:45:25 +0100 Subject: [PATCH] added first draft of train-test data split --- .../internals/DataFetcher.cpp | 79 +++++++++++++++++- .../internals/DataFetcher.h | 10 +++ .../internals/DataManager.cpp | 37 ++++++++- .../internals/TrainTestDetails.cpp | 78 ++++++++++++++++++ .../internals/TrainTestDetails.h | 72 ++++++++++++++++ .../internals/DataManager_unittest.cc | 82 +++++++++++++++++++ 6 files changed, 352 insertions(+), 6 deletions(-) create mode 100644 src/shogun/statistical_testing/internals/TrainTestDetails.cpp create mode 100644 src/shogun/statistical_testing/internals/TrainTestDetails.h diff --git a/src/shogun/statistical_testing/internals/DataFetcher.cpp b/src/shogun/statistical_testing/internals/DataFetcher.cpp index b5492766c51..a21caca8124 100644 --- a/src/shogun/statistical_testing/internals/DataFetcher.cpp +++ b/src/shogun/statistical_testing/internals/DataFetcher.cpp @@ -24,15 +24,16 @@ using namespace shogun; using namespace internal; -DataFetcher::DataFetcher() : m_num_samples(0), m_samples(nullptr) +DataFetcher::DataFetcher() : m_num_samples(0), m_samples(nullptr), train_test_subset_used(false) { } -DataFetcher::DataFetcher(CFeatures* samples) : m_samples(samples) +DataFetcher::DataFetcher(CFeatures* samples) : m_samples(samples), train_test_subset_used(false) { REQUIRE(m_samples!=nullptr, "Samples cannot be null!\n"); SG_REF(m_samples); m_num_samples=m_samples->get_num_vectors(); + m_train_test_details.set_total_num_samples(m_num_samples); } DataFetcher::~DataFetcher() @@ -46,12 +47,82 @@ const char* DataFetcher::get_name() const return "DataFetcher"; } +void DataFetcher::set_train_test_ratio(float64_t train_test_ratio) +{ + m_num_samples=m_train_test_details.get_total_num_samples(); + index_t num_training_samples=m_num_samples*train_test_ratio/(train_test_ratio+1); + m_train_test_details.set_num_training_samples(num_training_samples); +} + +float64_t DataFetcher::get_train_test_ratio() const +{ + return float64_t(m_train_test_details.get_num_training_samples())/m_train_test_details.get_num_test_samples(); +} + +void DataFetcher::set_train_mode(bool train_mode) +{ + index_t start_index=0; + if (train_mode) + { + m_num_samples=m_train_test_details.get_num_training_samples(); + if (m_num_samples==0) + SG_SERROR("The number of training samples is 0! Please set a valid train-test ratio\n"); + } + else + { + m_num_samples=m_train_test_details.get_num_test_samples(); + start_index=m_train_test_details.get_num_training_samples(); + if (start_index==0) + { + if (train_test_subset_used) + { + m_samples->remove_subset(); + train_test_subset_used=false; + } + return; + } + } + SGVector inds(m_num_samples); + std::iota(inds.data(), inds.data()+inds.size(), start_index); + if (train_test_subset_used) + { + m_samples->remove_subset(); + } + m_samples->add_subset(inds); + train_test_subset_used=true; +} + +void DataFetcher::set_xvalidation_mode(bool xvalidation_mode) +{ +// using fetcher_type=std::unique_ptr; +// std::for_each(fetchers.begin(), fetchers.end(), [&train_mode](fetcher_type& f) +// { +// f->set_xvalidation_mode(xvalidation_mode); +// }); +} + +index_t DataFetcher::get_num_folds() const +{ +// REQUIRE(fetchers[0]!=nullptr, "Please set the samples first!\n"); +// return fetchers[0]->get_train_test_ratio(); + return 0; +} + +void DataFetcher::use_fold(index_t idx) +{ +// using fetcher_type=std::unique_ptr; +// std::for_each(fetchers.begin(), fetchers.end(), [&train_mode](fetcher_type& f) +// { +// f->use_fold(idx); +// }); +} + void DataFetcher::start() { REQUIRE(m_num_samples>0, "Number of samples is 0!\n"); - if (m_block_details.m_blocksize==0) + if (m_block_details.m_blocksize==0 || m_block_details.m_blocksize>m_num_samples) { - SG_SINFO("Block details not set! Fetching entire data (%d samples)!\n", m_num_samples); + SG_SINFO("Block details invalid! Fetching entire data (%d samples)!\n", m_num_samples); m_block_details.with_blocksize(m_num_samples); } m_block_details.m_total_num_blocks=m_num_samples/m_block_details.m_blocksize; diff --git a/src/shogun/statistical_testing/internals/DataFetcher.h b/src/shogun/statistical_testing/internals/DataFetcher.h index a225d868761..40c675552b5 100644 --- a/src/shogun/statistical_testing/internals/DataFetcher.h +++ b/src/shogun/statistical_testing/internals/DataFetcher.h @@ -30,6 +30,7 @@ #include #include +#include #ifndef DATA_FETCHER_H__ #define DATA_FETCHER_H__ @@ -51,6 +52,13 @@ class DataFetcher public: DataFetcher(CFeatures* samples); virtual ~DataFetcher(); + void set_train_test_ratio(float64_t train_test_ratio); + float64_t get_train_test_ratio() const; + void set_train_mode(bool train_mode); + void set_xvalidation_mode(bool xvalidation_mode); + index_t get_num_folds() const; + void use_fold(index_t idx); + virtual void start(); virtual CFeatures* next(); virtual void reset(); @@ -61,9 +69,11 @@ class DataFetcher protected: DataFetcher(); BlockwiseDetails m_block_details; + TrainTestDetails m_train_test_details; index_t m_num_samples; private: CFeatures* m_samples; + bool train_test_subset_used; }; } diff --git a/src/shogun/statistical_testing/internals/DataManager.cpp b/src/shogun/statistical_testing/internals/DataManager.cpp index d11c2ced185..007802c34a8 100644 --- a/src/shogun/statistical_testing/internals/DataManager.cpp +++ b/src/shogun/statistical_testing/internals/DataManager.cpp @@ -41,6 +41,9 @@ using namespace shogun; using namespace internal; +// TODO add nullptr check before calling the methods on actual fetchers +// this would be where someone calls the other methofds before setiing the sameples + DataManager::DataManager(size_t num_distributions) { SG_SDEBUG("Data manager instance initialized with %d data sources!\n", num_distributions); @@ -195,28 +198,58 @@ const index_t DataManager::blocksize_at(size_t i) const void DataManager::set_train_test_ratio(float64_t train_test_ratio) { + SG_SDEBUG("Entering!\n"); + using fetcher_type=std::unique_ptr; + std::for_each(fetchers.begin(), fetchers.end(), [&train_test_ratio](fetcher_type& f) + { + f->set_train_test_ratio(train_test_ratio); + }); + SG_SDEBUG("Leaving!\n"); } float64_t DataManager::get_train_test_ratio() const { - return 0; + REQUIRE(fetchers[0]!=nullptr, "Please set the samples first!\n"); + return fetchers[0]->get_train_test_ratio(); } void DataManager::set_train_mode(bool train_mode) { + SG_SDEBUG("Entering!\n"); + using fetcher_type=std::unique_ptr; + std::for_each(fetchers.begin(), fetchers.end(), [&train_mode](fetcher_type& f) + { + f->set_train_mode(train_mode); + }); + SG_SDEBUG("Leaving!\n"); } void DataManager::set_xvalidation_mode(bool xvalidation_mode) { + SG_SDEBUG("Entering!\n"); + using fetcher_type=std::unique_ptr; + std::for_each(fetchers.begin(), fetchers.end(), [&xvalidation_mode](fetcher_type& f) + { + f->set_xvalidation_mode(xvalidation_mode); + }); + SG_SDEBUG("Leaving!\n"); } index_t DataManager::get_num_folds() const { - return 0; + REQUIRE(fetchers[0]!=nullptr, "Please set the samples first!\n"); + return fetchers[0]->get_train_test_ratio(); } void DataManager::use_fold(index_t idx) { + SG_SDEBUG("Entering!\n"); + using fetcher_type=std::unique_ptr; + std::for_each(fetchers.begin(), fetchers.end(), [&idx](fetcher_type& f) + { + f->use_fold(idx); + }); + SG_SDEBUG("Leaving!\n"); } void DataManager::start() diff --git a/src/shogun/statistical_testing/internals/TrainTestDetails.cpp b/src/shogun/statistical_testing/internals/TrainTestDetails.cpp new file mode 100644 index 00000000000..9df301f0786 --- /dev/null +++ b/src/shogun/statistical_testing/internals/TrainTestDetails.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (w) 2016 Soumyajit De + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. + */ + +#include +#include + +using namespace shogun; +using namespace internal; + +TrainTestDetails::TrainTestDetails() : m_total_num_samples(0), m_num_training_samples(0) +{ +} + +void TrainTestDetails::set_total_num_samples(index_t total_num_samples) +{ + m_total_num_samples=total_num_samples; +} + +index_t TrainTestDetails::get_total_num_samples() const +{ + return m_total_num_samples; +} + +void TrainTestDetails::set_num_training_samples(index_t num_training_samples) +{ + REQUIRE(m_total_num_samples>=num_training_samples, + "Number of training samples cannot be greater than the total number of samples!\n"); + m_num_training_samples=num_training_samples; +} + +index_t TrainTestDetails::get_num_training_samples() const +{ + return m_num_training_samples; +} + +index_t TrainTestDetails::get_num_test_samples() const +{ + return m_total_num_samples-m_num_training_samples; +} +// +//bool TrainTestDetails::is_training_mode() const +//{ +//} +// +//void TrainTestDetails::set_train_mode(bool train_mode) +//{ +//} +// +//void TrainTestDetails::set_xvalidation_mode(bool xvalidation_mode) +//{ +//} diff --git a/src/shogun/statistical_testing/internals/TrainTestDetails.h b/src/shogun/statistical_testing/internals/TrainTestDetails.h new file mode 100644 index 00000000000..68acd1c6093 --- /dev/null +++ b/src/shogun/statistical_testing/internals/TrainTestDetails.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (w) 2016 Soumyajit De + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. + */ + +#include + +#ifndef TRAIN_TEST_DETAILS_H__ +#define TRAIN_TEST_DETAILS_H__ + +namespace shogun +{ + +namespace internal +{ + +/** + * @brief Class that holds train-test details for the data-fetchers. + * There are one instance of this class per fetcher. + */ +class TrainTestDetails +{ + friend class DataFetcher; + friend class StreamingDataFetcher; + +public: + TrainTestDetails(); + + void set_total_num_samples(index_t total_num_sampels); + index_t get_total_num_samples() const; + + void set_num_training_samples(index_t num_training_samples); + index_t get_num_training_samples() const; + index_t get_num_test_samples() const; + +// bool is_training_mode() const; +// void set_train_mode(bool train_mode); +// void set_xvalidation_mode(bool xvalidation_mode); +private: + index_t m_total_num_samples; + index_t m_num_training_samples; +}; + +} + +} +#endif // TRAIN_TEST_DETAILS_H__ diff --git a/tests/unit/statistical_testing/internals/DataManager_unittest.cc b/tests/unit/statistical_testing/internals/DataManager_unittest.cc index 5e54a167993..da752a31510 100644 --- a/tests/unit/statistical_testing/internals/DataManager_unittest.cc +++ b/tests/unit/statistical_testing/internals/DataManager_unittest.cc @@ -480,3 +480,85 @@ TEST(DataManager, block_data_two_distributions_streaming_feats_different_blocksi ASSERT_TRUE(total_p==num_vec_p); ASSERT_TRUE(total_q==num_vec_q); } + +TEST(DataManager, train_data_two_distributions_normal_feats) +{ + const index_t dim=3; + const index_t num_vec=8; + const index_t num_distributions=2; + const index_t train_test_ratio=3; + + SGMatrix data_p(dim, num_vec); + std::iota(data_p.matrix, data_p.matrix+dim*num_vec, 0); + + SGMatrix data_q(dim, num_vec); + std::iota(data_q.matrix, data_q.matrix+dim*num_vec, dim*num_vec); + + using feat_type=CDenseFeatures; + auto feats_p=new feat_type(data_p); + auto feats_q=new feat_type(data_q); + + DataManager mgr(num_distributions); + mgr.samples_at(0)=feats_p; + mgr.samples_at(1)=feats_q; + + // training data + mgr.set_train_test_ratio(train_test_ratio); + mgr.set_train_mode(true); + mgr.start(); + + auto next_burst=mgr.next(); + ASSERT_TRUE(!next_burst.empty()); + ASSERT_TRUE(next_burst.num_blocks()==1); + + auto tmp_p=dynamic_cast(next_burst[0][0].get()); + auto tmp_q=dynamic_cast(next_burst[1][0].get()); + + ASSERT_TRUE(tmp_p!=nullptr); + ASSERT_TRUE(tmp_q!=nullptr); + ASSERT_TRUE(tmp_p->get_num_vectors()==num_vec*train_test_ratio/(train_test_ratio+1)); + ASSERT_TRUE(tmp_q->get_num_vectors()==num_vec*train_test_ratio/(train_test_ratio+1)); + + next_burst=mgr.next(); + ASSERT_TRUE(next_burst.empty()); + + // test data + mgr.set_train_mode(false); + mgr.start(); + + next_burst=mgr.next(); + ASSERT_TRUE(!next_burst.empty()); + ASSERT_TRUE(next_burst.num_blocks()==1); + + tmp_p=dynamic_cast(next_burst[0][0].get()); + tmp_q=dynamic_cast(next_burst[1][0].get()); + + ASSERT_TRUE(tmp_p!=nullptr); + ASSERT_TRUE(tmp_q!=nullptr); + ASSERT_TRUE(tmp_p->get_num_vectors()==num_vec/(train_test_ratio+1)); + ASSERT_TRUE(tmp_q->get_num_vectors()==num_vec/(train_test_ratio+1)); + + next_burst=mgr.next(); + ASSERT_TRUE(next_burst.empty()); + + // full data + mgr.set_train_test_ratio(0); + mgr.set_train_mode(false); + mgr.start(); + + next_burst=mgr.next(); + ASSERT_TRUE(!next_burst.empty()); + ASSERT_TRUE(next_burst.num_blocks()==1); + + tmp_p=dynamic_cast(next_burst[0][0].get()); + tmp_q=dynamic_cast(next_burst[1][0].get()); + + ASSERT_TRUE(tmp_p!=nullptr); + ASSERT_TRUE(tmp_q!=nullptr); + ASSERT_TRUE(tmp_p->get_num_vectors()==num_vec); + ASSERT_TRUE(tmp_q->get_num_vectors()==num_vec); + + next_burst=mgr.next(); + ASSERT_TRUE(next_burst.empty()); +} +