Skip to content

Commit

Permalink
added option for save/query permutation inds in quadratic time MMD
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday committed Jul 25, 2016
1 parent 888fbb1 commit a5221bd
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/shogun/statistical_testing/QuadraticTimeMMD.cpp
Expand Up @@ -619,6 +619,16 @@ void CQuadraticTimeMMD::precompute_kernel_matrix(bool precompute)
self->precompute=precompute;
}

void CQuadraticTimeMMD::save_permutation_inds(bool save_inds)
{
self->permutation_job.m_save_inds=save_inds;
}

SGMatrix<index_t> CQuadraticTimeMMD::get_permutation_inds() const
{
return self->permutation_job.m_all_inds;
}

const char* CQuadraticTimeMMD::get_name() const
{
return "QuadraticTimeMMD";
Expand Down
2 changes: 2 additions & 0 deletions src/shogun/statistical_testing/QuadraticTimeMMD.h
Expand Up @@ -70,6 +70,8 @@ class CQuadraticTimeMMD : public CMMD
index_t spectrum_get_num_eigenvalues() const;

void precompute_kernel_matrix(bool precompute);
void save_permutation_inds(bool save_inds);
SGMatrix<index_t> get_permutation_inds() const;

virtual const char* get_name() const;

Expand Down
14 changes: 14 additions & 0 deletions src/shogun/statistical_testing/internals/mmd/PermutationMMD.h
Expand Up @@ -48,6 +48,10 @@ namespace mmd

struct PermutationMMD : ComputeMMD
{
PermutationMMD() : m_save_inds(false)
{
}

template <class Kernel>
SGVector<float32_t> operator()(const Kernel& kernel)
{
Expand Down Expand Up @@ -197,6 +201,11 @@ struct PermutationMMD : ComputeMMD
{
std::iota(m_permuted_inds.data(), m_permuted_inds.data()+m_permuted_inds.size(), 0);
CMath::permute(sg_wrapper);
if (m_save_inds)
{
auto offset=n*m_num_null_samples;
std::copy(sg_wrapper.data(), sg_wrapper.data()+sg_wrapper.size(), &m_all_inds.matrix[offset]);
}
for (size_t i=0; i<m_permuted_inds.size(); ++i)
m_inverted_permuted_inds[n][m_permuted_inds[i]]=i;
}
Expand All @@ -223,11 +232,16 @@ struct PermutationMMD : ComputeMMD
if (m_inverted_permuted_inds[i].size()!=size_t(size))
m_inverted_permuted_inds[i].resize(size);
}

if (m_save_inds)
m_all_inds=SGMatrix<index_t>(size, m_num_null_samples);
}

index_t m_num_null_samples;
bool m_save_inds;
std::vector<index_t> m_permuted_inds;
std::vector<std::vector<index_t> > m_inverted_permuted_inds;
SGMatrix<index_t> m_all_inds;
};

}
Expand Down

0 comments on commit a5221bd

Please sign in to comment.