Skip to content

Commit

Permalink
cleaned up and optimized multi-kernel permutation test
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday committed Jul 13, 2016
1 parent 2f8912e commit ada1f1d
Showing 1 changed file with 38 additions and 60 deletions.
Expand Up @@ -28,29 +28,13 @@
* either expressed or implied, of the Shogun Development Team.
*/

#include <shogun/base/some.h>
#include <shogun/io/SGIO.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/mathematics/Math.h>
#include <shogun/labels/BinaryLabels.h>
#include <shogun/lib/SGVector.h>
#include <shogun/kernel/Kernel.h>
#include <shogun/features/SubsetStack.h>
#include <shogun/distance/CustomDistance.h>
#include <shogun/statistical_testing/MMD.h>
#include <shogun/statistical_testing/internals/KernelManager.h>
#include <shogun/statistical_testing/internals/mmd/BiasedFull.h>
#include <shogun/statistical_testing/internals/mmd/MultiKernelPermutationTest.h>

// TODO remove
#include <shogun/mathematics/eigen3.h>
#include <iostream>

using Eigen::MatrixXd;
using Eigen::Map;
using std::cout;
using std::endl;
// TODO remove

using namespace shogun;
using namespace internal;
using namespace mmd;
Expand Down Expand Up @@ -150,59 +134,53 @@ SGVector<bool> MultiKernelPermutationTest::operator()(const KernelManager& kerne
inverted_permuted_inds[n][permuted_inds[i]]=i;
}

SGVector<float64_t> statistic(kernel_mgr.num_kernels());
std::vector<terms_t> stat_terms(kernel_mgr.num_kernels());
for (auto col=0; col<n_x+n_y; ++col)
SGVector<float64_t> null_samples(num_null_samples);
SGVector<bool> result(kernel_mgr.num_kernels());

const index_t size=n_x+n_y;
SGVector<float32_t> km(size*(size+1)/2);
#pragma omp parallel
{
for (auto row=col; row<n_x+n_y; ++row)
for (size_t k=0; k<kernel_mgr.num_kernels(); ++k)
{
for (size_t k=0; k<kernel_mgr.num_kernels(); ++k)
auto kernel=kernel_mgr.kernel_at(k);
terms_t stat_terms;
for (auto row=0; row<size; ++row)
{
float64_t kernel_value=kernel_mgr.kernel_at(k)->kernel(row, col);
add_term(stat_terms[k], kernel_value, row, col);
for (auto col=row; col<size; ++col)
{
auto index=row*size-row*(row+1)/2+col;
km[index]=kernel->kernel(row, col);
add_term(stat_terms, km[index], row, col);
}
}
}
}
for (size_t k=0; k<kernel_mgr.num_kernels(); ++k)
statistic[k]=compute_mmd(stat_terms[k]);
stat_terms.resize(0);
auto statistic=compute_mmd(stat_terms);
SG_SDEBUG("Kernel(%d): statistic=%f\n", k, statistic);

SGMatrix<float64_t> null_samples(kernel_mgr.num_kernels(), num_null_samples);
Map<MatrixXd> null_samples_map(null_samples.data(), null_samples.num_rows, null_samples.num_cols);
#pragma omp parallel for
for (auto n=0; n<num_null_samples; ++n)
{
std::vector<terms_t> terms(kernel_mgr.num_kernels());
for (auto col=0; col<n_x+n_y; ++col)
{
for (auto row=col; row<n_x+n_y; ++row)
#pragma omp for
for (auto n=0; n<num_null_samples; ++n)
{
auto row_inds_inv=inverted_permuted_inds[n][row];
auto col_inds_inv=inverted_permuted_inds[n][col];
for (size_t k=0; k<kernel_mgr.num_kernels(); ++k)
terms_t null_sample_terms;
for (auto row=0; row<size; ++row)
{
float64_t kernel_value=kernel_mgr.kernel_at(k)->kernel(row, col);
add_term(terms[k], kernel_value, row_inds_inv, col_inds_inv);
for (auto col=row; col<size; ++col)
{
float64_t kernel_value=km[row*size-row*(row+1)/2+col];
auto row_inds_inv=inverted_permuted_inds[n][row];
auto col_inds_inv=inverted_permuted_inds[n][col];
add_term(null_sample_terms, kernel_value, row_inds_inv, col_inds_inv);
}
}
null_samples[n]=compute_mmd(null_sample_terms);
}
std::sort(null_samples.data(), null_samples.data()+null_samples.size());
float64_t idx=null_samples.find_position_to_insert(statistic);
SG_SDEBUG("Kernel(%d): index=%f\n", k, idx);
auto p_value=1.0-idx/num_null_samples;
bool rejected=p_value<alpha;
SG_SDEBUG("Kernel(%d): p-value=%f, alpha=%f, rejected=%d\n", k, p_value, alpha, rejected);
result[k]=rejected;
}
for (size_t k=0; k<kernel_mgr.num_kernels(); ++k)
null_samples(k, n)=compute_mmd(terms[k]);
}

SGVector<bool> result(kernel_mgr.num_kernels());
MatrixXd transposed_null_samples=null_samples_map.transpose();
for (size_t k=0; k<kernel_mgr.num_kernels(); ++k)
{
SGVector<float64_t> null_samples_k(transposed_null_samples.col(k).data(), num_null_samples, false);
std::sort(null_samples_k.data(), null_samples_k.data()+null_samples_k.size());
SG_SDEBUG("statistic=%f\n", statistic[k]);
float64_t idx=null_samples_k.find_position_to_insert(statistic[k]);
SG_SDEBUG("index=%f\n", idx);
auto p_value=1.0-idx/num_null_samples;
bool rejected=p_value<alpha;
SG_SDEBUG("p-value=%f, alpha=%f, rejected=%d\n", p_value, alpha, rejected);
result[k]=rejected;
}

SG_SDEBUG("Leaving!\n");
Expand Down

0 comments on commit ada1f1d

Please sign in to comment.