Skip to content

Commit

Permalink
Also use custom scorer if available for fit --evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Nouri authored and dnouri committed Nov 9, 2017
1 parent 3ec0a6b commit 6b587a6
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions palladium/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datetime import datetime
from docopt import docopt
from pprint import pformat
from sklearn.metrics import get_scorer
from sklearn.model_selection import GridSearchCV

from .interfaces import annotate
Expand All @@ -22,7 +23,7 @@
@args_from_config
def fit(dataset_loader_train, model, model_persister, persist=True,
activate=True, dataset_loader_test=None, evaluate=False,
persist_if_better_than=None):
persist_if_better_than=None, scoring=None):

if persist_if_better_than is not None:
evaluate = True
Expand All @@ -32,6 +33,18 @@ def fit(dataset_loader_train, model, model_persister, persist=True,
"provide a 'dataset_loader_test'."
)

if evaluate and not (hasattr(model, 'score') or scoring is not None):
raise ValueError(
"Your model doesn't seem to implement a 'score' method. You may "
"want to define a 'scoring' option in the configuration."
)

if scoring is not None:
scorer = get_scorer(scoring)
else:
def scorer(model, X, y):
return model.score(X, y)

with timer(logger.info, "Loading data"):
X, y = dataset_loader_train()

Expand All @@ -40,15 +53,15 @@ def fit(dataset_loader_train, model, model_persister, persist=True,

if evaluate:
with timer(logger.debug, "Evaluating model on train set"):
score_train = model.score(X, y)
score_train = scorer(model, X, y)
annotate(model, {'score_train': score_train})
logger.info("Train score: {}".format(score_train))

score_test = None
if evaluate and dataset_loader_test is not None:
X_test, y_test = dataset_loader_test()
with timer(logger.debug, "Evaluating model on test set"):
score_test = model.score(X_test, y_test)
score_test = scorer(model, X_test, y_test)
annotate(model, {'score_test': score_test})
logger.info("Test score: {}".format(score_test))

Expand Down

0 comments on commit 6b587a6

Please sign in to comment.