Skip to content

Commit

Permalink
refactored the interface for specifying kernel selection strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 3, 2016
1 parent e793278 commit 501e46f
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 16 deletions.
23 changes: 22 additions & 1 deletion src/shogun/statistical_testing/KernelSelectionStrategy.cpp
Expand Up @@ -70,7 +70,7 @@ struct CKernelSelectionStrategy::Self
const EKernelSelectionMethod CKernelSelectionStrategy::Self::default_method=KSM_AUTO;
const bool CKernelSelectionStrategy::Self::default_weighted=false;
const index_t CKernelSelectionStrategy::Self::default_num_runs=10;
const float64_t CKernelSelectionStrategy::Self::default_alpha=0.5;
const float64_t CKernelSelectionStrategy::Self::default_alpha=0.05;

CKernelSelectionStrategy::Self::Self() : policy(nullptr), method(default_method),
weighted(default_weighted), num_runs(default_num_runs), alpha(default_alpha)
Expand Down Expand Up @@ -148,6 +148,27 @@ CKernelSelectionStrategy::CKernelSelectionStrategy(EKernelSelectionMethod method
self->alpha=alpha;
}

//CKernelSelectionStrategy::CKernelSelectionStrategy(const CKernelSelectionStrategy& other)
//{
// init();
// self->method=other.self->method;
// self->num_runs=other.self->num_runs;
// self->alpha=other.self->alpha;
// for (size_t i=0; i<other.self->kernel_mgr.num_kernels(); ++i)
// self->kernel_mgr.push_back(other.self->kernel_mgr.kernel_at(i));
//}
//
//CKernelSelectionStrategy& CKernelSelectionStrategy::operator=(const CKernelSelectionStrategy& other)
//{
// init();
// self->method=other.self->method;
// self->num_runs=other.self->num_runs;
// self->alpha=other.self->alpha;
// for (size_t i=0; i<other.self->kernel_mgr.num_kernels(); ++i)
// self->kernel_mgr.push_back(other.self->kernel_mgr.kernel_at(i));
// return *this;
//}

void CKernelSelectionStrategy::init()
{
self=std::unique_ptr<Self>(new Self());
Expand Down
6 changes: 3 additions & 3 deletions src/shogun/statistical_testing/KernelSelectionStrategy.h
Expand Up @@ -48,13 +48,13 @@ namespace internal
class KernelManager;
}

enum EKernelSelectionMethod
enum EKernelSelectionMethod : uint32_t
{
KSM_MEDIAN_HEURISTIC,
KSM_MAXIMIZE_MMD,
KSM_MAXIMIZE_POWER,
KSM_MAXIMIZE_XVALIDATION,
KSM_AUTO
KSM_AUTO=KSM_MAXIMIZE_POWER
};

class CKernelSelectionStrategy : public CSGObject
Expand All @@ -67,7 +67,7 @@ class CKernelSelectionStrategy : public CSGObject
CKernelSelectionStrategy(EKernelSelectionMethod method, index_t num_runs, float64_t alpha);
CKernelSelectionStrategy(const CKernelSelectionStrategy& other)=delete;
CKernelSelectionStrategy& operator=(const CKernelSelectionStrategy& other)=delete;
~CKernelSelectionStrategy();
virtual ~CKernelSelectionStrategy();

CKernelSelectionStrategy& use_method(EKernelSelectionMethod method);
CKernelSelectionStrategy& use_num_runs(index_t num_runs);
Expand Down
30 changes: 27 additions & 3 deletions src/shogun/statistical_testing/MMD.cpp
Expand Up @@ -41,6 +41,7 @@
#include <shogun/statistical_testing/QuadraticTimeMMD.h>
#include <shogun/statistical_testing/BTestMMD.h>
#include <shogun/statistical_testing/LinearTimeMMD.h>
#include <shogun/statistical_testing/KernelSelectionStrategy.h>
#include <shogun/statistical_testing/internals/NextSamples.h>
#include <shogun/statistical_testing/internals/DataManager.h>
#include <shogun/statistical_testing/internals/FeaturesUtil.h>
Expand Down Expand Up @@ -456,13 +457,36 @@ CMMD::~CMMD()
{
}

void CMMD::set_kernel_selection_strategy(CKernelSelectionStrategy* strategy)
void CMMD::set_kernel_selection_strategy(EKernelSelectionMethod method)
{
SG_REF(strategy);
auto strategy=std::shared_ptr<CKernelSelectionStrategy>(new CKernelSelectionStrategy(method));
const auto& kernel_mgr=self->strategy->get_kernel_mgr();
for (size_t i=0; i<kernel_mgr.num_kernels(); ++i)
strategy->add_kernel(kernel_mgr.kernel_at(i));
self->strategy=std::shared_ptr<CKernelSelectionStrategy>(strategy, [](CKernelSelectionStrategy* ptr) { SG_UNREF(ptr); });
self->strategy=strategy;
}

void CMMD::set_kernel_selection_strategy(EKernelSelectionMethod method, bool weighted)
{
auto strategy=std::shared_ptr<CKernelSelectionStrategy>(new CKernelSelectionStrategy(method, weighted));
const auto& kernel_mgr=self->strategy->get_kernel_mgr();
for (size_t i=0; i<kernel_mgr.num_kernels(); ++i)
strategy->add_kernel(kernel_mgr.kernel_at(i));
self->strategy=strategy;
}

