Skip to content

Commit

Permalink
fix a bug for linear coef
Browse files Browse the repository at this point in the history
  • Loading branch information
QinbinLi committed Apr 25, 2019
1 parent e6d365a commit a7442a3
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 14 deletions.
2 changes: 2 additions & 0 deletions include/thundersvm/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class DataSet {

const vector<int> original_index(int y_i, int y_j) const;

const bool is_zero_based() const;
private:
vector<float_type> y_;
node2d instances_;
Expand All @@ -89,5 +90,6 @@ class DataSet {
vector<int> count_; //the number of instances of each class
vector<int> label_;
vector<int> perm_;
bool zero_based = 0; //is zero_based format dataset?
};
#endif //THUNDERSVM_DATASET_H
2 changes: 1 addition & 1 deletion include/thundersvm/model/svmmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class SvmModel {
//return prob_predict
const vector<float> &get_prob_predict() const;

void compute_linear_coef_single_model(size_t n_feature);
void compute_linear_coef_single_model(size_t n_feature, const bool zero_based);
//get the params, for scikit load params
void get_param(char* kernel_type, int* degree, float* gamma, float* coef0, int* probability);

Expand Down
7 changes: 5 additions & 2 deletions src/thundersvm/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ void DataSet::load_from_file(string file_name) {
float v;
CHECK_EQ(sscanf(tuple.c_str(), "%d:%f", &i, &v), 2) << "read error, using [index]:[value] format";
instances_thread[tid].back().emplace_back(i, v);
if(i == 0 && zero_based == 0) zero_based = 1;
if (i > local_feature[tid]) local_feature[tid] = i;
};

Expand Down Expand Up @@ -158,7 +159,7 @@ void DataSet::load_from_dense(int row_size, int features, float* data, float* la
if(label != NULL)
y_.push_back(label[i]);
instances_.emplace_back();
for(int j = 0; j < features; j++){
for(int j = 1; j <= features; j++){
ind = j;
v = data[off];
off++;
Expand Down Expand Up @@ -285,4 +286,6 @@ const vector<float_type> &DataSet::y() const {
return y_;
}


const bool DataSet::is_zero_based() const{
return zero_based;
}
2 changes: 1 addition & 1 deletion src/thundersvm/model/nusvr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void NuSVR::train(const DataSet &dataset, SvmParam param) {
save_svr_coef(alpha_2, dataset.instances());

if(param.kernel_type == SvmParam::LINEAR){
compute_linear_coef_single_model(dataset.n_features());
compute_linear_coef_single_model(dataset.n_features(), dataset.is_zero_based());
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/thundersvm/model/oneclass_svc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void OneClassSVC::train(const DataSet &dataset, SvmParam param) {
coef.copy_from(coef_vec.data(), coef_vec.size());

if(param.kernel_type == SvmParam::LINEAR){
compute_linear_coef_single_model(dataset.n_features());
compute_linear_coef_single_model(dataset.n_features(), dataset.is_zero_based());
}
}

Expand Down
15 changes: 11 additions & 4 deletions src/thundersvm/model/svc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,24 @@ void SVC::train(const DataSet &dataset, SvmParam param) {
///TODO: Use coef instead of alpha_data to compute linear_coef_data
if(param.kernel_type == SvmParam::LINEAR){
int k = 0;
linear_coef.resize(n_binary_models * dataset_.n_features());
if(dataset_.is_zero_based())
linear_coef.resize(n_binary_models * (dataset_.n_features()+1));
else
linear_coef.resize(n_binary_models * dataset_.n_features());
float_type *linear_coef_data = linear_coef.host_data();
for (int i = 0; i < n_classes; i++){
for (int j = i + 1; j < n_classes; j++){
const float_type *alpha_data = alpha[k].host_data();
DataSet::node2d ins = dataset_.instances(i, j);//get instances of class i and j
for(int iid = 0; iid < ins.size(); iid++) {
for (int fid = 0; fid < ins[iid].size(); fid++) {
if(alpha_data[iid] != 0)
linear_coef_data[k * dataset_.n_features() + ins[iid][fid].index - 1] += alpha_data[iid] * ins[iid][fid].value;
}
if(alpha_data[iid] != 0){
if(dataset_.is_zero_based())
linear_coef_data[k * dataset_.n_features() + ins[iid][fid].index] += alpha_data[iid] * ins[iid][fid].value;
else
linear_coef_data[k * dataset_.n_features() + ins[iid][fid].index - 1] += alpha_data[iid] * ins[iid][fid].value;
}
}
}
k++;
}
Expand Down
14 changes: 10 additions & 4 deletions src/thundersvm/model/svmmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,17 +381,23 @@ void SvmModel::get_param(char* kernel_type, int* degree, float* gamma, float* co
*probability = param.probability;
}

void SvmModel::compute_linear_coef_single_model(size_t n_feature){
linear_coef.resize(n_feature);
void SvmModel::compute_linear_coef_single_model(size_t n_feature, const bool zero_based){
if(zero_based)
linear_coef.resize(n_feature+1);
else
linear_coef.resize(n_feature);
float_type* linear_coef_data = linear_coef.host_data();
float_type* coef_data = coef.host_data();
for(int i = 0; i < n_total_sv; i++){
for(int j = 0; j < sv[i].size(); j++){
linear_coef_data[sv[i][j].index - 1] += coef_data[i] * sv[i][j].value;
if(zero_based)
linear_coef_data[sv[i][j].index] += coef_data[i] * sv[i][j].value;
else
linear_coef_data[sv[i][j].index - 1] += coef_data[i] * sv[i][j].value;
}
}
}

int SvmModel::get_sv_max_index() const{
return sv_max_index;
}
}
2 changes: 1 addition & 1 deletion src/thundersvm/model/svr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void SVR::train(const DataSet &dataset, SvmParam param) {
save_svr_coef(alpha_2, dataset.instances());

if(param.kernel_type == SvmParam::LINEAR){
compute_linear_coef_single_model(dataset.n_features());
compute_linear_coef_single_model(dataset.n_features(), dataset.is_zero_based());
}
}

Expand Down

0 comments on commit a7442a3

Please sign in to comment.