Skip to content

Commit

Permalink
Updated linear kernel API
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Oct 23, 2012
1 parent e2f28de commit 0e9806f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 64 deletions.
25 changes: 5 additions & 20 deletions src/shogun/kernel/LinearKernel.cpp
Expand Up @@ -18,13 +18,13 @@
using namespace shogun;

CLinearKernel::CLinearKernel()
: CDotKernel(0), normal(NULL), normal_length(0)
: CDotKernel(0)
{
properties |= KP_LINADD;
}

CLinearKernel::CLinearKernel(CDotFeatures* l, CDotFeatures* r)
: CDotKernel(0), normal(NULL), normal_length(0)
: CDotKernel(0)
{
properties |= KP_LINADD;
init(l,r);
Expand All @@ -49,24 +49,10 @@ void CLinearKernel::cleanup()
CKernel::cleanup();
}

void CLinearKernel::clear_normal()
{
int32_t num = ((CDotFeatures*) lhs)->get_dim_feature_space();
if (normal==NULL)
{
normal = SG_MALLOC(float64_t, num);
normal_length=num;
}

memset(normal, 0, sizeof(float64_t)*normal_length);

set_is_initialized(true);
}

void CLinearKernel::add_to_normal(int32_t idx, float64_t weight)
{
((CDotFeatures*) lhs)->add_to_dense_vec(
normalizer->normalize_lhs(weight, idx), idx, normal, normal_length);
normalizer->normalize_lhs(weight, idx), idx, normal.vector, normal.size());
set_is_initialized(true);
}

Expand Down Expand Up @@ -98,8 +84,7 @@ bool CLinearKernel::init_optimization(CKernelMachine* km)
bool CLinearKernel::delete_optimization()
{
SG_FREE(normal);
normal_length=0;
normal=NULL;
normal = SGVector<float64_t>();
set_is_initialized(false);

return true;
Expand All @@ -109,6 +94,6 @@ float64_t CLinearKernel::compute_optimized(int32_t idx)
{
ASSERT(get_is_initialized());
float64_t result = ((CDotFeatures*) rhs)->
dense_dot(idx, normal, normal_length);
dense_dot(idx, normal.vector, normal.size());
return normalizer->normalize_rhs(result, idx);
}
51 changes: 10 additions & 41 deletions src/shogun/kernel/LinearKernel.h
Expand Up @@ -98,68 +98,37 @@ class CLinearKernel: public CDotKernel
*/
virtual float64_t compute_optimized(int32_t idx);

/** clear normal vector */
virtual void clear_normal();

/** add to normal vector
*
* @param idx where to add
* @param weight what to add
*/
virtual void add_to_normal(int32_t idx, float64_t weight);

/** get normal
*
* @param len where length of normal vector will be stored
* @return normal vector
*/
inline const float64_t* get_normal(int32_t& len)
{
if (lhs && normal)
{
len = ((CDotFeatures*) lhs)->get_dim_feature_space();
return normal;
}
else
{
len = 0;
return NULL;
}
}

/** get normal vector (swig compatible)
/** get normal vector
*
* @param dst_w store w in this argument
* @param dst_dims dimension of w
*/
inline void get_w(float64_t** dst_w, int32_t* dst_dims)
SGVector<float64_t> get_w() const
{
ASSERT(lhs && normal);
int32_t len = ((CDotFeatures*) lhs)->get_dim_feature_space();
ASSERT(dst_w && dst_dims);
*dst_dims=len;
*dst_w=SG_MALLOC(float64_t, *dst_dims);
ASSERT(*dst_w);
memcpy(*dst_w, normal, sizeof(float64_t) * (*dst_dims));
ASSERT(lhs);
return normal;
}

/** set normal vector (swig compatible)
/** set normal vector
*
* @param src_w new w
* @param src_w_dim dimension of new w - must fit dim of lhs
* @param w new normal
*/
inline void set_w(float64_t* src_w, int32_t src_w_dim)
void set_w(SGVector<float64_t> w)
{
ASSERT(lhs && src_w_dim==((CDotFeatures*) lhs)->get_dim_feature_space());
clear_normal();
memcpy(normal, src_w, sizeof(float64_t) * src_w_dim);
ASSERT(lhs && w.size()==((CDotFeatures*) lhs)->get_dim_feature_space());
this->normal = w;
}

protected:
/** normal vector (used in case of optimized kernel) */
float64_t* normal;
/** length of normal vector */
int32_t normal_length;
SGVector<float64_t> normal;
};
}
#endif /* _LINEARKERNEL_H__ */
5 changes: 2 additions & 3 deletions src/shogun/ui/SGInterface.cpp
Expand Up @@ -4080,10 +4080,9 @@ bool CSGInterface::cmd_get_kernel_optimization()
case K_LINEAR:
{
CLinearKernel* k=(CLinearKernel*) kernel;
int32_t len=0;
const float64_t* weights=k->get_normal(len);
SGVector<float64_t> weights=k->get_w();

set_vector(weights, len);
set_vector(weights.vector, weights.size());
return true;
}
default:
Expand Down

0 comments on commit 0e9806f

Please sign in to comment.