From b192ec5b992256d4c47de92c851b0254427cedd4 Mon Sep 17 00:00:00 2001 From: lambday Date: Thu, 12 May 2016 15:55:34 +0100 Subject: [PATCH] added support for cross-validation setting for kernel selection (incomplete) --- src/shogun/statistical_testing/MMD.cpp | 52 +++++---- src/shogun/statistical_testing/MMD.h | 7 +- .../internals/DataManager.cpp | 22 ++++ .../internals/DataManager.h | 36 +++++++ .../internals/MaxXValidation.cpp | 102 ++++++++++++++++++ .../internals/MaxXValidation.h | 67 ++++++++++++ 6 files changed, 266 insertions(+), 20 deletions(-) create mode 100644 src/shogun/statistical_testing/internals/MaxXValidation.cpp create mode 100644 src/shogun/statistical_testing/internals/MaxXValidation.h diff --git a/src/shogun/statistical_testing/MMD.cpp b/src/shogun/statistical_testing/MMD.cpp index a45953fd1db..bb78147b754 100644 --- a/src/shogun/statistical_testing/MMD.cpp +++ b/src/shogun/statistical_testing/MMD.cpp @@ -46,6 +46,7 @@ #include #include #include +#include #include #include #include @@ -402,43 +403,56 @@ void CMMD::add_kernel(CKernel* kernel) self->kernel_selection_mgr.push_back(kernel); } -void CMMD::select_kernel(EKernelSelectionMethod kmethod, bool weighted_kernel) +void CMMD::select_kernel(EKernelSelectionMethod kmethod, bool weighted_kernel, float64_t train_test_ratio, + index_t num_run, float64_t alpha) { SG_DEBUG("Entering!\n"); SG_DEBUG("Selecting kernels from a total of %d kernels!\n", self->kernel_selection_mgr.num_kernels()); - std::shared_ptr policy=nullptr; + std::unique_ptr policy=nullptr; + + auto& dm=get_data_manager(); + dm.set_train_test_ratio(train_test_ratio); + dm.set_train_mode(true); + switch (kmethod) { + case EKernelSelectionMethod::MEDIAN_HEURISTIC: + { + REQUIRE(!weighted_kernel, "Weighted kernel selection is not possible with MEDIAN_HEURISTIC!\n"); + auto distance=compute_distance(); + policy=std::unique_ptr(new MedianHeuristic(self->kernel_selection_mgr, distance)); + dm.set_train_test_ratio(0); + } + break; + case EKernelSelectionMethod::MAXIMIZE_XVALIDATION: + { + REQUIRE(!weighted_kernel, "Weighted kernel selection is not possible with MAXIMIZE_XVALIDATION!\n"); + policy=std::unique_ptr(new MaxXValidation(self->kernel_selection_mgr, this, num_run, alpha)); + } + break; case EKernelSelectionMethod::MAXIMIZE_MMD: if (weighted_kernel) - policy=std::shared_ptr(new WeightedMaxMeasure(self->kernel_selection_mgr, this)); + policy=std::unique_ptr(new WeightedMaxMeasure(self->kernel_selection_mgr, this)); else - policy=std::shared_ptr(new MaxMeasure(self->kernel_selection_mgr, this)); + policy=std::unique_ptr(new MaxMeasure(self->kernel_selection_mgr, this)); break; case EKernelSelectionMethod::MAXIMIZE_POWER: if (weighted_kernel) - policy=std::shared_ptr(new WeightedMaxTestPower(self->kernel_selection_mgr, this)); + policy=std::unique_ptr(new WeightedMaxTestPower(self->kernel_selection_mgr, this)); else - policy=std::shared_ptr(new MaxTestPower(self->kernel_selection_mgr, this)); - break; - case EKernelSelectionMethod::MEDIAN_HEURISTIC: - { - REQUIRE(!weighted_kernel, "Weighted kernel selection is not possible with MEDIAN_HEURISTIC!\n"); - auto distance=compute_distance(); - policy=std::shared_ptr(new MedianHeuristic(self->kernel_selection_mgr, distance)); - } + policy=std::unique_ptr(new MaxTestPower(self->kernel_selection_mgr, this)); break; default: SG_ERROR("Unsupported kernel selection method specified! " "Presently only accepted values are MAXIMIZE_MMD, MAXIMIZE_POWER and MEDIAN_HEURISTIC!\n"); break; } - if (policy!=nullptr) - { - auto& km=get_kernel_manager(); - km.kernel_at(0)=policy->select_kernel(); - km.restore_kernel_at(0); - } + ASSERT(policy!=nullptr); + auto& km=get_kernel_manager(); + km.kernel_at(0)=policy->select_kernel(); + km.restore_kernel_at(0); + + dm.set_train_mode(false); SG_DEBUG("Leaving!\n"); } diff --git a/src/shogun/statistical_testing/MMD.h b/src/shogun/statistical_testing/MMD.h index 22baf178215..f57617ad510 100644 --- a/src/shogun/statistical_testing/MMD.h +++ b/src/shogun/statistical_testing/MMD.h @@ -48,10 +48,12 @@ namespace internal { class MaxTestPower; +class MaxXValidation; class WeightedMaxTestPower; } +// TODO change enum class to enum in order to co-operate with python swig etc blah enum class EStatisticType { UNBIASED_FULL, @@ -87,12 +89,15 @@ class CMMD : public CTwoSampleTest using operation=std::function)>; friend class internal::MaxTestPower; friend class internal::WeightedMaxTestPower; + friend class internal::MaxXValidation; public: CMMD(); virtual ~CMMD(); void add_kernel(CKernel *kernel); - void select_kernel(EKernelSelectionMethod kmethod=EKernelSelectionMethod::AUTO, bool weighted_kernel=false); + void select_kernel(EKernelSelectionMethod kmethod=EKernelSelectionMethod::AUTO, + bool weighted_kernel=false, float64_t train_test_ratio=1.0, + index_t num_run=10, float64_t alpha=0.05); virtual float64_t compute_statistic() override; virtual float64_t compute_variance(); diff --git a/src/shogun/statistical_testing/internals/DataManager.cpp b/src/shogun/statistical_testing/internals/DataManager.cpp index 464245f8d03..567827386e5 100644 --- a/src/shogun/statistical_testing/internals/DataManager.cpp +++ b/src/shogun/statistical_testing/internals/DataManager.cpp @@ -194,6 +194,28 @@ const index_t DataManager::blocksize_at(size_t i) const return fetchers[i]->m_block_details.m_blocksize; } +void DataManager::set_train_test_ratio(float64_t train_test_ratio) +{ +} + +float64_t DataManager::get_train_test_ratio() const +{ + return 0; +} + +void DataManager::set_train_mode(bool train_mode) +{ +} + +void DataManager::set_xvalidation_mode(bool xvalidation_mode) +{ +} + +index_t DataManager::get_num_folds() const +{ + return 0; +} + void DataManager::start() { SG_SDEBUG("Entering!\n"); diff --git a/src/shogun/statistical_testing/internals/DataManager.h b/src/shogun/statistical_testing/internals/DataManager.h index 22936e6ddd5..ee87a544d60 100644 --- a/src/shogun/statistical_testing/internals/DataManager.h +++ b/src/shogun/statistical_testing/internals/DataManager.h @@ -189,6 +189,42 @@ class DataManager */ index_t get_min_blocksize() const; + /** + * @param train_test_ratio The split ratio for train-test data. The default value is 0 + * which means that all of the data would be used for testing. + */ + void set_train_test_ratio(float64_t train_test_ratio); + + /** + * @return The split ratio for train-test data. The default value is 0, which means + * that all of the data would be used for testing. + */ + float64_t get_train_test_ratio() const; + + /** + * @param train_mode If set to true, then the training data would be returned by the data + * fetching API of this data manager. Otherwise, test data would be returend. + */ + void set_train_mode(bool train_mode); + + /** + * @param xvalidation_mode If set to true, then the data would be split in N fold (the value + * of N is determined from the train_test_ratio). + */ + void set_xvalidation_mode(bool xvalidation_mode); + + /** + * @return The number of folds that can be used based on the train-test ratio. Returns + * an integer if xvalidation mode is ON, 0 otherwise. + */ + index_t get_num_folds() const; + + /** + * @param idx The index of the fold in X-validation scenario, has to be within the range of + * \f$[0, N)\f$, where N is the number of folds as returned by get_num_folds() method. + */ + void use_fold(index_t idx); + /** * Call this method before fetching the data from the data manager */ diff --git a/src/shogun/statistical_testing/internals/MaxXValidation.cpp b/src/shogun/statistical_testing/internals/MaxXValidation.cpp new file mode 100644 index 00000000000..820960f8d5d --- /dev/null +++ b/src/shogun/statistical_testing/internals/MaxXValidation.cpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (W) 2013 Heiko Strathmann + * Written (w) 2014 - 2016 Soumyajit De + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. + */ + +#include +#include +#include +#include +#include +#include +#include + +using namespace shogun; +using namespace internal; + +MaxXValidation::MaxXValidation(KernelManager& km, CMMD* est, const index_t& M, const float64_t& alp) +: KernelSelection(km), estimator(est), num_run(M), alpha(alp) +{ + // TODO write a more meaningful error message + REQUIRE(estimator!=nullptr, "Estimator is not set!\n"); + REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels is %d!\n", kernel_mgr.num_kernels()); + REQUIRE(num_run>0, "Number of runs is %d!\n", num_run); + REQUIRE(alpha>=0.0 && alpha<=1.0, "Threshold is %f!\n", alpha); +} + +MaxXValidation::~MaxXValidation() +{ +} + +void MaxXValidation::compute_measures(SGVector& measures, SGVector& term_counters) +{ + const size_t num_kernels=kernel_mgr.num_kernels(); + for (size_t i=0; iset_kernel(kernel); + bool rejected=estimator->compute_p_value(estimator->compute_statistic())cleanup(); + } +} + +CKernel* MaxXValidation::select_kernel() +{ + auto& dm=estimator->get_data_manager(); + dm.set_xvalidation_mode(true); + auto existing_kernel=estimator->get_kernel(); + + const index_t N=dm.get_num_folds(); + // TODO write a more meaningful error message + REQUIRE(N!=0, "Number of folds is not set!\n"); + SG_SINFO("Performing %d fold cross-validattion!\n", N); + // train mode is already ON by now! set by the caller + SGVector measures(kernel_mgr.num_kernels()); + std::fill(measures.data(), measures.data()+measures.size(), 0); + SGVector term_counters(measures.size()); + std::fill(term_counters.data(), term_counters.data()+term_counters.size(), 1); + for (auto i=0; iset_kernel(existing_kernel); + dm.set_xvalidation_mode(false); + + auto min_element=std::min_element(measures.vector, measures.vector+measures.vlen); + auto min_idx=std::distance(measures.vector, min_element); + SG_SDEBUG("Selected kernel at %d position!\n", min_idx); + return kernel_mgr.kernel_at(min_idx); +} diff --git a/src/shogun/statistical_testing/internals/MaxXValidation.h b/src/shogun/statistical_testing/internals/MaxXValidation.h new file mode 100644 index 00000000000..a24068ac1d4 --- /dev/null +++ b/src/shogun/statistical_testing/internals/MaxXValidation.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (W) 2013 Heiko Strathmann + * Written (w) 2014 - 2016 Soumyajit De + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. + */ + +#ifndef MAX_XVALIDATIN_H__ +#define MAX_XVALIDATIN_H__ + +#include +#include + +namespace shogun +{ + +class CKernel; +class CMMD; +template class SGVector; + +namespace internal +{ + +class MaxXValidation : public KernelSelection +{ +public: + MaxXValidation(KernelManager&, CMMD*, const index_t&, const float64_t&); + MaxXValidation(const MaxXValidation& other)=delete; + ~MaxXValidation(); + MaxXValidation& operator=(const MaxXValidation& other)=delete; + virtual CKernel* select_kernel() override; +protected: + void compute_measures(SGVector&, SGVector&); + CMMD* estimator; + const index_t num_run; + const float64_t alpha; +}; + +} + +} + +#endif // MAX_XVALIDATIN_H__