void CMMD::set_kernel_selection_strategy(EKernelSelectionMethod method, index_t num_runs, float64_t alpha)
{
auto strategy=std::shared_ptr<CKernelSelectionStrategy>(new CKernelSelectionStrategy(method, num_runs, alpha));
const auto& kernel_mgr=self->strategy->get_kernel_mgr();
for (size_t i=0; i<kernel_mgr.num_kernels(); ++i)
strategy->add_kernel(kernel_mgr.kernel_at(i));
self->strategy=strategy;
}

CKernelSelectionStrategy* CMMD::get_kernel_selection_strategy() const
{
return self->strategy.get();
}

void CMMD::add_kernel(CKernel* kernel)
Expand Down
9 changes: 7 additions & 2 deletions src/shogun/statistical_testing/MMD.h
Expand Up @@ -35,7 +35,6 @@
#include <memory>
#include <functional>
#include <shogun/statistical_testing/TwoSampleTest.h>
#include <shogun/statistical_testing/KernelSelectionStrategy.h>

namespace shogun
{
Expand All @@ -44,6 +43,8 @@ class CKernel;
class CCustomDistance;
template <typename> class SGVector;
template <typename> class SGMatrix;
class CKernelSelectionStrategy;
enum EKernelSelectionMethod : uint32_t;

namespace internal
{
Expand Down Expand Up @@ -87,7 +88,11 @@ class CMMD : public CTwoSampleTest
CMMD();
virtual ~CMMD();

void set_kernel_selection_strategy(CKernelSelectionStrategy* strategy);
void set_kernel_selection_strategy(EKernelSelectionMethod method);
void set_kernel_selection_strategy(EKernelSelectionMethod method, bool weighted);
void set_kernel_selection_strategy(EKernelSelectionMethod method, index_t num_runs, float64_t alpha);
CKernelSelectionStrategy* get_kernel_selection_strategy() const;

void add_kernel(CKernel *kernel);
void select_kernel();

Expand Down
Expand Up @@ -73,7 +73,7 @@ TEST(KernelSelectionMaxMMD, single_kernel)
mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice));
}

mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_MMD));
mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_MMD);
mmd->set_train_test_mode(true);
mmd->select_kernel();
auto selected_kernel=static_cast<CGaussianKernel*>(mmd->get_kernel());
Expand Down Expand Up @@ -110,7 +110,7 @@ TEST(KernelSelectionMaxMMD, weighted_kernel)
mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice));
}

mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_MMD, true));
mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_MMD, true);
mmd->set_train_test_mode(true);
mmd->select_kernel();
auto weighted_kernel=dynamic_cast<CCombinedKernel*>(mmd->get_kernel());
Expand Down
Expand Up @@ -73,7 +73,7 @@ TEST(KernelSelectionMaxPower, single_kernel)
mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice));
}

mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_POWER));
mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_POWER);
mmd->set_train_test_mode(true);
mmd->select_kernel();
auto selected_kernel=static_cast<CGaussianKernel*>(mmd->get_kernel());
Expand Down Expand Up @@ -110,7 +110,7 @@ TEST(KernelSelectionMaxPower, weighted_kernel)
mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice));
}

mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_POWER, true));
mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_POWER, true);
mmd->set_train_test_mode(true);
mmd->select_kernel();
auto weighted_kernel=dynamic_cast<CCombinedKernel*>(mmd->get_kernel());
Expand Down
Expand Up @@ -76,7 +76,7 @@ TEST(KernelSelectionMaxXValidation, single_kernel)
mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice));
}

mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MAXIMIZE_XVALIDATION, 5, 0.05));
mmd->set_kernel_selection_strategy(KSM_MAXIMIZE_XVALIDATION, 5, 0.05);
mmd->set_train_test_mode(true);
mmd->set_train_test_ratio(4);
mmd->select_kernel();
Expand Down
Expand Up @@ -69,7 +69,7 @@ TEST(KernelSelectionMedianHeuristic, quadratic_time_mmd)
mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice));
}

mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MEDIAN_HEURISTIC));
mmd->set_kernel_selection_strategy(KSM_MEDIAN_HEURISTIC);
mmd->set_train_test_mode(true);
mmd->select_kernel();
auto selected_kernel=static_cast<CGaussianKernel*>(mmd->get_kernel());
Expand Down Expand Up @@ -105,7 +105,7 @@ TEST(KernelSelectionMedianHeuristic, linear_time_mmd)
mmd->add_kernel(new CGaussianKernel(10, sq_sigma_twice));
}

mmd->set_kernel_selection_strategy(new CKernelSelectionStrategy(KSM_MEDIAN_HEURISTIC));
mmd->set_kernel_selection_strategy(KSM_MEDIAN_HEURISTIC);
mmd->set_train_test_mode(true);
mmd->select_kernel();
auto selected_kernel=static_cast<CGaussianKernel*>(mmd->get_kernel());
Expand Down

0 comments on commit 501e46f

Please sign in to comment.