Skip to content

Commit

Permalink
Introduced Kernel Mean Matching
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Aug 4, 2012
1 parent a79170d commit a686d8c
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/interfaces/modular/Statistics.i
Expand Up @@ -15,6 +15,7 @@
%rename(QuadraticTimeMMD) CQuadraticTimeMMD;
%rename(KernelIndependenceTestStatistic) CKernelIndependenceTestStatistic;
%rename(HSIC) CHSIC;
%rename(KernelMeanMatching) CKernelMeanMatching;

/* Include Class Headers to make them visible from within the target language */
%include <shogun/statistics/TestStatistic.h>
Expand All @@ -24,3 +25,4 @@
%include <shogun/statistics/QuadraticTimeMMD.h>
%include <shogun/statistics/KernelIndependenceTestStatistic.h>
%include <shogun/statistics/HSIC.h>
%include <shogun/statistics/KernelMeanMatching.h>
1 change: 1 addition & 0 deletions src/interfaces/modular/Statistics_includes.i
Expand Up @@ -6,5 +6,6 @@
#include <shogun/statistics/QuadraticTimeMMD.h>
#include <shogun/statistics/KernelIndependenceTestStatistic.h>
#include <shogun/statistics/HSIC.h>
#include <shogun/statistics/KernelMeanMatching.h>
%}

6 changes: 6 additions & 0 deletions src/shogun/lib/SGVector.h
Expand Up @@ -37,6 +37,12 @@ template<class T> class SGVector : public SGReferencedData
/** empty destructor */
virtual ~SGVector();

/** size */
inline int32_t size() const { return vlen; };

/** cast to pointer */
operator T*() { return vector; };

/** fill vector with zeros */
void zero();

Expand Down
3 changes: 2 additions & 1 deletion src/shogun/lib/external/libqp_gsmo.cpp
Expand Up @@ -120,8 +120,9 @@ libqp_state_T libqp_gsmo_solver(const float64_t* (*get_col)(uint32_t),
// check equality constraint
for (i=0; i<n; i++)
atx += a[i]*x[i];
if (b != atx)
if (fabs(b-atx)>1e-9)
{
printf("%f \ne %f\n",b,atx);
state.exitflag = -3;
goto cleanup;
}
Expand Down
105 changes: 105 additions & 0 deletions src/shogun/statistics/KernelMeanMatching.cpp
@@ -0,0 +1,105 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Copyright (W) 2012 Sergey Lisitsyn
*/

#include <shogun/statistics/KernelMeanMatching.h>
#include <shogun/lib/external/libqp.h>


static float64_t* kmm_K = NULL;
static int32_t kmm_K_ld = 0;

static const float64_t* kmm_get_col(uint32_t i)
{
return kmm_K + kmm_K_ld*i;
}

namespace shogun
{
CKernelMeanMatching::CKernelMeanMatching() :
CSGObject(), m_kernel(NULL)
{
}

CKernelMeanMatching::CKernelMeanMatching(CKernel* kernel, SGVector<index_t> training_indices,
SGVector<index_t> test_indices) :
CSGObject(), m_kernel(NULL)
{
set_kernel(kernel);
set_training_indices(training_indices);
set_test_indices(test_indices);
}

SGVector<float64_t> CKernelMeanMatching::compute_weights()
{
int32_t i,j;
ASSERT(m_kernel);
ASSERT(m_training_indices.size());
ASSERT(m_test_indices.size());

int32_t n_tr = m_training_indices.size();
int32_t n_te = m_test_indices.size();

SGVector<float64_t> weights(n_tr);
weights.zero();

kmm_K = SG_MALLOC(float64_t, n_tr*n_tr);
kmm_K_ld = n_tr;
float64_t* diag_K = SG_MALLOC(float64_t, n_tr);
for (i=0; i<n_tr; i++)
{
float64_t d = m_kernel->kernel(m_training_indices[i], m_training_indices[i]);
diag_K[i] = d;
kmm_K[i*n_tr+i] = d;
for (j=i+1; j<n_tr; j++)
{
d = m_kernel->kernel(m_training_indices[i],m_training_indices[j]);
kmm_K[i*n_tr+j] = d;
kmm_K[j*n_tr+i] = d;
}
}
float64_t* kappa = SG_MALLOC(float64_t, n_tr);
for (i=0; i<n_tr; i++)
{
float64_t avg = 0.0;
for (j=0; j<n_te; j++)
avg+= m_kernel->kernel(m_training_indices[i],m_test_indices[j]);

avg *= float64_t(n_tr)/n_te;
kappa[i] = avg;
}
float64_t* a = SG_MALLOC(float64_t, n_tr);
for (i=0; i<n_tr; i++) a[i] = 1.0;
float64_t* LB = SG_MALLOC(float64_t, n_tr);
float64_t* UB = SG_MALLOC(float64_t, n_tr);
float64_t B = 2.0;
for (i=0; i<n_tr; i++)
{
LB[i] = 0.0;
UB[i] = B;
}
for (i=0; i<n_tr; i++)
weights[i] = 1.0/float64_t(n_tr);

libqp_state_T result =
libqp_gsmo_solver(&kmm_get_col,diag_K,kappa,a,1.0,LB,UB,weights,n_tr,1000,1e-9,NULL);

SG_DEBUG("libqp exitflag=%d, %d iterations passed, primal objective=%f\n",
result.exitflag,result.nIter,result.QP);

SG_FREE(kappa);
SG_FREE(a);
SG_FREE(LB);
SG_FREE(UB);
SG_FREE(diag_K);
SG_FREE(kmm_K);

return weights;
}

}
59 changes: 59 additions & 0 deletions src/shogun/statistics/KernelMeanMatching.h
@@ -0,0 +1,59 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Copyright (W) 2012 Sergey Lisitsyn
*/

#ifndef KERNELMEANMATCHING_H_
#define KERNELMEANMATCHING_H_

#include <shogun/base/SGObject.h>
#include <shogun/kernel/Kernel.h>

namespace shogun
{

/** @brief Kernel Mean Matching */
class CKernelMeanMatching: public CSGObject
{
public:

/** constructor */
CKernelMeanMatching();

/** constructor */
CKernelMeanMatching(CKernel* kernel, SGVector<index_t> training_indices, SGVector<index_t> test_indices);

/** get kernel */
CKernel* get_kernel() const { SG_REF(m_kernel); return m_kernel; }
/** set kernel */
void set_kernel(CKernel* kernel) { SG_REF(kernel); SG_UNREF(m_kernel); m_kernel = kernel; }
/** get training indices */
SGVector<index_t> get_training_indices() const { return m_training_indices; }
/** set training indices */
void set_training_indices(SGVector<index_t> training_indices) { m_training_indices = training_indices; }
/** get test indices */
SGVector<index_t> get_test_indices() const { return m_test_indices; }
/** set test indices */
void set_test_indices(SGVector<index_t> test_indices) { m_test_indices = test_indices; }

/** compute weights */
SGVector<float64_t> compute_weights();

virtual const char* get_name() const { return "KernelMeanMatching"; }

protected:

/** kernel */
CKernel* m_kernel;
/** training indices */
SGVector<index_t> m_training_indices;
/** test indices */
SGVector<index_t> m_test_indices;
};

}
#endif

0 comments on commit a686d8c

Please sign in to comment.