Skip to content

Commit

Permalink
Precomputed model training.
Browse files Browse the repository at this point in the history
  • Loading branch information
fullung committed Jul 14, 2006
1 parent 69cd19a commit 3b8723a
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Lib/sandbox/svm/classification.py
Expand Up @@ -130,7 +130,7 @@ def cross_validate(self, dataset, nr_fold):
This function returns the percentage of data that was
classified correctly over all the experiments.
"""
problem = dataset.create_svm_problem()
problem = dataset._create_svm_problem()
target = N.empty((len(dataset.data),), dtype=N.float64)
tp = cast(target.ctypes.data, POINTER(c_double))
libsvm.svm_cross_validation(problem, self.param, nr_fold, tp)
Expand Down
24 changes: 12 additions & 12 deletions Lib/sandbox/svm/dataset.py
@@ -1,4 +1,3 @@
from ctypes import c_double, POINTER, cast
import numpy as N

import libsvm
Expand All @@ -24,17 +23,12 @@ def getgamma(self):
def precompute(self, kernel):
return LibSvmPrecomputedDataSet(kernel, self.data)

def create_svm_problem(self):
problem = libsvm.svm_problem()
problem.l = len(self.data)
y = (c_double*problem.l)()
x = (POINTER(libsvm.svm_node)*problem.l)()
for i, (yi, xi) in enumerate(self.data):
y[i] = yi
x[i] = cast(xi.ctypes.data, POINTER(libsvm.svm_node))
problem.x = x
problem.y = y
return problem
def _create_svm_problem(self):
return libsvm.create_svm_problem(self.data)

def _update_svm_parameter(self, param):
# XXX we can handle gamma=None here
pass

class LibSvmPrecomputedDataSet:
def __init__(self, kernel, origdata=None):
Expand Down Expand Up @@ -126,6 +120,12 @@ def combine(self, dataset):
newdataset.grammat = newgrammat
return newdataset

def _create_svm_problem(self):
return libsvm.create_svm_problem(self.data)

def _update_svm_parameter(self, param):
param.kernel_type = libsvm.PRECOMPUTED

class LibSvmRegressionDataSet(LibSvmDataSet):
def __init__(self, origdata):
data = map(lambda x: (x[0], convert_to_svm_node(x[1])), origdata)
Expand Down
14 changes: 13 additions & 1 deletion Lib/sandbox/svm/libsvm.py
@@ -1,7 +1,7 @@
import inspect

from ctypes import *
import numpy as N
from ctypes import c_int, c_double, POINTER, Structure, c_char_p

_libsvm = N.ctypes_load_library('libsvm_', __file__)

Expand Down Expand Up @@ -124,6 +124,18 @@ class svm_model(Structure):
func.argtypes = argtypes
inspect.currentframe().f_locals[f] = func

def create_svm_problem(data):
problem = svm_problem()
problem.l = len(data)
y = (c_double*problem.l)()
x = (POINTER(svm_node)*problem.l)()
for i, (yi, xi) in enumerate(data):
y[i] = yi
x[i] = cast(xi.ctypes.data, POINTER(svm_node))
problem.x = x
problem.y = y
return problem

__all__ = [
'svm_node_dtype',
'C_SVC',
Expand Down
15 changes: 10 additions & 5 deletions Lib/sandbox/svm/model.py
@@ -1,4 +1,4 @@
from ctypes import *
from ctypes import POINTER, c_double, c_int

from kernel import *
import libsvm
Expand Down Expand Up @@ -47,9 +47,10 @@ def __init__(self, kernel,
self.param = param

def fit(self, dataset):
problem = dataset.create_svm_problem()

problem = dataset._create_svm_problem()
dataset._update_svm_parameter(self.param)
self._check_problem_param(problem, self.param)

model = libsvm.svm_train(problem, self.param)

# weights are no longer required, so remove to them as the
Expand All @@ -58,8 +59,12 @@ def fit(self, dataset):
model.contents.param.weight = c_double_null_ptr
model.contents.param.weight_label = c_int_null_ptr

# results keep a refence to the dataset because the svm_model
# refers to some of its vectors as the support vectors
# results keep a reference to the dataset because the
# svm_model refers to some of its vectors as the support
# vectors
# XXX we can hide an id in the end of record marker so that we
# can figure out which support vectors to keep references to
# even when not using precomputed kernels
return self.Results(model, dataset)

def _check_problem_param(self, problem, param):
Expand Down
2 changes: 1 addition & 1 deletion Lib/sandbox/svm/regression.py
Expand Up @@ -66,7 +66,7 @@ def cross_validate(self, dataset, nr_fold):
error and the squared correlation coefficient.
"""

problem = dataset.create_svm_problem()
problem = dataset._create_svm_problem()
target = N.empty((len(dataset.data),), dtype=N.float64)
tp = cast(target.ctypes.data, POINTER(c_double))
libsvm.svm_cross_validation(problem, self.param, nr_fold, tp)
Expand Down
1 change: 1 addition & 0 deletions Lib/sandbox/svm/tests/test_all.py
Expand Up @@ -3,6 +3,7 @@
from test_dataset import *
from test_oneclass import *
from test_libsvm import *
from test_precomputed import *

if __name__ == '__main__':
NumpyTest().run()
45 changes: 45 additions & 0 deletions Lib/sandbox/svm/tests/test_precomputed.py
@@ -0,0 +1,45 @@
from numpy.testing import *
import numpy as N

set_local_path('../..')
from svm.regression import *
from svm.dataset import *
from svm.kernel import LinearKernel
restore_path()

class test_precomputed(NumpyTestCase):
def check_precomputed(self):
kernel = LinearKernel()

# this dataset remains constant
y1 = N.random.randn(50)
x1 = N.random.randn(len(y1), 10)
data1 = LibSvmRegressionDataSet(zip(y1, x1))
pcdata1 = data1.precompute(kernel)

# in a typical problem, this dataset would be smaller than the
# part that remains constant and would differ for each model
y2 = N.random.randn(5)
x2 = N.random.randn(len(y2), x1.shape[1])
data2 = LibSvmRegressionDataSet(zip(y2, x2))

pcdata12 = pcdata1.combine(data2)
model = LibSvmEpsilonRegressionModel(kernel)
results = model.fit(pcdata12)

# reference model, calculated without involving the
# precomputed Gram matrix
refy = N.concatenate([y1, y2])
refx = N.vstack([x1, x2])
refdata = LibSvmRegressionDataSet(zip(refy, refx))
model = LibSvmEpsilonRegressionModel(kernel)
refresults = model.fit(refdata)

self.assertAlmostEqual(results.rho, refresults.rho)
assert_array_almost_equal(results.sv_coef, refresults.sv_coef)

# XXX sigmas don't match yet. need to find out why.
#self.assertAlmostEqual(results.sigma, refresults.sigma)

if __name__ == '__main__':
NumpyTest().run()

0 comments on commit 3b8723a

Please sign in to comment.