From a9f069381e88461a644d59799028777194afcd02 Mon Sep 17 00:00:00 2001 From: lambday Date: Thu, 7 Jul 2016 16:52:21 +0100 Subject: [PATCH] added multikernel() API in quadratic time MMD --- src/interfaces/modular/Statistics.i | 2 + src/interfaces/modular/Statistics_includes.i | 1 + .../MultiKernelQuadraticTimeMMD.cpp | 203 ++++++++++++++++++ .../MultiKernelQuadraticTimeMMD.h | 88 ++++++++ .../statistical_testing/QuadraticTimeMMD.cpp | 61 +----- .../statistical_testing/QuadraticTimeMMD.h | 9 +- .../mmd/MultiKernelPermutationTest.cpp | 14 +- .../mmd/MultiKernelPermutationTest.h | 4 +- .../kernelselection/internals/MaxMeasure.cpp | 3 +- .../QuadraticTimeMMD_unittest.cc | 11 +- 10 files changed, 316 insertions(+), 80 deletions(-) create mode 100644 src/shogun/statistical_testing/MultiKernelQuadraticTimeMMD.cpp create mode 100644 src/shogun/statistical_testing/MultiKernelQuadraticTimeMMD.h diff --git a/src/interfaces/modular/Statistics.i b/src/interfaces/modular/Statistics.i index 6c866e6aef8..4aa3ef4a696 100644 --- a/src/interfaces/modular/Statistics.i +++ b/src/interfaces/modular/Statistics.i @@ -21,6 +21,7 @@ %rename(LinearTimeMMD) CLinearTimeMMD; %rename(BTestMMD) CBTestMMD; %rename(QuadraticTimeMMD) CQuadraticTimeMMD; +%rename(MultiKernelQuadraticTimeMMD) CMultiKernelQuadraticTimeMMD; %rename(KernelSelectionStrategy) CKernelSelectionStrategy; /* Include Class Headers to make them visible from within the target language */ @@ -33,4 +34,5 @@ %include %include %include +%include %include diff --git a/src/interfaces/modular/Statistics_includes.i b/src/interfaces/modular/Statistics_includes.i index f70052fb46c..899a1d5ec53 100644 --- a/src/interfaces/modular/Statistics_includes.i +++ b/src/interfaces/modular/Statistics_includes.i @@ -8,6 +8,7 @@ #include #include #include + #include #include %} diff --git a/src/shogun/statistical_testing/MultiKernelQuadraticTimeMMD.cpp b/src/shogun/statistical_testing/MultiKernelQuadraticTimeMMD.cpp new file mode 100644 index 00000000000..20eb0498e81 --- /dev/null +++ b/src/shogun/statistical_testing/MultiKernelQuadraticTimeMMD.cpp @@ -0,0 +1,203 @@ +/* + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (w) 2012 - 2013 Heiko Strathmann + * Written (w) 2014 - 2016 Soumyajit De + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace shogun; +using namespace internal; +using namespace mmd; +using std::unique_ptr; + +struct CMultiKernelQuadraticTimeMMD::Self +{ + Self(CQuadraticTimeMMD* owner); + void update_pairwise_distance(CDistance *distance); + + CQuadraticTimeMMD *m_owner; + unique_ptr m_pairwise_distance; + EDistanceType m_dtype; + KernelManager m_kernel_mgr; +}; + +CMultiKernelQuadraticTimeMMD::Self::Self(CQuadraticTimeMMD *owner) : m_owner(owner), + m_pairwise_distance(nullptr), m_dtype(D_UNKNOWN) +{ +} + +void CMultiKernelQuadraticTimeMMD::Self::update_pairwise_distance(CDistance* distance) +{ + ASSERT(distance); + if (m_dtype==distance->get_distance_type()) + { + ASSERT(m_pairwise_distance!=nullptr); + SG_SINFO("Precomputed distance exists for %s!\n", distance->get_name()); + } + else + { + auto precomputed_distance=m_owner->compute_joint_distance(distance); + m_pairwise_distance=unique_ptr(precomputed_distance); + m_dtype=distance->get_distance_type(); + } +} + +CMultiKernelQuadraticTimeMMD::CMultiKernelQuadraticTimeMMD() : CSGObject() +{ + self=unique_ptr(new Self(nullptr)); +} + +CMultiKernelQuadraticTimeMMD::CMultiKernelQuadraticTimeMMD(CQuadraticTimeMMD* owner) : CSGObject() +{ + self=unique_ptr(new Self(owner)); +} + +CMultiKernelQuadraticTimeMMD::~CMultiKernelQuadraticTimeMMD() +{ + cleanup(); +} + +void CMultiKernelQuadraticTimeMMD::add_kernel(CShiftInvariantKernel *kernel) +{ + ASSERT(self->m_owner); + REQUIRE(kernel, "Kernel instance cannot be NULL!\n"); + self->m_kernel_mgr.push_back(kernel); +} + +void CMultiKernelQuadraticTimeMMD::cleanup() +{ + ASSERT(self->m_owner); + self->m_kernel_mgr.clear(); + self->m_pairwise_distance=nullptr; + self->m_dtype=D_UNKNOWN; +} + +SGVector CMultiKernelQuadraticTimeMMD::statistic() +{ + ASSERT(self->m_owner); + return statistic(self->m_kernel_mgr); +} + +SGVector CMultiKernelQuadraticTimeMMD::variance_h0() +{ + ASSERT(self->m_owner); + SG_NOTIMPLEMENTED; + return SGVector(); +} + +SGVector CMultiKernelQuadraticTimeMMD::variance_h1() +{ + ASSERT(self->m_owner); + SG_NOTIMPLEMENTED; + return SGVector(); +} + +SGVector CMultiKernelQuadraticTimeMMD::p_values() +{ + ASSERT(self->m_owner); + return p_values(self->m_kernel_mgr); +} + +SGVector CMultiKernelQuadraticTimeMMD::perform_test(float64_t alpha) +{ + SGVector pvalues=p_values(); + SGVector rejections(pvalues.size()); + for (auto i=0; i CMultiKernelQuadraticTimeMMD::statistic(const KernelManager& kernel_mgr) +{ + SG_DEBUG("Entering"); + REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels (%d) have to be greater than 0!\n", kernel_mgr.num_kernels()); + + const auto nx=self->m_owner->get_num_samples_p(); + const auto ny=self->m_owner->get_num_samples_q(); + const auto stype = self->m_owner->get_statistic_type(); + + CDistance* distance=kernel_mgr.get_distance_instance(); + self->update_pairwise_distance(distance); + kernel_mgr.set_precomputed_distance(self->m_pairwise_distance.get()); + SG_UNREF(distance); + + MultiKernelMMD compute(nx, ny, stype); + SGVector result=compute(kernel_mgr); + + kernel_mgr.unset_precomputed_distance(); + + for (auto i=0; im_owner->normalize_statistic(result[i]); + + SG_DEBUG("Leaving"); + return result; +} + +SGVector CMultiKernelQuadraticTimeMMD::p_values(const KernelManager& kernel_mgr) +{ + SG_DEBUG("Entering"); + REQUIRE(self->m_owner->get_null_approximation_method()==ENullApproximationMethod::NAM_PERMUTATION, + "Multi-kernel tests requires the H0 approximation method to be PERMUTATION!\n"); + + REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels (%d) have to be greater than 0!\n", kernel_mgr.num_kernels()); + + const auto nx=self->m_owner->get_num_samples_p(); + const auto ny=self->m_owner->get_num_samples_q(); + const auto stype = self->m_owner->get_statistic_type(); + const auto num_null_samples = self->m_owner->get_num_null_samples(); + + CDistance* distance=kernel_mgr.get_distance_instance(); + self->update_pairwise_distance(distance); + kernel_mgr.set_precomputed_distance(self->m_pairwise_distance.get()); + SG_UNREF(distance); + + MultiKernelPermutationTest compute(nx, ny, num_null_samples, stype); + SGVector result=compute(kernel_mgr); + + kernel_mgr.unset_precomputed_distance(); + + SG_DEBUG("Leaving"); + return result; +} + +const char* CMultiKernelQuadraticTimeMMD::get_name() const +{ + return "MultiKernelQuadraticTimeMMD"; +} diff --git a/src/shogun/statistical_testing/MultiKernelQuadraticTimeMMD.h b/src/shogun/statistical_testing/MultiKernelQuadraticTimeMMD.h new file mode 100644 index 00000000000..d85c067c4b8 --- /dev/null +++ b/src/shogun/statistical_testing/MultiKernelQuadraticTimeMMD.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (w) 2012 - 2013 Heiko Strathmann + * Written (w) 2014 - 2016 Soumyajit De + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. + */ + +#ifndef MULTI_KERNEL_QUADRATIC_TIME_MMD_H_ +#define MULTI_KERNEL_QUADRATIC_TIME_MMD_H_ + +#include +#include + +namespace shogun +{ + +class CFeatures; +class CQuadraticTimeMMD; +class CShiftInvariantKernel; +template class SGVector; + +namespace internal +{ +class KernelManager; +class MaxMeasure; +} + +/** + * @brief Class that performs quadratic time MMD test optimized for multiple + * shift-invariant kernels. If the kernels are not shift-invariant, then the + * class CQuadraticTimeMMD should be used multiple times instead of this one. + * + * This implementation assumes that features are never updated. If new features + * are to be used, new instance of this class should be created. + */ +class CMultiKernelQuadraticTimeMMD : public CSGObject +{ + friend class CQuadraticTimeMMD; + friend class internal::MaxMeasure; +private: + CMultiKernelQuadraticTimeMMD(CQuadraticTimeMMD* owner); +public: + CMultiKernelQuadraticTimeMMD(); + virtual ~CMultiKernelQuadraticTimeMMD(); + void add_kernel(CShiftInvariantKernel *kernel); + void cleanup(); + + SGVector statistic(); + SGVector variance_h0(); + SGVector variance_h1(); + + SGVector p_values(); + SGVector perform_test(float64_t alpha); + + virtual const char* get_name() const; +private: + struct Self; + std::unique_ptr self; + SGVector statistic(const internal::KernelManager& kernel_mgr); + SGVector p_values(const internal::KernelManager& kernel_mgr); +}; + +} +#endif // MULTI_KERNEL_QUADRATIC_TIME_MMD_H_ diff --git a/src/shogun/statistical_testing/QuadraticTimeMMD.cpp b/src/shogun/statistical_testing/QuadraticTimeMMD.cpp index d549162e228..3848ec2fa7b 100644 --- a/src/shogun/statistical_testing/QuadraticTimeMMD.cpp +++ b/src/shogun/statistical_testing/QuadraticTimeMMD.cpp @@ -39,6 +39,7 @@ #include #include #include +#include #include #include #include @@ -74,6 +75,7 @@ struct CQuadraticTimeMMD::Self void compute_jobs(ComputationManager&) const; CQuadraticTimeMMD& owner; + std::unique_ptr multi_kernel; index_t num_eigenvalues; bool precompute; bool is_kernel_initialized; @@ -275,11 +277,13 @@ SGVector CQuadraticTimeMMD::Self::sample_null() CQuadraticTimeMMD::CQuadraticTimeMMD() : CMMD() { self=std::unique_ptr(new Self(*this)); + self->multi_kernel=std::unique_ptr(new CMultiKernelQuadraticTimeMMD(this)); } CQuadraticTimeMMD::CQuadraticTimeMMD(CFeatures* samples_from_p, CFeatures* samples_from_q) : CMMD() { self=std::unique_ptr(new Self(*this)); + self->multi_kernel=std::unique_ptr(new CMultiKernelQuadraticTimeMMD(this)); set_p(samples_from_p); set_q(samples_from_q); } @@ -329,16 +333,6 @@ float64_t CQuadraticTimeMMD::compute_variance() return self->compute_statistic_variance().second; } -SGVector CQuadraticTimeMMD::compute_multiple() -{ - return compute_statistic(get_strategy()->get_kernel_mgr()); -} - -SGVector CQuadraticTimeMMD::perform_test_multiple(float64_t alpha) -{ - return perform_test_multiple(get_strategy()->get_kernel_mgr(), alpha); -} - float64_t CQuadraticTimeMMD::compute_p_value(float64_t statistic) { SG_DEBUG("Entering\n"); @@ -557,52 +551,9 @@ void CQuadraticTimeMMD::precompute_kernel_matrix(bool precompute) self->precompute=precompute; } -SGVector CQuadraticTimeMMD::compute_statistic(const internal::KernelManager& kernel_mgr) +CMultiKernelQuadraticTimeMMD* CQuadraticTimeMMD::multikernel() { - SG_DEBUG("Entering"); - REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels (%d) have to be greater than 0!\n", kernel_mgr.num_kernels()); - - CDistance* distance=kernel_mgr.get_distance_instance(); - SG_REF(distance); - kernel_mgr.set_precomputed_distance(compute_joint_distance(distance)); - SG_UNREF(distance); - - const index_t nx=get_num_samples_p(); - const index_t ny=get_num_samples_q(); - - MultiKernelMMD compute(nx, ny, get_statistic_type()); - SGVector result=compute(kernel_mgr); - - kernel_mgr.unset_precomputed_distance(); - - for (auto i=0; i CQuadraticTimeMMD::perform_test_multiple(const internal::KernelManager& kernel_mgr, float64_t alpha) -{ - SG_DEBUG("Entering"); - REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels (%d) have to be greater than 0!\n", kernel_mgr.num_kernels()); - - CDistance* distance=kernel_mgr.get_distance_instance(); - SG_REF(distance); - kernel_mgr.set_precomputed_distance(compute_joint_distance(distance)); - SG_UNREF(distance); - - const index_t nx=get_num_samples_p(); - const index_t ny=get_num_samples_q(); - - MultiKernelPermutationTest compute(nx, ny, get_num_null_samples(), get_statistic_type()); - compute.set_alpha(alpha); - SGVector result=compute(kernel_mgr); - - kernel_mgr.unset_precomputed_distance(); - - SG_DEBUG("Leaving"); - return result; + return self->multi_kernel.get(); } const char* CQuadraticTimeMMD::get_name() const diff --git a/src/shogun/statistical_testing/QuadraticTimeMMD.h b/src/shogun/statistical_testing/QuadraticTimeMMD.h index ae480640aab..c92f2ba60f3 100644 --- a/src/shogun/statistical_testing/QuadraticTimeMMD.h +++ b/src/shogun/statistical_testing/QuadraticTimeMMD.h @@ -39,6 +39,7 @@ namespace shogun { +class CMultiKernelQuadraticTimeMMD; template class SGVector; namespace internal @@ -49,7 +50,7 @@ class MaxMeasure; class CQuadraticTimeMMD : public CMMD { - friend class internal::MaxMeasure; + friend class CMultiKernelQuadraticTimeMMD; public: typedef std::function)> operation; CQuadraticTimeMMD(); @@ -61,8 +62,7 @@ class CQuadraticTimeMMD : public CMMD virtual float64_t compute_statistic(); virtual float64_t compute_variance(); - virtual SGVector compute_multiple(); - virtual SGVector perform_test_multiple(float64_t alpha); + CMultiKernelQuadraticTimeMMD* multikernel(); virtual SGVector sample_null(); void spectrum_set_num_eigenvalues(index_t num_eigenvalues); @@ -82,9 +82,6 @@ class CQuadraticTimeMMD : public CMMD virtual const float64_t normalize_variance(float64_t variance) const; SGVector gamma_fit_null(); SGVector spectrum_sample_null(); - - SGVector compute_statistic(const internal::KernelManager& kernel_mgr); - SGVector perform_test_multiple(const internal::KernelManager& kernel_mgr, float64_t alpha); }; } diff --git a/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTest.cpp b/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTest.cpp index b16d13acf2c..cf44d02fd2c 100644 --- a/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTest.cpp +++ b/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTest.cpp @@ -122,7 +122,7 @@ float64_t MultiKernelPermutationTest::compute_mmd(terms_t& terms) return terms.term[0]+terms.term[1]-2*terms.term[2]; } -SGVector MultiKernelPermutationTest::operator()(const KernelManager& kernel_mgr) +SGVector MultiKernelPermutationTest::operator()(const KernelManager& kernel_mgr) { SG_SDEBUG("Entering!\n"); @@ -135,7 +135,7 @@ SGVector MultiKernelPermutationTest::operator()(const KernelManager& kerne } SGVector null_samples(num_null_samples); - SGVector result(kernel_mgr.num_kernels()); + SGVector result(kernel_mgr.num_kernels()); const index_t size=n_x+n_y; SGVector km(size*(size+1)/2); @@ -177,17 +177,11 @@ SGVector MultiKernelPermutationTest::operator()(const KernelManager& kerne float64_t idx=null_samples.find_position_to_insert(statistic); SG_SDEBUG("Kernel(%d): index=%f\n", k, idx); auto p_value=1.0-idx/num_null_samples; - bool rejected=p_value operator()(const KernelManager& km); - void set_alpha(float64_t alp); + SGVector operator()(const KernelManager& km); private: struct terms_t; @@ -69,7 +68,6 @@ class MultiKernelPermutationTest const index_t n_y; const index_t num_null_samples; const EStatisticType stype; - float64_t alpha; SGVector permuted_inds; std::vector > inverted_permuted_inds; diff --git a/src/shogun/statistical_testing/kernelselection/internals/MaxMeasure.cpp b/src/shogun/statistical_testing/kernelselection/internals/MaxMeasure.cpp index f36b9e5f264..8e2e3b3e1ed 100644 --- a/src/shogun/statistical_testing/kernelselection/internals/MaxMeasure.cpp +++ b/src/shogun/statistical_testing/kernelselection/internals/MaxMeasure.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include @@ -74,7 +75,7 @@ void MaxMeasure::compute_measures() REQUIRE(estimator!=nullptr, "Estimator is not set!\n"); CQuadraticTimeMMD* mmd=dynamic_cast(estimator); if (mmd!=nullptr && kernel_mgr.same_distance_type()) - measures=mmd->compute_statistic(kernel_mgr); + measures=mmd->multikernel()->statistic(kernel_mgr); else { init_measures(); diff --git a/tests/unit/statistical_testing/QuadraticTimeMMD_unittest.cc b/tests/unit/statistical_testing/QuadraticTimeMMD_unittest.cc index e195427990e..36b6092cfaa 100644 --- a/tests/unit/statistical_testing/QuadraticTimeMMD_unittest.cc +++ b/tests/unit/statistical_testing/QuadraticTimeMMD_unittest.cc @@ -38,6 +38,7 @@ #include #include #include +#include #include using namespace shogun; @@ -514,9 +515,9 @@ TEST(QuadraticTimeMMD, compute_multiple) for (auto i=0, sigma=-5; iadd_kernel(new CGaussianKernel(10, tau)); + mmd->multikernel()->add_kernel(new CGaussianKernel(10, tau)); } - SGVector mmd_multiple=mmd->compute_multiple(); + SGVector mmd_multiple=mmd->multikernel()->statistic(); SGVector mmd_single(num_kernels); for (auto i=0, sigma=-5; iadd_kernel(new CGaussianKernel(10, tau)); + mmd->multikernel()->add_kernel(new CGaussianKernel(10, tau)); } sg_rand->set_seed(12345); - SGVector rejections_multiple=mmd->perform_test_multiple(alpha); + SGVector rejections_multiple=mmd->multikernel()->perform_test(alpha); SGVector rejections_single(num_kernels); - sg_rand->set_seed(12345); for (auto i=0, sigma=-5; iset_seed(12345); auto mmd2=some(feat_p, feat_q); float64_t tau=pow(2, sigma); mmd2->set_kernel(new CGaussianKernel(10, tau));