Skip to content

Commit

Permalink
save the kernel selection measures for later query
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 1, 2016
1 parent 499f633 commit 0b54a9c
Show file tree
Hide file tree
Showing 16 changed files with 165 additions and 89 deletions.
18 changes: 15 additions & 3 deletions src/shogun/statistical_testing/KernelSelectionStrategy.cpp
Expand Up @@ -30,6 +30,8 @@
*/

#include <shogun/io/SGIO.h>
#include <shogun/lib/SGVector.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/distance/CustomDistance.h>
#include <shogun/statistical_testing/MMD.h>
#include <shogun/statistical_testing/KernelSelectionStrategy.h>
Expand Down Expand Up @@ -82,9 +84,7 @@ void CKernelSelectionStrategy::Self::init_policy(CMMD* estimator)
case KSM_MEDIAN_HEURISTIC:
{
REQUIRE(!weighted, "Weighted kernel selection is not possible with MEDIAN_HEURISTIC!\n");
auto distance=estimator->compute_distance();
policy=std::unique_ptr<MedianHeuristic>(new MedianHeuristic(kernel_mgr, distance));
SG_UNREF(distance);
policy=std::unique_ptr<MedianHeuristic>(new MedianHeuristic(kernel_mgr, estimator));
}
break;
case KSM_MAXIMIZE_XVALIDATION:
Expand Down Expand Up @@ -205,6 +205,18 @@ void CKernelSelectionStrategy::erase_intermediate_results()
self->kernel_mgr.clear();
}

SGMatrix<float64_t> CKernelSelectionStrategy::get_measure_matrix()
{
REQUIRE(self->policy!=nullptr, "The kernel selection policy is not initialized!\n");
return self->policy->get_measure_matrix();
}

SGVector<float64_t> CKernelSelectionStrategy::get_measure_vector()
{
REQUIRE(self->policy!=nullptr, "The kernel selection policy is not initialized!\n");
return self->policy->get_measure_vector();
}

