diff --git a/src/shogun/classifier/svm/SVMOcas.cpp b/src/shogun/classifier/svm/SVMOcas.cpp index 917eadb5835..280ffabb9dc 100644 --- a/src/shogun/classifier/svm/SVMOcas.cpp +++ b/src/shogun/classifier/svm/SVMOcas.cpp @@ -74,18 +74,18 @@ bool CSVMOcas::train_machine(CFeatures* data) for (int32_t i=0; iget_label(i); - SGVector w(features->get_dim_feature_space()); - w.zero(); + current_w = SGVector(features->get_dim_feature_space()); + current_w.zero(); if (num_vec!=lab.vlen || num_vec<=0) SG_ERROR("num_vec=%d num_train_labels=%d\n", num_vec, lab.vlen) SG_FREE(old_w); - old_w=SG_CALLOC(float64_t, w.vlen); + old_w=SG_CALLOC(float64_t, current_w.vlen); bias=0; old_bias=0; - tmp_a_buf=SG_CALLOC(float64_t, w.vlen); + tmp_a_buf=SG_CALLOC(float64_t, current_w.vlen); cp_value=SG_CALLOC(float64_t*, bufsize); cp_index=SG_CALLOC(uint32_t*, bufsize); cp_nz_dims=SG_CALLOC(uint32_t, bufsize); @@ -144,7 +144,7 @@ bool CSVMOcas::train_machine(CFeatures* data) SG_FREE(old_w); old_w=NULL; - set_w(w); + set_w(current_w); return true; } @@ -160,9 +160,8 @@ float64_t CSVMOcas::update_W( float64_t t, void* ptr ) { float64_t sq_norm_W = 0; CSVMOcas* o = (CSVMOcas*) ptr; - SGVector w = o->get_w(); - uint32_t nDim = (uint32_t) w.vlen; - float64_t* W = w.vector; + uint32_t nDim = (uint32_t) o->current_w.vlen; + float64_t* W = o->current_w.vector; float64_t* oldW=o->old_w; for(uint32_t j=0; j features; - SGVector w = o->get_w(); - uint32_t nDim=(uint32_t) w.vlen; + uint32_t nDim=(uint32_t) o->current_w.vlen; float64_t* y = o->lab.vector; float64_t** c_val = o->cp_value; @@ -277,16 +275,14 @@ int CSVMOcas::compute_output(float64_t *output, void* ptr) CSVMOcas* o = (CSVMOcas*) ptr; CDotFeatures* f=o->features; int32_t nData=f->get_num_vectors(); - SGVector w = o->get_w(); + SGVector w(o->get_w()); float64_t* y = o->lab.vector; - f->dense_dot_range(output, 0, nData, y, w.vector, w.vlen, 0.0); + f->dense_dot_range(output, 0, nData, y, o->current_w.vector, o->current_w.vlen, 0.0); for (int32_t i=0; ibias; - //CMath::display_vector(w, w.vlen, "w"); - //CMath::display_vector(output, nData, "out"); return 0; } @@ -304,7 +300,7 @@ void CSVMOcas::compute_W( void* ptr ) { CSVMOcas* o = (CSVMOcas*) ptr; - SGVector w_vector = o->get_w(); + SGVector w_vector(o->get_w()); uint32_t nDim= (uint32_t) w_vector.vlen; CMath::swap(w_vector.vector, o->old_w); float64_t* W=w_vector.vector; diff --git a/src/shogun/classifier/svm/SVMOcas.h b/src/shogun/classifier/svm/SVMOcas.h index 0695d7b8c5a..838385a9406 100644 --- a/src/shogun/classifier/svm/SVMOcas.h +++ b/src/shogun/classifier/svm/SVMOcas.h @@ -212,6 +212,8 @@ class CSVMOcas : public CLinearMachine /** method */ E_SVM_TYPE method; + /** current W */ + SGVector current_w; /** old W */ float64_t* old_w; /** old bias */