Skip to content

Commit

Permalink
added weighted kernel learning methods
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 1, 2016
1 parent 4d1d131 commit d120e78
Show file tree
Hide file tree
Showing 14 changed files with 557 additions and 56 deletions.
34 changes: 23 additions & 11 deletions src/shogun/statistical_testing/MMD.cpp
Expand Up @@ -46,6 +46,8 @@
#include <shogun/statistical_testing/internals/ComputationManager.h>
#include <shogun/statistical_testing/internals/MaxMeasure.h>
#include <shogun/statistical_testing/internals/MaxTestPower.h>
#include <shogun/statistical_testing/internals/WeightedMaxMeasure.h>
#include <shogun/statistical_testing/internals/WeightedMaxTestPower.h>
#include <shogun/statistical_testing/internals/mmd/BiasedFull.h>
#include <shogun/statistical_testing/internals/mmd/UnbiasedFull.h>
#include <shogun/statistical_testing/internals/mmd/UnbiasedIncomplete.h>
Expand Down Expand Up @@ -388,31 +390,36 @@ void CMMD::add_kernel(CKernel* kernel)
self->kernel_selection_mgr.push_back(kernel);
}

void CMMD::select_kernel(EKernelSelectionMethod kmethod)
void CMMD::select_kernel(EKernelSelectionMethod kmethod, bool weighted_kernel)
{
SG_DEBUG("Entering!\n");
SG_DEBUG("Selecting kernels from a total of %d kernels!\n", self->kernel_selection_mgr.num_kernels());
std::shared_ptr<KernelSelection> policy=nullptr;
switch (kmethod)
{
case EKernelSelectionMethod::MAXIMIZE_MMD:
{
MaxMeasure policy(self->kernel_selection_mgr, this);
get_kernel_manager().kernel_at(0)=policy.select_kernel();
get_kernel_manager().restore_kernel_at(0);
if (weighted_kernel)
policy=std::shared_ptr<WeightedMaxMeasure>(new WeightedMaxMeasure(self->kernel_selection_mgr, this));
else
policy=std::shared_ptr<MaxMeasure>(new MaxMeasure(self->kernel_selection_mgr, this));
break;
}
case EKernelSelectionMethod::MAXIMIZE_POWER:
{
MaxTestPower policy(self->kernel_selection_mgr, this);
get_kernel_manager().kernel_at(0)=policy.select_kernel();
get_kernel_manager().restore_kernel_at(0);
if (weighted_kernel)
policy=std::shared_ptr<WeightedMaxTestPower>(new WeightedMaxTestPower(self->kernel_selection_mgr, this));
else
policy=std::shared_ptr<MaxTestPower>(new MaxTestPower(self->kernel_selection_mgr, this));
break;
}
default:
SG_ERROR("Unsupported kernel selection method specified! "
"Presently only accepted values are MAXIMIZE_MMD, MAXIMIZE_POWER!\n");
break;
}
if (policy!=nullptr)
{
auto& km=get_kernel_manager();
km.kernel_at(0)=policy->select_kernel();
km.restore_kernel_at(0);
}
SG_DEBUG("Leaving!\n");
}

Expand All @@ -431,6 +438,11 @@ std::pair<float64_t, float64_t> CMMD::compute_statistic_variance()
return self->compute_statistic_variance();
}

std::pair<SGVector<float64_t>, SGMatrix<float64_t>> CMMD::compute_statistic_and_Q()
{
return self->compute_statistic_and_Q();
}

