Skip to content

Commit

Permalink
initial commit of libsvm stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
yamins81 committed Feb 28, 2012
1 parent 266400a commit 12220a7
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 3 deletions.
126 changes: 126 additions & 0 deletions eccv12/classifier.py
@@ -1,6 +1,132 @@
import numpy as np
import scipy as sp

#############
###scikits###
#############
try:
##use yamins81/feature/multiclass_kwargs branch of sklearn
##XXX need to include as git submodule
from sklearn import svm as sklearn_svm
from sklearn.multiclass import OneVsRestClassifier
except ImportError:
print("Can't import scikits stuff")


def train_scikits(train_Xyd,
labelset,
model_type,
model_kwargs=None,
fit_kwargs=None,
normalization=True,
trace_normalize=False,
sample_weight_opts=None):

"""
construct and train a scikits svm model
model_type = which svm model to use (e.g. "svm.LinearSVC")
model_kwargs = args sent to classifier initialization
fit_kwargs = args sent to classifier fit
use_decisions = whether to use "raw decisions" in weights or not
normalization = do feature-wise train data normalization
trace_normalize = do example-wise trace normalization
"""
if model_kwargs is None:
model_kwargs = {}
if fit_kwargs is None:
fit_kwargs = {}

train_features, train_labels, train_decisions = train_Xyd
assert labelset == [-1, 1] or labelset == range(len(labels)), labels
assert set(train_labels) == set(labelset)

#do normalization
if normalization:
train_features, train_mean, train_std, trace = normalize(
[train_features], trace_normalize=trace_normalize)
else:
train_mean = None
train_std = None
trace = None

if sample_weight_opts is not None:
#NB: if sample_weight_opts is not None, the classifier better support
#sample_weights fit argument, e.g. svm.SVC is fine but svm.LinearSVC
#is NOT.
assert 'sample_weight' not in fit_kwargs
use_raw_decisions = sample_weight_opts['use_raw_decisions']
alpha = sample_weight_opts['alpha']
sample_weights = sample_weights_from_decisions(decisions=train_decisions,
labels=train_labels,
labelset=labelset,
use_raw_decisions=use_raw_decisions,
alpha=alpha)
fit_kwargs['sample_weight'] = sample_weights

model = train_scikits_core(train_features=train_features,
train_labels=train_labels,
model_type=model_type,
labelset=labelset,
model_kwargs=model_kwargs,
fit_kwargs=fit_kwargs)

train_data = {'train_mean':train_mean,
'train_std': train_std,
'trace': trace}

return model, train_data


def train_scikits_core(train_features,
train_labels,
model_type,
labelset,
model_kwargs,
fit_kwargs
):
"""
"""
if model_type.startswith('svm.'):
ct = model_type.split('.')[-1]
cls = getattr(sklearn_svm,ct)
else:
raise ValueError('Model type %s not recognized' % model_type)
if labelset == [-1, 1]:
clf = cls(**model_kwargs)
else:
clf = OneVsRestClassifier(cls(**model_kwargs))
clf.fit(train_features, train_labels, **fit_kwargs)
return clf


def sample_weights_from_decisions(decisions,
labels,
labelset,
use_raw_decisions,
alpha):

assert labelset == [-1, 1] or labelset == range(len(labelset))
assert decisions.shape[0] == labels.shape[0]

if labelset == [-1, 1]:
decisions = np.column_stack([-decisions, decisions]) / 2.
labels = ((1 + labels) / 2).astype(np.int)

if use_raw_decisions:
actual = decisions[range(len(labels)), labels]
decisions_c = decisions.copy()
decisions_c[range(len(labels)), labels] = -np.inf
max_c = decisions_c.max(1)
margins = actual - max_c
else:
predictions = decisions.argmax(1)
margins = 2 * (predictions == labels).astype(np.int) - 1
weights = np.exp(-alpha * margins)
weights = weights / weights.sum()

return weights


#########
##stats##
Expand Down
14 changes: 11 additions & 3 deletions eccv12/lfw.py
Expand Up @@ -20,7 +20,7 @@

from .bandits import BaseBandit, validate_config, validate_result
from .utils import ImgLoaderResizer
from .classifier import get_result
from .classifier import get_result, train_scikits

# -- register symbols in pyll.scope
import toyproblem
Expand Down Expand Up @@ -209,8 +209,11 @@ def verification_pairs(split, test=None):
lidxs, ridxs = _verification_pairs_helper(all_paths, lpaths, rpaths)
if test is None:
return lidxs, ridxs, (matches * 2 - 1)
else:
elif isinstance(test, int):
return lidxs[:test], ridxs[:test], (matches[:test] * 2 - 1)
else:
assert all([isinstance(_t, int) for _t in test])
return lidxs[test], ridxs[test], (matches[test] * 2 - 1)


@scope.define
Expand Down Expand Up @@ -523,7 +526,12 @@ def train_view2(namebases, basedirs, test=None, use_libsvm=False):

print ('Training split %d ...' % ind)
if use_libsvm:
pass
svm, _ = train_scikits(train_Xyd_n,
labelset=[-1, 1],
model_type='svm.SVC',
model_kwargs={'kernel': 'linear'},
normalization=False
)
else:
svm = toyproblem.train_svm(train_Xyd_n,
l2_regularization=1e-3,
Expand Down
9 changes: 9 additions & 0 deletions eccv12/tests/test_lfw.py
Expand Up @@ -259,3 +259,12 @@ def test_baby_view2():
test=50)
return lfw.train_view2([''],[os.getcwd()], test=50)


@attr('slow') #takes about 30 sec with cpu
def test_baby_view2_libsvm():
c = config_tiny_rnd0
test_set = range(20) + range(500, 520)
lfw.get_view2_features(c['slm'], c['preproc'], 'mult', 'libsvm', os.getcwd(),
test=test_set)
return lfw.train_view2(['libsvm'],[os.getcwd()],
test=test_set, use_libsvm=True)

0 comments on commit 12220a7

Please sign in to comment.