const char* CKernelSelectionStrategy::get_name() const
{
return "KernelSelectionStrategy";
Expand Down
5 changes: 5 additions & 0 deletions src/shogun/statistical_testing/KernelSelectionStrategy.h
Expand Up @@ -40,6 +40,8 @@ namespace shogun

class CKernel;
class CMMD;
template <class> class SGVector;
template <class> class SGMatrix;

namespace internal
{
Expand Down Expand Up @@ -76,6 +78,9 @@ class CKernelSelectionStrategy : public CSGObject
CKernel* select_kernel(CMMD* estimator);
virtual const char* get_name() const;
void erase_intermediate_results();

SGMatrix<float64_t> get_measure_matrix();
SGVector<float64_t> get_measure_vector();
private:
struct Self;
std::unique_ptr<Self> self;
Expand Down
6 changes: 5 additions & 1 deletion src/shogun/statistical_testing/internals/KernelSelection.cpp
Expand Up @@ -29,14 +29,18 @@
* either expressed or implied, of the Shogun Development Team.
*/

#include <shogun/io/SGIO.h>
#include <shogun/statistical_testing/MMD.h>
#include <shogun/statistical_testing/internals/KernelManager.h>
#include <shogun/statistical_testing/internals/KernelSelection.h>

using namespace shogun;
using namespace internal;

KernelSelection::KernelSelection(KernelManager& km) : kernel_mgr(km)
KernelSelection::KernelSelection(KernelManager& km, CMMD* est) : kernel_mgr(km), estimator(est)
{
REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels is %d!\n", kernel_mgr.num_kernels());
REQUIRE(estimator!=nullptr, "Estimator is not set!\n");
}

KernelSelection::~KernelSelection()
Expand Down
8 changes: 7 additions & 1 deletion src/shogun/statistical_testing/internals/KernelSelection.h
Expand Up @@ -38,6 +38,9 @@ namespace shogun
{

class CKernel;
class CMMD;
template <class> class SGVector;
template <class> class SGMatrix;

namespace internal
{
Expand All @@ -47,13 +50,16 @@ class KernelManager;
class KernelSelection
{
public:
explicit KernelSelection(KernelManager&);
KernelSelection(KernelManager&, CMMD*);
KernelSelection(const KernelSelection& other)=delete;
virtual ~KernelSelection();
KernelSelection& operator=(const KernelSelection& other)=delete;
virtual CKernel* select_kernel()=0;
virtual SGMatrix<float64_t> get_measure_matrix()=0;
virtual SGVector<float64_t> get_measure_vector()=0;
protected:
const KernelManager& kernel_mgr;
CMMD* estimator;
};

}
Expand Down
31 changes: 24 additions & 7 deletions src/shogun/statistical_testing/internals/MaxMeasure.cpp
Expand Up @@ -31,6 +31,7 @@

#include <algorithm>
#include <shogun/lib/SGVector.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/kernel/Kernel.h>
#include <shogun/statistical_testing/MMD.h>
#include <shogun/statistical_testing/internals/MaxMeasure.h>
Expand All @@ -39,37 +40,53 @@
using namespace shogun;
using namespace internal;

MaxMeasure::MaxMeasure(KernelManager& km, CMMD* est) : KernelSelection(km), estimator(est)
MaxMeasure::MaxMeasure(KernelManager& km, CMMD* est) : KernelSelection(km, est)
{
}

MaxMeasure::~MaxMeasure()
{
}

SGVector<float64_t> MaxMeasure::compute_measures()
SGVector<float64_t> MaxMeasure::get_measure_vector()
{
REQUIRE(estimator!=nullptr, "Estimator is not set!\n");
return measures;
}

SGMatrix<float64_t> MaxMeasure::get_measure_matrix()
{
SG_SNOTIMPLEMENTED;
return SGMatrix<float64_t>();
}

void MaxMeasure::init_measures()
{
const size_t num_kernels=kernel_mgr.num_kernels();
REQUIRE(num_kernels>0, "Number of kernels is %d!\n", kernel_mgr.num_kernels());
if (measures.size()!=num_kernels)
measures=SGVector<float64_t>(num_kernels);
std::fill(measures.data(), measures.data()+measures.size(), 0);
}

SGVector<float64_t> result(num_kernels);
void MaxMeasure::compute_measures()
{
init_measures();
REQUIRE(estimator!=nullptr, "Estimator is not set!\n");
auto existing_kernel=estimator->get_kernel();
const size_t num_kernels=kernel_mgr.num_kernels();
for (size_t i=0; i<num_kernels; ++i)
{
auto kernel=kernel_mgr.kernel_at(i);
estimator->set_kernel(kernel);
result[i]=estimator->compute_statistic();
measures[i]=estimator->compute_statistic();
estimator->cleanup();
}
estimator->set_kernel(existing_kernel);
return result;
}

CKernel* MaxMeasure::select_kernel()
{
SGVector<float64_t> measures=compute_measures();
compute_measures();
auto max_element=std::max_element(measures.vector, measures.vector+measures.vlen);
auto max_idx=std::distance(measures.vector, max_element);
SG_SDEBUG("Selected kernel at %d position!\n", max_idx);
Expand Down
11 changes: 8 additions & 3 deletions src/shogun/statistical_testing/internals/MaxMeasure.h
Expand Up @@ -41,6 +41,7 @@ namespace shogun
class CKernel;
class CMMD;
template <typename T> class SGVector;
template <typename T> class SGMatrix;

namespace internal
{
Expand All @@ -52,10 +53,14 @@ class MaxMeasure : public KernelSelection
MaxMeasure(const MaxMeasure& other)=delete;
~MaxMeasure();
MaxMeasure& operator=(const MaxMeasure& other)=delete;
virtual CKernel* select_kernel() override;
virtual CKernel* select_kernel();
virtual SGVector<float64_t> get_measure_vector();
virtual SGMatrix<float64_t> get_measure_matrix();
protected:
SGVector<float64_t> compute_measures();
CMMD* estimator;
virtual void compute_measures();
SGVector<float64_t> measures;

virtual void init_measures();
};

}
Expand Down
25 changes: 8 additions & 17 deletions src/shogun/statistical_testing/internals/MaxTestPower.cpp
Expand Up @@ -40,36 +40,27 @@
using namespace shogun;
using namespace internal;

MaxTestPower::MaxTestPower(KernelManager& km, CMMD* est) : KernelSelection(km), estimator(est), lambda(1E-5)
MaxTestPower::MaxTestPower(KernelManager& km, CMMD* est) : MaxMeasure(km, est), lambda(1E-5)
{
}

MaxTestPower::~MaxTestPower()
{
}

SGVector<float64_t> MaxTestPower::compute_measures()
void MaxTestPower::compute_measures()
{
init_measures();
REQUIRE(estimator!=nullptr, "Estimator is not set!\n");
REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels is %d!\n", kernel_mgr.num_kernels());

SGVector<float64_t> result(kernel_mgr.num_kernels());
for (size_t i=0; i<kernel_mgr.num_kernels(); ++i)
auto existing_kernel=estimator->get_kernel();
const size_t num_kernels=kernel_mgr.num_kernels();
for (size_t i=0; i<num_kernels; ++i)
{
auto kernel=kernel_mgr.kernel_at(i);
estimator->set_kernel(kernel);
auto estimates=estimator->compute_statistic_variance();
result[i]=estimates.first/CMath::sqrt(estimates.second+lambda);
measures[i]=estimates.first/CMath::sqrt(estimates.second+lambda);
estimator->cleanup();
}
return result;
}

CKernel* MaxTestPower::select_kernel()
{
SGVector<float64_t> measures=compute_measures();
auto max_element=std::max_element(measures.vector, measures.vector+measures.vlen);
auto max_idx=std::distance(measures.vector, max_element);
SG_SDEBUG("Selected kernel at %d position!\n", max_idx);
return kernel_mgr.kernel_at(max_idx);
estimator->set_kernel(existing_kernel);
}
9 changes: 3 additions & 6 deletions src/shogun/statistical_testing/internals/MaxTestPower.h
Expand Up @@ -33,29 +33,26 @@
#define MAX_TEST_POWER_H__

#include <shogun/lib/common.h>
#include <shogun/statistical_testing/internals/KernelSelection.h>
#include <shogun/statistical_testing/internals/MaxMeasure.h>

namespace shogun
{

class CKernel;
class CMMD;
template <typename T> class SGVector;

namespace internal
{

class MaxTestPower : public KernelSelection
class MaxTestPower : public MaxMeasure
{
public:
MaxTestPower(KernelManager&, CMMD*);
MaxTestPower(const MaxTestPower& other)=delete;
~MaxTestPower();
MaxTestPower& operator=(const MaxTestPower& other)=delete;
virtual CKernel* select_kernel() override;
protected:
SGVector<float64_t> compute_measures();
CMMD* estimator;
virtual void compute_measures();
float64_t lambda;
};

Expand Down
18 changes: 14 additions & 4 deletions src/shogun/statistical_testing/internals/MaxXValidation.cpp
Expand Up @@ -31,6 +31,7 @@

#include <algorithm>
#include <shogun/lib/SGVector.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/kernel/Kernel.h>
#include <shogun/statistical_testing/MMD.h>
#include <shogun/statistical_testing/internals/MaxXValidation.h>
Expand All @@ -41,11 +42,8 @@ using namespace shogun;
using namespace internal;

MaxXValidation::MaxXValidation(KernelManager& km, CMMD* est, const index_t& M, const float64_t& alp)
: KernelSelection(km), estimator(est), num_run(M), alpha(alp)
: KernelSelection(km, est), num_run(M), alpha(alp)
{
// TODO write a more meaningful error message
REQUIRE(estimator!=nullptr, "Estimator is not set!\n");
REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels is %d!\n", kernel_mgr.num_kernels());
REQUIRE(num_run>0, "Number of runs is %d!\n", num_run);
REQUIRE(alpha>=0.0 && alpha<=1.0, "Threshold is %f!\n", alpha);
}
Expand All @@ -54,6 +52,18 @@ MaxXValidation::~MaxXValidation()
{
}

SGVector<float64_t> MaxXValidation::get_measure_vector()
{
SG_SNOTIMPLEMENTED;
return SGVector<float64_t>();
}

SGMatrix<float64_t> MaxXValidation::get_measure_matrix()
{
SG_SNOTIMPLEMENTED;
return SGMatrix<float64_t>();
}

void MaxXValidation::compute_measures(SGVector<float64_t>& measures, SGVector<index_t>& term_counters)
{
const size_t num_kernels=kernel_mgr.num_kernels();
Expand Down
3 changes: 2 additions & 1 deletion src/shogun/statistical_testing/internals/MaxXValidation.h
Expand Up @@ -53,9 +53,10 @@ class MaxXValidation : public KernelSelection
~MaxXValidation();
MaxXValidation& operator=(const MaxXValidation& other)=delete;
virtual CKernel* select_kernel() override;
virtual SGVector<float64_t> get_measure_vector();
virtual SGMatrix<float64_t> get_measure_matrix();
protected:
void compute_measures(SGVector<float64_t>&, SGVector<index_t>&);
CMMD* estimator;
const index_t num_run;
const float64_t alpha;
};
Expand Down

0 comments on commit 0b54a9c

Please sign in to comment.