Skip to content

Commit

Permalink
refactored preset distance matrix in multi kernel MMD
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday committed Jun 28, 2016
1 parent 14a8c55 commit bfab28f
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 28 deletions.
28 changes: 24 additions & 4 deletions src/shogun/statistical_testing/QuadraticTimeMMD.cpp
Expand Up @@ -34,7 +34,9 @@
#include <shogun/lib/SGVector.h>
#include <shogun/kernel/Kernel.h>
#include <shogun/kernel/CustomKernel.h>
#include <shogun/kernel/ShiftInvariantKernel.h>
#include <shogun/distance/EuclideanDistance.h>
#include <shogun/distance/ManhattanMetric.h>
#include <shogun/mathematics/eigen3.h>
#include <shogun/mathematics/Statistics.h>
#include <shogun/statistical_testing/QuadraticTimeMMD.h>
Expand Down Expand Up @@ -548,15 +550,33 @@ void CQuadraticTimeMMD::precompute_kernel_matrix(bool precompute)
SGVector<float64_t> CQuadraticTimeMMD::compute_statistic(const internal::KernelManager& kernel_mgr)
{
SG_DEBUG("Entering");
REQUIRE(kernel_mgr.same_distance_type(), "The kernels have to have same distance type!\n");
REQUIRE(kernel_mgr.num_kernels()>0, "Number of kernels (%d) have to be greater than 0!\n", kernel_mgr.num_kernels());
const auto& data_mgr=get_data_mgr();
const index_t nx=data_mgr.num_samples_at(0);
const index_t ny=data_mgr.num_samples_at(1);
MultiKernelMMD compute(nx, ny, get_statistic_type());
// TODO refactor and remove
auto distance=new CEuclideanDistance();
distance->set_disable_sqrt(true);
SG_REF(distance);

CDistance* distance=nullptr;
CShiftInvariantKernel* kernel_0=dynamic_cast<CShiftInvariantKernel*>(kernel_mgr.kernel_at(0));
REQUIRE(kernel_0, "Kernel (%s) must be of CShiftInvariantKernel type!\n", kernel_mgr.kernel_at(0)->get_name());
if (kernel_0->get_distance_type()==D_EUCLIDEAN)
{
auto euclidean_distance=new CEuclideanDistance();
euclidean_distance->set_disable_sqrt(true);
distance=euclidean_distance;
}
else if (kernel_0->get_distance_type()==D_MANHATTAN)
{
auto manhattan_distance=new CManhattanMetric();
distance=manhattan_distance;
}
else
{
SG_ERROR("Unsupported distance type!\n");
}

SG_REF(distance);
compute.set_distance(compute_joint_distance(distance));
SGVector<float64_t> result=compute(kernel_mgr);
SG_UNREF(distance);
Expand Down
32 changes: 32 additions & 0 deletions src/shogun/statistical_testing/internals/KernelManager.cpp
Expand Up @@ -31,8 +31,10 @@
#include <vector>
#include <memory>
#include <shogun/io/SGIO.h>
#include <shogun/distance/Distance.h>
#include <shogun/kernel/Kernel.h>
#include <shogun/kernel/CustomKernel.h>
#include <shogun/kernel/ShiftInvariantKernel.h>
#include <shogun/statistical_testing/internals/KernelManager.h>

using namespace shogun;
Expand Down Expand Up @@ -130,3 +132,33 @@ void KernelManager::restore_kernel_at(size_t i)
m_precomputed_kernels[i]=nullptr;
SG_SDEBUG("Leaving!\n");
}

bool KernelManager::same_distance_type() const
{
bool same=false;
EDistanceType distance_type=D_UNKNOWN;
for (size_t i=0; i<num_kernels(); ++i)
{
CShiftInvariantKernel* shift_invariant_kernel=dynamic_cast<CShiftInvariantKernel*>(kernel_at(i));
if (shift_invariant_kernel!=nullptr)
{
if (distance_type==D_UNKNOWN)
distance_type=shift_invariant_kernel->get_distance_type();
else if (distance_type==shift_invariant_kernel->get_distance_type())
same=true;
else
{
same=false;
break;
}
}
else
{
same=false;
SG_SINFO("Kernel at location %d is not of CShiftInvariantKernel type (was of %s type)!\n",
i, kernel_at(i)->get_name());
break;
}
}
return same;
}
1 change: 1 addition & 0 deletions src/shogun/statistical_testing/internals/KernelManager.h
Expand Up @@ -62,6 +62,7 @@ class KernelManager
void restore_kernel_at(size_t i);

void clear();
bool same_distance_type() const;
private:
std::vector<std::shared_ptr<CKernel> > m_kernels;
std::vector<std::shared_ptr<CCustomKernel> > m_precomputed_kernels;
Expand Down
Expand Up @@ -49,25 +49,3 @@ KernelSelection::KernelSelection(KernelManager& km, CMMD* est) : kernel_mgr(km),
KernelSelection::~KernelSelection()
{
}

bool KernelSelection::same_distance_type() const
{
bool same=false;
EDistanceType distance_type=D_UNKNOWN;
for (size_t i=0; i<kernel_mgr.num_kernels(); ++i)
{
CShiftInvariantKernel* shift_invariant_kernel=dynamic_cast<CShiftInvariantKernel*>(kernel_mgr.kernel_at(i));
if (shift_invariant_kernel!=nullptr)
{
if (distance_type==D_UNKNOWN)
distance_type=shift_invariant_kernel->get_distance_type();
else if (distance_type==shift_invariant_kernel->get_distance_type())
same=true;
else
break;
}
else
break;
}
return same;
}
Expand Up @@ -62,7 +62,6 @@ class KernelSelection
CMMD* estimator;
virtual void init_measures()=0;
virtual void compute_measures()=0;
bool same_distance_type() const;
};

}
Expand Down
Expand Up @@ -73,7 +73,7 @@ void MaxMeasure::compute_measures()
{
REQUIRE(estimator!=nullptr, "Estimator is not set!\n");
CQuadraticTimeMMD* mmd=dynamic_cast<CQuadraticTimeMMD*>(estimator);
if (mmd!=nullptr && same_distance_type())
if (mmd!=nullptr && kernel_mgr.same_distance_type())
measures=mmd->compute_statistic(kernel_mgr);
else
{
Expand Down

0 comments on commit bfab28f

Please sign in to comment.