Skip to content

Commit

Permalink
added permutation test with precomputed kernel matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 3, 2016
1 parent debc0c4 commit b75450f
Showing 1 changed file with 100 additions and 17 deletions.
117 changes: 100 additions & 17 deletions src/shogun/statistical_testing/MMD.cpp
Expand Up @@ -47,7 +47,7 @@ struct CMMD::Self

void create_computation_jobs(index_t Bx);
void create_statistic_job(index_t Bx);
void create_variance_job(index_t Bx);
void create_variance_job();

std::pair<float64_t, float64_t> compute_statistic_variance();
SGVector<float64_t> sample_null();
Expand All @@ -62,6 +62,7 @@ struct CMMD::Self
ENullApproximationMethod null_approximation_method;

std::function<float64_t(SGMatrix<float64_t>)> statistic_job;
std::function<float64_t(SGMatrix<float64_t>)> permutation_job;
std::function<float64_t(SGMatrix<float64_t>)> variance_job;
};

Expand All @@ -77,7 +78,7 @@ CMMD::Self::Self(CMMD& cmmd) : owner(cmmd),
void CMMD::Self::create_computation_jobs(index_t Bx)
{
create_statistic_job(Bx);
create_variance_job(Bx);
create_variance_job();
}

void CMMD::Self::create_statistic_job(index_t Bx)
Expand All @@ -86,38 +87,29 @@ void CMMD::Self::create_statistic_job(index_t Bx)
{
case EStatisticType::UNBIASED_FULL:
statistic_job = mmd::UnbiasedFull(Bx);
permutation_job = mmd::WithinBlockPermutation<mmd::UnbiasedFull>(Bx);
break;
case EStatisticType::UNBIASED_INCOMPLETE:
statistic_job = mmd::UnbiasedIncomplete(Bx);
permutation_job = mmd::WithinBlockPermutation<mmd::UnbiasedIncomplete>(Bx);
break;
case EStatisticType::BIASED_FULL:
statistic_job = mmd::BiasedFull(Bx);
permutation_job = mmd::WithinBlockPermutation<mmd::BiasedFull>(Bx);
break;
default : break;
};
}

void CMMD::Self::create_variance_job(index_t Bx)
void CMMD::Self::create_variance_job()
{
switch (variance_estimation_method)
{
case EVarianceEstimationMethod::DIRECT:
variance_job = owner.get_direct_estimation_method();
break;
case EVarianceEstimationMethod::PERMUTATION:
switch (statistic_type)
{
case EStatisticType::UNBIASED_FULL:
variance_job = mmd::WithinBlockPermutation<mmd::UnbiasedFull>(Bx);
break;
case EStatisticType::UNBIASED_INCOMPLETE:
variance_job = mmd::WithinBlockPermutation<mmd::UnbiasedIncomplete>(Bx);
break;
case EStatisticType::BIASED_FULL:
variance_job = mmd::WithinBlockPermutation<mmd::BiasedFull>(Bx);
break;
default : break;
}
variance_job = permutation_job;
break;
default : break;
};
Expand Down Expand Up @@ -240,7 +232,98 @@ std::pair<float64_t, float64_t> CMMD::Self::compute_statistic_variance()

SGVector<float64_t> CMMD::Self::sample_null()
{
return SGVector<float64_t>();
DataManager& dm = owner.get_data_manager();
const KernelManager& km = owner.get_kernel_manager();

SGVector<float64_t> statistic(num_null_samples);
std::fill(statistic.vector, statistic.vector + statistic.vlen, 0);

auto kernel = km.kernel_at(0);
REQUIRE(kernel != nullptr, "Kernel is not set!\n");

std::vector<index_t> term_counters(num_null_samples);
std::fill(term_counters.data(), term_counters.data() + term_counters.size(), 1);

ComputationManager cm;
dm.start();
auto next_burst = dm.next();

create_statistic_job(owner.get_data_manager().blocksize_at(0));

std::vector<std::shared_ptr<CFeatures>> blocks;

while (!next_burst.empty())
{
cm.num_data(next_burst.num_blocks());
blocks.resize(next_burst.num_blocks());

#pragma omp parallel for
for (size_t i = 0; i < blocks.size(); ++i)
{
auto block_p = next_burst[0][i];
auto block_q = next_burst[1][i];

auto block_p_q = block_p->create_merged_copy(block_q.get());
SG_REF(block_p_q);

block_p = nullptr;
block_q = nullptr;

blocks[i] = std::shared_ptr<CFeatures>(block_p_q, [](CFeatures* ptr) { SG_UNREF(ptr); });
}

#pragma omp parallel for
for (size_t i = 0; i < blocks.size(); ++i)
{
try
{
auto kernel_clone = std::unique_ptr<CKernel>(static_cast<CKernel*>(kernel->clone()));
kernel_clone->init(blocks[i].get(), blocks[i].get());
cm.data(i) = std::unique_ptr<CCustomKernel>(new CCustomKernel(kernel_clone.get()))->get_kernel_matrix();
kernel_clone->remove_lhs_and_rhs();
}
catch (ShogunException e)
{
SG_SERROR("%s, Try using less number of blocks per burst!\n", e.get_exception_string());
}
}

cm.enqueue_job(permutation_job);

for (auto j = 0; j < num_null_samples; ++j)
{
if (use_gpu_for_computation)
{
cm.use_gpu().compute();
}
else
{
cm.use_cpu().compute();
}

auto mmds = cm.next_result();

for (size_t i = 0; i < mmds.size(); ++i)
{
auto delta = mmds[i] - statistic[j];
statistic[j] += delta / term_counters[j];
}

term_counters[j]++;
}

next_burst = dm.next();
}

dm.end();

// normalize statistic
std::for_each(statistic.vector, statistic.vector + statistic.vlen, [this](float64_t& value)
{
value = owner.normalize_statistic(value);
});

return statistic;
}

CMMD::CMMD() : CTwoSampleTest()
Expand Down

0 comments on commit b75450f

Please sign in to comment.