Skip to content

Commit

Permalink
If a model to persist has 'cv_results', store them in the metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
dnouri committed May 24, 2018
1 parent 4ef2cee commit 3db09ce
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 3 deletions.
10 changes: 8 additions & 2 deletions palladium/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@


def _persist_model(model, model_persister, activate=True):
annotate(model, {'train_timestamp': datetime.now().isoformat()})
metadata = {
'train_timestamp': datetime.now().isoformat(),
}
cv_results = getattr(model, 'cv_results_', None)
if cv_results is not None:
metadata['cv_results'] = pandas.DataFrame(cv_results).to_json()
annotate(model, metadata)
with timer(logger.info, "Writing model"):
version = model_persister.write(model)
logger.info("Wrote model with version {}.".format(version))
Expand Down Expand Up @@ -215,7 +221,7 @@ def grid_search(dataset_loader_train, model, grid_search, scoring=None,
if save_results:
results.to_csv(save_results, index=False)
if persist_best:
_persist_model(gs.best_estimator_, model_persister, activate=True)
_persist_model(gs, model_persister, activate=True)
return gs


Expand Down
41 changes: 40 additions & 1 deletion palladium/tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def dataset_loader(self):

def test_it(self, fit):
model, dataset_loader_train, model_persister = Mock(), Mock(), Mock()
del model.cv_results_
X, y = object(), object()
dataset_loader_train.return_value = X, y

Expand All @@ -38,6 +39,7 @@ def test_it(self, fit):

def test_no_persist(self, fit):
model, dataset_loader_train, model_persister = Mock(), Mock(), Mock()
del model.cv_results_
X, y = object(), object()
dataset_loader_train.return_value = X, y

Expand All @@ -51,6 +53,7 @@ def test_no_persist(self, fit):

def test_evaluate_no_test_dataset(self, fit):
model, dataset_loader_train, model_persister = Mock(), Mock(), Mock()
del model.cv_results_
X, y = object(), object()
dataset_loader_train.return_value = X, y

Expand All @@ -66,6 +69,7 @@ def test_evaluate_no_test_dataset(self, fit):

def test_evaluate_with_test_dataset(self, fit):
model, dataset_loader_train, model_persister = Mock(), Mock(), Mock()
del model.cv_results_
dataset_loader_test = Mock()
X, y, X_test, y_test = object(), object(), object(), object()
dataset_loader_train.return_value = X, y
Expand All @@ -86,6 +90,7 @@ def test_evaluate_with_test_dataset(self, fit):

def test_evaluate_annotations(self, fit, dataset_loader):
model = Mock()
del model.cv_results_
model.score.side_effect = [0.9, 0.8]

result = fit(
Expand All @@ -101,6 +106,7 @@ def test_evaluate_annotations(self, fit, dataset_loader):

def test_evaluate_scoring(self, fit, dataset_loader):
model = Mock()
del model.cv_results_
scorer = Mock()
scorer.side_effect = [0.99, 0.01]

Expand All @@ -118,6 +124,7 @@ def test_evaluate_scoring(self, fit, dataset_loader):
def test_evaluate_no_score(self, fit, dataset_loader):
model = Mock()
del model.score
del model.cv_results_

with pytest.raises(ValueError):
fit(
Expand All @@ -131,6 +138,7 @@ def test_evaluate_no_score(self, fit, dataset_loader):
def test_persist_if_better_than(self, fit, dataset_loader):
model, model_persister = Mock(), Mock()
model.score.return_value = 0.9
del model.cv_results_

result = fit(
dataset_loader_train=dataset_loader,
Expand All @@ -146,6 +154,7 @@ def test_persist_if_better_than(self, fit, dataset_loader):
def test_persist_if_better_than_false(self, fit, dataset_loader):
model, model_persister = Mock(), Mock()
model.score.return_value = 0.9
del model.cv_results_

result = fit(
dataset_loader_train=dataset_loader,
Expand All @@ -161,6 +170,7 @@ def test_persist_if_better_than_false(self, fit, dataset_loader):
def test_persist_if_better_than_persist_false(self, fit, dataset_loader):
model, model_persister = Mock(), Mock()
model.score.return_value = 0.9
del model.cv_results_

result = fit(
dataset_loader_train=dataset_loader,
Expand All @@ -177,6 +187,7 @@ def test_persist_if_better_than_persist_false(self, fit, dataset_loader):
def test_persist_if_better_than_no_dataset_test(self, fit, dataset_loader):
model, model_persister = Mock(), Mock()
model.score.return_value = 0.9
del model.cv_results_

with pytest.raises(ValueError):
fit(
Expand All @@ -189,6 +200,7 @@ def test_persist_if_better_than_no_dataset_test(self, fit, dataset_loader):

def test_activate_no_persist(self, fit, dataset_loader):
model, model_persister = Mock(), Mock()
del model.cv_results_

result = fit(
dataset_loader_train=dataset_loader,
Expand All @@ -201,6 +213,7 @@ def test_activate_no_persist(self, fit, dataset_loader):

def test_timestamp(self, fit, dataset_loader):
model, model_persister = Mock(), Mock()
del model.cv_results_

def persist(model):
assert 'train_timestamp' in model.__metadata__
Expand All @@ -220,6 +233,30 @@ def persist(model):
assert before_fit < timestamp < after_fit
model_persister.write.assert_called_with(model)

def test_cv_results(self, fit, dataset_loader):
model, model_persister = Mock(), Mock()
model.cv_results_ = {
'mean_train_score': [3, 2, 1],
'mean_test_score': [1, 2, 3],
}

def persist(model):
assert 'cv_results' in model.__metadata__

model_persister.write.side_effect = persist

result = fit(
dataset_loader,
model,
model_persister,
)
assert result is model

cv_results = model.__metadata__['cv_results']
cv_results = pandas.read_json(cv_results).to_dict(orient='list')
assert cv_results == model.cv_results_
model_persister.write.assert_called_with(model)


def test_activate():
from palladium.fit import activate
Expand Down Expand Up @@ -321,6 +358,7 @@ def test_deprecated_scoring(self, grid_search, GridSearchCVWithScores):

def test_persist_best_requires_persister(self, grid_search):
model = Mock(spec=['fit', 'predict'])
del model.cv_results_
dataset_loader_train = Mock()
scoring = Mock()
dataset_loader_train.return_value = object(), object()
Expand All @@ -331,6 +369,7 @@ def test_persist_best_requires_persister(self, grid_search):

def test_persist_best(self, grid_search, GridSearchCVWithScores):
model = Mock(spec=['fit', 'predict'])
del model.cv_results_
dataset_loader_train = Mock()
scoring = Mock()
model_persister = Mock()
Expand All @@ -341,7 +380,7 @@ def test_persist_best(self, grid_search, GridSearchCVWithScores):
GridSearchCVWithScores.assert_called_with(
model, refit=True, scoring=scoring)
model_persister.write.assert_called_with(
GridSearchCVWithScores().best_estimator_)
GridSearchCVWithScores())

def test_grid_search(self, grid_search):
model, dataset_loader_train = Mock(), Mock()
Expand Down
1 change: 1 addition & 0 deletions palladium/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ def jobs(self, process_store):

def test_it(self, fit, config, jobs, flask_app):
dsl, model, model_persister = Mock(), Mock(), Mock()
del model.cv_results_
X, y = Mock(), Mock()
dsl.return_value = X, y
config['dataset_loader_train'] = dsl
Expand Down

0 comments on commit 3db09ce

Please sign in to comment.