From 186ec3628de76e6f0b192157c99349bfe53df197 Mon Sep 17 00:00:00 2001 From: lambday Date: Thu, 26 May 2016 19:29:26 +0100 Subject: [PATCH] refactored maximize cross validation kernel selection method --- .../internals/KernelSelection.h | 2 + .../internals/MaxMeasure.h | 3 +- .../internals/MaxXValidation.cpp | 73 +++++++++++-------- .../internals/MaxXValidation.h | 5 +- .../internals/MedianHeuristic.cpp | 14 ++-- .../internals/MedianHeuristic.h | 4 +- 6 files changed, 60 insertions(+), 41 deletions(-) diff --git a/src/shogun/statistical_testing/internals/KernelSelection.h b/src/shogun/statistical_testing/internals/KernelSelection.h index c06cd2b4593..7a2ac9d986b 100644 --- a/src/shogun/statistical_testing/internals/KernelSelection.h +++ b/src/shogun/statistical_testing/internals/KernelSelection.h @@ -60,6 +60,8 @@ class KernelSelection protected: const KernelManager& kernel_mgr; CMMD* estimator; + virtual void init_measures()=0; + virtual void compute_measures()=0; }; } diff --git a/src/shogun/statistical_testing/internals/MaxMeasure.h b/src/shogun/statistical_testing/internals/MaxMeasure.h index 2f47d2bd87f..ba8efee57ce 100644 --- a/src/shogun/statistical_testing/internals/MaxMeasure.h +++ b/src/shogun/statistical_testing/internals/MaxMeasure.h @@ -57,10 +57,9 @@ class MaxMeasure : public KernelSelection virtual SGVector get_measure_vector(); virtual SGMatrix get_measure_matrix(); protected: + virtual void init_measures(); virtual void compute_measures(); SGVector measures; - - virtual void init_measures(); }; } diff --git a/src/shogun/statistical_testing/internals/MaxXValidation.cpp b/src/shogun/statistical_testing/internals/MaxXValidation.cpp index 6a97de144ce..b38b10fab6f 100644 --- a/src/shogun/statistical_testing/internals/MaxXValidation.cpp +++ b/src/shogun/statistical_testing/internals/MaxXValidation.cpp @@ -54,59 +54,70 @@ MaxXValidation::~MaxXValidation() SGVector MaxXValidation::get_measure_vector() { - SG_SNOTIMPLEMENTED; - return SGVector(); + return measures; } SGMatrix MaxXValidation::get_measure_matrix() { - SG_SNOTIMPLEMENTED; - return SGMatrix(); + return rejections; } -void MaxXValidation::compute_measures(SGVector& measures, SGVector& term_counters) +void MaxXValidation::init_measures() { - const size_t num_kernels=kernel_mgr.num_kernels(); - for (size_t i=0; iset_kernel(kernel); - bool rejected=estimator->compute_p_value(estimator->compute_statistic())cleanup(); - } + const index_t num_kernels=kernel_mgr.num_kernels(); + auto& dm=estimator->get_data_manager(); + const index_t N=dm.get_num_folds(); + REQUIRE(N!=0, "Number of folds is not set!\n"); + if (rejections.num_rows!=N*num_run || rejections.num_cols!=num_kernels) + rejections=SGMatrix(N*num_run, num_kernels); + std::fill(rejections.data(), rejections.data()+rejections.size(), 0); + if (measures.size()!=num_kernels) + measures=SGVector(num_kernels); + std::fill(measures.data(), measures.data()+measures.size(), 0); } -CKernel* MaxXValidation::select_kernel() +void MaxXValidation::compute_measures() { auto& dm=estimator->get_data_manager(); dm.set_xvalidation_mode(true); - auto existing_kernel=estimator->get_kernel(); const index_t N=dm.get_num_folds(); - // TODO write a more meaningful error message - REQUIRE(N!=0, "Number of folds is not set!\n"); SG_SINFO("Performing %d fold cross-validattion!\n", N); - // train mode is already ON by now! set by the caller - SGVector measures(kernel_mgr.num_kernels()); - std::fill(measures.data(), measures.data()+measures.size(), 0); - SGVector term_counters(measures.size()); - std::fill(term_counters.data(), term_counters.data()+term_counters.size(), 1); + + const size_t num_kernels=kernel_mgr.num_kernels(); + auto existing_kernel=estimator->get_kernel(); for (auto i=0; iset_kernel(kernel); + rejections(i*N+j, k)=estimator->compute_p_value(estimator->compute_statistic())cleanup(); + } } } - - estimator->set_kernel(existing_kernel); dm.set_xvalidation_mode(false); + estimator->set_kernel(existing_kernel); - auto min_element=std::min_element(measures.vector, measures.vector+measures.vlen); - auto min_idx=std::distance(measures.vector, min_element); - SG_SDEBUG("Selected kernel at %d position!\n", min_idx); - return kernel_mgr.kernel_at(min_idx); + for (auto j=0; j get_measure_vector(); virtual SGMatrix get_measure_matrix(); protected: - void compute_measures(SGVector&, SGVector&); + virtual void init_measures(); + virtual void compute_measures(); const index_t num_run; const float64_t alpha; + SGMatrix rejections; + SGVector measures; }; } diff --git a/src/shogun/statistical_testing/internals/MedianHeuristic.cpp b/src/shogun/statistical_testing/internals/MedianHeuristic.cpp index 2a97022c8c7..ca3430aef5e 100644 --- a/src/shogun/statistical_testing/internals/MedianHeuristic.cpp +++ b/src/shogun/statistical_testing/internals/MedianHeuristic.cpp @@ -58,13 +58,18 @@ MedianHeuristic::~MedianHeuristic() } void MedianHeuristic::init_measures() +{ + SG_SNOTIMPLEMENTED; +} + +void MedianHeuristic::compute_measures() { distance=estimator->compute_distance(); SG_REF(distance); n=distance->get_num_vec_lhs(); REQUIRE(distance->get_num_vec_lhs()==distance->get_num_vec_rhs(), - "Distance matrix is supposed to be a square matrix (was of dimension %dX%d)!\n", - distance->get_num_vec_lhs(), distance->get_num_vec_rhs()); + "Distance matrix is supposed to be a square matrix (was of dimension %dX%d)!\n", + distance->get_num_vec_lhs(), distance->get_num_vec_rhs()); measures=SGVector((n*(n-1))/2); size_t write_idx=0; for (auto j=0; j MedianHeuristic::get_measure_vector() { - SG_SNOTIMPLEMENTED; - return SGVector(); + return measures; } SGMatrix MedianHeuristic::get_measure_matrix() @@ -88,7 +92,7 @@ SGMatrix MedianHeuristic::get_measure_matrix() CKernel* MedianHeuristic::select_kernel() { - init_measures(); + compute_measures(); auto median_distance=measures[measures.size()/2]; SG_SDEBUG("kernel width (shogun): %f\n", median_distance); diff --git a/src/shogun/statistical_testing/internals/MedianHeuristic.h b/src/shogun/statistical_testing/internals/MedianHeuristic.h index e13b0c38338..1086c1ac258 100644 --- a/src/shogun/statistical_testing/internals/MedianHeuristic.h +++ b/src/shogun/statistical_testing/internals/MedianHeuristic.h @@ -58,8 +58,8 @@ class MedianHeuristic : public KernelSelection virtual SGVector get_measure_vector(); virtual SGMatrix get_measure_matrix(); protected: - void init_measures(); - void compute_measures(); + virtual void init_measures(); + virtual void compute_measures(); CCustomDistance* distance; SGVector measures; int32_t n;