From 37f34f6716cf4488dee492cd9961fc8d3a480892 Mon Sep 17 00:00:00 2001 From: lambday Date: Thu, 30 Jun 2016 18:00:59 +0100 Subject: [PATCH] added multi kernel permutation test cross validation --- src/shogun/kernel/ShiftInvariantKernel.h | 7 +- .../statistical_testing/QuadraticTimeMMD.cpp | 36 +-- .../internals/KernelManager.cpp | 53 ++++ .../internals/KernelManager.h | 5 + .../internals/mmd/MultiKernelMMD.cpp | 22 +- ...tiKernelPermutationTestCrossValidation.cpp | 278 ++++++++++++++++++ ...ultiKernelPermutationTestCrossValidation.h | 91 ++++++ .../internals/MaxCrossValidation.cpp | 35 ++- 8 files changed, 466 insertions(+), 61 deletions(-) create mode 100644 src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.cpp create mode 100644 src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.h diff --git a/src/shogun/kernel/ShiftInvariantKernel.h b/src/shogun/kernel/ShiftInvariantKernel.h index c310ef91e77..3f75646a45d 100644 --- a/src/shogun/kernel/ShiftInvariantKernel.h +++ b/src/shogun/kernel/ShiftInvariantKernel.h @@ -41,10 +41,7 @@ namespace shogun namespace internal { - namespace mmd - { - class MultiKernelMMD; - } + class KernelManager; } /** @brief Base class for the family of kernel functions that only depend on @@ -58,7 +55,7 @@ namespace internal class CShiftInvariantKernel: public CKernel { - friend class internal::mmd::MultiKernelMMD; + friend class internal::KernelManager; public: /** Default constructor. */ diff --git a/src/shogun/statistical_testing/QuadraticTimeMMD.cpp b/src/shogun/statistical_testing/QuadraticTimeMMD.cpp index 762f9957fff..25acff7bf8a 100644 --- a/src/shogun/statistical_testing/QuadraticTimeMMD.cpp +++ b/src/shogun/statistical_testing/QuadraticTimeMMD.cpp @@ -35,8 +35,7 @@ #include #include #include -#include -#include +#include #include #include #include @@ -550,39 +549,22 @@ void CQuadraticTimeMMD::precompute_kernel_matrix(bool precompute) SGVector CQuadraticTimeMMD::compute_statistic(const internal::KernelManager& kernel_mgr) { SG_DEBUG("Entering"); - REQUIRE(kernel_mgr.same_distance_type(), "The kernels have to have same distance type!\n"); REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels (%d) have to be greater than 0!\n", kernel_mgr.num_kernels()); - const auto& data_mgr=get_data_mgr(); - const index_t nx=data_mgr.num_samples_at(0); - const index_t ny=data_mgr.num_samples_at(1); - MultiKernelMMD compute(nx, ny, get_statistic_type()); - - CDistance* distance=nullptr; - CShiftInvariantKernel* kernel_0=dynamic_cast(kernel_mgr.kernel_at(0)); - REQUIRE(kernel_0, "Kernel (%s) must be of CShiftInvariantKernel type!\n", kernel_mgr.kernel_at(0)->get_name()); - if (kernel_0->get_distance_type()==D_EUCLIDEAN) - { - auto euclidean_distance=new CEuclideanDistance(); - euclidean_distance->set_disable_sqrt(true); - distance=euclidean_distance; - } - else if (kernel_0->get_distance_type()==D_MANHATTAN) - { - auto manhattan_distance=new CManhattanMetric(); - distance=manhattan_distance; - } - else - { - SG_ERROR("Unsupported distance type!\n"); - } + CDistance* distance=kernel_mgr.get_distance_instance(); SG_REF(distance); + + const index_t nx=get_num_samples_p(); + const index_t ny=get_num_samples_q(); + + MultiKernelMMD compute(nx, ny, get_statistic_type()); compute.set_distance(compute_joint_distance(distance)); SGVector result=compute(kernel_mgr); - SG_UNREF(distance); for (auto i=0; i #include #include +#include +#include #include #include @@ -162,3 +164,54 @@ bool KernelManager::same_distance_type() const } return same; } + +CDistance* KernelManager::get_distance_instance() const +{ + REQUIRE(same_distance_type(), "Distance types for all the kernels are not the same!\n"); + + CDistance* distance=nullptr; + CShiftInvariantKernel* kernel_0=dynamic_cast(kernel_at(0)); + REQUIRE(kernel_0, "Kernel (%s) must be of CShiftInvariantKernel type!\n", kernel_at(0)->get_name()); + if (kernel_0->get_distance_type()==D_EUCLIDEAN) + { + auto euclidean_distance=new CEuclideanDistance(); + euclidean_distance->set_disable_sqrt(true); + distance=euclidean_distance; + } + else if (kernel_0->get_distance_type()==D_MANHATTAN) + { + auto manhattan_distance=new CManhattanMetric(); + distance=manhattan_distance; + } + else + { + SG_SERROR("Unsupported distance type!\n"); + } + return distance; +} + +void KernelManager::set_precomputed_distance(CCustomDistance* distance) const +{ + for (size_t i=0; i(kernel); + REQUIRE(shift_inv_kernel!=nullptr, "Kernel instance (was %s) must be of CShiftInvarintKernel type!\n", kernel->get_name()); + shift_inv_kernel->m_precomputed_distance=distance; + shift_inv_kernel->num_lhs=distance->get_num_vec_lhs(); + shift_inv_kernel->num_rhs=distance->get_num_vec_rhs(); + } +} + +void KernelManager::unset_precomputed_distance() const +{ + for (size_t i=0; i(kernel); + REQUIRE(shift_inv_kernel!=nullptr, "Kernel instance (was %s) must be of CShiftInvarintKernel type!\n", kernel->get_name()); + shift_inv_kernel->m_precomputed_distance=nullptr; + shift_inv_kernel->num_lhs=0; + shift_inv_kernel->num_rhs=0; + } +} diff --git a/src/shogun/statistical_testing/internals/KernelManager.h b/src/shogun/statistical_testing/internals/KernelManager.h index 1faaef303b4..ee9f817a8e4 100644 --- a/src/shogun/statistical_testing/internals/KernelManager.h +++ b/src/shogun/statistical_testing/internals/KernelManager.h @@ -40,6 +40,8 @@ namespace shogun { class CKernel; +class CDistance; +class CCustomDistance; class CCustomKernel; namespace internal @@ -63,6 +65,9 @@ class KernelManager void clear(); bool same_distance_type() const; + CDistance* get_distance_instance() const; + void set_precomputed_distance(CCustomDistance* distance) const; + void unset_precomputed_distance() const; private: std::vector > m_kernels; std::vector > m_precomputed_kernels; diff --git a/src/shogun/statistical_testing/internals/mmd/MultiKernelMMD.cpp b/src/shogun/statistical_testing/internals/mmd/MultiKernelMMD.cpp index 78810da4f33..cf81991e206 100644 --- a/src/shogun/statistical_testing/internals/mmd/MultiKernelMMD.cpp +++ b/src/shogun/statistical_testing/internals/mmd/MultiKernelMMD.cpp @@ -85,15 +85,8 @@ void MultiKernelMMD::add_term(terms_t& t, float32_t val, index_t i, index_t j) c SGVector MultiKernelMMD::operator()(const KernelManager& kernel_mgr) const { SG_SDEBUG("Entering!\n"); - for (size_t i=0; i(kernel_mgr.kernel_at(i)); - REQUIRE(kernel!=nullptr, "Kernel instance (was %s) must be of CShiftInvarintKernel type!\n", - kernel_mgr.kernel_at(i)->get_name()); - kernel->m_precomputed_distance=m_distance.get(); - kernel->num_lhs=n_x+n_y; - kernel->num_rhs=n_x+n_y; - } + REQUIRE(m_distance, "Distance instace is not set!\n"); + kernel_mgr.set_precomputed_distance(m_distance.get()); SGVector result(kernel_mgr.num_kernels()); #pragma omp parallel for @@ -145,16 +138,7 @@ SGVector MultiKernelMMD::operator()(const KernelManager& kernel_mgr) SG_SDEBUG("result[%d] = %f!\n", k, result[k]); } - for (size_t i=0; i(kernel_mgr.kernel_at(i)); - REQUIRE(kernel!=nullptr, "Kernel instance (was %s) must be of CShiftInvarintKernel type!\n", - kernel_mgr.kernel_at(i)->get_name()); - kernel->m_precomputed_distance=nullptr; - kernel->num_lhs=0; - kernel->num_rhs=0; - } - + kernel_mgr.unset_precomputed_distance(); SG_SDEBUG("Leaving!\n"); return result; } diff --git a/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.cpp b/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.cpp new file mode 100644 index 00000000000..8a787eb37cf --- /dev/null +++ b/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.cpp @@ -0,0 +1,278 @@ +/* + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (w) 2016 Soumyajit De + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// TODO remove +#include +#include + +using Eigen::MatrixXd; +using Eigen::Map; +using std::cout; +using std::endl; +// TODO remove + +using namespace shogun; +using namespace internal; +using namespace mmd; + +struct MultiKernelPermutationTestCrossValidation::terms_t +{ + std::array term{}; + std::array diag{}; +}; + +MultiKernelPermutationTestCrossValidation::MultiKernelPermutationTestCrossValidation(index_t nx, index_t ny, index_t nns, EStatisticType type) +: n_x(nx), n_y(ny), num_null_samples(nns), stype(type) +{ + SG_SDEBUG("number of samples are %d and %d!\n", n_x, n_y); + SG_SDEBUG("Number of null samples is %d!\n", num_null_samples); +} + +MultiKernelPermutationTestCrossValidation::~MultiKernelPermutationTestCrossValidation() +{ +} + +void MultiKernelPermutationTestCrossValidation::add_term(terms_t& terms, float64_t val, index_t i, index_t j) +{ + if (i=n_x && j>=n_x && i<=j) + { + SG_SDEBUG("Adding Kernel(%d,%d)=%f to term_1!\n", i, j, val); + terms.term[1]+=val; + if (i==j) + terms.diag[1]+=val; + } + else if (i>=n_x && j= %d*%d = %d!\n", + rejections.num_rows, num_runs, num_folds, num_runs*num_folds); + kernel_mgr.set_precomputed_distance(m_distance.get()); + + SGVector dummy_labels_x(n_x); + SGVector dummy_labels_y(n_y); + auto kfold_x=some(new CBinaryLabels(dummy_labels_x), num_folds); + auto kfold_y=some(new CBinaryLabels(dummy_labels_y), num_folds); + + SGVector statistic(kernel_mgr.num_kernels()); + SGMatrix null_samples(kernel_mgr.num_kernels(), num_null_samples); + Map null_samples_map(null_samples.data(), null_samples.num_rows, null_samples.num_cols); + + for (auto i=0; ibuild_subsets(); + kfold_y->build_subsets(); + for (auto j=0; j x_inds=kfold_x->generate_subset_inverse(j); + SGVector y_inds=kfold_y->generate_subset_inverse(j); + std::for_each(y_inds.data(), y_inds.data()+y_inds.size(), [this](index_t& val) { val += n_x; }); + + SGVector xy_inds(x_inds.size()+y_inds.size()); + std::copy(x_inds.data(), x_inds.data()+x_inds.size(), xy_inds.data()); + std::copy(y_inds.data(), y_inds.data()+y_inds.size(), xy_inds.data()+x_inds.size()); + + SGVector inverted_inds(n_x+n_y); + std::fill(inverted_inds.data(), inverted_inds.data()+n_x+n_y, -1); + for (int idx=0; idxkernel(row, col), inverted_row, inverted_col); + } + } + } + statistic[k]=compute_mmd(stat_terms); + } + + // compute the null samples + for (auto n=0; n(); + stack->add_subset(xy_inds); + + SGVector permutation_inds(xy_inds.size()); + std::iota(permutation_inds.data(), permutation_inds.data()+permutation_inds.size(), 0); + CMath::permute(permutation_inds); + stack->add_subset(permutation_inds); + + SGVector inds=stack->get_last_subset()->get_subset_idx(); + + SGVector inverted_permutation_inds(n_x+n_y); + std::fill(inverted_permutation_inds.data(), inverted_permutation_inds.data()+n_x+n_y, -1); + for (int idx=0; idxkernel(row, col), permuted_row, permuted_col); + } + } + } + null_samples(k, n)=compute_mmd(terms); + } + } + + // transpose the null_samples matrix for faster access + MatrixXd transposed_null_samples=null_samples_map.transpose(); +#pragma omp parallel for + for (size_t k=0; k null_samples_k(transposed_null_samples.col(k).data(), num_null_samples, false); + std::sort(null_samples_k.data(), null_samples_k.data()+null_samples_k.size()); + SG_SDEBUG("statistic=%f\n", statistic[k]); + float64_t idx=null_samples_k.find_position_to_insert(statistic[k]); + auto p_value=1.0-idx/num_null_samples; + SG_SDEBUG("p-value=%f, rejected=%d\n", p_value, p_value measures) +{ + rejections=measures; +} + +void MultiKernelPermutationTestCrossValidation::set_distance(CCustomDistance* distance) +{ + REQUIRE(distance, "Distance instace cannot be NULL!\n"); + SG_REF(distance); + m_distance=std::shared_ptr(distance, [](CCustomDistance* ptr) { SG_UNREF(ptr); }); +} diff --git a/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.h b/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.h new file mode 100644 index 00000000000..6a33efc2928 --- /dev/null +++ b/src/shogun/statistical_testing/internals/mmd/MultiKernelPermutationTestCrossValidation.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (w) 2016 Soumyajit De + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. + */ + +#ifndef MULTIKERNEL_PERMUTATION_TEST_CROSS_VALIDATION +#define MULTIKERNEL_PERMUTATION_TEST_CROSS_VALIDATION + +#include +#include +#include + +namespace shogun +{ + +enum EStatisticType : uint32_t; +class CCustomDistance; + +namespace internal +{ + +class KernelManager; + +namespace mmd +{ + +/** + * @brief class that runs cross-validation test for MMD for multiple kernels. + */ +class MultiKernelPermutationTestCrossValidation +{ +public: + MultiKernelPermutationTestCrossValidation(index_t nx, index_t ny, index_t null_samples, EStatisticType type); + ~MultiKernelPermutationTestCrossValidation(); + + void operator()(const KernelManager& km); + + void set_num_runs(index_t nr); + void set_num_folds(index_t nf); + void set_alpha(index_t alp); + void set_measure_matrix(SGMatrix measures); + void set_distance(CCustomDistance* distance); +private: + struct terms_t; + + void add_term(terms_t&, float64_t kernel, index_t i, index_t j); + float64_t compute_mmd(terms_t&); + + const index_t n_x; + const index_t n_y; + const index_t num_null_samples; + const EStatisticType stype; + index_t num_runs; + index_t num_folds; + float64_t alpha; + SGMatrix rejections; + std::shared_ptr m_distance; +}; + +} + +} + +} + +#endif // MULTIKERNEL_PERMUTATION_TEST_CROSS_VALIDATION diff --git a/src/shogun/statistical_testing/kernelselection/internals/MaxCrossValidation.cpp b/src/shogun/statistical_testing/kernelselection/internals/MaxCrossValidation.cpp index b299ef0c4e3..1152797ff4e 100644 --- a/src/shogun/statistical_testing/kernelselection/internals/MaxCrossValidation.cpp +++ b/src/shogun/statistical_testing/kernelselection/internals/MaxCrossValidation.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -40,6 +41,7 @@ #include #include #include +#include #include using namespace shogun; @@ -81,22 +83,35 @@ void MaxCrossValidation::init_measures() void MaxCrossValidation::compute_measures() { - auto& data_mgr=estimator->get_data_mgr(); - SG_SINFO("Performing %d fold cross-validattion!\n", num_folds); - + SG_SDEBUG("Performing %d fold cross-validattion!\n", num_folds); const size_t num_kernels=kernel_mgr.num_kernels(); CQuadraticTimeMMD* quadratic_time_mmd=dynamic_cast(estimator); if (quadratic_time_mmd) { -// if (kernel_mgr.same_distance_type()) -// { -// // compute distance on estimator and set the distance to the kernels -// MultiKernelPermutationTestCrossValidation compute(num_runs, num_folds, rejections); -// compute(kernel_mgr); -// } -// else + if (kernel_mgr.same_distance_type()) + { + auto Nx=estimator->get_num_samples_p(); + auto Ny=estimator->get_num_samples_q(); + auto num_null_samples=estimator->get_num_null_samples(); + auto stype=estimator->get_statistic_type(); + + MultiKernelPermutationTestCrossValidation compute(Nx, Ny, num_null_samples, stype); + compute.set_num_runs(num_runs); + compute.set_num_folds(num_folds); + compute.set_alpha(alpha); + compute.set_measure_matrix(rejections); + + CDistance* distance=kernel_mgr.get_distance_instance(); + SG_REF(distance); + compute.set_distance(estimator->compute_joint_distance(distance)); + SG_UNREF(distance); + + compute(kernel_mgr); + } + else { + auto& data_mgr=estimator->get_data_mgr(); data_mgr.start(); auto samples=data_mgr.next(); if (!samples.empty())