Skip to content

Commit

Permalink
probability model
Browse files Browse the repository at this point in the history
  • Loading branch information
shijiashuai committed Nov 17, 2017
1 parent 0906293 commit 67685e7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
29 changes: 15 additions & 14 deletions src/thundersvm/model/svc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,37 +62,37 @@ void SVC::train(const DataSet &dataset, SvmParam param) {
vector<int> original_index = dataset_.original_index(i);
DataSet::node2d i_instances = dataset_.instances(i);
for (int j = 0; j < i_instances.size(); ++j) {
if (is_sv[original_index[j]]){
if (is_sv[original_index[j]]) {
n_sv[i]++;
sv.push_back(i_instances[j]);
}
}
}

n_total_sv = sv.size();
LOG(INFO)<<"#total unique sv = "<<n_total_sv;
LOG(INFO) << "#total unique sv = " << n_total_sv;
coef.resize((n_classes - 1) * n_total_sv);

vector<int> sv_start(1,0);
vector<int> sv_start(1, 0);
for (int i = 1; i < n_classes; ++i) {
sv_start.push_back(sv_start[i-1] + n_sv[i-1]);
sv_start.push_back(sv_start[i - 1] + n_sv[i - 1]);
}

k = 0;
for (int i = 0; i < n_classes; ++i) {
for (int j = i+1; j < n_classes; ++j) {
vector<int> original_index = dataset_.original_index(i,j);
for (int j = i + 1; j < n_classes; ++j) {
vector<int> original_index = dataset_.original_index(i, j);
int ci = dataset_.count()[i];
int cj = dataset_.count()[j];
int m = sv_start[i];
for (int l = 0; l < ci; ++l) {
if (is_sv[original_index[l]]){
coef[(j-1) * n_total_sv + m++] = alpha[k][l];
if (is_sv[original_index[l]]) {
coef[(j - 1) * n_total_sv + m++] = alpha[k][l];
}
}
m = sv_start[j];
for (int l = ci; l < ci + cj; ++l) {
if (is_sv[original_index[l]]){
if (is_sv[original_index[l]]) {
coef[i * n_total_sv + m++] = alpha[k][l];
}
}
Expand All @@ -105,7 +105,7 @@ void SVC::train(const DataSet &dataset, SvmParam param) {
LOG(INFO) << "performing probability train";
probA.resize(n_binary_models);
probB.resize(n_binary_models);
probability_train(dataset);
probability_train(dataset_);
}
}

Expand All @@ -127,13 +127,13 @@ void SVC::train_binary(const DataSet &dataset, int i, int j, SyncData<float_type
int ws_size = min(max2power(ins.size()), 1024);
CSMOSolver solver;
solver.solve(k_mat, y, alpha, rho, f_val, param.epsilon, param.C * c_weight[i], param.C * c_weight[j], ws_size);
LOG(INFO)<<"rho = "<<rho;
LOG(INFO) << "rho = " << rho;
int n_sv = 0;
for (int l = 0; l < alpha.size(); ++l) {
alpha[l] *= y[l];
if (alpha[l] != 0) n_sv++;
}
LOG(INFO)<<"#sv = "<<n_sv;
LOG(INFO) << "#sv = " << n_sv;
}

vector<float_type> SVC::predict(const DataSet::node2d &instances, int batch_size) {
Expand Down Expand Up @@ -369,8 +369,9 @@ void sigmoidTrain(const float_type *decValues, const int l, const vector<int> &l
}

void SVC::probability_train(const DataSet &dataset) {
SyncData<float_type> dec_values(dataset.n_instances() * n_binary_models);
predict_dec_values(dataset.instances(), dec_values, 10000);
SvmParam param_no_prob = param;
param_no_prob.probability = false;
vector<float_type> dec_values = cross_validation(dataset, param_no_prob, 5);
int k = 0;
for (int i = 0; i < n_classes; ++i) {
for (int j = i + 1; j < n_classes; ++j) {
Expand Down
22 changes: 21 additions & 1 deletion src/thundersvm/model/svmmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,18 @@ void SvmModel::save_to_file(string path) {
}
fs_model<< endl;
}
//todo save probA and probB
if (param.probability) {
fs_model << "probA ";
for (int i = 0; i < n_binary_models; ++i) {
fs_model << probA[i] << " ";
}
fs_model << endl;
fs_model << "probB ";
for (int i = 0; i < n_binary_models; ++i) {
fs_model << probB[i] << " ";
}
fs_model << endl;
}
fs_model << "SV " << endl;
for (int i = 0; i < sv.size(); i++) {
for (int j = 0; j < n_classes - 1; ++j) {
Expand Down Expand Up @@ -211,6 +222,15 @@ void SvmModel::load_from_file(string path) {
for (int i = 0; i < n_classes; ++i) {
ifs >> n_sv[i];
}
} else if (feature == "probA") {
param.probability = true;
for (int i = 0; i < n_binary_models; ++i) {
ifs >> probA[i];
}
} else if (feature == "probB") {
for (int i = 0; i < n_binary_models; ++i) {
ifs >> probB[i];
}
} else if (feature == "SV") {
sv.clear();
coef.resize((n_classes - 1) * n_total_sv);
Expand Down

0 comments on commit 67685e7

Please sign in to comment.