Skip to content

Commit

Permalink
more refactoring and internals cleanup for MMD
Browse files Browse the repository at this point in the history
 - removed BiasedFull, UnbiasedFull and UnbiasedIncomplete classes
 - different statistic types are taken care of in ComputeMMD class
  • Loading branch information
lambday committed Jul 7, 2016
1 parent bcc9cf5 commit 810e4d9
Show file tree
Hide file tree
Showing 19 changed files with 142 additions and 420 deletions.
101 changes: 49 additions & 52 deletions src/shogun/statistical_testing/MMD.cpp
Expand Up @@ -45,9 +45,7 @@
#include <shogun/statistical_testing/internals/FeaturesUtil.h>
#include <shogun/statistical_testing/internals/KernelManager.h>
#include <shogun/statistical_testing/internals/ComputationManager.h>
#include <shogun/statistical_testing/internals/mmd/BiasedFull.h>
#include <shogun/statistical_testing/internals/mmd/UnbiasedFull.h>
#include <shogun/statistical_testing/internals/mmd/UnbiasedIncomplete.h>
#include <shogun/statistical_testing/internals/mmd/ComputeMMD.h>
#include <shogun/statistical_testing/internals/mmd/WithinBlockDirect.h>
#include <shogun/statistical_testing/internals/mmd/WithinBlockPermutation.h>
#include <shogun/mathematics/eigen3.h>
Expand Down Expand Up @@ -108,21 +106,19 @@ void CMMD::Self::create_computation_jobs()
void CMMD::Self::create_statistic_job()
{
const DataManager& data_mgr=owner.get_data_mgr();

auto Bx=data_mgr.blocksize_at(0);
auto By=data_mgr.blocksize_at(1);
switch (statistic_type)
{
case EStatisticType::ST_UNBIASED_FULL:
statistic_job=mmd::UnbiasedFull(Bx);
break;
case EStatisticType::ST_UNBIASED_INCOMPLETE:
statistic_job=mmd::UnbiasedIncomplete(Bx);
break;
case EStatisticType::ST_BIASED_FULL:
statistic_job=mmd::BiasedFull(Bx);
break;
default : break;
};

REQUIRE(Bx>0, "Blocksize for samples from P cannot be 0!\n");
REQUIRE(By>0, "Blocksize for samples from Q cannot be 0!\n");

auto mmd=mmd::ComputeMMD();
mmd.m_n_x=Bx;
mmd.m_n_y=By;
mmd.m_stype=statistic_type;

statistic_job=mmd;
permutation_job=mmd::WithinBlockPermutation(Bx, By, statistic_type);
}

Expand Down Expand Up @@ -196,57 +192,58 @@ std::pair<float64_t, float64_t> CMMD::Self::compute_statistic_variance()
index_t variance_term_counter=1;

DataManager& data_mgr=owner.get_data_mgr();
ComputationManager cm;

create_computation_jobs();
cm.enqueue_job(statistic_job);
cm.enqueue_job(variance_job);

std::vector<CFeatures*> blocks;

data_mgr.start();
auto next_burst=data_mgr.next();
while (!next_burst.empty())
if (!next_burst.empty())
{
merge_samples(next_burst, blocks);
compute_kernel(cm, blocks, kernel);
blocks.resize(0);
compute_jobs(cm);
ComputationManager cm;
create_computation_jobs();
cm.enqueue_job(statistic_job);
cm.enqueue_job(variance_job);

auto mmds=cm.result(0);
auto vars=cm.result(1);
std::vector<CFeatures*> blocks;

for (size_t i=0; i<mmds.size(); ++i)
while (!next_burst.empty())
{
auto delta=mmds[i]-statistic;
statistic+=delta/statistic_term_counter;
statistic_term_counter++;
}
merge_samples(next_burst, blocks);
compute_kernel(cm, blocks, kernel);
blocks.resize(0);
compute_jobs(cm);

auto mmds=cm.result(0);
auto vars=cm.result(1);

if (variance_estimation_method==EVarianceEstimationMethod::VEM_DIRECT)
{
for (size_t i=0; i<mmds.size(); ++i)
{
auto delta=vars[i]-variance;
variance+=delta/variance_term_counter;
variance_term_counter++;
auto delta=mmds[i]-statistic;
statistic+=delta/statistic_term_counter;
statistic_term_counter++;
}
}
else
{
for (size_t i=0; i<vars.size(); ++i)

if (variance_estimation_method==EVarianceEstimationMethod::VEM_DIRECT)
{
auto delta=vars[i]-permuted_samples_statistic;
permuted_samples_statistic+=delta/variance_term_counter;
variance+=delta*(vars[i]-permuted_samples_statistic);
variance_term_counter++;
for (size_t i=0; i<mmds.size(); ++i)
{
auto delta=vars[i]-variance;
variance+=delta/variance_term_counter;
variance_term_counter++;
}
}
else
{
for (size_t i=0; i<vars.size(); ++i)
{
auto delta=vars[i]-permuted_samples_statistic;
permuted_samples_statistic+=delta/variance_term_counter;
variance+=delta*(vars[i]-permuted_samples_statistic);
variance_term_counter++;
}
}
next_burst=data_mgr.next();
}
next_burst=data_mgr.next();
cm.done();
}

