Skip to content

Commit

Permalink
refactored train/test subsetting and hypothesis test framework
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 1, 2016
1 parent cef312c commit cfa8dcb
Show file tree
Hide file tree
Showing 31 changed files with 655 additions and 617 deletions.
20 changes: 10 additions & 10 deletions src/shogun/statistical_testing/BTestMMD.cpp
Expand Up @@ -37,12 +37,12 @@ CBTestMMD::~CBTestMMD()

void CBTestMMD::set_blocksize(index_t blocksize)
{
get_data_manager().set_blocksize(blocksize);
get_data_mgr().set_blocksize(blocksize);
}

void CBTestMMD::set_num_blocks_per_burst(index_t num_blocks_per_burst)
{
get_data_manager().set_num_blocks_per_burst(num_blocks_per_burst);
get_data_mgr().set_num_blocks_per_burst(num_blocks_per_burst);
}

const std::function<float32_t(SGMatrix<float32_t>)> CBTestMMD::get_direct_estimation_method() const
Expand All @@ -52,19 +52,19 @@ const std::function<float32_t(SGMatrix<float32_t>)> CBTestMMD::get_direct_estima

const float64_t CBTestMMD::normalize_statistic(float64_t statistic) const
{
const DataManager& dm=get_data_manager();
const index_t Nx=dm.num_samples_at(0);
const index_t Ny=dm.num_samples_at(1);
const index_t Bx=dm.blocksize_at(0);
const index_t By=dm.blocksize_at(1);
const DataManager& data_mgr=get_data_mgr();
const index_t Nx=data_mgr.num_samples_at(0);
const index_t Ny=data_mgr.num_samples_at(1);
const index_t Bx=data_mgr.blocksize_at(0);
const index_t By=data_mgr.blocksize_at(1);
return Nx*Ny*statistic*CMath::sqrt((Bx+By)/float64_t(Nx+Ny))/(Nx+Ny);
}

const float64_t CBTestMMD::normalize_variance(float64_t variance) const
{
const DataManager& dm=get_data_manager();
const index_t Bx=dm.blocksize_at(0);
const index_t By=dm.blocksize_at(1);
const DataManager& data_mgr=get_data_mgr();
const index_t Bx=data_mgr.blocksize_at(0);
const index_t By=data_mgr.blocksize_at(1);
return variance*CMath::sq(Bx*By/float64_t(Bx+By));
}

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/statistical_testing/HypothesisTest.cpp
Expand Up @@ -64,7 +64,7 @@ void CHypothesisTest::set_train_test_mode(bool on)

void CHypothesisTest::set_train_test_ratio(float64_t ratio)
{
self->data_mgr.train_test_ratio(ratio);
self->data_mgr.set_train_test_ratio(ratio);
}

float64_t CHypothesisTest::compute_p_value(float64_t statistic)
Expand Down
35 changes: 26 additions & 9 deletions src/shogun/statistical_testing/IndependenceTest.cpp
Expand Up @@ -24,8 +24,19 @@
using namespace shogun;
using namespace internal;

CIndependenceTest::CIndependenceTest() : CTwoDistributionTest(IndependenceTest::num_kernels)
struct CIndependenceTest::Self
{
Self(index_t num_kernels);
KernelManager kernel_mgr;
};

CIndependenceTest::Self::Self(index_t num_kernels) : kernel_mgr(num_kernels)
{
}

CIndependenceTest::CIndependenceTest() : CTwoDistributionTest()
{
self=std::unique_ptr<Self>(new Self(IndependenceTest::num_kernels));
}

CIndependenceTest::~CIndependenceTest()
Expand All @@ -34,29 +45,35 @@ CIndependenceTest::~CIndependenceTest()

void CIndependenceTest::set_kernel_p(CKernel* kernel_p)
{
auto& km = get_kernel_manager();
km.kernel_at(0) = kernel_p;
self->kernel_mgr.kernel_at(0)=kernel_p;
}

CKernel* CIndependenceTest::get_kernel_p() const
{
const auto& km = get_kernel_manager();
return km.kernel_at(0);
return self->kernel_mgr.kernel_at(0);
}

void CIndependenceTest::set_kernel_q(CKernel* kernel_q)
{
auto& km = get_kernel_manager();
km.kernel_at(1) = kernel_q;
self->kernel_mgr.kernel_at(1)=kernel_q;
}

CKernel* CIndependenceTest::get_kernel_q() const
{
const auto& km = get_kernel_manager();
return km.kernel_at(1);
return self->kernel_mgr.kernel_at(1);
}

const char* CIndependenceTest::get_name() const
{
return "IndependenceTest";
}

KernelManager& CIndependenceTest::get_kernel_mgr()
{
return self->kernel_mgr;
}

const KernelManager& CIndependenceTest::get_kernel_mgr() const
{
return self->kernel_mgr;
}
16 changes: 14 additions & 2 deletions src/shogun/statistical_testing/IndependenceTest.h
Expand Up @@ -19,13 +19,19 @@
#ifndef INDEPENDENCE_TEST_H_
#define INDEPENDENCE_TEST_H_

#include <memory>
#include <shogun/statistical_testing/TwoDistributionTest.h>

