Skip to content

Commit

Permalink
Merge pull request #207 from litxio/pickle
Browse files Browse the repository at this point in the history
Fix pickling
  • Loading branch information
QinbinLi committed Feb 18, 2020
2 parents cf708d5 + c44f535 commit 8ac8ec3
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 48 deletions.
11 changes: 11 additions & 0 deletions include/thundersvm/model/svmmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ class SvmModel {
*/
virtual void load_from_file(string path);

/**
* save SvmModel to a string
*/
virtual string save_to_string();

/**
* load SvmModel from a string created by save_to_string.
* @param data string created by save_to_string
*/
virtual void load_from_string(string data);

//return n_total_sv
int total_sv() const;

Expand Down
41 changes: 40 additions & 1 deletion python/thundersvm/thundersvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,15 +408,38 @@ def _sparse_decision_function(self, X):
return self.dec_values

def save_to_file(self, path):
if self.model is None:
raise ValueError("Cannot serialize model before fitting")
thundersvm.save_to_file_scikit(c_void_p(self.model), path.encode('utf-8'))

def save_to_string(self):
if self.model is None:
raise ValueError("Cannot serialize model before fitting")
thundersvm.save_to_string_scikit.restype = c_void_p
sp = thundersvm.save_to_string_scikit(c_void_p(self.model))
retval = string_at(sp)
thundersvm.free_string(cast(sp, c_void_p))
return retval

def load_from_file(self, path):
if self.model is None:
thundersvm.model_new.restype = c_void_p
self.model = thundersvm.model_new(SVM_TYPE.index(self._impl))
if self.max_mem_size != -1:
thundersvm.set_memory_size(c_void_p(self.model), self.max_mem_size)
thundersvm.load_from_file_scikit(c_void_p(self.model), path.encode('utf-8'))
self._post_load_init()

def load_from_string(self, data):
if self.model is None:
thundersvm.model_new.restype = c_void_p
self.model = thundersvm.model_new(SVM_TYPE.index(self._impl))
if self.max_mem_size != -1:
thundersvm.set_memory_size(c_void_p(self.model), self.max_mem_size)
thundersvm.load_from_string_scikit(c_void_p(self.model), data)
self._post_load_init()

def _post_load_init(self):
degree = (c_int * 1)()
gamma = (c_float * 1)()
coef0 = (c_float * 1)()
Expand Down Expand Up @@ -466,13 +489,27 @@ def load_from_file(self, path):
# self.coef_ = np.array([coef[index] for index in range(0, self.n_binary_model * self.n_features)]).astype(float)
# self.coef_ = np.reshape(self.coef_, (self.n_binary_model, self.n_features))

self.kernel = kernel.value
self.kernel = kernel.value.decode()
self.degree = degree[0]
if gamma[0] != 0.0:
self.gamma = gamma[0]
self.coef0 = coef0[0]
self.probability = probability[0]

def __getstate__(self):
state = self.__dict__.copy()
state['predict_label_ptr'] = None
state['_train_succeed'] = None
if state['model'] is not None:
state['_saved_as_str'] = self.save_to_string()
state['model'] = None
return state

def __setstate__(self, state):
self.__dict__.update(state)
if '_saved_as_str' in state:
self.load_from_string(state['_saved_as_str'])


class SVC(SvmModel, ClassifierMixin):
_impl = 'c_svc'
Expand All @@ -492,6 +529,7 @@ def __init__(self, kernel='rbf', degree=3,
max_iter=max_iter, n_jobs=n_jobs, max_mem_size=max_mem_size, random_state=random_state, gpu_id=gpu_id)



class NuSVC(SvmModel, ClassifierMixin):
_impl = 'nu_svc'

Expand Down Expand Up @@ -559,3 +597,4 @@ def __init__(self, kernel='rbf', degree=3, gamma='auto',
shrinking=shrinking, cache_size=cache_size, verbose=verbose,
max_iter=max_iter, n_jobs=n_jobs, max_mem_size=max_mem_size, random_state=None, gpu_id=gpu_id
)

107 changes: 60 additions & 47 deletions src/thundersvm/model/svmmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,143 +131,157 @@ vector<float_type> SvmModel::predict(const DataSet::node2d &instances, int batch
return dec_values_vec;
}


void SvmModel::save_to_file(string path) {
ofstream fs_model;
fs_model.open(path.c_str(), std::ios_base::out | std::ios_base::trunc);
CHECK(fs_model.is_open()) << "create file " << path << "failed";
fs_model << save_to_string();
fs_model.close();
}

string SvmModel::save_to_string() {
std::ostringstream s_model;
const SvmParam &param = this->param;
fs_model << "svm_type " << SvmParam::svm_type_name[param.svm_type] << endl;
fs_model << "kernel_type " << SvmParam::kernel_type_name[param.kernel_type] << endl;
s_model << "svm_type " << SvmParam::svm_type_name[param.svm_type] << endl;
s_model << "kernel_type " << SvmParam::kernel_type_name[param.kernel_type] << endl;
if (param.kernel_type == SvmParam::POLY)
fs_model << "degree " << param.degree << endl;
s_model << "degree " << param.degree << endl;
if (param.kernel_type == SvmParam::POLY
|| param.kernel_type == SvmParam::RBF
|| param.kernel_type == SvmParam::SIGMOID)
fs_model << "gamma " << param.gamma << endl;
s_model << "gamma " << param.gamma << endl;
if (param.kernel_type == SvmParam::POLY || param.kernel_type == SvmParam::SIGMOID)
fs_model << "coef0 " << param.coef0 << endl;
fs_model << "nr_class " << n_classes << endl;
fs_model << "total_sv " << sv.size() << endl;
fs_model << "rho ";
s_model << "coef0 " << param.coef0 << endl;
s_model << "nr_class " << n_classes << endl;
s_model << "total_sv " << sv.size() << endl;
s_model << "rho ";
for (int i = 0; i < n_binary_models; ++i) {
fs_model << rho.host_data()[i] << " ";
s_model << rho.host_data()[i] << " ";
}
fs_model << endl;
s_model << endl;
if (param.svm_type == SvmParam::NU_SVC || param.svm_type == SvmParam::C_SVC) {
fs_model << "label ";
s_model << "label ";
for (int i = 0; i < n_classes; ++i) {
fs_model << label[i] << " ";
s_model << label[i] << " ";
}
fs_model << endl;
fs_model << "nr_sv ";
s_model << endl;
s_model << "nr_sv ";
for (int i = 0; i < n_classes; ++i) {
fs_model << n_sv.host_data()[i] << " ";
s_model << n_sv.host_data()[i] << " ";
}
fs_model << endl;
s_model << endl;
}
if (param.probability == 1) {
fs_model << "probA ";
s_model << "probA ";
for (int i = 0; i < n_binary_models; ++i) {
fs_model << probA[i] << " ";
s_model << probA[i] << " ";
}
fs_model << endl;
fs_model << "probB ";
s_model << endl;
s_model << "probB ";
for (int i = 0; i < n_binary_models; ++i) {
fs_model << probB[i] << " ";
s_model << probB[i] << " ";
}
fs_model << endl;
s_model << endl;
}
fs_model << "SV " << endl;
s_model << "SV " << endl;
const float_type *coef_data = coef.host_data();
for (int i = 0; i < sv.size(); i++) {
for (int j = 0; j < n_classes - 1; ++j) {
fs_model << setprecision(16) << coef_data[j * sv.size() + i] << " ";
s_model << setprecision(16) << coef_data[j * sv.size() + i] << " ";
}

vector<DataSet::node> p = sv[i];
int k = 0;
// if (param.kernel_type == SvmParam::PRECOMPUTED)
// fs_model << "0:" << p[k].value << " ";
// s_model << "0:" << p[k].value << " ";
// else
for (; k < p.size(); k++) {
fs_model << p[k].index << ":" << setprecision(8) << p[k].value << " ";
s_model << p[k].index << ":" << setprecision(8) << p[k].value << " ";
}
fs_model << endl;
s_model << endl;
}
fs_model.close();

return s_model.str();
}

void SvmModel::load_from_file(string path) {
//CHECK(ifs.is_open()) << "file " << path << " not found";
ifstream ifs;
ifs.open(path.c_str());
if(!ifs.is_open()){
LOG(INFO)<<"file "<<path<<" not found";
exit(1);
}
//CHECK(ifs.is_open()) << "file " << path << " not found";
std::stringstream sstr;
sstr << ifs.rdbuf();
load_from_string(sstr.str());
ifs.close();
}

void SvmModel::load_from_string(string data) {
stringstream iss(data);
string feature;
while (ifs >> feature) {
while (iss >> feature) {
if (feature == "svm_type") {
string value;
ifs >> value;
iss >> value;
for (int i = 0; i < 6; i++) {
if (value == SvmParam::svm_type_name[i])
param.svm_type = static_cast<SvmParam::SVM_TYPE>(i);
}
} else if (feature == "kernel_type") {
string value;
ifs >> value;
iss >> value;
for (int i = 0; i < 6; i++) {
if (value == SvmParam::kernel_type_name[i])
param.kernel_type = static_cast<SvmParam::KERNEL_TYPE>(i);
}
} else if (feature == "degree") {
ifs >> param.degree;
iss >> param.degree;
} else if (feature == "nr_class") {
ifs >> n_classes;
iss >> n_classes;
n_binary_models = n_classes * (n_classes - 1) / 2;
rho.resize(n_binary_models);
n_sv.resize(n_classes);
} else if (feature == "coef0") {
ifs >> param.coef0;
iss >> param.coef0;
} else if (feature == "gamma") {
ifs >> param.gamma;
iss >> param.gamma;

} else if (feature == "total_sv") {
ifs >> n_total_sv;
iss >> n_total_sv;
} else if (feature == "rho") {
for (int i = 0; i < n_binary_models; ++i) {
ifs >> rho.host_data()[i];
iss >> rho.host_data()[i];
}
} else if (feature == "label") {
label = vector<int>(n_classes);
for (int i = 0; i < n_classes; ++i) {
ifs >> label[i];
iss >> label[i];
}
} else if (feature == "nr_sv") {
for (int i = 0; i < n_classes; ++i) {
ifs >> n_sv.host_data()[i];
iss >> n_sv.host_data()[i];
}
} else if (feature == "probA") {
param.probability = 1;
probA = vector<float_type>(n_binary_models);
for (int i = 0; i < n_binary_models; ++i) {
ifs >> probA[i];
iss >> probA[i];
}
} else if (feature == "probB") {
probB = vector<float_type>(n_binary_models);
for (int i = 0; i < n_binary_models; ++i) {
ifs >> probB[i];
iss >> probB[i];
}
} else if (feature == "SV") {
sv.clear();
coef.resize((n_classes - 1) * n_total_sv);
float_type *coef_data = coef.host_data();
string line;
getline(ifs, line);
getline(iss, line);
for (int i = 0; i < n_total_sv; i++) {
getline(ifs, line);
getline(iss, line);
stringstream ss(line);
for (int j = 0; j < n_classes - 1; ++j) {
ss >> coef_data[j * n_total_sv + i];
Expand All @@ -287,7 +301,6 @@ void SvmModel::load_from_file(string path) {
sv_max_index = sv.back().back().index;
};
}
ifs.close();
}
}
if (param.svm_type != SvmParam::C_SVC && param.svm_type != SvmParam::NU_SVC) {
Expand Down Expand Up @@ -413,4 +426,4 @@ int SvmModel::get_sv_max_index() const{

const vector<int> &SvmModel::get_label() const{
return label;
}
}
20 changes: 20 additions & 0 deletions src/thundersvm/thundersvm-scikit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,26 @@ extern "C" {
model->load_from_file(path);
}

char* save_to_string_scikit(SvmModel *model){
string s = model->save_to_string();
// Copy the bytes to the heap so we can send to Python
char* buf = (char *)malloc(s.length());
memcpy(buf, s.c_str(), s.length());
return buf;
}

/* Because we allocate the string returned by save_to_string_scikit on the
* heap with malloc, we provide free_string as a way of cleaning up from
* python code */
void free_string(char* s) {
free(s);
}

void load_from_string_scikit(SvmModel *model, char *mstring) {
string s(mstring);
model->load_from_string(mstring);
}

void get_pro(SvmModel *model, float* prob){
vector<float> prob_predict;
prob_predict = model->get_prob_predict();
Expand Down

0 comments on commit 8ac8ec3

Please sign in to comment.