Skip to content

Commit

Permalink
Merge pull request #163 from shengshuyang/master
Browse files Browse the repository at this point in the history
create a SaveLogger.save() method.
  • Loading branch information
amueller committed Dec 14, 2015
2 parents 6504cd8 + 45e13b1 commit f103fb9
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 9 deletions.
44 changes: 44 additions & 0 deletions pystruct/tests/test_utils/test_utils_logging.py
@@ -0,0 +1,44 @@
import numpy as np
from tempfile import mkstemp

from sklearn.datasets import load_iris
from sklearn.cross_validation import train_test_split

from pystruct.models import GraphCRF
from pystruct.learners import NSlackSSVM
from pystruct.utils import SaveLogger
from pystruct.inference import get_installed

from nose.tools import assert_less, assert_almost_equal

# we always try to get the fastest installed inference method
inference_method = get_installed(["qpbo", "ad3", "lp"])[0]

def test_logging():
iris = load_iris()
X, y = iris.data, iris.target

X_ = [(np.atleast_2d(x), np.empty((0, 2), dtype=np.int)) for x in X]
Y = y.reshape(-1, 1)

X_train, X_test, y_train, y_test = train_test_split(X_, Y, random_state=1)
_, file_name = mkstemp()

pbl = GraphCRF(n_features=4, n_states=3, inference_method='lp')
logger = SaveLogger(file_name)
svm = NSlackSSVM(pbl, C=100, n_jobs=1, logger=logger)
svm.fit(X_train, y_train)

score_current = svm.score(X_test, y_test)
score_auto_saved = logger.load().score(X_test, y_test)

alt_file_name = file_name + "alt"
logger.save(svm, alt_file_name)
logger.file_name = alt_file_name
logger.load()
score_manual_saved = logger.load().score(X_test, y_test)

assert_less(.97, score_current)
assert_less(.97, score_auto_saved)
assert_less(.97, score_manual_saved)
assert_almost_equal(score_auto_saved, score_manual_saved)
22 changes: 13 additions & 9 deletions pystruct/utils/logging.py
Expand Up @@ -50,15 +50,19 @@ def __call__(self, learner, iteration=0):
file_name = file_name % iteration
if self.verbose > 0:
print("saving %s to file %s" % (learner, file_name))
with open(file_name, "wb") as f:
if hasattr(learner, 'inference_cache_'):
# don't store the large inference cache!
learner.inference_cache_, tmp = (None,
learner.inference_cache_)
pickle.dump(learner, f, -1)
learner.inference_cache_ = tmp
else:
pickle.dump(learner, f, -1)
self.save(learner, file_name)

def save(self, learner, file_name):
"""Save the model to location specified in file_name."""
with open(file_name, "wb") as f:
if hasattr(learner, 'inference_cache_'):
# don't store the large inference cache!
learner.inference_cache_, tmp = (None,
learner.inference_cache_)
pickle.dump(learner, f, -1)
learner.inference_cache_ = tmp
else:
pickle.dump(learner, f, -1)

def load(self):
"""Load the model stoed in file_name and return it."""
Expand Down

0 comments on commit f103fb9

Please sign in to comment.