Skip to content

Commit

Permalink
made multi-kernel MMD work when the features are updated
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday committed Jul 14, 2016
1 parent fa571a9 commit 2d974ab
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
Expand Up @@ -106,6 +106,12 @@ void CMultiKernelQuadraticTimeMMD::cleanup()
{
ASSERT(self->m_owner);
self->m_kernel_mgr.clear();
invalidate_precomputed_distance();
}

void CMultiKernelQuadraticTimeMMD::invalidate_precomputed_distance()
{
ASSERT(self->m_owner);
self->m_pairwise_distance=nullptr;
self->m_dtype=D_UNKNOWN;
}
Expand Down
5 changes: 3 additions & 2 deletions src/shogun/statistical_testing/MultiKernelQuadraticTimeMMD.h
Expand Up @@ -54,8 +54,8 @@ class MaxMeasure;
* shift-invariant kernels. If the kernels are not shift-invariant, then the
* class CQuadraticTimeMMD should be used multiple times instead of this one.
*
* This implementation assumes that features are never updated. If new features
* are to be used, new instance of this class should be created.
* If the features are updated, then (if any) existing precomputed distance
* instance has to be invalidated by the owner (CQuadraticTimeMMD instance).
*/
class CMultiKernelQuadraticTimeMMD : public CSGObject
{
Expand All @@ -80,6 +80,7 @@ class CMultiKernelQuadraticTimeMMD : public CSGObject
private:
struct Self;
std::unique_ptr<Self> self;
void invalidate_precomputed_distance();
SGVector<float64_t> statistic(const internal::KernelManager& kernel_mgr);
SGMatrix<float32_t> sample_null(const internal::KernelManager& kernel_mgr);
SGVector<float64_t> p_values(const internal::KernelManager& kernel_mgr);
Expand Down
2 changes: 2 additions & 0 deletions src/shogun/statistical_testing/QuadraticTimeMMD.cpp
Expand Up @@ -193,6 +193,7 @@ void CQuadraticTimeMMD::set_p(CFeatures* samples_from_p)
CTwoDistributionTest::set_p(samples_from_p);
get_kernel_mgr().restore_kernel_at(0);
self->is_kernel_initialized=false;
self->multi_kernel->invalidate_precomputed_distance();

if (get_kernel() && get_kernel()->get_kernel_type()==K_CUSTOM)
{
Expand All @@ -214,6 +215,7 @@ void CQuadraticTimeMMD::set_q(CFeatures* samples_from_q)
CTwoDistributionTest::set_q(samples_from_q);
get_kernel_mgr().restore_kernel_at(0);
self->is_kernel_initialized=false;
self->multi_kernel->invalidate_precomputed_distance();

if (get_kernel() && get_kernel()->get_kernel_type()==K_CUSTOM)
{
Expand Down

0 comments on commit 2d974ab

Please sign in to comment.