Skip to content

Commit

Permalink
Proper way to handle W in OCAS SVM
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Mar 22, 2017
1 parent bf1b262 commit b7b9dfe
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
26 changes: 11 additions & 15 deletions src/shogun/classifier/svm/SVMOcas.cpp
Expand Up @@ -74,18 +74,18 @@ bool CSVMOcas::train_machine(CFeatures* data)
for (int32_t i=0; i<num_vec; i++)
lab[i] = ((CBinaryLabels*)m_labels)->get_label(i);

SGVector<float64_t> w(features->get_dim_feature_space());
w.zero();
current_w = SGVector<float64_t>(features->get_dim_feature_space());
current_w.zero();

if (num_vec!=lab.vlen || num_vec<=0)
SG_ERROR("num_vec=%d num_train_labels=%d\n", num_vec, lab.vlen)

SG_FREE(old_w);
old_w=SG_CALLOC(float64_t, w.vlen);
old_w=SG_CALLOC(float64_t, current_w.vlen);
bias=0;
old_bias=0;

tmp_a_buf=SG_CALLOC(float64_t, w.vlen);
tmp_a_buf=SG_CALLOC(float64_t, current_w.vlen);
cp_value=SG_CALLOC(float64_t*, bufsize);
cp_index=SG_CALLOC(uint32_t*, bufsize);
cp_nz_dims=SG_CALLOC(uint32_t, bufsize);
Expand Down Expand Up @@ -144,7 +144,7 @@ bool CSVMOcas::train_machine(CFeatures* data)
SG_FREE(old_w);
old_w=NULL;

set_w(w);
set_w(current_w);

return true;
}
Expand All @@ -160,9 +160,8 @@ float64_t CSVMOcas::update_W( float64_t t, void* ptr )
{
float64_t sq_norm_W = 0;
CSVMOcas* o = (CSVMOcas*) ptr;
SGVector<float64_t> w = o->get_w();
uint32_t nDim = (uint32_t) w.vlen;
float64_t* W = w.vector;
uint32_t nDim = (uint32_t) o->current_w.vlen;
float64_t* W = o->current_w.vector;
float64_t* oldW=o->old_w;

for(uint32_t j=0; j <nDim; j++)
Expand Down Expand Up @@ -190,8 +189,7 @@ int CSVMOcas::add_new_cut(
{
CSVMOcas* o = (CSVMOcas*) ptr;
CDotFeatures* f = o->features;
SGVector<float64_t> w = o->get_w();
uint32_t nDim=(uint32_t) w.vlen;
uint32_t nDim=(uint32_t) o->current_w.vlen;
float64_t* y = o->lab.vector;

float64_t** c_val = o->cp_value;
Expand Down Expand Up @@ -277,16 +275,14 @@ int CSVMOcas::compute_output(float64_t *output, void* ptr)
CSVMOcas* o = (CSVMOcas*) ptr;
CDotFeatures* f=o->features;
int32_t nData=f->get_num_vectors();
SGVector<float64_t> w = o->get_w();
SGVector<float64_t> w(o->get_w());

float64_t* y = o->lab.vector;

f->dense_dot_range(output, 0, nData, y, w.vector, w.vlen, 0.0);
f->dense_dot_range(output, 0, nData, y, o->current_w.vector, o->current_w.vlen, 0.0);

for (int32_t i=0; i<nData; i++)
output[i]+=y[i]*o->bias;
//CMath::display_vector(w, w.vlen, "w");
//CMath::display_vector(output, nData, "out");
return 0;
}

Expand All @@ -304,7 +300,7 @@ void CSVMOcas::compute_W(
void* ptr )
{
CSVMOcas* o = (CSVMOcas*) ptr;
SGVector<float64_t> w_vector = o->get_w();
SGVector<float64_t> w_vector(o->get_w());
uint32_t nDim= (uint32_t) w_vector.vlen;
CMath::swap(w_vector.vector, o->old_w);
float64_t* W=w_vector.vector;
Expand Down
2 changes: 2 additions & 0 deletions src/shogun/classifier/svm/SVMOcas.h
Expand Up @@ -212,6 +212,8 @@ class CSVMOcas : public CLinearMachine
/** method */
E_SVM_TYPE method;

/** current W */
SGVector<float64_t> current_w;
/** old W */
float64_t* old_w;
/** old bias */
Expand Down

0 comments on commit b7b9dfe

Please sign in to comment.