Skip to content

Commit

Permalink
save the kernel selection policy
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday committed Jul 13, 2016
1 parent cf21fd1 commit 4d4625f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 46 deletions.
107 changes: 61 additions & 46 deletions src/shogun/statistical_testing/KernelSelectionStrategy.cpp
Expand Up @@ -50,12 +50,15 @@ struct CKernelSelectionStrategy::Self
Self();

KernelManager kernel_mgr;
std::unique_ptr<KernelSelection> policy;

EKernelSelectionMethod method;
bool weighted;
index_t num_runs;
float64_t alpha;

void init_policy(CMMD* estimator);

const static EKernelSelectionMethod default_method;
const static bool default_weighted;
const static index_t default_num_runs;
Expand All @@ -67,12 +70,56 @@ const bool CKernelSelectionStrategy::Self::default_weighted=false;
const index_t CKernelSelectionStrategy::Self::default_num_runs=10;
const float64_t CKernelSelectionStrategy::Self::default_alpha=0.5;

CKernelSelectionStrategy::Self::Self()
CKernelSelectionStrategy::Self::Self() : policy(nullptr), method(default_method),
weighted(default_weighted), num_runs(default_num_runs), alpha(default_alpha)
{
method=default_method;
weighted=default_weighted;
num_runs=default_num_runs;
alpha=default_alpha;
}

void CKernelSelectionStrategy::Self::init_policy(CMMD* estimator)
{
switch (method)
{
case KSM_MEDIAN_HEURISTIC:
{
REQUIRE(!weighted, "Weighted kernel selection is not possible with MEDIAN_HEURISTIC!\n");
auto distance=estimator->compute_distance();
policy=std::unique_ptr<MedianHeuristic>(new MedianHeuristic(kernel_mgr, distance));
SG_UNREF(distance);
}
break;
case KSM_MAXIMIZE_XVALIDATION:
{
REQUIRE(!weighted, "Weighted kernel selection is not possible with MAXIMIZE_XVALIDATION!\n");
policy=std::unique_ptr<MaxXValidation>(new MaxXValidation(kernel_mgr, estimator,
num_runs, alpha));
}
break;
case KSM_MAXIMIZE_MMD:
{
if (weighted)
policy=std::unique_ptr<WeightedMaxMeasure>(new WeightedMaxMeasure(kernel_mgr, estimator));
else
policy=std::unique_ptr<MaxMeasure>(new MaxMeasure(kernel_mgr, estimator));
}
break;
case KSM_MAXIMIZE_POWER:
{
if (weighted)
policy=std::unique_ptr<WeightedMaxTestPower>(new WeightedMaxTestPower(kernel_mgr, estimator));
else
policy=std::unique_ptr<MaxTestPower>(new MaxTestPower(kernel_mgr, estimator));
}
break;
default:
{
SG_SERROR("Unsupported kernel selection method specified! Accepted strategies are "
"MAXIMIZE_MMD (single, weighted), "
"MAXIMIZE_POWER (single, weighted), "
"MAXIMIZE_XVALIDATION (single) and "
"MEDIAN_HEURISTIC (single)!\n");
}
break;
}
}

CKernelSelectionStrategy::CKernelSelectionStrategy()
Expand Down Expand Up @@ -142,52 +189,20 @@ void CKernelSelectionStrategy::add_kernel(CKernel* kernel)

CKernel* CKernelSelectionStrategy::select_kernel(CMMD* estimator)
{
SG_DEBUG("Entering!\n");
auto num_kernels=self->kernel_mgr.num_kernels();
REQUIRE(num_kernels>0, "Number of kernels is 0. Please add kernels using add_kernel method!\n");

SG_DEBUG("Selecting kernels from a total of %d kernels!\n", num_kernels);
std::unique_ptr<KernelSelection> policy=nullptr;

switch (self->method)
{
case KSM_MEDIAN_HEURISTIC:
{
REQUIRE(!self->weighted, "Weighted kernel selection is not possible with MEDIAN_HEURISTIC!\n");
auto distance=estimator->compute_distance();
policy=std::unique_ptr<MedianHeuristic>(new MedianHeuristic(self->kernel_mgr, distance));
SG_UNREF(distance);
// estimator->set_train_test_ratio(0);
}
break;
case KSM_MAXIMIZE_XVALIDATION:
{
REQUIRE(!self->weighted, "Weighted kernel selection is not possible with MAXIMIZE_XVALIDATION!\n");
policy=std::unique_ptr<MaxXValidation>(new MaxXValidation(self->kernel_mgr, estimator,
self->num_runs, self->alpha));
}
break;
case KSM_MAXIMIZE_MMD:
if (self->weighted)
policy=std::unique_ptr<WeightedMaxMeasure>(new WeightedMaxMeasure(self->kernel_mgr, estimator));
else
policy=std::unique_ptr<MaxMeasure>(new MaxMeasure(self->kernel_mgr, estimator));
break;
case KSM_MAXIMIZE_POWER:
if (self->weighted)
policy=std::unique_ptr<WeightedMaxTestPower>(new WeightedMaxTestPower(self->kernel_mgr, estimator));
else
policy=std::unique_ptr<MaxTestPower>(new MaxTestPower(self->kernel_mgr, estimator));
break;
default:
SG_ERROR("Unsupported kernel selection method specified! "
"Presently only accepted values are MAXIMIZE_MMD, MAXIMIZE_POWER and MEDIAN_HEURISTIC!\n");
break;
}
self->init_policy(estimator);
ASSERT(self->policy!=nullptr);

return self->policy->select_kernel();
}

ASSERT(policy!=nullptr);
SG_DEBUG("Leaving!\n");
return policy->select_kernel();
void CKernelSelectionStrategy::erase_intermediate_results()
{
self->policy=nullptr;
self->kernel_mgr.clear();
}

const char* CKernelSelectionStrategy::get_name() const
Expand Down
1 change: 1 addition & 0 deletions src/shogun/statistical_testing/KernelSelectionStrategy.h
Expand Up @@ -75,6 +75,7 @@ class CKernelSelectionStrategy : public CSGObject
void add_kernel(CKernel* kernel);
CKernel* select_kernel(CMMD* estimator);
virtual const char* get_name() const;
void erase_intermediate_results();
private:
struct Self;
std::unique_ptr<Self> self;
Expand Down

0 comments on commit 4d4625f

Please sign in to comment.