Skip to content

Commit

Permalink
made data-splitting work with cross-validation
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday committed Jul 13, 2016
1 parent c8a6e15 commit 989fbc1
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 14 deletions.
6 changes: 6 additions & 0 deletions src/shogun/statistical_testing/QuadraticTimeMMD.cpp
Expand Up @@ -287,6 +287,12 @@ CQuadraticTimeMMD::~CQuadraticTimeMMD()
get_kernel_manager().restore_kernel_at(0);
}

void CQuadraticTimeMMD::set_kernel(CKernel* kernel)
{
CTwoSampleTest::set_kernel(kernel);
self->is_kernel_initialized=false;
}

const std::function<float32_t(SGMatrix<float32_t>)> CQuadraticTimeMMD::get_direct_estimation_method() const
{
return FullDirect();
Expand Down
1 change: 1 addition & 0 deletions src/shogun/statistical_testing/QuadraticTimeMMD.h
Expand Up @@ -49,6 +49,7 @@ class CQuadraticTimeMMD : public CMMD
CQuadraticTimeMMD(CFeatures* samples_from_p, CFeatures* samples_from_q);

virtual ~CQuadraticTimeMMD();
virtual void set_kernel(CKernel* kernel);

virtual float64_t compute_statistic();
virtual float64_t compute_variance();
Expand Down
7 changes: 4 additions & 3 deletions src/shogun/statistical_testing/TwoSampleTest.cpp
Expand Up @@ -34,13 +34,14 @@ CTwoSampleTest::~CTwoSampleTest()

void CTwoSampleTest::set_kernel(CKernel* kernel)
{
auto& km = get_kernel_manager();
km.kernel_at(0) = kernel;
auto& km=get_kernel_manager();
km.kernel_at(0)=kernel;
km.restore_kernel_at(0);
}

CKernel* CTwoSampleTest::get_kernel() const
{
const auto& km = get_kernel_manager();
const auto& km=get_kernel_manager();
return km.kernel_at(0);
}

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/statistical_testing/TwoSampleTest.h
Expand Up @@ -32,7 +32,7 @@ class CTwoSampleTest : public CTwoDistributionTest
CTwoSampleTest();
virtual ~CTwoSampleTest();

void set_kernel(CKernel* kernel);
virtual void set_kernel(CKernel* kernel);
CKernel* get_kernel() const;

virtual float64_t compute_statistic() = 0;
Expand Down
48 changes: 39 additions & 9 deletions src/shogun/statistical_testing/internals/DataFetcher.cpp
Expand Up @@ -65,8 +65,10 @@ float64_t DataFetcher::get_train_test_ratio() const

void DataFetcher::set_train_mode(bool train_mode)
{
m_train_test_details.train_mode=train_mode;
// TODO put the following in another methods
index_t start_index=0;
if (train_mode)
if (m_train_test_details.train_mode)
{
m_num_samples=m_train_test_details.get_num_training_samples();
if (m_num_samples==0)
Expand Down Expand Up @@ -107,18 +109,46 @@ void DataFetcher::set_xvalidation_mode(bool xvalidation_mode)

index_t DataFetcher::get_num_folds() const
{
// REQUIRE(fetchers[0]!=nullptr, "Please set the samples first!\n");
// return fetchers[0]->get_train_test_ratio();
return 0;
return 1+ceil(get_train_test_ratio());
}

void DataFetcher::use_fold(index_t idx)
{
// using fetcher_type=std::unique_ptr<DataFetcher>;
// std::for_each(fetchers.begin(), fetchers.end(), [&train_mode](fetcher_type& f)
// {
// f->use_fold(idx);
// });
auto num_folds=get_num_folds();
REQUIRE(idx>=0, "The index (%d) has to be between 0 and %d, both inclusive!\n", idx, num_folds-1);
REQUIRE(idx<num_folds, "The index (%d) has to be between 0 and %d, both inclusive!\n", idx, num_folds-1);

auto num_per_fold=m_train_test_details.get_total_num_samples()/num_folds;

if (train_test_subset_used)
m_samples->remove_subset();

SGVector<index_t> inds;
auto start_idx=idx*num_per_fold;
auto num_samples=0;

if (m_train_test_details.train_mode)
{
num_samples=m_train_test_details.get_num_training_samples();
inds=SGVector<index_t>(num_samples);
std::iota(inds.data(), inds.data()+inds.size(), 0);
if (start_idx<inds.size())
{
std::for_each(inds.data()+start_idx, inds.data()+inds.size(), [&num_per_fold](index_t& val)
{
val+=num_per_fold;
});
}
}
else
{
num_samples=m_train_test_details.get_num_test_samples();
inds=SGVector<index_t>(num_samples);
std::iota(inds.data(), inds.data()+inds.size(), start_idx);
m_samples->add_subset(inds);
}
inds.display_vector("inds");
m_samples->add_subset(inds);
}

void DataFetcher::set_blockwise(bool blockwise)
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/statistical_testing/internals/DataManager.cpp
Expand Up @@ -256,7 +256,7 @@ void DataManager::set_xvalidation_mode(bool xvalidation_mode)
index_t DataManager::get_num_folds() const
{
REQUIRE(fetchers[0]!=nullptr, "Please set the samples first!\n");
return fetchers[0]->get_train_test_ratio();
return fetchers[0]->get_num_folds();
}

void DataManager::use_fold(index_t idx)
Expand Down
@@ -0,0 +1,84 @@
/*
* Copyright (c) The Shogun Machine Learning Toolbox
* Written (W) 2012-2013 Heiko Strathmann
* Written (w) 2016 Soumyajit De
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* The views and conclusions contained in the software and documentation are those
* of the authors and should not be interpreted as representing official policies,
* either expressed or implied, of the Shogun Development Team.
*/

#include <shogun/base/some.h>
#include <shogun/kernel/GaussianKernel.h>
#include <shogun/kernel/CombinedKernel.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/features/streaming/generators/MeanShiftDataGenerator.h>
#include <shogun/mathematics/Statistics.h>
#include <shogun/mathematics/eigen3.h>
#include <shogun/mathematics/Math.h>
#include <shogun/statistical_testing/QuadraticTimeMMD.h>
#include <gtest/gtest.h>

using namespace shogun;

TEST(KernelSelectionMaxXValidation, single_kernel)
{
const index_t m=10;
const index_t n=15;
const index_t dim=2;
const float64_t difference=0.5;
const index_t num_kernels=10;

// sg_io->set_loglevel(MSG_DEBUG);
// sg_io->set_location_info(MSG_FUNCTION);

// use fixed seed
sg_rand->set_seed(12345);

// streaming data generator for mean shift distributions
auto gen_p=new CMeanShiftDataGenerator(0, dim, 0);
auto gen_q=new CMeanShiftDataGenerator(difference, dim, 0);

auto feats_p=gen_p->get_streamed_features(m);
auto feats_q=gen_q->get_streamed_features(n);

SG_UNREF(gen_p);
SG_UNREF(gen_q);

// create MMD instance, convienience constructor
auto mmd=some<CQuadraticTimeMMD>(feats_p, feats_q);
mmd->set_statistic_type(ST_BIASED_FULL);

for (auto i=0; i<num_kernels; ++i)
{
// shoguns kernel width is different
float64_t sigma=(i+1)*0.5;
float64_t sq_sigma_twice=sigma*sigma*2;
mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice));
}

mmd->select_kernel(KSM_MAXIMIZE_XVALIDATION, false, 4, 1, 0.05);
auto selected_kernel=static_cast<CGaussianKernel*>(mmd->get_kernel());
EXPECT_NEAR(selected_kernel->get_width(), 0.5, 1E-10);
}

0 comments on commit 989fbc1

Please sign in to comment.