From a5221bd7e849fc994cf34c118abc61a4ada0d586 Mon Sep 17 00:00:00 2001 From: lambday Date: Mon, 25 Jul 2016 12:38:30 +0100 Subject: [PATCH] added option for save/query permutation inds in quadratic time MMD --- .../statistical_testing/QuadraticTimeMMD.cpp | 10 ++++++++++ src/shogun/statistical_testing/QuadraticTimeMMD.h | 2 ++ .../internals/mmd/PermutationMMD.h | 14 ++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/src/shogun/statistical_testing/QuadraticTimeMMD.cpp b/src/shogun/statistical_testing/QuadraticTimeMMD.cpp index a14bbac3909..b1596a29a1a 100644 --- a/src/shogun/statistical_testing/QuadraticTimeMMD.cpp +++ b/src/shogun/statistical_testing/QuadraticTimeMMD.cpp @@ -619,6 +619,16 @@ void CQuadraticTimeMMD::precompute_kernel_matrix(bool precompute) self->precompute=precompute; } +void CQuadraticTimeMMD::save_permutation_inds(bool save_inds) +{ + self->permutation_job.m_save_inds=save_inds; +} + +SGMatrix CQuadraticTimeMMD::get_permutation_inds() const +{ + return self->permutation_job.m_all_inds; +} + const char* CQuadraticTimeMMD::get_name() const { return "QuadraticTimeMMD"; diff --git a/src/shogun/statistical_testing/QuadraticTimeMMD.h b/src/shogun/statistical_testing/QuadraticTimeMMD.h index 544eb9bf56a..01b686fcc88 100644 --- a/src/shogun/statistical_testing/QuadraticTimeMMD.h +++ b/src/shogun/statistical_testing/QuadraticTimeMMD.h @@ -70,6 +70,8 @@ class CQuadraticTimeMMD : public CMMD index_t spectrum_get_num_eigenvalues() const; void precompute_kernel_matrix(bool precompute); + void save_permutation_inds(bool save_inds); + SGMatrix get_permutation_inds() const; virtual const char* get_name() const; diff --git a/src/shogun/statistical_testing/internals/mmd/PermutationMMD.h b/src/shogun/statistical_testing/internals/mmd/PermutationMMD.h index 08ff81e4d1d..be26765d722 100644 --- a/src/shogun/statistical_testing/internals/mmd/PermutationMMD.h +++ b/src/shogun/statistical_testing/internals/mmd/PermutationMMD.h @@ -48,6 +48,10 @@ namespace mmd struct PermutationMMD : ComputeMMD { + PermutationMMD() : m_save_inds(false) + { + } + template SGVector operator()(const Kernel& kernel) { @@ -197,6 +201,11 @@ struct PermutationMMD : ComputeMMD { std::iota(m_permuted_inds.data(), m_permuted_inds.data()+m_permuted_inds.size(), 0); CMath::permute(sg_wrapper); + if (m_save_inds) + { + auto offset=n*m_num_null_samples; + std::copy(sg_wrapper.data(), sg_wrapper.data()+sg_wrapper.size(), &m_all_inds.matrix[offset]); + } for (size_t i=0; i(size, m_num_null_samples); } index_t m_num_null_samples; + bool m_save_inds; std::vector m_permuted_inds; std::vector > m_inverted_permuted_inds; + SGMatrix m_all_inds; }; }