diff --git a/src/shogun/statistical_testing/QuadraticTimeMMD.cpp b/src/shogun/statistical_testing/QuadraticTimeMMD.cpp index acde91a22e2..ec7203a0fef 100644 --- a/src/shogun/statistical_testing/QuadraticTimeMMD.cpp +++ b/src/shogun/statistical_testing/QuadraticTimeMMD.cpp @@ -287,6 +287,12 @@ CQuadraticTimeMMD::~CQuadraticTimeMMD() get_kernel_manager().restore_kernel_at(0); } +void CQuadraticTimeMMD::set_kernel(CKernel* kernel) +{ + CTwoSampleTest::set_kernel(kernel); + self->is_kernel_initialized=false; +} + const std::function)> CQuadraticTimeMMD::get_direct_estimation_method() const { return FullDirect(); diff --git a/src/shogun/statistical_testing/QuadraticTimeMMD.h b/src/shogun/statistical_testing/QuadraticTimeMMD.h index 869534f1d4a..18570381436 100644 --- a/src/shogun/statistical_testing/QuadraticTimeMMD.h +++ b/src/shogun/statistical_testing/QuadraticTimeMMD.h @@ -49,6 +49,7 @@ class CQuadraticTimeMMD : public CMMD CQuadraticTimeMMD(CFeatures* samples_from_p, CFeatures* samples_from_q); virtual ~CQuadraticTimeMMD(); + virtual void set_kernel(CKernel* kernel); virtual float64_t compute_statistic(); virtual float64_t compute_variance(); diff --git a/src/shogun/statistical_testing/TwoSampleTest.cpp b/src/shogun/statistical_testing/TwoSampleTest.cpp index 7d85bbffd3d..a5d98f5a95d 100644 --- a/src/shogun/statistical_testing/TwoSampleTest.cpp +++ b/src/shogun/statistical_testing/TwoSampleTest.cpp @@ -34,13 +34,14 @@ CTwoSampleTest::~CTwoSampleTest() void CTwoSampleTest::set_kernel(CKernel* kernel) { - auto& km = get_kernel_manager(); - km.kernel_at(0) = kernel; + auto& km=get_kernel_manager(); + km.kernel_at(0)=kernel; + km.restore_kernel_at(0); } CKernel* CTwoSampleTest::get_kernel() const { - const auto& km = get_kernel_manager(); + const auto& km=get_kernel_manager(); return km.kernel_at(0); } diff --git a/src/shogun/statistical_testing/TwoSampleTest.h b/src/shogun/statistical_testing/TwoSampleTest.h index 865acdcb0a2..f8008ce72fd 100644 --- a/src/shogun/statistical_testing/TwoSampleTest.h +++ b/src/shogun/statistical_testing/TwoSampleTest.h @@ -32,7 +32,7 @@ class CTwoSampleTest : public CTwoDistributionTest CTwoSampleTest(); virtual ~CTwoSampleTest(); - void set_kernel(CKernel* kernel); + virtual void set_kernel(CKernel* kernel); CKernel* get_kernel() const; virtual float64_t compute_statistic() = 0; diff --git a/src/shogun/statistical_testing/internals/DataFetcher.cpp b/src/shogun/statistical_testing/internals/DataFetcher.cpp index c2cb3693b9a..b770b9e7de1 100644 --- a/src/shogun/statistical_testing/internals/DataFetcher.cpp +++ b/src/shogun/statistical_testing/internals/DataFetcher.cpp @@ -65,8 +65,10 @@ float64_t DataFetcher::get_train_test_ratio() const void DataFetcher::set_train_mode(bool train_mode) { + m_train_test_details.train_mode=train_mode; + // TODO put the following in another methods index_t start_index=0; - if (train_mode) + if (m_train_test_details.train_mode) { m_num_samples=m_train_test_details.get_num_training_samples(); if (m_num_samples==0) @@ -107,18 +109,46 @@ void DataFetcher::set_xvalidation_mode(bool xvalidation_mode) index_t DataFetcher::get_num_folds() const { -// REQUIRE(fetchers[0]!=nullptr, "Please set the samples first!\n"); -// return fetchers[0]->get_train_test_ratio(); - return 0; + return 1+ceil(get_train_test_ratio()); } void DataFetcher::use_fold(index_t idx) { -// using fetcher_type=std::unique_ptr; -// std::for_each(fetchers.begin(), fetchers.end(), [&train_mode](fetcher_type& f) -// { -// f->use_fold(idx); -// }); + auto num_folds=get_num_folds(); + REQUIRE(idx>=0, "The index (%d) has to be between 0 and %d, both inclusive!\n", idx, num_folds-1); + REQUIRE(idxremove_subset(); + + SGVector inds; + auto start_idx=idx*num_per_fold; + auto num_samples=0; + + if (m_train_test_details.train_mode) + { + num_samples=m_train_test_details.get_num_training_samples(); + inds=SGVector(num_samples); + std::iota(inds.data(), inds.data()+inds.size(), 0); + if (start_idx(num_samples); + std::iota(inds.data(), inds.data()+inds.size(), start_idx); + m_samples->add_subset(inds); + } + inds.display_vector("inds"); + m_samples->add_subset(inds); } void DataFetcher::set_blockwise(bool blockwise) diff --git a/src/shogun/statistical_testing/internals/DataManager.cpp b/src/shogun/statistical_testing/internals/DataManager.cpp index 1a3a7db588e..6ba4b9629fa 100644 --- a/src/shogun/statistical_testing/internals/DataManager.cpp +++ b/src/shogun/statistical_testing/internals/DataManager.cpp @@ -256,7 +256,7 @@ void DataManager::set_xvalidation_mode(bool xvalidation_mode) index_t DataManager::get_num_folds() const { REQUIRE(fetchers[0]!=nullptr, "Please set the samples first!\n"); - return fetchers[0]->get_train_test_ratio(); + return fetchers[0]->get_num_folds(); } void DataManager::use_fold(index_t idx) diff --git a/tests/unit/statistical_testing/KernelSelectionMaxXValidation_unittest.cc b/tests/unit/statistical_testing/KernelSelectionMaxXValidation_unittest.cc new file mode 100644 index 00000000000..1b4ea1f29a4 --- /dev/null +++ b/tests/unit/statistical_testing/KernelSelectionMaxXValidation_unittest.cc @@ -0,0 +1,84 @@ +/* + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (W) 2012-2013 Heiko Strathmann + * Written (w) 2016 Soumyajit De + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace shogun; + +TEST(KernelSelectionMaxXValidation, single_kernel) +{ + const index_t m=10; + const index_t n=15; + const index_t dim=2; + const float64_t difference=0.5; + const index_t num_kernels=10; + +// sg_io->set_loglevel(MSG_DEBUG); +// sg_io->set_location_info(MSG_FUNCTION); + + // use fixed seed + sg_rand->set_seed(12345); + + // streaming data generator for mean shift distributions + auto gen_p=new CMeanShiftDataGenerator(0, dim, 0); + auto gen_q=new CMeanShiftDataGenerator(difference, dim, 0); + + auto feats_p=gen_p->get_streamed_features(m); + auto feats_q=gen_q->get_streamed_features(n); + + SG_UNREF(gen_p); + SG_UNREF(gen_q); + + // create MMD instance, convienience constructor + auto mmd=some(feats_p, feats_q); + mmd->set_statistic_type(ST_BIASED_FULL); + + for (auto i=0; iadd_kernel(new CGaussianKernel(10, sq_sigma_twice)); + } + + mmd->select_kernel(KSM_MAXIMIZE_XVALIDATION, false, 4, 1, 0.05); + auto selected_kernel=static_cast(mmd->get_kernel()); + EXPECT_NEAR(selected_kernel->get_width(), 0.5, 1E-10); +}