From 48d50b95b09da8a10ffb8859f092f6f3247b6b5b Mon Sep 17 00:00:00 2001 From: lambday Date: Fri, 1 Apr 2016 02:29:04 +0530 Subject: [PATCH] removed statistic computation API for multiple kernels --- src/shogun/statistical_testing/MMD.cpp | 172 ++++++++----------------- src/shogun/statistical_testing/MMD.h | 21 ++- 2 files changed, 68 insertions(+), 125 deletions(-) diff --git a/src/shogun/statistical_testing/MMD.cpp b/src/shogun/statistical_testing/MMD.cpp index 2cb462d62fd..4b4452b9a5b 100644 --- a/src/shogun/statistical_testing/MMD.cpp +++ b/src/shogun/statistical_testing/MMD.cpp @@ -49,7 +49,7 @@ struct CMMD::Self void create_statistic_job(index_t Bx); void create_variance_job(index_t Bx); - std::pair, SGVector> compute_statistic_variance(); + std::pair compute_statistic_variance(); CMMD& owner; @@ -120,47 +120,19 @@ void CMMD::Self::create_variance_job(index_t Bx) }; } -std::pair, SGVector> CMMD::Self::compute_statistic_variance() +std::pair CMMD::Self::compute_statistic_variance() { DataManager& dm = owner.get_data_manager(); const KernelManager& km = owner.get_kernel_manager(); - SGVector statistic; - SGVector stat_perm; - SGVector variance; + float64_t statistic = 0; + float64_t permuted_samples_statistic = 0; + float64_t variance = 0; auto kernel = km.kernel_at(0); - ASSERT(kernel != nullptr); - auto num_kernels = 1; + REQUIRE(kernel != nullptr, "Kernel is not set!\n"); - std::vector kernels; - - if (kernel->get_kernel_type() == K_COMBINED) - { - auto combined_kernel = static_cast(kernel); - num_kernels = combined_kernel->get_num_subkernels(); - - kernels = std::vector(num_kernels); - for (auto i = 0; i < num_kernels; ++i) - { - kernels[i] = combined_kernel->get_kernel(i); - } - } - else - { - kernels.push_back(kernel); - } - - statistic = SGVector(num_kernels); - stat_perm = SGVector(num_kernels); - variance = SGVector(num_kernels); - - std::fill(statistic.vector, statistic.vector + statistic.vlen, 0); - std::fill(stat_perm.vector, stat_perm.vector + stat_perm.vlen, 0); - std::fill(variance.vector, variance.vector + variance.vlen, 0); - - std::vector term_counters(statistic.vlen); - std::fill(term_counters.data(), term_counters.data() + term_counters.size(), 1); + index_t term_counters = 1; ComputationManager cm; dm.start(); @@ -176,7 +148,7 @@ std::pair, SGVector> CMMD::Self::compute_statisti blocks.resize(next_burst.num_blocks()); #pragma omp parallel for - for (auto i = 0; i < next_burst.num_blocks(); ++i) + for (size_t i = 0; i < blocks.size(); ++i) { auto block_p = next_burst[0][i]; auto block_q = next_burst[1][i]; @@ -197,65 +169,62 @@ std::pair, SGVector> CMMD::Self::compute_statisti blocks[i] = std::shared_ptr(block_p_q, [](CFeatures* ptr) { SG_UNREF(ptr); }); } - for (auto i = 0; i < kernels.size(); ++i) - { #pragma omp parallel for - for (auto j = 0; j < blocks.size(); ++j) - { - try - { - auto curr_kernel = std::unique_ptr(static_cast(kernels[i]->clone())); - curr_kernel->init(blocks[j].get(), blocks[j].get()); - cm.data(j) = std::unique_ptr(new CCustomKernel(curr_kernel.get()))->get_kernel_matrix(); - curr_kernel->remove_lhs_and_rhs(); - } - catch (ShogunException e) - { - SG_SERROR("%s, Try using less number of blocks per burst!\n", e.get_exception_string()); - } - } - - // enqueue statistic and variance computation jobs on the computed kernel matrices - cm.enqueue_job(statistic_job); - cm.enqueue_job(variance_job); - - if (use_gpu_for_computation) + for (size_t i = 0; i < blocks.size(); ++i) + { + try { - cm.use_gpu().compute(); + auto kernel_clone = std::unique_ptr(static_cast(kernel->clone())); + kernel_clone->init(blocks[i].get(), blocks[i].get()); + cm.data(i) = std::unique_ptr(new CCustomKernel(kernel_clone.get()))->get_kernel_matrix(); + kernel_clone->remove_lhs_and_rhs(); } - else + catch (ShogunException e) { - cm.use_cpu().compute(); + SG_SERROR("%s, Try using less number of blocks per burst!\n", e.get_exception_string()); } + } - auto mmds = cm.next_result(); - auto vars = cm.next_result(); + // enqueue statistic and variance computation jobs on the computed kernel matrices + cm.enqueue_job(statistic_job); + cm.enqueue_job(variance_job); - for (auto j = 0; j < mmds.size(); ++j) - { - auto delta = mmds[j] - statistic[i]; - statistic[i] += delta / term_counters[i]; - } + if (use_gpu_for_computation) + { + cm.use_gpu().compute(); + } + else + { + cm.use_cpu().compute(); + } - if (variance_estimation_method == EVarianceEstimationMethod::DIRECT) + auto mmds = cm.next_result(); + auto vars = cm.next_result(); + + for (size_t i = 0; i < mmds.size(); ++i) + { + auto delta = mmds[i] - statistic; + statistic += delta / term_counters; + } + + if (variance_estimation_method == EVarianceEstimationMethod::DIRECT) + { + for (size_t i = 0; i < mmds.size(); ++i) { - for (auto j = 0; j < mmds.size(); ++j) - { - auto delta = vars[j] - variance[i]; - variance[i] += delta / term_counters[i]; - } + auto delta = vars[i] - variance; + variance += delta / term_counters; } - else + } + else + { + for (size_t i = 0; i < mmds.size(); ++i) { - for (auto j = 0; j < mmds.size(); ++j) - { - auto delta = vars[j] - stat_perm[i]; - stat_perm[i] += delta / term_counters[i]; - variance[i] += delta * (vars[j] - stat_perm[i]); - } + auto delta = vars[i] - permuted_samples_statistic; + permuted_samples_statistic += delta / term_counters; + variance += delta * (vars[i] - permuted_samples_statistic); } - term_counters[i]++; } + term_counters++; next_burst = dm.next(); } @@ -263,17 +232,11 @@ std::pair, SGVector> CMMD::Self::compute_statisti dm.end(); // normalize statistic and variance - std::for_each(statistic.vector, statistic.vector + statistic.vlen, [this](float64_t& v) - { - v = owner.normalize_statistic(v); - }); + statistic = owner.normalize_statistic(statistic); if (variance_estimation_method == EVarianceEstimationMethod::PERMUTATION) { - std::for_each(variance.vector, variance.vector + variance.vlen, [this](float64_t& v) - { - v = owner.normalize_variance(v); - }); + variance = owner.normalize_variance(variance); } return std::make_pair(statistic, variance); @@ -290,33 +253,11 @@ CMMD::~CMMD() float64_t CMMD::compute_statistic() { - return self->compute_statistic_variance().first[0]; -} - -float64_t CMMD::compute_variance() -{ - return self->compute_statistic_variance().second[0]; -} - -SGVector CMMD::compute_statistic(bool multiple_kernels) -{ - if (multiple_kernels) - { - const KernelManager& km = get_kernel_manager(); - auto kernel = km.kernel_at(0); - ASSERT(kernel->get_kernel_type() == K_COMBINED); - } return self->compute_statistic_variance().first; } -SGVector CMMD::compute_variance(bool multiple_kernels) +float64_t CMMD::compute_variance() { - if (multiple_kernels) - { - const KernelManager& km = get_kernel_manager(); - auto kernel = km.kernel_at(0); - ASSERT(kernel->get_kernel_type() == K_COMBINED); - } return self->compute_statistic_variance().second; } @@ -348,11 +289,6 @@ void CMMD::use_gpu(bool gpu) self->use_gpu_for_computation = gpu; } -void CMMD::set_simulate_null(bool simulate_null) -{ - self->simulate_null = simulate_null; -} - void CMMD::set_statistic_type(EStatisticType stype) { self->statistic_type = stype; diff --git a/src/shogun/statistical_testing/MMD.h b/src/shogun/statistical_testing/MMD.h index dc26e01d64d..c18876545f2 100644 --- a/src/shogun/statistical_testing/MMD.h +++ b/src/shogun/statistical_testing/MMD.h @@ -51,18 +51,26 @@ enum class ENullApproximationMethod MMD2_GAMMA }; +enum class EKernelSelectionMethod +{ + MEDIAN_HEURISRIC, + MAXIMIZE_MMD, + MAXIMIZE_POWER +}; + class CMMD : public CTwoSampleTest { using operation = std::function)>; public: CMMD(); virtual ~CMMD(); - +/* + void add_kernel(CKernel *kernel); + void select_kernel(EKernelSelectionMethod kmethod); + CKernel* get_kernel() const; +*/ virtual float64_t compute_statistic() override; - SGVector compute_statistic(bool multiple_kernels); - float64_t compute_variance(); - SGVector compute_variance(bool multiple_kernels); void set_statistic_type(EStatisticType stype); const EStatisticType get_statistic_type() const; @@ -70,15 +78,14 @@ class CMMD : public CTwoSampleTest void set_variance_estimation_method(EVarianceEstimationMethod vmethod); const EVarianceEstimationMethod get_variance_estimation_method() const; - void set_simulate_null(bool simulate_null); void set_num_null_samples(index_t null_samples); const index_t get_num_null_samples() const; - virtual SGVector sample_null() override; - void set_null_approximation_method(ENullApproximationMethod nmethod); const ENullApproximationMethod get_null_approximation_method() const; + virtual SGVector sample_null() override; + void use_gpu(bool gpu); virtual const char* get_name() const;