Skip to content

Commit

Permalink
Fix issue 112: EpochScoring flaky when y_test is None.
Browse files Browse the repository at this point in the history
  • Loading branch information
benjamin-work authored and ottonemo committed Nov 29, 2017
1 parent fe41f9d commit b5df830
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
7 changes: 5 additions & 2 deletions skorch/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ class EpochScoring(ScoringBase):
This is called on y before it is passed to scoring.
"""
# pylint: disable=unused-argument
# pylint: disable=unused-argument,arguments-differ
def on_epoch_end(
self,
net,
Expand All @@ -291,7 +291,10 @@ def on_epoch_end(
if X_test is None:
return

y_test = self.target_extractor(y_test)
if y_test is not None:
# We allow y_test to be None but the scoring function has
# to be able to deal with it (i.e. called without y_test).
y_test = self.target_extractor(y_test)
current_score = self._scoring(net, X_test, y_test)
history.record(self.name_, current_score)

Expand Down
24 changes: 24 additions & 0 deletions skorch/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,30 @@ def test_target_extractor_is_called(

assert extractor.call_count == 2

def test_without_target_data_works(
self, net_cls, module_cls, scoring_cls, data,
):
def myscore(_, X, y=None):
assert y is None
return np.mean(X)

def mysplit(X, y):
# set y_valid to None
return X, X, y, None

X, y = data
net = net_cls(
module=module_cls,
callbacks=[scoring_cls(myscore)],
train_split=mysplit,
max_epochs=2,
)
net.fit(X, y)

expected = np.mean(X)
loss = net.history[:, 'myscore']
assert np.allclose(loss, expected)


class TestBatchScoring:
@pytest.fixture
Expand Down

0 comments on commit b5df830

Please sign in to comment.