From 8a325c7a71d98bdf1a6b3fb73280b16ddb314651 Mon Sep 17 00:00:00 2001 From: lambday Date: Tue, 31 May 2016 12:15:16 +0100 Subject: [PATCH] refactor API (incomplete) --- .../statistical_testing/HypothesisTest.cpp | 104 +++++++++++------- .../statistical_testing/HypothesisTest.h | 60 ++++++---- .../statistical_testing/TwoDistributionTest.h | 10 +- .../internals/DataManager.cpp | 11 ++ .../internals/DataManager.h | 21 +++- .../internals/MaxXValidation.cpp | 3 +- 6 files changed, 137 insertions(+), 72 deletions(-) diff --git a/src/shogun/statistical_testing/HypothesisTest.cpp b/src/shogun/statistical_testing/HypothesisTest.cpp index afadb52450d..f813f2d6f45 100644 --- a/src/shogun/statistical_testing/HypothesisTest.cpp +++ b/src/shogun/statistical_testing/HypothesisTest.cpp @@ -1,89 +1,111 @@ /* - * Restructuring Shogun's statistical hypothesis testing framework. - * Copyright (C) 2016 Soumyajit De + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (w) 2012 - 2013 Heiko Strathmann + * Written (w) 2014 - 2016 Soumyajit De + * All rights reserved. * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the selfied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. */ #include #include #include #include -#include #include -#include using namespace shogun; using namespace internal; struct CHypothesisTest::Self { - Self(index_t num_distributions, index_t num_kernels) - : data_manager(num_distributions), kernel_manager(num_kernels) - { - } - DataManager data_manager; - KernelManager kernel_manager; + explicit Self(index_t num_distributions); + DataManager data_mgr; }; -CHypothesisTest::CHypothesisTest(index_t num_distributions, index_t num_kernels) : CSGObject() -{ - self = std::unique_ptr(new CHypothesisTest::Self(num_distributions, num_kernels)); -} - -CHypothesisTest::~CHypothesisTest() +CHypothesisTest::Self::Self(index_t num_distributions) : data_mgr(num_distributions) { } -DataManager& CHypothesisTest::get_data_manager() +CHypothesisTest::CHypothesisTest(index_t num_distributions) : CSGObject() { - return self->data_manager; + self=std::unique_ptr(new CHypothesisTest::Self(num_distributions)); } -const DataManager& CHypothesisTest::get_data_manager() const +CHypothesisTest::~CHypothesisTest() { - return self->data_manager; } -KernelManager& CHypothesisTest::get_kernel_manager() +void CHypothesisTest::set_train_test_mode(bool on) { - return self->kernel_manager; + self->data_mgr.set_train_test_mode(on); } -const KernelManager& CHypothesisTest::get_kernel_manager() const +void CHypothesisTest::set_train_test_ratio(float64_t ratio) { - return self->kernel_manager; + self->data_mgr.train_test_ratio(ratio); } float64_t CHypothesisTest::compute_p_value(float64_t statistic) { - SGVector values = sample_null(); - + SGVector values=sample_null(); std::sort(values.vector, values.vector + values.vlen); - float64_t i = values.find_position_to_insert(statistic); - - return 1.0 - i / values.vlen; + float64_t i=values.find_position_to_insert(statistic); + return 1.0-i/values.vlen; } float64_t CHypothesisTest::compute_threshold(float64_t alpha) { - SGVector values = sample_null(); + SGVector values=sample_null(); std::sort(values.vector, values.vector + values.vlen); - return values[index_t(CMath::floor(values.vlen * (1 - alpha)))]; + return values[index_t(CMath::floor(values.vlen*(1-alpha)))]; +} + +bool CHypothesisTest::perform_test(float64_t alpha) +{ + auto statistic=compute_statistic(); + auto p_value=compute_p_value(statistic); + return p_valuedata_mgr; +} + +const DataManager& CHypothesisTest::get_data_mgr() const +{ + return self->data_mgr; +} diff --git a/src/shogun/statistical_testing/HypothesisTest.h b/src/shogun/statistical_testing/HypothesisTest.h index 3a7408d34f6..c278edfbb8c 100644 --- a/src/shogun/statistical_testing/HypothesisTest.h +++ b/src/shogun/statistical_testing/HypothesisTest.h @@ -1,19 +1,32 @@ /* - * Restructuring Shogun's statistical hypothesis testing framework. - * Copyright (C) 2016 Soumyajit De + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (w) 2012 - 2013 Heiko Strathmann + * Written (w) 2014 - 2016 Soumyajit De + * All rights reserved. * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. */ #ifndef HYPOTHESIS_TEST_H_ @@ -32,33 +45,36 @@ namespace internal { class DataManager; -class KernelManager; } class CHypothesisTest : public CSGObject { public: - CHypothesisTest(index_t num_distributions, index_t num_kernels); + explicit CHypothesisTest(index_t num_distributions); virtual ~CHypothesisTest(); - virtual float64_t compute_statistic() = 0; + CHypothesisTest(const CHypothesisTest& other)=delete; + CHypothesisTest& operator=(const CHypothesisTest& other)=delete; + + void set_train_test_mode(bool on); + void set_train_test_ratio(float64_t ratio); virtual float64_t compute_p_value(float64_t statistic); virtual float64_t compute_threshold(float64_t alpha); + virtual bool perform_test(float64_t alpha); - virtual SGVector sample_null() = 0; + virtual float64_t compute_statistic()=0; + virtual SGVector sample_null()=0; virtual const char* get_name() const; + virtual CSGObject* clone(); +protected: + internal::DataManager& get_data_mgr(); + const internal::DataManager& get_data_mgr() const; private: struct Self; std::unique_ptr self; -protected: - internal::DataManager& get_data_manager(); - const internal::DataManager& get_data_manager() const; - - internal::KernelManager& get_kernel_manager(); - const internal::KernelManager& get_kernel_manager() const; }; } diff --git a/src/shogun/statistical_testing/TwoDistributionTest.h b/src/shogun/statistical_testing/TwoDistributionTest.h index 33e96b188db..f343d32ae45 100644 --- a/src/shogun/statistical_testing/TwoDistributionTest.h +++ b/src/shogun/statistical_testing/TwoDistributionTest.h @@ -32,9 +32,6 @@ class CTwoDistributionTest : public CHypothesisTest CTwoDistributionTest(index_t num_kernels); virtual ~CTwoDistributionTest(); - void set_p(CFeatures* samples_from_p); - void set_q(CFeatures* samples_from_q); - CFeatures* get_p() const; CFeatures* get_q() const; @@ -44,10 +41,13 @@ class CTwoDistributionTest : public CHypothesisTest const index_t get_num_samples_p() const; const index_t get_num_samples_q() const; - virtual float64_t compute_statistic() = 0; - virtual SGVector sample_null() = 0; + virtual float64_t compute_statistic()=0; + virtual SGVector sample_null()=0; virtual const char* get_name() const; +protected: + void set_p(CFeatures* samples_from_p); + void set_q(CFeatures* samples_from_q); }; } diff --git a/src/shogun/statistical_testing/internals/DataManager.cpp b/src/shogun/statistical_testing/internals/DataManager.cpp index 6ba4b9629fa..b5a9acce2d8 100644 --- a/src/shogun/statistical_testing/internals/DataManager.cpp +++ b/src/shogun/statistical_testing/internals/DataManager.cpp @@ -49,6 +49,7 @@ DataManager::DataManager(size_t num_distributions) SG_SDEBUG("Data manager instance initialized with %d data sources!\n", num_distributions); fetchers.resize(num_distributions); std::fill(fetchers.begin(), fetchers.end(), nullptr); + train_test_mode=default_train_test_mode; } DataManager::~DataManager() @@ -324,3 +325,13 @@ void DataManager::reset() std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->reset(); }); SG_SDEBUG("Leaving!\n"); } + +void DataManager::set_train_test_mode(bool on) +{ + train_test_mode_on=on; +} + +void DataManager::set_train_test_ratio(float64_t ratio) +{ + train_test_ratio=ratio; +} diff --git a/src/shogun/statistical_testing/internals/DataManager.h b/src/shogun/statistical_testing/internals/DataManager.h index 7b5616a5ba0..df08b359af1 100644 --- a/src/shogun/statistical_testing/internals/DataManager.h +++ b/src/shogun/statistical_testing/internals/DataManager.h @@ -232,6 +232,14 @@ class DataManager */ index_t get_num_folds() const; + /** + * Permutes the feature vectors. Useful for cross-validation set-up. Everytime + * TODO + * + void shuffle_features() + void unshuffle_features() + */ + /** * @param idx The index of the fold in X-validation scenario, has to be within the range of * \f$[0, N)\f$, where N is the number of folds as returned by get_num_folds() method. @@ -257,11 +265,18 @@ class DataManager * Resets the fetchers to the initial states. */ void reset(); + + void set_train_test_mode(bool on); + void set_train_test_ratio(float64_t ratio); + + bool is_train_test_mode() const; + float64_t get_train_test_ratio() const; private: - /** - * The internal data fetcher instances. - */ std::vector > fetchers; + bool train_test_mode; + float64_t train_test_ratio; + constexpr static bool default_train_test_mode=false; + constexpr static float64_t default_train_test_ratio=1.0; }; } diff --git a/src/shogun/statistical_testing/internals/MaxXValidation.cpp b/src/shogun/statistical_testing/internals/MaxXValidation.cpp index b38b10fab6f..3c2a6244cc4 100644 --- a/src/shogun/statistical_testing/internals/MaxXValidation.cpp +++ b/src/shogun/statistical_testing/internals/MaxXValidation.cpp @@ -88,7 +88,7 @@ void MaxXValidation::compute_measures() auto existing_kernel=estimator->get_kernel(); for (auto i=0; icleanup(); } } + dm.unshuffle_features(); } dm.set_xvalidation_mode(false); estimator->set_kernel(existing_kernel);