Skip to content

Commit

Permalink
added median heuristic for kernel selection for quadratic time mmd
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday committed Jul 13, 2016
1 parent d941c9f commit 7494caa
Show file tree
Hide file tree
Showing 11 changed files with 313 additions and 6 deletions.
7 changes: 7 additions & 0 deletions src/shogun/statistical_testing/BTestMMD.cpp
Expand Up @@ -19,6 +19,7 @@
#include <shogun/lib/SGMatrix.h>
#include <shogun/mathematics/Math.h>
#include <shogun/mathematics/Statistics.h>
#include <shogun/distance/CustomDistance.h>
#include <shogun/statistical_testing/BTestMMD.h>
#include <shogun/statistical_testing/internals/DataManager.h>
#include <shogun/statistical_testing/internals/mmd/WithinBlockDirect.h>
Expand Down Expand Up @@ -109,6 +110,12 @@ float64_t CBTestMMD::compute_threshold(float64_t alpha)
return result;
}

std::shared_ptr<CCustomDistance> CBTestMMD::compute_distance()
{
auto distance=std::shared_ptr<CCustomDistance>(new CCustomDistance());
return distance;
}

const char* CBTestMMD::get_name() const
{
return "BTestMMD";
Expand Down
1 change: 1 addition & 0 deletions src/shogun/statistical_testing/BTestMMD.h
Expand Up @@ -42,6 +42,7 @@ class CBTestMMD : public CMMD
virtual const operation get_direct_estimation_method() const override;
virtual const float64_t normalize_statistic(float64_t statistic) const override;
virtual const float64_t normalize_variance(float64_t variance) const override;
virtual std::shared_ptr<CCustomDistance> compute_distance() override;
};

}
Expand Down
7 changes: 7 additions & 0 deletions src/shogun/statistical_testing/LinearTimeMMD.cpp
Expand Up @@ -20,6 +20,7 @@
#include <shogun/lib/SGMatrix.h>
#include <shogun/mathematics/Math.h>
#include <shogun/mathematics/Statistics.h>
#include <shogun/distance/CustomDistance.h>
#include <shogun/statistical_testing/LinearTimeMMD.h>
#include <shogun/statistical_testing/internals/DataManager.h>
#include <shogun/statistical_testing/internals/mmd/WithinBlockDirect.h>
Expand Down Expand Up @@ -144,6 +145,12 @@ float64_t CLinearTimeMMD::compute_threshold(float64_t alpha)
return result;
}

std::shared_ptr<CCustomDistance> CLinearTimeMMD::compute_distance()
{
auto distance=std::shared_ptr<CCustomDistance>(new CCustomDistance());
return distance;
}

const char* CLinearTimeMMD::get_name() const
{
return "LinearTimeMMD";
Expand Down
1 change: 1 addition & 0 deletions src/shogun/statistical_testing/LinearTimeMMD.h
Expand Up @@ -42,6 +42,7 @@ class CLinearTimeMMD : public CMMD
virtual const operation get_direct_estimation_method() const override;
virtual const float64_t normalize_statistic(float64_t statistic) const override;
virtual const float64_t normalize_variance(float64_t variance) const override;
virtual std::shared_ptr<CCustomDistance> compute_distance() override;
const float64_t gaussian_variance(float64_t variance) const;
};

Expand Down
11 changes: 9 additions & 2 deletions src/shogun/statistical_testing/MMD.cpp
Expand Up @@ -46,6 +46,7 @@
#include <shogun/statistical_testing/internals/ComputationManager.h>
#include <shogun/statistical_testing/internals/MaxMeasure.h>
#include <shogun/statistical_testing/internals/MaxTestPower.h>
#include <shogun/statistical_testing/internals/MedianHeuristic.h>
#include <shogun/statistical_testing/internals/WeightedMaxMeasure.h>
#include <shogun/statistical_testing/internals/WeightedMaxTestPower.h>
#include <shogun/statistical_testing/internals/mmd/BiasedFull.h>
Expand Down Expand Up @@ -390,7 +391,6 @@ CMMD::CMMD() : CTwoSampleTest()
Eigen::initParallel();
#endif
self=std::unique_ptr<Self>(new Self(*this));
Eigen::initParallel();
}

