diff --git a/src/shogun/statistical_testing/KernelSelectionStrategy.cpp b/src/shogun/statistical_testing/KernelSelectionStrategy.cpp index f28a373a551..7ddbbb1993a 100644 --- a/src/shogun/statistical_testing/KernelSelectionStrategy.cpp +++ b/src/shogun/statistical_testing/KernelSelectionStrategy.cpp @@ -30,6 +30,8 @@ */ #include +#include +#include #include #include #include @@ -82,9 +84,7 @@ void CKernelSelectionStrategy::Self::init_policy(CMMD* estimator) case KSM_MEDIAN_HEURISTIC: { REQUIRE(!weighted, "Weighted kernel selection is not possible with MEDIAN_HEURISTIC!\n"); - auto distance=estimator->compute_distance(); - policy=std::unique_ptr(new MedianHeuristic(kernel_mgr, distance)); - SG_UNREF(distance); + policy=std::unique_ptr(new MedianHeuristic(kernel_mgr, estimator)); } break; case KSM_MAXIMIZE_XVALIDATION: @@ -205,6 +205,18 @@ void CKernelSelectionStrategy::erase_intermediate_results() self->kernel_mgr.clear(); } +SGMatrix CKernelSelectionStrategy::get_measure_matrix() +{ + REQUIRE(self->policy!=nullptr, "The kernel selection policy is not initialized!\n"); + return self->policy->get_measure_matrix(); +} + +SGVector CKernelSelectionStrategy::get_measure_vector() +{ + REQUIRE(self->policy!=nullptr, "The kernel selection policy is not initialized!\n"); + return self->policy->get_measure_vector(); +} + const char* CKernelSelectionStrategy::get_name() const { return "KernelSelectionStrategy"; diff --git a/src/shogun/statistical_testing/KernelSelectionStrategy.h b/src/shogun/statistical_testing/KernelSelectionStrategy.h index 2f8dcfa1b8c..b2f032d2d9d 100644 --- a/src/shogun/statistical_testing/KernelSelectionStrategy.h +++ b/src/shogun/statistical_testing/KernelSelectionStrategy.h @@ -40,6 +40,8 @@ namespace shogun class CKernel; class CMMD; +template class SGVector; +template class SGMatrix; namespace internal { @@ -76,6 +78,9 @@ class CKernelSelectionStrategy : public CSGObject CKernel* select_kernel(CMMD* estimator); virtual const char* get_name() const; void erase_intermediate_results(); + + SGMatrix get_measure_matrix(); + SGVector get_measure_vector(); private: struct Self; std::unique_ptr self; diff --git a/src/shogun/statistical_testing/internals/KernelSelection.cpp b/src/shogun/statistical_testing/internals/KernelSelection.cpp index 6daea828d98..a9970558ef5 100644 --- a/src/shogun/statistical_testing/internals/KernelSelection.cpp +++ b/src/shogun/statistical_testing/internals/KernelSelection.cpp @@ -29,14 +29,18 @@ * either expressed or implied, of the Shogun Development Team. */ +#include +#include #include #include using namespace shogun; using namespace internal; -KernelSelection::KernelSelection(KernelManager& km) : kernel_mgr(km) +KernelSelection::KernelSelection(KernelManager& km, CMMD* est) : kernel_mgr(km), estimator(est) { + REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels is %d!\n", kernel_mgr.num_kernels()); + REQUIRE(estimator!=nullptr, "Estimator is not set!\n"); } KernelSelection::~KernelSelection() diff --git a/src/shogun/statistical_testing/internals/KernelSelection.h b/src/shogun/statistical_testing/internals/KernelSelection.h index 9a3ae3a4230..c06cd2b4593 100644 --- a/src/shogun/statistical_testing/internals/KernelSelection.h +++ b/src/shogun/statistical_testing/internals/KernelSelection.h @@ -38,6 +38,9 @@ namespace shogun { class CKernel; +class CMMD; +template class SGVector; +template class SGMatrix; namespace internal { @@ -47,13 +50,16 @@ class KernelManager; class KernelSelection { public: - explicit KernelSelection(KernelManager&); + KernelSelection(KernelManager&, CMMD*); KernelSelection(const KernelSelection& other)=delete; virtual ~KernelSelection(); KernelSelection& operator=(const KernelSelection& other)=delete; virtual CKernel* select_kernel()=0; + virtual SGMatrix get_measure_matrix()=0; + virtual SGVector get_measure_vector()=0; protected: const KernelManager& kernel_mgr; + CMMD* estimator; }; } diff --git a/src/shogun/statistical_testing/internals/MaxMeasure.cpp b/src/shogun/statistical_testing/internals/MaxMeasure.cpp index 4409a98cc86..82a23f66d37 100644 --- a/src/shogun/statistical_testing/internals/MaxMeasure.cpp +++ b/src/shogun/statistical_testing/internals/MaxMeasure.cpp @@ -31,6 +31,7 @@ #include #include +#include #include #include #include @@ -39,7 +40,7 @@ using namespace shogun; using namespace internal; -MaxMeasure::MaxMeasure(KernelManager& km, CMMD* est) : KernelSelection(km), estimator(est) +MaxMeasure::MaxMeasure(KernelManager& km, CMMD* est) : KernelSelection(km, est) { } @@ -47,29 +48,45 @@ MaxMeasure::~MaxMeasure() { } -SGVector MaxMeasure::compute_measures() +SGVector MaxMeasure::get_measure_vector() { - REQUIRE(estimator!=nullptr, "Estimator is not set!\n"); + return measures; +} + +SGMatrix MaxMeasure::get_measure_matrix() +{ + SG_SNOTIMPLEMENTED; + return SGMatrix(); +} +void MaxMeasure::init_measures() +{ const size_t num_kernels=kernel_mgr.num_kernels(); REQUIRE(num_kernels>0, "Number of kernels is %d!\n", kernel_mgr.num_kernels()); + if (measures.size()!=num_kernels) + measures=SGVector(num_kernels); + std::fill(measures.data(), measures.data()+measures.size(), 0); +} - SGVector result(num_kernels); +void MaxMeasure::compute_measures() +{ + init_measures(); + REQUIRE(estimator!=nullptr, "Estimator is not set!\n"); auto existing_kernel=estimator->get_kernel(); + const size_t num_kernels=kernel_mgr.num_kernels(); for (size_t i=0; iset_kernel(kernel); - result[i]=estimator->compute_statistic(); + measures[i]=estimator->compute_statistic(); estimator->cleanup(); } estimator->set_kernel(existing_kernel); - return result; } CKernel* MaxMeasure::select_kernel() { - SGVector measures=compute_measures(); + compute_measures(); auto max_element=std::max_element(measures.vector, measures.vector+measures.vlen); auto max_idx=std::distance(measures.vector, max_element); SG_SDEBUG("Selected kernel at %d position!\n", max_idx); diff --git a/src/shogun/statistical_testing/internals/MaxMeasure.h b/src/shogun/statistical_testing/internals/MaxMeasure.h index 36652e5455c..2f47d2bd87f 100644 --- a/src/shogun/statistical_testing/internals/MaxMeasure.h +++ b/src/shogun/statistical_testing/internals/MaxMeasure.h @@ -41,6 +41,7 @@ namespace shogun class CKernel; class CMMD; template class SGVector; +template class SGMatrix; namespace internal { @@ -52,10 +53,14 @@ class MaxMeasure : public KernelSelection MaxMeasure(const MaxMeasure& other)=delete; ~MaxMeasure(); MaxMeasure& operator=(const MaxMeasure& other)=delete; - virtual CKernel* select_kernel() override; + virtual CKernel* select_kernel(); + virtual SGVector get_measure_vector(); + virtual SGMatrix get_measure_matrix(); protected: - SGVector compute_measures(); - CMMD* estimator; + virtual void compute_measures(); + SGVector measures; + + virtual void init_measures(); }; } diff --git a/src/shogun/statistical_testing/internals/MaxTestPower.cpp b/src/shogun/statistical_testing/internals/MaxTestPower.cpp index 8510807c50b..f65d1b3f670 100644 --- a/src/shogun/statistical_testing/internals/MaxTestPower.cpp +++ b/src/shogun/statistical_testing/internals/MaxTestPower.cpp @@ -40,7 +40,7 @@ using namespace shogun; using namespace internal; -MaxTestPower::MaxTestPower(KernelManager& km, CMMD* est) : KernelSelection(km), estimator(est), lambda(1E-5) +MaxTestPower::MaxTestPower(KernelManager& km, CMMD* est) : MaxMeasure(km, est), lambda(1E-5) { } @@ -48,28 +48,19 @@ MaxTestPower::~MaxTestPower() { } -SGVector MaxTestPower::compute_measures() +void MaxTestPower::compute_measures() { + init_measures(); REQUIRE(estimator!=nullptr, "Estimator is not set!\n"); - REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels is %d!\n", kernel_mgr.num_kernels()); - - SGVector result(kernel_mgr.num_kernels()); - for (size_t i=0; iget_kernel(); + const size_t num_kernels=kernel_mgr.num_kernels(); + for (size_t i=0; iset_kernel(kernel); auto estimates=estimator->compute_statistic_variance(); - result[i]=estimates.first/CMath::sqrt(estimates.second+lambda); + measures[i]=estimates.first/CMath::sqrt(estimates.second+lambda); estimator->cleanup(); } - return result; -} - -CKernel* MaxTestPower::select_kernel() -{ - SGVector measures=compute_measures(); - auto max_element=std::max_element(measures.vector, measures.vector+measures.vlen); - auto max_idx=std::distance(measures.vector, max_element); - SG_SDEBUG("Selected kernel at %d position!\n", max_idx); - return kernel_mgr.kernel_at(max_idx); + estimator->set_kernel(existing_kernel); } diff --git a/src/shogun/statistical_testing/internals/MaxTestPower.h b/src/shogun/statistical_testing/internals/MaxTestPower.h index 0487ae1f7d0..10164a1d856 100644 --- a/src/shogun/statistical_testing/internals/MaxTestPower.h +++ b/src/shogun/statistical_testing/internals/MaxTestPower.h @@ -33,29 +33,26 @@ #define MAX_TEST_POWER_H__ #include -#include +#include namespace shogun { class CKernel; class CMMD; -template class SGVector; namespace internal { -class MaxTestPower : public KernelSelection +class MaxTestPower : public MaxMeasure { public: MaxTestPower(KernelManager&, CMMD*); MaxTestPower(const MaxTestPower& other)=delete; ~MaxTestPower(); MaxTestPower& operator=(const MaxTestPower& other)=delete; - virtual CKernel* select_kernel() override; protected: - SGVector compute_measures(); - CMMD* estimator; + virtual void compute_measures(); float64_t lambda; }; diff --git a/src/shogun/statistical_testing/internals/MaxXValidation.cpp b/src/shogun/statistical_testing/internals/MaxXValidation.cpp index 820960f8d5d..6a97de144ce 100644 --- a/src/shogun/statistical_testing/internals/MaxXValidation.cpp +++ b/src/shogun/statistical_testing/internals/MaxXValidation.cpp @@ -31,6 +31,7 @@ #include #include +#include #include #include #include @@ -41,11 +42,8 @@ using namespace shogun; using namespace internal; MaxXValidation::MaxXValidation(KernelManager& km, CMMD* est, const index_t& M, const float64_t& alp) -: KernelSelection(km), estimator(est), num_run(M), alpha(alp) +: KernelSelection(km, est), num_run(M), alpha(alp) { - // TODO write a more meaningful error message - REQUIRE(estimator!=nullptr, "Estimator is not set!\n"); - REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels is %d!\n", kernel_mgr.num_kernels()); REQUIRE(num_run>0, "Number of runs is %d!\n", num_run); REQUIRE(alpha>=0.0 && alpha<=1.0, "Threshold is %f!\n", alpha); } @@ -54,6 +52,18 @@ MaxXValidation::~MaxXValidation() { } +SGVector MaxXValidation::get_measure_vector() +{ + SG_SNOTIMPLEMENTED; + return SGVector(); +} + +SGMatrix MaxXValidation::get_measure_matrix() +{ + SG_SNOTIMPLEMENTED; + return SGMatrix(); +} + void MaxXValidation::compute_measures(SGVector& measures, SGVector& term_counters) { const size_t num_kernels=kernel_mgr.num_kernels(); diff --git a/src/shogun/statistical_testing/internals/MaxXValidation.h b/src/shogun/statistical_testing/internals/MaxXValidation.h index a24068ac1d4..418bc5ae9a0 100644 --- a/src/shogun/statistical_testing/internals/MaxXValidation.h +++ b/src/shogun/statistical_testing/internals/MaxXValidation.h @@ -53,9 +53,10 @@ class MaxXValidation : public KernelSelection ~MaxXValidation(); MaxXValidation& operator=(const MaxXValidation& other)=delete; virtual CKernel* select_kernel() override; + virtual SGVector get_measure_vector(); + virtual SGMatrix get_measure_matrix(); protected: void compute_measures(SGVector&, SGVector&); - CMMD* estimator; const index_t num_run; const float64_t alpha; }; diff --git a/src/shogun/statistical_testing/internals/MedianHeuristic.cpp b/src/shogun/statistical_testing/internals/MedianHeuristic.cpp index 5dfdc6c89bf..2a97022c8c7 100644 --- a/src/shogun/statistical_testing/internals/MedianHeuristic.cpp +++ b/src/shogun/statistical_testing/internals/MedianHeuristic.cpp @@ -30,31 +30,25 @@ */ #include -#include // TODO remove #include #include #include #include #include +#include #include #include using namespace shogun; using namespace internal; -MedianHeuristic::MedianHeuristic(KernelManager& km, CCustomDistance* dist) : KernelSelection(km), distance(dist) +MedianHeuristic::MedianHeuristic(KernelManager& km, CMMD* est) : KernelSelection(km, est) { - SG_REF(distance); - n=dist->get_num_vec_lhs(); - REQUIRE(distance->get_num_vec_lhs()==distance->get_num_vec_rhs(), - "Distance matrix is supposed to be a square matrix (was of dimension %dX%d)!\n", - distance->get_num_vec_lhs(), distance->get_num_vec_rhs()); - for (size_t i=0; iget_kernel_type()==K_GAUSSIAN, - "The underlying kernel has to be a GaussianKernel (was %s)!\n", - kernel_mgr.kernel_at(i)->get_name()); + "The underlying kernel has to be a GaussianKernel (was %s)!\n", + kernel_mgr.kernel_at(i)->get_name()); } } @@ -63,29 +57,50 @@ MedianHeuristic::~MedianHeuristic() SG_UNREF(distance); } -CKernel* MedianHeuristic::select_kernel() +void MedianHeuristic::init_measures() { - std::vector measures((n*(n-1))/2); + distance=estimator->compute_distance(); + SG_REF(distance); + n=distance->get_num_vec_lhs(); + REQUIRE(distance->get_num_vec_lhs()==distance->get_num_vec_rhs(), + "Distance matrix is supposed to be a square matrix (was of dimension %dX%d)!\n", + distance->get_num_vec_lhs(), distance->get_num_vec_rhs()); + measures=SGVector((n*(n-1))/2); size_t write_idx=0; for (auto j=0; jdistance(i, j); } - std::sort(measures.begin(), measures.end()); + std::sort(measures.data(), measures.data()+measures.size()); +} + +SGVector MedianHeuristic::get_measure_vector() +{ + SG_SNOTIMPLEMENTED; + return SGVector(); +} + +SGMatrix MedianHeuristic::get_measure_matrix() +{ + return distance->get_distance_matrix(); +} + +CKernel* MedianHeuristic::select_kernel() +{ + init_measures(); auto median_distance=measures[measures.size()/2]; SG_SDEBUG("kernel width (shogun): %f\n", median_distance); const size_t num_kernels=kernel_mgr.num_kernels(); - measures.resize(num_kernels); + measures=SGVector(num_kernels); for (size_t i=0; i(kernel_mgr.kernel_at(i)); - ASSERT(kernel!=nullptr); + CGaussianKernel *kernel=static_cast(kernel_mgr.kernel_at(i)); measures[i]=CMath::abs(kernel->get_width()-median_distance); } - size_t kernel_idx=std::distance(measures.begin(), std::min_element(measures.begin(), measures.end())); + size_t kernel_idx=std::distance(measures.data(), std::min_element(measures.data(), measures.data()+measures.size())); SG_SDEBUG("Selected kernel at %d position!\n", kernel_idx); return kernel_mgr.kernel_at(kernel_idx); } diff --git a/src/shogun/statistical_testing/internals/MedianHeuristic.h b/src/shogun/statistical_testing/internals/MedianHeuristic.h index 9aa8c8b9c83..e13b0c38338 100644 --- a/src/shogun/statistical_testing/internals/MedianHeuristic.h +++ b/src/shogun/statistical_testing/internals/MedianHeuristic.h @@ -42,6 +42,7 @@ class CKernel; class CMMD; class CCustomDistance; template class SGVector; +template class SGMatrix; namespace internal { @@ -49,13 +50,18 @@ namespace internal class MedianHeuristic : public KernelSelection { public: - MedianHeuristic(KernelManager&, CCustomDistance*); + MedianHeuristic(KernelManager&, CMMD*); MedianHeuristic(const MedianHeuristic& other)=delete; ~MedianHeuristic(); MedianHeuristic& operator=(const MedianHeuristic& other)=delete; virtual CKernel* select_kernel() override; + virtual SGVector get_measure_vector(); + virtual SGMatrix get_measure_matrix(); protected: + void init_measures(); + void compute_measures(); CCustomDistance* distance; + SGVector measures; int32_t n; }; diff --git a/src/shogun/statistical_testing/internals/WeightedMaxMeasure.cpp b/src/shogun/statistical_testing/internals/WeightedMaxMeasure.cpp index a6f550fe27f..52ace838d09 100644 --- a/src/shogun/statistical_testing/internals/WeightedMaxMeasure.cpp +++ b/src/shogun/statistical_testing/internals/WeightedMaxMeasure.cpp @@ -49,19 +49,32 @@ WeightedMaxMeasure::~WeightedMaxMeasure() { } -CKernel* WeightedMaxMeasure::select_kernel() +void WeightedMaxMeasure::compute_measures() { + MaxMeasure::compute_measures(); const size_t num_kernels=kernel_mgr.num_kernels(); - SGVector measures=compute_measures(); - SGMatrix Q(num_kernels, num_kernels); + if (Q.num_rows!=num_kernels || Q.num_cols!=num_kernels) + Q=SGMatrix(num_kernels, num_kernels); std::fill(Q.data(), Q.data()+Q.size(), 0); for (size_t i=0; i WeightedMaxMeasure::get_measure_matrix() +{ + return Q; +} + +CKernel* WeightedMaxMeasure::select_kernel() +{ + init_measures(); + compute_measures(); OptimizationSolver solver(measures, Q); SGVector weights=solver.solve(); CCombinedKernel* kernel=new CCombinedKernel(); + const size_t num_kernels=kernel_mgr.num_kernels(); for (size_t i=0; iappend_kernel(kernel_mgr.kernel_at(i))) diff --git a/src/shogun/statistical_testing/internals/WeightedMaxMeasure.h b/src/shogun/statistical_testing/internals/WeightedMaxMeasure.h index be17efe38dd..8b365836d30 100644 --- a/src/shogun/statistical_testing/internals/WeightedMaxMeasure.h +++ b/src/shogun/statistical_testing/internals/WeightedMaxMeasure.h @@ -51,7 +51,11 @@ class WeightedMaxMeasure : public MaxMeasure WeightedMaxMeasure(const WeightedMaxMeasure& other)=delete; ~WeightedMaxMeasure(); WeightedMaxMeasure& operator=(const WeightedMaxMeasure& other)=delete; - virtual CKernel* select_kernel() override; + virtual CKernel* select_kernel(); + virtual SGMatrix get_measure_matrix(); +protected: + virtual void compute_measures(); + SGMatrix Q; }; } diff --git a/src/shogun/statistical_testing/internals/WeightedMaxTestPower.cpp b/src/shogun/statistical_testing/internals/WeightedMaxTestPower.cpp index 7c2860d0ac7..6f8e7e3aa5e 100644 --- a/src/shogun/statistical_testing/internals/WeightedMaxTestPower.cpp +++ b/src/shogun/statistical_testing/internals/WeightedMaxTestPower.cpp @@ -41,7 +41,7 @@ using namespace shogun; using namespace internal; -WeightedMaxTestPower::WeightedMaxTestPower(KernelManager& km, CMMD* est) : MaxTestPower(km, est) +WeightedMaxTestPower::WeightedMaxTestPower(KernelManager& km, CMMD* est) : WeightedMaxMeasure(km, est), lambda(1E-5) { } @@ -49,28 +49,15 @@ WeightedMaxTestPower::~WeightedMaxTestPower() { } -CKernel* WeightedMaxTestPower::select_kernel() +void WeightedMaxTestPower::init_measures() { - REQUIRE(estimator!=nullptr, "Estimator is not set!\n"); - REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels is %d!\n", kernel_mgr.num_kernels()); - - auto estimates=estimator->compute_statistic_and_Q(kernel_mgr); - SGVector measures=estimates.first; - SGMatrix Q=estimates.second; +} +void WeightedMaxTestPower::compute_measures() +{ + const auto& estimates=estimator->compute_statistic_and_Q(kernel_mgr); + measures=estimates.first; + Q=estimates.second; for (index_t i=0; i weights=solver.solve(); - - CCombinedKernel* kernel=new CCombinedKernel(); - for (size_t i=0; iappend_kernel(kernel_mgr.kernel_at(i))) - SG_SERROR("Error while creating a combined kernel! Please contact Shogun developers!\n"); - } - kernel->set_subkernel_weights(weights); - SG_SDEBUG("Created a weighted kernel!\n"); - return kernel; } diff --git a/src/shogun/statistical_testing/internals/WeightedMaxTestPower.h b/src/shogun/statistical_testing/internals/WeightedMaxTestPower.h index df54320d062..ec964812f3e 100644 --- a/src/shogun/statistical_testing/internals/WeightedMaxTestPower.h +++ b/src/shogun/statistical_testing/internals/WeightedMaxTestPower.h @@ -33,7 +33,7 @@ #define WEIGHTED_MAX_TEST_POWER_H__ #include -#include +#include namespace shogun { @@ -45,14 +45,17 @@ template class SGVector; namespace internal { -class WeightedMaxTestPower : public MaxTestPower +class WeightedMaxTestPower : public WeightedMaxMeasure { public: WeightedMaxTestPower(KernelManager&, CMMD*); WeightedMaxTestPower(const WeightedMaxTestPower& other)=delete; ~WeightedMaxTestPower(); WeightedMaxTestPower& operator=(const WeightedMaxTestPower& other)=delete; - virtual CKernel* select_kernel() override; +protected: + virtual void init_measures(); + virtual void compute_measures(); + float64_t lambda; }; }