diff --git a/src/shogun/statistical_testing/KernelSelectionStrategy.cpp b/src/shogun/statistical_testing/KernelSelectionStrategy.cpp index b6608c6cfd3..f28a373a551 100644 --- a/src/shogun/statistical_testing/KernelSelectionStrategy.cpp +++ b/src/shogun/statistical_testing/KernelSelectionStrategy.cpp @@ -50,12 +50,15 @@ struct CKernelSelectionStrategy::Self Self(); KernelManager kernel_mgr; + std::unique_ptr policy; EKernelSelectionMethod method; bool weighted; index_t num_runs; float64_t alpha; + void init_policy(CMMD* estimator); + const static EKernelSelectionMethod default_method; const static bool default_weighted; const static index_t default_num_runs; @@ -67,12 +70,56 @@ const bool CKernelSelectionStrategy::Self::default_weighted=false; const index_t CKernelSelectionStrategy::Self::default_num_runs=10; const float64_t CKernelSelectionStrategy::Self::default_alpha=0.5; -CKernelSelectionStrategy::Self::Self() +CKernelSelectionStrategy::Self::Self() : policy(nullptr), method(default_method), + weighted(default_weighted), num_runs(default_num_runs), alpha(default_alpha) { - method=default_method; - weighted=default_weighted; - num_runs=default_num_runs; - alpha=default_alpha; +} + +void CKernelSelectionStrategy::Self::init_policy(CMMD* estimator) +{ + switch (method) + { + case KSM_MEDIAN_HEURISTIC: + { + REQUIRE(!weighted, "Weighted kernel selection is not possible with MEDIAN_HEURISTIC!\n"); + auto distance=estimator->compute_distance(); + policy=std::unique_ptr(new MedianHeuristic(kernel_mgr, distance)); + SG_UNREF(distance); + } + break; + case KSM_MAXIMIZE_XVALIDATION: + { + REQUIRE(!weighted, "Weighted kernel selection is not possible with MAXIMIZE_XVALIDATION!\n"); + policy=std::unique_ptr(new MaxXValidation(kernel_mgr, estimator, + num_runs, alpha)); + } + break; + case KSM_MAXIMIZE_MMD: + { + if (weighted) + policy=std::unique_ptr(new WeightedMaxMeasure(kernel_mgr, estimator)); + else + policy=std::unique_ptr(new MaxMeasure(kernel_mgr, estimator)); + } + break; + case KSM_MAXIMIZE_POWER: + { + if (weighted) + policy=std::unique_ptr(new WeightedMaxTestPower(kernel_mgr, estimator)); + else + policy=std::unique_ptr(new MaxTestPower(kernel_mgr, estimator)); + } + break; + default: + { + SG_SERROR("Unsupported kernel selection method specified! Accepted strategies are " + "MAXIMIZE_MMD (single, weighted), " + "MAXIMIZE_POWER (single, weighted), " + "MAXIMIZE_XVALIDATION (single) and " + "MEDIAN_HEURISTIC (single)!\n"); + } + break; + } } CKernelSelectionStrategy::CKernelSelectionStrategy() @@ -142,52 +189,20 @@ void CKernelSelectionStrategy::add_kernel(CKernel* kernel) CKernel* CKernelSelectionStrategy::select_kernel(CMMD* estimator) { - SG_DEBUG("Entering!\n"); auto num_kernels=self->kernel_mgr.num_kernels(); REQUIRE(num_kernels>0, "Number of kernels is 0. Please add kernels using add_kernel method!\n"); - SG_DEBUG("Selecting kernels from a total of %d kernels!\n", num_kernels); - std::unique_ptr policy=nullptr; - switch (self->method) - { - case KSM_MEDIAN_HEURISTIC: - { - REQUIRE(!self->weighted, "Weighted kernel selection is not possible with MEDIAN_HEURISTIC!\n"); - auto distance=estimator->compute_distance(); - policy=std::unique_ptr(new MedianHeuristic(self->kernel_mgr, distance)); - SG_UNREF(distance); -// estimator->set_train_test_ratio(0); - } - break; - case KSM_MAXIMIZE_XVALIDATION: - { - REQUIRE(!self->weighted, "Weighted kernel selection is not possible with MAXIMIZE_XVALIDATION!\n"); - policy=std::unique_ptr(new MaxXValidation(self->kernel_mgr, estimator, - self->num_runs, self->alpha)); - } - break; - case KSM_MAXIMIZE_MMD: - if (self->weighted) - policy=std::unique_ptr(new WeightedMaxMeasure(self->kernel_mgr, estimator)); - else - policy=std::unique_ptr(new MaxMeasure(self->kernel_mgr, estimator)); - break; - case KSM_MAXIMIZE_POWER: - if (self->weighted) - policy=std::unique_ptr(new WeightedMaxTestPower(self->kernel_mgr, estimator)); - else - policy=std::unique_ptr(new MaxTestPower(self->kernel_mgr, estimator)); - break; - default: - SG_ERROR("Unsupported kernel selection method specified! " - "Presently only accepted values are MAXIMIZE_MMD, MAXIMIZE_POWER and MEDIAN_HEURISTIC!\n"); - break; - } + self->init_policy(estimator); + ASSERT(self->policy!=nullptr); + + return self->policy->select_kernel(); +} - ASSERT(policy!=nullptr); - SG_DEBUG("Leaving!\n"); - return policy->select_kernel(); +void CKernelSelectionStrategy::erase_intermediate_results() +{ + self->policy=nullptr; + self->kernel_mgr.clear(); } const char* CKernelSelectionStrategy::get_name() const diff --git a/src/shogun/statistical_testing/KernelSelectionStrategy.h b/src/shogun/statistical_testing/KernelSelectionStrategy.h index 6de57c1e098..2f8dcfa1b8c 100644 --- a/src/shogun/statistical_testing/KernelSelectionStrategy.h +++ b/src/shogun/statistical_testing/KernelSelectionStrategy.h @@ -75,6 +75,7 @@ class CKernelSelectionStrategy : public CSGObject void add_kernel(CKernel* kernel); CKernel* select_kernel(CMMD* estimator); virtual const char* get_name() const; + void erase_intermediate_results(); private: struct Self; std::unique_ptr self;