Skip to content

Commit

Permalink
refactor API (incomplete)
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 3, 2016
1 parent 186ec36 commit 8d66ed9
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 72 deletions.
104 changes: 63 additions & 41 deletions src/shogun/statistical_testing/HypothesisTest.cpp
@@ -1,89 +1,111 @@
/*
* Restructuring Shogun's statistical hypothesis testing framework.
* Copyright (C) 2016 Soumyajit De
* Copyright (c) The Shogun Machine Learning Toolbox
* Written (w) 2012 - 2013 Heiko Strathmann
* Written (w) 2014 - 2016 Soumyajit De
* All rights reserved.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the selfied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* The views and conclusions contained in the software and documentation are those
* of the authors and should not be interpreted as representing official policies,
* either expressed or implied, of the Shogun Development Team.
*/

#include <algorithm>
#include <shogun/lib/SGVector.h>
#include <shogun/mathematics/Math.h>
#include <shogun/statistical_testing/HypothesisTest.h>
#include <shogun/statistical_testing/internals/TestTypes.h>
#include <shogun/statistical_testing/internals/DataManager.h>
#include <shogun/statistical_testing/internals/KernelManager.h>

using namespace shogun;
using namespace internal;

struct CHypothesisTest::Self
{
Self(index_t num_distributions, index_t num_kernels)
: data_manager(num_distributions), kernel_manager(num_kernels)
{
}
DataManager data_manager;
KernelManager kernel_manager;
explicit Self(index_t num_distributions);
DataManager data_mgr;
};

CHypothesisTest::CHypothesisTest(index_t num_distributions, index_t num_kernels) : CSGObject()
{
self = std::unique_ptr<Self>(new CHypothesisTest::Self(num_distributions, num_kernels));
}

CHypothesisTest::~CHypothesisTest()
CHypothesisTest::Self::Self(index_t num_distributions) : data_mgr(num_distributions)
{
}

DataManager& CHypothesisTest::get_data_manager()
CHypothesisTest::CHypothesisTest(index_t num_distributions) : CSGObject()
{
return self->data_manager;
self=std::unique_ptr<Self>(new CHypothesisTest::Self(num_distributions));
}

const DataManager& CHypothesisTest::get_data_manager() const
CHypothesisTest::~CHypothesisTest()
{
return self->data_manager;
}

KernelManager& CHypothesisTest::get_kernel_manager()
void CHypothesisTest::set_train_test_mode(bool on)
{
return self->kernel_manager;
self->data_mgr.set_train_test_mode(on);
}

const KernelManager& CHypothesisTest::get_kernel_manager() const
void CHypothesisTest::set_train_test_ratio(float64_t ratio)
{
return self->kernel_manager;
self->data_mgr.train_test_ratio(ratio);
}

float64_t CHypothesisTest::compute_p_value(float64_t statistic)
{
SGVector<float64_t> values = sample_null();

SGVector<float64_t> values=sample_null();
std::sort(values.vector, values.vector + values.vlen);
float64_t i = values.find_position_to_insert(statistic);

return 1.0 - i / values.vlen;
float64_t i=values.find_position_to_insert(statistic);
return 1.0-i/values.vlen;
}

float64_t CHypothesisTest::compute_threshold(float64_t alpha)
{
SGVector<float64_t> values = sample_null();
SGVector<float64_t> values=sample_null();
std::sort(values.vector, values.vector + values.vlen);
return values[index_t(CMath::floor(values.vlen * (1 - alpha)))];
return values[index_t(CMath::floor(values.vlen*(1-alpha)))];
}

bool CHypothesisTest::perform_test(float64_t alpha)
{
auto statistic=compute_statistic();
auto p_value=compute_p_value(statistic);
return p_value<alpha;
}

const char* CHypothesisTest::get_name() const
{
return "HypothesisTest";
}

CSGObject* CHypothesisTest::clone()
{
SG_ERROR("Cloning CHypothesisTest instances is not supported!\n");
return nullptr;
}

DataManager& CHypothesisTest::get_data_mgr()
{
return self->data_mgr;
}

const DataManager& CHypothesisTest::get_data_mgr() const
{
return self->data_mgr;
}
60 changes: 38 additions & 22 deletions src/shogun/statistical_testing/HypothesisTest.h
@@ -1,19 +1,32 @@
/*
* Restructuring Shogun's statistical hypothesis testing framework.
* Copyright (C) 2016 Soumyajit De
* Copyright (c) The Shogun Machine Learning Toolbox
* Written (w) 2012 - 2013 Heiko Strathmann
* Written (w) 2014 - 2016 Soumyajit De
* All rights reserved.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* The views and conclusions contained in the software and documentation are those
* of the authors and should not be interpreted as representing official policies,
* either expressed or implied, of the Shogun Development Team.
*/

#ifndef HYPOTHESIS_TEST_H_
Expand All @@ -32,33 +45,36 @@ namespace internal
{

class DataManager;
class KernelManager;

}