namespace shogun
{

class CKernel;

namespace internal
{
class KernelManager;
}

class CIndependenceTest : public CTwoDistributionTest
{
public:
Expand All @@ -38,10 +44,16 @@ class CIndependenceTest : public CTwoDistributionTest
void set_kernel_q(CKernel* kernel_q);
CKernel* get_kernel_q() const;

virtual float64_t compute_statistic() = 0;
virtual SGVector<float64_t> sample_null() = 0;
virtual float64_t compute_statistic()=0;
virtual SGVector<float64_t> sample_null()=0;

virtual const char* get_name() const;
protected:
internal::KernelManager& get_kernel_mgr();
const internal::KernelManager& get_kernel_mgr() const;
private:
struct Self;
std::unique_ptr<Self> self;
};

}
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/statistical_testing/KernelSelectionStrategy.cpp
Expand Up @@ -222,7 +222,7 @@ const char* CKernelSelectionStrategy::get_name() const
return "KernelSelectionStrategy";
}

const KernelManager& CKernelSelectionStrategy::get_kernel_manager() const
const KernelManager& CKernelSelectionStrategy::get_kernel_mgr() const
{
return self->kernel_mgr;
}
2 changes: 1 addition & 1 deletion src/shogun/statistical_testing/KernelSelectionStrategy.h
Expand Up @@ -85,7 +85,7 @@ class CKernelSelectionStrategy : public CSGObject
struct Self;
std::unique_ptr<Self> self;
void init();
const internal::KernelManager& get_kernel_manager() const;
const internal::KernelManager& get_kernel_mgr() const;
};

}
Expand Down
30 changes: 15 additions & 15 deletions src/shogun/statistical_testing/LinearTimeMMD.cpp
Expand Up @@ -44,12 +44,12 @@ CLinearTimeMMD::~CLinearTimeMMD()

void CLinearTimeMMD::set_num_blocks_per_burst(index_t num_blocks_per_burst)
{
auto& dm=get_data_manager();
auto min_blocksize=dm.get_min_blocksize();
auto& data_mgr=get_data_mgr();
auto min_blocksize=data_mgr.get_min_blocksize();
if (min_blocksize==2)
{
// only possible when number of samples from both the distributions are the same
auto N=dm.num_samples_at(0);
auto N=data_mgr.num_samples_at(0);
for (auto i=2; i<N; ++i)
{
if (N%i==0)
Expand All @@ -59,9 +59,9 @@ void CLinearTimeMMD::set_num_blocks_per_burst(index_t num_blocks_per_burst)
}
}
}
dm.set_blocksize(min_blocksize);
dm.set_num_blocks_per_burst(num_blocks_per_burst);
SG_SDEBUG("Block contains %d and %d samples, from P and Q respectively!\n", dm.blocksize_at(0), dm.blocksize_at(1));
data_mgr.set_blocksize(min_blocksize);
data_mgr.set_num_blocks_per_burst(num_blocks_per_burst);
SG_SDEBUG("Block contains %d and %d samples, from P and Q respectively!\n", data_mgr.blocksize_at(0), data_mgr.blocksize_at(1));
}

const std::function<float32_t(SGMatrix<float32_t>)> CLinearTimeMMD::get_direct_estimation_method() const
Expand All @@ -71,17 +71,17 @@ const std::function<float32_t(SGMatrix<float32_t>)> CLinearTimeMMD::get_direct_e

const float64_t CLinearTimeMMD::normalize_statistic(float64_t statistic) const
{
const DataManager& dm = get_data_manager();
const index_t Nx = dm.num_samples_at(0);
const index_t Ny = dm.num_samples_at(1);
const DataManager& data_mgr = get_data_mgr();
const index_t Nx = data_mgr.num_samples_at(0);
const index_t Ny = data_mgr.num_samples_at(1);
return CMath::sqrt(Nx * Ny / float64_t(Nx + Ny)) * statistic;
}

const float64_t CLinearTimeMMD::normalize_variance(float64_t variance) const
{
const DataManager& dm = get_data_manager();
const index_t Bx = dm.blocksize_at(0);
const index_t By = dm.blocksize_at(1);
const DataManager& data_mgr = get_data_mgr();
const index_t Bx = data_mgr.blocksize_at(0);
const index_t By = data_mgr.blocksize_at(1);
const index_t B = Bx + By;
if (get_statistic_type() == ST_UNBIASED_INCOMPLETE)
{
Expand All @@ -92,9 +92,9 @@ const float64_t CLinearTimeMMD::normalize_variance(float64_t variance) const

const float64_t CLinearTimeMMD::gaussian_variance(float64_t variance) const
{
const DataManager& dm = get_data_manager();
const index_t Bx = dm.blocksize_at(0);
const index_t By = dm.blocksize_at(1);
const DataManager& data_mgr = get_data_mgr();
const index_t Bx = data_mgr.blocksize_at(0);
const index_t By = data_mgr.blocksize_at(1);
const index_t B = Bx + By;
if (get_statistic_type() == ST_UNBIASED_INCOMPLETE)
{
Expand Down

0 comments on commit cfa8dcb

Please sign in to comment.