data_mgr.end();
cm.done();

// normalize statistic and variance
statistic=owner.normalize_statistic(statistic);
Expand Down
33 changes: 9 additions & 24 deletions src/shogun/statistical_testing/QuadraticTimeMMD.cpp
Expand Up @@ -46,9 +46,7 @@
#include <shogun/statistical_testing/internals/DataManager.h>
#include <shogun/statistical_testing/internals/KernelManager.h>
#include <shogun/statistical_testing/internals/ComputationManager.h>
#include <shogun/statistical_testing/internals/mmd/BiasedFull.h>
#include <shogun/statistical_testing/internals/mmd/UnbiasedFull.h>
#include <shogun/statistical_testing/internals/mmd/UnbiasedIncomplete.h>
#include <shogun/statistical_testing/internals/mmd/ComputeMMD.h>
#include <shogun/statistical_testing/internals/mmd/FullDirect.h>
#include <shogun/statistical_testing/internals/mmd/WithinBlockPermutationBatch.h>
#include <shogun/statistical_testing/internals/mmd/MultiKernelMMD.h>
Expand Down Expand Up @@ -92,36 +90,24 @@ CQuadraticTimeMMD::Self::Self(CQuadraticTimeMMD& mmd) : owner(mmd), num_eigenval

void CQuadraticTimeMMD::Self::create_computation_jobs()
{
SG_SDEBUG("Entering\n");
create_statistic_job();
create_variance_job();
SG_SDEBUG("Leaving\n");
}

void CQuadraticTimeMMD::Self::create_statistic_job()
{
SG_SDEBUG("Entering\n");
const DataManager& data_mgr=owner.get_data_mgr();
auto Nx=data_mgr.num_samples_at(0);
switch (owner.get_statistic_type())
{
case EStatisticType::ST_UNBIASED_FULL:
statistic_job=UnbiasedFull(Nx);
break;
case EStatisticType::ST_UNBIASED_INCOMPLETE:
statistic_job=UnbiasedIncomplete(Nx);
break;
case EStatisticType::ST_BIASED_FULL:
statistic_job=BiasedFull(Nx);
break;
default : break;
};
SG_SDEBUG("Leaving\n");
auto mmd=mmd::ComputeMMD();
REQUIRE(owner.get_num_samples_p()>0, "Number of samples from P cannot be 0!\n");
REQUIRE(owner.get_num_samples_q()>0, "Number of samples from Q cannot be 0!\n");

mmd.m_n_x=owner.get_num_samples_p();
mmd.m_n_y=owner.get_num_samples_q();
mmd.m_stype=owner.get_statistic_type();
statistic_job=mmd;
}

void CQuadraticTimeMMD::Self::create_variance_job()
{
SG_SDEBUG("Entering\n");
switch (owner.get_variance_estimation_method())
{
case EVarianceEstimationMethod::VEM_DIRECT:
Expand All @@ -132,7 +118,6 @@ void CQuadraticTimeMMD::Self::create_variance_job()
break;
default : break;
};
SG_SDEBUG("Leaving\n");
}

void CQuadraticTimeMMD::Self::compute_jobs(ComputationManager& cm) const
Expand Down
Expand Up @@ -18,7 +18,6 @@

#include <shogun/lib/SGMatrix.h>
#include <shogun/statistical_testing/internals/ComputationManager.h>
#include <shogun/statistical_testing/internals/mmd/UnbiasedFull.h>

using namespace shogun;
using namespace internal;
Expand Down
43 changes: 0 additions & 43 deletions src/shogun/statistical_testing/internals/mmd/BiasedFull.cpp

This file was deleted.

53 changes: 0 additions & 53 deletions src/shogun/statistical_testing/internals/mmd/BiasedFull.h

This file was deleted.

0 comments on commit 810e4d9

Please sign in to comment.