From 501e46fc10631f2e3795c58a96b43a84be4e6592 Mon Sep 17 00:00:00 2001 From: lambday Date: Thu, 2 Jun 2016 18:05:13 +0100 Subject: [PATCH] refactored the interface for specifying kernel selection strategy --- .../KernelSelectionStrategy.cpp | 23 +++++++++++++- .../KernelSelectionStrategy.h | 6 ++-- src/shogun/statistical_testing/MMD.cpp | 30 +++++++++++++++++-- src/shogun/statistical_testing/MMD.h | 9 ++++-- .../KernelSelectionMaxMMD_unittest.cc | 4 +-- .../KernelSelectionMaxPower_unittest.cc | 4 +-- .../KernelSelectionMaxXValidation_unittest.cc | 2 +- ...KernelSelectionMedianHeuristic_unittest.cc | 4 +-- 8 files changed, 66 insertions(+), 16 deletions(-) diff --git a/src/shogun/statistical_testing/KernelSelectionStrategy.cpp b/src/shogun/statistical_testing/KernelSelectionStrategy.cpp index 96c13d1df51..b93729e1976 100644 --- a/src/shogun/statistical_testing/KernelSelectionStrategy.cpp +++ b/src/shogun/statistical_testing/KernelSelectionStrategy.cpp @@ -70,7 +70,7 @@ struct CKernelSelectionStrategy::Self const EKernelSelectionMethod CKernelSelectionStrategy::Self::default_method=KSM_AUTO; const bool CKernelSelectionStrategy::Self::default_weighted=false; const index_t CKernelSelectionStrategy::Self::default_num_runs=10; -const float64_t CKernelSelectionStrategy::Self::default_alpha=0.5; +const float64_t CKernelSelectionStrategy::Self::default_alpha=0.05; CKernelSelectionStrategy::Self::Self() : policy(nullptr), method(default_method), weighted(default_weighted), num_runs(default_num_runs), alpha(default_alpha) @@ -148,6 +148,27 @@ CKernelSelectionStrategy::CKernelSelectionStrategy(EKernelSelectionMethod method self->alpha=alpha; } +//CKernelSelectionStrategy::CKernelSelectionStrategy(const CKernelSelectionStrategy& other) +//{ +// init(); +// self->method=other.self->method; +// self->num_runs=other.self->num_runs; +// self->alpha=other.self->alpha; +// for (size_t i=0; ikernel_mgr.num_kernels(); ++i) +// self->kernel_mgr.push_back(other.self->kernel_mgr.kernel_at(i)); +//} +// +//CKernelSelectionStrategy& CKernelSelectionStrategy::operator=(const CKernelSelectionStrategy& other) +//{ +// init(); +// self->method=other.self->method; +// self->num_runs=other.self->num_runs; +// self->alpha=other.self->alpha; +// for (size_t i=0; ikernel_mgr.num_kernels(); ++i) +// self->kernel_mgr.push_back(other.self->kernel_mgr.kernel_at(i)); +// return *this; +//} + void CKernelSelectionStrategy::init() { self=std::unique_ptr(new Self()); diff --git a/src/shogun/statistical_testing/KernelSelectionStrategy.h b/src/shogun/statistical_testing/KernelSelectionStrategy.h index eac801046a0..4c89ec4b3ca 100644 --- a/src/shogun/statistical_testing/KernelSelectionStrategy.h +++ b/src/shogun/statistical_testing/KernelSelectionStrategy.h @@ -48,13 +48,13 @@ namespace internal class KernelManager; } -enum EKernelSelectionMethod +enum EKernelSelectionMethod : uint32_t { KSM_MEDIAN_HEURISTIC, KSM_MAXIMIZE_MMD, KSM_MAXIMIZE_POWER, KSM_MAXIMIZE_XVALIDATION, - KSM_AUTO + KSM_AUTO=KSM_MAXIMIZE_POWER }; class CKernelSelectionStrategy : public CSGObject @@ -67,7 +67,7 @@ class CKernelSelectionStrategy : public CSGObject CKernelSelectionStrategy(EKernelSelectionMethod method, index_t num_runs, float64_t alpha); CKernelSelectionStrategy(const CKernelSelectionStrategy& other)=delete; CKernelSelectionStrategy& operator=(const CKernelSelectionStrategy& other)=delete; - ~CKernelSelectionStrategy(); + virtual ~CKernelSelectionStrategy(); CKernelSelectionStrategy& use_method(EKernelSelectionMethod method); CKernelSelectionStrategy& use_num_runs(index_t num_runs); diff --git a/src/shogun/statistical_testing/MMD.cpp b/src/shogun/statistical_testing/MMD.cpp index b656473fa98..1289492b6a3 100644 --- a/src/shogun/statistical_testing/MMD.cpp +++ b/src/shogun/statistical_testing/MMD.cpp @@ -41,6 +41,7 @@ #include #include #include +#include #include #include #include @@ -456,13 +457,36 @@ CMMD::~CMMD() { } -void CMMD::set_kernel_selection_strategy(CKernelSelectionStrategy* strategy) +void CMMD::set_kernel_selection_strategy(EKernelSelectionMethod method) { - SG_REF(strategy); + auto strategy=std::shared_ptr(new CKernelSelectionStrategy(method)); const auto& kernel_mgr=self->strategy->get_kernel_mgr(); for (size_t i=0; iadd_kernel(kernel_mgr.kernel_at(i)); - self->strategy=std::shared_ptr(strategy, [](CKernelSelectionStrategy* ptr) { SG_UNREF(ptr); }); + self->strategy=strategy; +} + +void CMMD::set_kernel_selection_strategy(EKernelSelectionMethod method, bool weighted) +{ + auto strategy=std::shared_ptr(new CKernelSelectionStrategy(method, weighted)); + const auto& kernel_mgr=self->strategy->get_kernel_mgr(); + for (size_t i=0; iadd_kernel(kernel_mgr.kernel_at(i)); + self->strategy=strategy; +} + +void CMMD::set_kernel_selection_strategy(EKernelSelectionMethod method, index_t num_runs, float64_t alpha) +{ + auto strategy=std::shared_ptr(new CKernelSelectionStrategy(method, num_runs, alpha)); + const auto& kernel_mgr=self->strategy->get_kernel_mgr(); + for (size_t i=0; iadd_kernel(kernel_mgr.kernel_at(i)); + self->strategy=strategy; +} + +CKernelSelectionStrategy* CMMD::get_kernel_selection_strategy() const +{ + return self->strategy.get(); } void CMMD::add_kernel(CKernel* kernel) diff --git a/src/shogun/statistical_testing/MMD.h b/src/shogun/statistical_testing/MMD.h index 56f27b525da..2072f18101f 100644 --- a/src/shogun/statistical_testing/MMD.h +++ b/src/shogun/statistical_testing/MMD.h @@ -35,7 +35,6 @@ #include #include #include -#include namespace shogun { @@ -44,6 +43,8 @@ class CKernel; class CCustomDistance; template class SGVector; template class SGMatrix; +class CKernelSelectionStrategy; +enum EKernelSelectionMethod : uint32_t; namespace internal { @@ -87,7 +88,11 @@ class CMMD : public CTwoSampleTest CMMD(); virtual ~CMMD(); - void set_kernel_selection_strategy(CKernelSelectionStrategy* strategy); + void set_kernel_selection_strategy(EKernelSelectionMethod method); + void set_kernel_selection_strategy(EKernelSelectionMethod method, bool weighted); + void set_kernel_selection_strategy(EKernelSelectionMethod method, index_t num_runs, float64_t alpha); + CKernelSelectionStrategy* get_kernel_selection_strategy() const; + void add_kernel(CKernel *kernel); void select_kernel(); diff --git a/tests/unit/statistical_testing/KernelSelectionMaxMMD_unittest.cc b/tests/unit/statistical_testing/KernelSelectionMaxMMD_unittest.cc index 62c42b9ac35..cbd77d8599d 100644 --- a/tests/unit/statistical_testing/KernelSelectionMaxMMD_unittest.cc +++ b/tests/unit/statistical_testing/KernelSelectionMaxMMD_unittest.cc @@ -73,7 +73,7 @@ TEST(KernelSelectionMaxMMD, single_kernel) mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice)); } - mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_MMD)); + mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_MMD); mmd->set_train_test_mode(true); mmd->select_kernel(); auto selected_kernel=static_cast(mmd->get_kernel()); @@ -110,7 +110,7 @@ TEST(KernelSelectionMaxMMD, weighted_kernel) mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice)); } - mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_MMD, true)); + mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_MMD, true); mmd->set_train_test_mode(true); mmd->select_kernel(); auto weighted_kernel=dynamic_cast(mmd->get_kernel()); diff --git a/tests/unit/statistical_testing/KernelSelectionMaxPower_unittest.cc b/tests/unit/statistical_testing/KernelSelectionMaxPower_unittest.cc index c7bcb7d4929..fa7ff51716b 100644 --- a/tests/unit/statistical_testing/KernelSelectionMaxPower_unittest.cc +++ b/tests/unit/statistical_testing/KernelSelectionMaxPower_unittest.cc @@ -73,7 +73,7 @@ TEST(KernelSelectionMaxPower, single_kernel) mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice)); } - mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_POWER)); + mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_POWER); mmd->set_train_test_mode(true); mmd->select_kernel(); auto selected_kernel=static_cast(mmd->get_kernel()); @@ -110,7 +110,7 @@ TEST(KernelSelectionMaxPower, weighted_kernel) mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice)); } - mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_POWER, true)); + mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_POWER, true); mmd->set_train_test_mode(true); mmd->select_kernel(); auto weighted_kernel=dynamic_cast(mmd->get_kernel()); diff --git a/tests/unit/statistical_testing/KernelSelectionMaxXValidation_unittest.cc b/tests/unit/statistical_testing/KernelSelectionMaxXValidation_unittest.cc index 429503fce72..9f17ac926da 100644 --- a/tests/unit/statistical_testing/KernelSelectionMaxXValidation_unittest.cc +++ b/tests/unit/statistical_testing/KernelSelectionMaxXValidation_unittest.cc @@ -76,7 +76,7 @@ TEST(KernelSelectionMaxXValidation, single_kernel) mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice)); } - mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_XVALIDATION, 5, 0.05)); + mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_XVALIDATION, 5, 0.05); mmd->set_train_test_mode(true); mmd->set_train_test_ratio(4); mmd->select_kernel(); diff --git a/tests/unit/statistical_testing/KernelSelectionMedianHeuristic_unittest.cc b/tests/unit/statistical_testing/KernelSelectionMedianHeuristic_unittest.cc index 7ec92a9de05..6c3298668c9 100644 --- a/tests/unit/statistical_testing/KernelSelectionMedianHeuristic_unittest.cc +++ b/tests/unit/statistical_testing/KernelSelectionMedianHeuristic_unittest.cc @@ -69,7 +69,7 @@ TEST(KernelSelectionMedianHeuristic, quadratic_time_mmd) mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice)); } - mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MEDIAN_HEURISTIC)); + mmd->set_kernel_selection_strategy(KSM_MEDIAN_HEURISTIC); mmd->set_train_test_mode(true); mmd->select_kernel(); auto selected_kernel=static_cast(mmd->get_kernel()); @@ -105,7 +105,7 @@ TEST(KernelSelectionMedianHeuristic, linear_time_mmd) mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice)); } - mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MEDIAN_HEURISTIC)); + mmd->set_kernel_selection_strategy(KSM_MEDIAN_HEURISTIC); mmd->set_train_test_mode(true); mmd->select_kernel(); auto selected_kernel=static_cast(mmd->get_kernel());