Skip to content

Commit

Permalink
fix bug #123
Browse files Browse the repository at this point in the history
  • Loading branch information
QinbinLi committed Feb 9, 2019
1 parent 1ab487f commit d06b21c
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions python/thundersvmScikit.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,12 @@ def __init__(self, kernel, degree,
self.n_jobs = n_jobs
self.random_state = random_state
self.max_mem_size = max_mem_size
thundersvm.model_new.restype = c_void_p
self.model = thundersvm.model_new(SVM_TYPE.index(self._impl))
self.gpu_id = gpu_id
if self.max_mem_size != -1:
thundersvm.set_memory_size(c_void_p(self.model), self.max_mem_size)
self.model = None
thundersvm.model_new.restype = c_void_p
# self.model = thundersvm.model_new(SVM_TYPE.index(self._impl))
# if self.max_mem_size != -1:
# thundersvm.set_memory_size(c_void_p(self.model), self.max_mem_size)

def label_validate(self, y):

Expand All @@ -83,6 +84,7 @@ def label_validate(self, y):
def fit(self, X, y):
if self.model is not None:
thundersvm.model_free(c_void_p(self.model))
self.model = None
sparse = sp.isspmatrix(X)
self._sparse = sparse and not callable(self.kernel)
X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr')
Expand All @@ -102,7 +104,10 @@ def fit(self, X, y):
kernel = KERNEL_TYPE.index(self.kernel)

fit = self._sparse_fit if self._sparse else self._dense_fit
thundersvm.model_new.restype = c_void_p
self.model = thundersvm.model_new(solver_type)
if self.max_mem_size != -1:
thundersvm.set_memory_size(c_void_p(self.model), self.max_mem_size)
fit(X, y, solver_type, kernel)
if self._train_succeed[0] == -1:
print ("Training failed!")
Expand Down Expand Up @@ -375,6 +380,11 @@ def save_to_file(self, path):
thundersvm.save_to_file_scikit(c_void_p(self.model), path.encode('utf-8'))

def load_from_file(self, path):
if self.model is None:
thundersvm.model_new.restype = c_void_p
self.model = thundersvm.model_new(SVM_TYPE.index(self._impl))
if self.max_mem_size != -1:
thundersvm.set_memory_size(c_void_p(self.model), self.max_mem_size)
thundersvm.load_from_file_scikit(c_void_p(self.model), path.encode('utf-8'))
degree = (c_int * 1)()
gamma = (c_float * 1)()
Expand Down

0 comments on commit d06b21c

Please sign in to comment.