SGVector<float64_t> CMMD::sample_null()
{
return self->sample_null();
Expand Down
9 changes: 7 additions & 2 deletions src/shogun/statistical_testing/MMD.h
Expand Up @@ -47,6 +47,7 @@ namespace internal
{

class MaxTestPower;
class WeightedMaxTestPower;

}

Expand Down Expand Up @@ -75,19 +76,22 @@ enum class EKernelSelectionMethod
{
MEDIAN_HEURISRIC,
MAXIMIZE_MMD,
MAXIMIZE_POWER
MAXIMIZE_POWER,
MAXIMIZE_XVALIDATION,
AUTO
};

class CMMD : public CTwoSampleTest
{
using operation=std::function<float32_t(SGMatrix<float32_t>)>;
friend class internal::MaxTestPower;
friend class internal::WeightedMaxTestPower;
public:
CMMD();
virtual ~CMMD();

void add_kernel(CKernel *kernel);
void select_kernel(EKernelSelectionMethod kmethod);
void select_kernel(EKernelSelectionMethod kmethod=EKernelSelectionMethod::AUTO, bool weighted_kernel=false);

virtual float64_t compute_statistic() override;
virtual float64_t compute_variance();
Expand Down Expand Up @@ -119,6 +123,7 @@ class CMMD : public CTwoSampleTest
struct Self;
std::unique_ptr<Self> self;
virtual std::pair<float64_t, float64_t> compute_statistic_variance();
virtual std::pair<SGVector<float64_t>, SGMatrix<float64_t>> compute_statistic_and_Q();
};

}
Expand Down
26 changes: 4 additions & 22 deletions src/shogun/statistical_testing/internals/KernelSelection.cpp
Expand Up @@ -29,34 +29,16 @@
* either expressed or implied, of the Shogun Development Team.
*/

#include <shogun/io/SGIO.h>
#include <shogun/kernel/Kernel.h>
#include <shogun/statistical_testing/internals/KernelManager.h>
#include <shogun/statistical_testing/internals/KernelSelection.h>
#include <shogun/statistical_testing/internals/MaxMeasure.h>
#include <shogun/statistical_testing/internals/MaxTestPower.h>

namespace shogun
{

namespace internal
{
using namespace shogun;
using namespace internal;

template <class Derived>
KernelSelection<Derived>::KernelSelection(KernelManager& km) : kernel_mgr(km)
KernelSelection::KernelSelection(KernelManager& km) : kernel_mgr(km)
{
SG_SDEBUG("Kernel selection instance initialized!\n");
}

template <class Derived>
CKernel* KernelSelection<Derived>::select_kernel()
KernelSelection::~KernelSelection()
{
return static_cast<Derived*>(this)->select_kernel();
}

template class KernelSelection<MaxMeasure>;
template class KernelSelection<MaxTestPower>;

}

}
4 changes: 2 additions & 2 deletions src/shogun/statistical_testing/internals/KernelSelection.h
Expand Up @@ -44,14 +44,14 @@ namespace internal

class KernelManager;

template <class KernelSelectionPolicy>
class KernelSelection
{
public:
explicit KernelSelection(KernelManager&);
KernelSelection(const KernelSelection& other)=delete;
virtual ~KernelSelection();
KernelSelection& operator=(const KernelSelection& other)=delete;
CKernel* select_kernel();
virtual CKernel* select_kernel()=0;
protected:
const KernelManager& kernel_mgr;
};
Expand Down
21 changes: 12 additions & 9 deletions src/shogun/statistical_testing/internals/MaxMeasure.cpp
Expand Up @@ -39,16 +39,24 @@
using namespace shogun;
using namespace internal;

MaxMeasure::MaxMeasure(KernelManager& km, CMMD* est)
: KernelSelection(km), estimator(est)
MaxMeasure::MaxMeasure(KernelManager& km, CMMD* est) : KernelSelection(km), estimator(est)
{
}

MaxMeasure::~MaxMeasure()
{
}