CMMD::~CMMD()
Expand Down Expand Up @@ -421,9 +421,16 @@ void CMMD::select_kernel(EKernelSelectionMethod kmethod, bool weighted_kernel)
else
policy=std::shared_ptr<MaxTestPower>(new MaxTestPower(self->kernel_selection_mgr, this));
break;
case EKernelSelectionMethod::MEDIAN_HEURISTIC:
{
REQUIRE(!weighted_kernel, "Weighted kernel selection is not possible with MEDIAN_HEURISTIC!\n");
auto distance=compute_distance();
policy=std::shared_ptr<MedianHeuristic>(new MedianHeuristic(self->kernel_selection_mgr, distance));
}
break;
default:
SG_ERROR("Unsupported kernel selection method specified! "
"Presently only accepted values are MAXIMIZE_MMD, MAXIMIZE_POWER!\n");
"Presently only accepted values are MAXIMIZE_MMD, MAXIMIZE_POWER and MEDIAN_HEURISTIC!\n");
break;
}
if (policy!=nullptr)
Expand Down
10 changes: 6 additions & 4 deletions src/shogun/statistical_testing/MMD.h
Expand Up @@ -40,6 +40,7 @@ namespace shogun
{

class CKernel;
class CCustomDistance;
template <typename> class SGVector;
template <typename> class SGMatrix;

Expand Down Expand Up @@ -74,7 +75,7 @@ enum class ENullApproximationMethod

enum class EKernelSelectionMethod
{
MEDIAN_HEURISRIC,
MEDIAN_HEURISTIC,
MAXIMIZE_MMD,
MAXIMIZE_POWER,
MAXIMIZE_XVALIDATION,
Expand Down Expand Up @@ -115,9 +116,10 @@ class CMMD : public CTwoSampleTest

virtual const char* get_name() const;
protected:
virtual const operation get_direct_estimation_method() const = 0;
virtual const float64_t normalize_statistic(float64_t statistic) const = 0;
virtual const float64_t normalize_variance(float64_t variance) const = 0;
virtual const operation get_direct_estimation_method() const=0;
virtual const float64_t normalize_statistic(float64_t statistic) const=0;
virtual const float64_t normalize_variance(float64_t variance) const=0;
virtual std::shared_ptr<CCustomDistance> compute_distance()=0;
bool use_gpu() const;
private:
struct Self;
Expand Down
52 changes: 52 additions & 0 deletions src/shogun/statistical_testing/QuadraticTimeMMD.cpp
Expand Up @@ -36,6 +36,8 @@
#include <shogun/kernel/CustomKernel.h>
#include <shogun/mathematics/eigen3.h>
#include <shogun/mathematics/Statistics.h>
#include <shogun/distance/EuclideanDistance.h>
#include <shogun/distance/CustomDistance.h>
#include <shogun/statistical_testing/QuadraticTimeMMD.h>
#include <shogun/statistical_testing/internals/FeaturesUtil.h>
#include <shogun/statistical_testing/internals/NextSamples.h>
Expand Down Expand Up @@ -490,6 +492,56 @@ SGVector<float64_t> CQuadraticTimeMMD::spectrum_sample_null()
return null_samples;
}

std::shared_ptr<CCustomDistance> CQuadraticTimeMMD::compute_distance()
{
auto distance=std::shared_ptr<CCustomDistance>(new CCustomDistance());
DataManager& dm=get_data_manager();

// using data manager next() API in order to make it work with
// streaming samples as well.
dm.start();
auto samples=dm.next();
if (!samples.empty())
{
dm.end();

// use 0th block from each distribution (since there is only one block
// for quadratic time MMD
CFeatures *samples_p=samples[0][0].get();
CFeatures *samples_q=samples[1][0].get();

try
{
auto p_and_q=FeaturesUtil::create_merged_copy(samples_p, samples_q);
samples.clear();
auto euclidean_distance=std::unique_ptr<CEuclideanDistance>(new CEuclideanDistance());
if (euclidean_distance->init(p_and_q, p_and_q))
{
auto dist_mat=euclidean_distance->get_distance_matrix<float32_t>();
if (io->get_loglevel()==MSG_DEBUG)
{
dist_mat.display_matrix("distance_matrix");
}
distance->set_triangle_distance_matrix_from_full(dist_mat.data(), dist_mat.num_rows, dist_mat.num_cols);
}
else
{
SG_SERROR("Computing distance matrix was not possible! Please contact Shogun developers.\n");
}
}
catch (ShogunException e)
{
SG_SERROR("%s, Data is too large! Computing distance matrix was not possible!\n", e.get_exception_string());
}
}
else
{
dm.end();
SG_SERROR("Could not fetch samples!\n");
}
return distance;
}

const char* CQuadraticTimeMMD::get_name() const
{
return "QuadraticTimeMMD";
Expand Down
1 change: 1 addition & 0 deletions src/shogun/statistical_testing/QuadraticTimeMMD.h
Expand Up @@ -67,6 +67,7 @@ class CQuadraticTimeMMD : public CMMD
virtual const operation get_direct_estimation_method() const override;
virtual const float64_t normalize_statistic(float64_t statistic) const override;
virtual const float64_t normalize_variance(float64_t variance) const override;
virtual std::shared_ptr<CCustomDistance> compute_distance() override;
SGVector<float64_t> gamma_fit_null();
SGVector<float64_t> spectrum_sample_null();
};
Expand Down
89 changes: 89 additions & 0 deletions src/shogun/statistical_testing/internals/MedianHeuristic.cpp
@@ -0,0 +1,89 @@
/*
* Copyright (c) The Shogun Machine Learning Toolbox
* Written (W) 2013 Heiko Strathmann
* Written (w) 2014 - 2016 Soumyajit De
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* The views and conclusions contained in the software and documentation are those
* of the authors and should not be interpreted as representing official policies,
* either expressed or implied, of the Shogun Development Team.
*/

#include <vector>
#include <iostream> // TODO remove
#include <algorithm>
#include <shogun/io/SGIO.h>
#include <shogun/kernel/Kernel.h>
#include <shogun/kernel/GaussianKernel.h>
#include <shogun/distance/CustomDistance.h>
#include <shogun/statistical_testing/internals/MedianHeuristic.h>
#include <shogun/statistical_testing/internals/KernelManager.h>

using namespace shogun;
using namespace internal;

MedianHeuristic::MedianHeuristic(KernelManager& km, std::shared_ptr<CCustomDistance> dist)
: KernelSelection(km), distance(dist), n(dist->get_num_vec_lhs())
{
REQUIRE(distance->get_num_vec_lhs()==distance->get_num_vec_rhs(),
"Distance matrix is supposed to be a square matrix (was of dimension %dX%d)!\n",
distance->get_num_vec_lhs(), distance->get_num_vec_rhs());

for (size_t i=0; i<kernel_mgr.num_kernels(); ++i)
{
REQUIRE(kernel_mgr.kernel_at(i)->get_kernel_type()==K_GAUSSIAN,
"The underlying kernel has to be a GaussianKernel (was %s)!\n",
kernel_mgr.kernel_at(i)->get_name());
}
}

MedianHeuristic::~MedianHeuristic()
{
}

CKernel* MedianHeuristic::select_kernel()
{
std::vector<float64_t> measures((n*(n-1))/2);
size_t write_idx=0;
for (auto j=0; j<n; ++j)
{
for (auto i=j+1; i<n; ++i)
measures[write_idx++]=distance->distance(i, j);
}
std::sort(measures.begin(), measures.end());
auto median_distance=measures[measures.size()/2];
SG_SDEBUG("kernel width (shogun): %f\n", median_distance);

const size_t num_kernels=kernel_mgr.num_kernels();
measures.resize(num_kernels);
for (size_t i=0; i<num_kernels; ++i)
{
CGaussianKernel *kernel=dynamic_cast<CGaussianKernel*>(kernel_mgr.kernel_at(i));
ASSERT(kernel!=nullptr);
measures[i]=CMath::abs(kernel->get_width()-median_distance);
}

size_t kernel_idx=std::distance(measures.begin(), std::min_element(measures.begin(), measures.end()));
SG_SDEBUG("Selected kernel at %d position!\n", kernel_idx);
return kernel_mgr.kernel_at(kernel_idx);
}
67 changes: 67 additions & 0 deletions src/shogun/statistical_testing/internals/MedianHeuristic.h
@@ -0,0 +1,67 @@
/*
* Copyright (c) The Shogun Machine Learning Toolbox
* Written (W) 2013 Heiko Strathmann
* Written (w) 2014 - 2016 Soumyajit De
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* The views and conclusions contained in the software and documentation are those
* of the authors and should not be interpreted as representing official policies,
* either expressed or implied, of the Shogun Development Team.
*/

#ifndef MEDIAN_HEURISTIC_H__
#define MEDIAN_HEURISTIC_H__

#include <memory>
#include <shogun/lib/common.h>
#include <shogun/statistical_testing/internals/KernelSelection.h>

namespace shogun
{

class CKernel;
class CMMD;
class CCustomDistance;
template <typename T> class SGVector;

namespace internal
{

class MedianHeuristic : public KernelSelection
{
public:
MedianHeuristic(KernelManager&, std::shared_ptr<CCustomDistance>);
MedianHeuristic(const MedianHeuristic& other)=delete;
~MedianHeuristic();
MedianHeuristic& operator=(const MedianHeuristic& other)=delete;
virtual CKernel* select_kernel() override;
protected:
std::shared_ptr<CCustomDistance> distance;
const int32_t n;
};

}

}

#endif // MEDIAN_HEURISTIC_H__

0 comments on commit 7494caa

Please sign in to comment.