From 4648ea0ee3f5b302f81b0bc81fc6592545cf8bb7 Mon Sep 17 00:00:00 2001 From: lambday Date: Fri, 1 Apr 2016 03:12:22 +0530 Subject: [PATCH] added permutation test with precomputed kernel matrices --- src/shogun/statistical_testing/MMD.cpp | 117 +++++++++++++++++++++---- 1 file changed, 100 insertions(+), 17 deletions(-) diff --git a/src/shogun/statistical_testing/MMD.cpp b/src/shogun/statistical_testing/MMD.cpp index fecb88666fc..5ab8bab384e 100644 --- a/src/shogun/statistical_testing/MMD.cpp +++ b/src/shogun/statistical_testing/MMD.cpp @@ -47,7 +47,7 @@ struct CMMD::Self void create_computation_jobs(index_t Bx); void create_statistic_job(index_t Bx); - void create_variance_job(index_t Bx); + void create_variance_job(); std::pair compute_statistic_variance(); SGVector sample_null(); @@ -62,6 +62,7 @@ struct CMMD::Self ENullApproximationMethod null_approximation_method; std::function)> statistic_job; + std::function)> permutation_job; std::function)> variance_job; }; @@ -77,7 +78,7 @@ CMMD::Self::Self(CMMD& cmmd) : owner(cmmd), void CMMD::Self::create_computation_jobs(index_t Bx) { create_statistic_job(Bx); - create_variance_job(Bx); + create_variance_job(); } void CMMD::Self::create_statistic_job(index_t Bx) @@ -86,18 +87,21 @@ void CMMD::Self::create_statistic_job(index_t Bx) { case EStatisticType::UNBIASED_FULL: statistic_job = mmd::UnbiasedFull(Bx); + permutation_job = mmd::WithinBlockPermutation(Bx); break; case EStatisticType::UNBIASED_INCOMPLETE: statistic_job = mmd::UnbiasedIncomplete(Bx); + permutation_job = mmd::WithinBlockPermutation(Bx); break; case EStatisticType::BIASED_FULL: statistic_job = mmd::BiasedFull(Bx); + permutation_job = mmd::WithinBlockPermutation(Bx); break; default : break; }; } -void CMMD::Self::create_variance_job(index_t Bx) +void CMMD::Self::create_variance_job() { switch (variance_estimation_method) { @@ -105,19 +109,7 @@ void CMMD::Self::create_variance_job(index_t Bx) variance_job = owner.get_direct_estimation_method(); break; case EVarianceEstimationMethod::PERMUTATION: - switch (statistic_type) - { - case EStatisticType::UNBIASED_FULL: - variance_job = mmd::WithinBlockPermutation(Bx); - break; - case EStatisticType::UNBIASED_INCOMPLETE: - variance_job = mmd::WithinBlockPermutation(Bx); - break; - case EStatisticType::BIASED_FULL: - variance_job = mmd::WithinBlockPermutation(Bx); - break; - default : break; - } + variance_job = permutation_job; break; default : break; }; @@ -240,7 +232,98 @@ std::pair CMMD::Self::compute_statistic_variance() SGVector CMMD::Self::sample_null() { - return SGVector(); + DataManager& dm = owner.get_data_manager(); + const KernelManager& km = owner.get_kernel_manager(); + + SGVector statistic(num_null_samples); + std::fill(statistic.vector, statistic.vector + statistic.vlen, 0); + + auto kernel = km.kernel_at(0); + REQUIRE(kernel != nullptr, "Kernel is not set!\n"); + + std::vector term_counters(num_null_samples); + std::fill(term_counters.data(), term_counters.data() + term_counters.size(), 1); + + ComputationManager cm; + dm.start(); + auto next_burst = dm.next(); + + create_statistic_job(owner.get_data_manager().blocksize_at(0)); + + std::vector> blocks; + + while (!next_burst.empty()) + { + cm.num_data(next_burst.num_blocks()); + blocks.resize(next_burst.num_blocks()); + +#pragma omp parallel for + for (size_t i = 0; i < blocks.size(); ++i) + { + auto block_p = next_burst[0][i]; + auto block_q = next_burst[1][i]; + + auto block_p_q = block_p->create_merged_copy(block_q.get()); + SG_REF(block_p_q); + + block_p = nullptr; + block_q = nullptr; + + blocks[i] = std::shared_ptr(block_p_q, [](CFeatures* ptr) { SG_UNREF(ptr); }); + } + +#pragma omp parallel for + for (size_t i = 0; i < blocks.size(); ++i) + { + try + { + auto kernel_clone = std::unique_ptr(static_cast(kernel->clone())); + kernel_clone->init(blocks[i].get(), blocks[i].get()); + cm.data(i) = std::unique_ptr(new CCustomKernel(kernel_clone.get()))->get_kernel_matrix(); + kernel_clone->remove_lhs_and_rhs(); + } + catch (ShogunException e) + { + SG_SERROR("%s, Try using less number of blocks per burst!\n", e.get_exception_string()); + } + } + + cm.enqueue_job(permutation_job); + + for (auto j = 0; j < num_null_samples; ++j) + { + if (use_gpu_for_computation) + { + cm.use_gpu().compute(); + } + else + { + cm.use_cpu().compute(); + } + + auto mmds = cm.next_result(); + + for (size_t i = 0; i < mmds.size(); ++i) + { + auto delta = mmds[i] - statistic[j]; + statistic[j] += delta / term_counters[j]; + } + + term_counters[j]++; + } + + next_burst = dm.next(); + } + + dm.end(); + + // normalize statistic + std::for_each(statistic.vector, statistic.vector + statistic.vlen, [this](float64_t& value) + { + value = owner.normalize_statistic(value); + }); + + return statistic; } CMMD::CMMD() : CTwoSampleTest()