SGVector<float64_t> MaxMeasure::compute_measures()
{
SGVector<float64_t> result(kernel_mgr.num_kernels());
REQUIRE(estimator!=nullptr, "Estimator is not set!\n");

const size_t num_kernels=kernel_mgr.num_kernels();
REQUIRE(num_kernels>0, "Number of kernels is %d!\n", kernel_mgr.num_kernels());

SGVector<float64_t> result(num_kernels);
auto existing_kernel=estimator->get_kernel();
for (size_t i=0; i<kernel_mgr.num_kernels(); ++i)
for (size_t i=0; i<num_kernels; ++i)
{
auto kernel=kernel_mgr.kernel_at(i);
estimator->set_kernel(kernel);
Expand All @@ -61,14 +69,9 @@ SGVector<float64_t> MaxMeasure::compute_measures()

CKernel* MaxMeasure::select_kernel()
{
REQUIRE(estimator!=nullptr, "Estimator is not set!\n");
REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels is %d!\n", kernel_mgr.num_kernels());

SGVector<float64_t> measures=compute_measures();
auto max_element=std::max_element(measures.vector, measures.vector+measures.vlen);
auto max_idx=std::distance(measures.vector, max_element);

SG_SDEBUG("Selected kernel at %d position!\n", max_idx);

return kernel_mgr.kernel_at(max_idx);
}
5 changes: 3 additions & 2 deletions src/shogun/statistical_testing/internals/MaxMeasure.h
Expand Up @@ -45,13 +45,14 @@ template <typename T> class SGVector;
namespace internal
{

class MaxMeasure : public KernelSelection<MaxMeasure>
class MaxMeasure : public KernelSelection
{
public:
MaxMeasure(KernelManager&, CMMD*);
MaxMeasure(const MaxMeasure& other)=delete;
~MaxMeasure();
MaxMeasure& operator=(const MaxMeasure& other)=delete;
CKernel* select_kernel();
virtual CKernel* select_kernel() override;
protected:
SGVector<float64_t> compute_measures();
CMMD* estimator;
Expand Down
15 changes: 9 additions & 6 deletions src/shogun/statistical_testing/internals/MaxTestPower.cpp
Expand Up @@ -40,13 +40,19 @@
using namespace shogun;
using namespace internal;

MaxTestPower::MaxTestPower(KernelManager& km, CMMD* est)
: KernelSelection<MaxTestPower>(km), estimator(est), lambda(1E-5)
MaxTestPower::MaxTestPower(KernelManager& km, CMMD* est) : KernelSelection(km), estimator(est), lambda(1E-5)
{
}

MaxTestPower::~MaxTestPower()
{
}

SGVector<float64_t> MaxTestPower::compute_measures()
{
REQUIRE(estimator!=nullptr, "Estimator is not set!\n");
REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels is %d!\n", kernel_mgr.num_kernels());

SGVector<float64_t> result(kernel_mgr.num_kernels());
for (size_t i=0; i<kernel_mgr.num_kernels(); ++i)
{
Expand All @@ -61,12 +67,9 @@ SGVector<float64_t> MaxTestPower::compute_measures()

CKernel* MaxTestPower::select_kernel()
{
REQUIRE(estimator!=nullptr, "Estimator is not set!\n");
REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels is %d!\n", kernel_mgr.num_kernels());

SGVector<float64_t> measures=compute_measures();
auto max_element=std::max_element(measures.vector, measures.vector+measures.vlen);
auto max_idx=std::distance(measures.vector, max_element);

SG_SDEBUG("Selected kernel at %d position!\n", max_idx);
return kernel_mgr.kernel_at(max_idx);
}
5 changes: 3 additions & 2 deletions src/shogun/statistical_testing/internals/MaxTestPower.h
Expand Up @@ -45,13 +45,14 @@ template <typename T> class SGVector;
namespace internal
{

class MaxTestPower : public KernelSelection<MaxTestPower>
class MaxTestPower : public KernelSelection
{
public:
MaxTestPower(KernelManager&, CMMD*);
MaxTestPower(const MaxTestPower& other)=delete;
~MaxTestPower();
MaxTestPower& operator=(const MaxTestPower& other)=delete;
CKernel* select_kernel();
virtual CKernel* select_kernel() override;
protected:
SGVector<float64_t> compute_measures();
CMMD* estimator;
Expand Down

0 comments on commit d120e78

Please sign in to comment.