From b779d7d9e47b1614220dd6cdad4ceff26f26cdfb Mon Sep 17 00:00:00 2001 From: lambday Date: Wed, 20 Apr 2016 04:12:55 +0530 Subject: [PATCH] updated streaming mmd to work with blocks --- src/shogun/statistical_testing/MMD.cpp | 81 +++++++++++++++++--------- src/shogun/statistical_testing/MMD.h | 25 +++++--- 2 files changed, 72 insertions(+), 34 deletions(-) diff --git a/src/shogun/statistical_testing/MMD.cpp b/src/shogun/statistical_testing/MMD.cpp index 030b0f0a062..1882f8f7169 100644 --- a/src/shogun/statistical_testing/MMD.cpp +++ b/src/shogun/statistical_testing/MMD.cpp @@ -28,7 +28,6 @@ * either expressed or implied, of the Shogun Development Team. */ -#include #include #include #include @@ -42,8 +41,11 @@ #include #include #include +#include #include #include +#include +#include #include #include #include @@ -61,8 +63,8 @@ struct CMMD::Self void create_variance_job(); void create_computation_jobs(); - void merge_samples(NextSamples&, std::vector>&) const; - void compute_kernel(ComputationManager&, std::vector>&, CKernel*) const; + void merge_samples(NextSamples&, std::vector&) const; + void compute_kernel(ComputationManager&, std::vector&, CKernel*) const; void compute_jobs(ComputationManager&) const; std::pair compute_statistic_variance(); @@ -134,41 +136,32 @@ void CMMD::Self::create_variance_job() }; } -#define get_block_p(i) next_burst[0][i] -#define get_block_q(i) next_burst[1][i] -void CMMD::Self::merge_samples(NextSamples& next_burst, std::vector>& blocks) const +void CMMD::Self::merge_samples(NextSamples& next_burst, std::vector& blocks) const { blocks.resize(next_burst.num_blocks()); - #pragma omp parallel for - for (size_t i = 0; i < blocks.size(); ++i) + for (size_t i=0; icreate_merged_copy(block_q.get()); - SG_REF(block_p_and_q); - - block_p=nullptr; - block_q=nullptr; - - blocks[i]=std::shared_ptr(block_p_and_q, [](CFeatures* ptr) { SG_UNREF(ptr); }); + auto block_p=next_burst[0][i].get(); + auto block_q=next_burst[1][i].get(); + auto block_p_and_q=FeaturesUtil::create_merged_copy(block_p, block_q); + blocks[i]=block_p_and_q; } + next_burst.clear(); } -#undef get_block_p -#undef get_block_q -void CMMD::Self::compute_kernel(ComputationManager& cm, std::vector>& blocks, CKernel* kernel) const +void CMMD::Self::compute_kernel(ComputationManager& cm, std::vector& blocks, CKernel* kernel) const { - REQUIRE(kernel->get_kernel_type()!=K_CUSTOM, "Underlying kernel cannot be custom for streaming test!\n"); + REQUIRE(kernel->get_kernel_type()!=K_CUSTOM, "Underlying kernel cannot be custom!\n"); cm.num_data(blocks.size()); + const auto& dm=owner.get_data_manager(); #pragma omp parallel for for (size_t i=0; i(static_cast(kernel->clone())); - kernel_clone->init(blocks[i].get(), blocks[i].get()); + kernel_clone->init(blocks[i], blocks[i]); cm.data(i)=kernel_clone->get_kernel_matrix(); kernel_clone->remove_lhs_and_rhs(); } @@ -177,6 +170,7 @@ void CMMD::Self::compute_kernel(ComputationManager& cm, std::vector CMMD::Self::compute_statistic_variance() cm.enqueue_job(statistic_job); cm.enqueue_job(variance_job); - std::vector> blocks; + std::vector blocks; dm.start(); auto next_burst=dm.next(); - while (!next_burst.empty()) { merge_samples(next_burst, blocks); @@ -278,7 +271,7 @@ SGVector CMMD::Self::sample_null() create_statistic_job(); cm.enqueue_job(permutation_job); - std::vector> blocks; + std::vector blocks; dm.start(); auto next_burst=dm.next(); @@ -328,6 +321,31 @@ void CMMD::add_kernel(CKernel* kernel) self->kernel_selection_mgr.push_back(kernel); } +void CMMD::select_kernel(EKernelSelectionMethod kmethod) +{ + SG_DEBUG("Entering!\n"); + SG_DEBUG("Selecting kernels from a total of %d kernels!\n", self->kernel_selection_mgr.num_kernels()); + switch (kmethod) + { + case EKernelSelectionMethod::MAXIMIZE_MMD: + { + MaxMeasure policy(self->kernel_selection_mgr, this); + get_kernel_manager().kernel_at(0)=policy.select_kernel(); + break; + } + case EKernelSelectionMethod::OPTIMIZE_MMD: + { + OptMeasure policy(self->kernel_selection_mgr, this); + get_kernel_manager().kernel_at(0)=policy.select_kernel(); + break; + } + default: + SG_ERROR("Unsupported kernel selection method specified!\n"); + break; + } + SG_DEBUG("Leaving!\n"); +} + float64_t CMMD::compute_statistic() { return self->compute_statistic_variance().first; @@ -338,6 +356,11 @@ float64_t CMMD::compute_variance() return self->compute_statistic_variance().second; } +std::pair CMMD::compute_statistic_variance() +{ + return self->compute_statistic_variance(); +} + SGVector CMMD::sample_null() { return self->sample_null(); @@ -363,6 +386,12 @@ bool CMMD::use_gpu() const return self->use_gpu; } +void CMMD::cleanup() +{ + for (size_t i=0; istatistic_type=stype; diff --git a/src/shogun/statistical_testing/MMD.h b/src/shogun/statistical_testing/MMD.h index 9d62df93ee3..5282e5b4fa6 100644 --- a/src/shogun/statistical_testing/MMD.h +++ b/src/shogun/statistical_testing/MMD.h @@ -31,6 +31,7 @@ #ifndef MMD_H_ #define MMD_H_ +#include #include #include #include @@ -42,6 +43,13 @@ class CKernel; template class SGVector; template class SGMatrix; +namespace internal +{ + +class OptMeasure; + +} + enum class EStatisticType { UNBIASED_FULL, @@ -67,24 +75,29 @@ enum class EKernelSelectionMethod { MEDIAN_HEURISRIC, MAXIMIZE_MMD, + OPTIMIZE_MMD, MAXIMIZE_POWER }; class CMMD : public CTwoSampleTest { using operation=std::function)>; + friend class internal::OptMeasure; public: CMMD(); virtual ~CMMD(); void add_kernel(CKernel *kernel); -/* void select_kernel(EKernelSelectionMethod kmethod); - CKernel* get_kernel() const; -*/ + virtual float64_t compute_statistic() override; virtual float64_t compute_variance(); + virtual SGVector sample_null() override; + + void use_gpu(bool gpu); + void cleanup(); + void set_statistic_type(EStatisticType stype); const EStatisticType get_statistic_type() const; @@ -97,10 +110,6 @@ class CMMD : public CTwoSampleTest void set_null_approximation_method(ENullApproximationMethod nmethod); const ENullApproximationMethod get_null_approximation_method() const; - virtual SGVector sample_null() override; - - void use_gpu(bool gpu); - virtual const char* get_name() const; protected: virtual const operation get_direct_estimation_method() const = 0; @@ -110,7 +119,7 @@ class CMMD : public CTwoSampleTest private: struct Self; std::unique_ptr self; - + virtual std::pair compute_statistic_variance(); }; }