class CHypothesisTest : public CSGObject
{
public:
CHypothesisTest(index_t num_distributions, index_t num_kernels);
explicit CHypothesisTest(index_t num_distributions);
virtual ~CHypothesisTest();

virtual float64_t compute_statistic() = 0;
CHypothesisTest(const CHypothesisTest& other)=delete;
CHypothesisTest& operator=(const CHypothesisTest& other)=delete;

void set_train_test_mode(bool on);
void set_train_test_ratio(float64_t ratio);

virtual float64_t compute_p_value(float64_t statistic);
virtual float64_t compute_threshold(float64_t alpha);
virtual bool perform_test(float64_t alpha);

virtual SGVector<float64_t> sample_null() = 0;
virtual float64_t compute_statistic()=0;
virtual SGVector<float64_t> sample_null()=0;

virtual const char* get_name() const;
virtual CSGObject* clone();
protected:
internal::DataManager& get_data_mgr();
const internal::DataManager& get_data_mgr() const;
private:
struct Self;
std::unique_ptr<Self> self;
protected:
internal::DataManager& get_data_manager();
const internal::DataManager& get_data_manager() const;

internal::KernelManager& get_kernel_manager();
const internal::KernelManager& get_kernel_manager() const;
};

}
Expand Down
10 changes: 5 additions & 5 deletions src/shogun/statistical_testing/TwoDistributionTest.h
Expand Up @@ -32,9 +32,6 @@ class CTwoDistributionTest : public CHypothesisTest
CTwoDistributionTest(index_t num_kernels);
virtual ~CTwoDistributionTest();

void set_p(CFeatures* samples_from_p);
void set_q(CFeatures* samples_from_q);

CFeatures* get_p() const;
CFeatures* get_q() const;

Expand All @@ -44,10 +41,13 @@ class CTwoDistributionTest : public CHypothesisTest
const index_t get_num_samples_p() const;
const index_t get_num_samples_q() const;

virtual float64_t compute_statistic() = 0;
virtual SGVector<float64_t> sample_null() = 0;
virtual float64_t compute_statistic()=0;
virtual SGVector<float64_t> sample_null()=0;

virtual const char* get_name() const;
protected:
void set_p(CFeatures* samples_from_p);
void set_q(CFeatures* samples_from_q);
};

}
Expand Down
11 changes: 11 additions & 0 deletions src/shogun/statistical_testing/internals/DataManager.cpp
Expand Up @@ -49,6 +49,7 @@ DataManager::DataManager(size_t num_distributions)
SG_SDEBUG("Data manager instance initialized with %d data sources!\n", num_distributions);
fetchers.resize(num_distributions);
std::fill(fetchers.begin(), fetchers.end(), nullptr);
train_test_mode=default_train_test_mode;
}

DataManager::~DataManager()
Expand Down Expand Up @@ -324,3 +325,13 @@ void DataManager::reset()
std::for_each(fetchers.begin(), fetchers.end(), [](fetcher_type& f) { f->reset(); });
SG_SDEBUG("Leaving!\n");
}

void DataManager::set_train_test_mode(bool on)
{
train_test_mode_on=on;
}

void DataManager::set_train_test_ratio(float64_t ratio)
{
train_test_ratio=ratio;
}
21 changes: 18 additions & 3 deletions src/shogun/statistical_testing/internals/DataManager.h
Expand Up @@ -232,6 +232,14 @@ class DataManager
*/
index_t get_num_folds() const;

/**
* Permutes the feature vectors. Useful for cross-validation set-up. Everytime
* TODO
*
void shuffle_features()
void unshuffle_features()
*/

/**
* @param idx The index of the fold in X-validation scenario, has to be within the range of
* \f$[0, N)\f$, where N is the number of folds as returned by get_num_folds() method.
Expand All @@ -257,11 +265,18 @@ class DataManager
* Resets the fetchers to the initial states.
*/
void reset();

void set_train_test_mode(bool on);
void set_train_test_ratio(float64_t ratio);

bool is_train_test_mode() const;
float64_t get_train_test_ratio() const;
private:
/**
* The internal data fetcher instances.
*/
std::vector<std::unique_ptr<DataFetcher> > fetchers;
bool train_test_mode;
float64_t train_test_ratio;
constexpr static bool default_train_test_mode=false;
constexpr static float64_t default_train_test_ratio=1.0;
};

}
Expand Down
3 changes: 2 additions & 1 deletion src/shogun/statistical_testing/internals/MaxXValidation.cpp
Expand Up @@ -88,7 +88,7 @@ void MaxXValidation::compute_measures()
auto existing_kernel=estimator->get_kernel();
for (auto i=0; i<num_run; ++i)
{
// TODO set permutation beforehand
dm.shuffle_features();
for (auto j=0; j<N; ++j)
{
dm.use_fold(j);
Expand All @@ -100,6 +100,7 @@ void MaxXValidation::compute_measures()
estimator->cleanup();
}
}
dm.unshuffle_features();
}
dm.set_xvalidation_mode(false);
estimator->set_kernel(existing_kernel);
Expand Down

0 comments on commit 8d66ed9

Please sign in to comment.