Skip to content
Permalink
Browse files

BUG Adds eps to loss to prevent nans (#491)

Adds eps to loss to prevent nans
  • Loading branch information...
thomasjpfan authored and BenjaminBossan committed Jul 8, 2019
1 parent afb6654 commit 873b481049a376c9836056140c121fe0ee1bd64e
Showing with 11 additions and 1 deletion.
  1. +2 −0 CHANGES.md
  2. +2 −1 skorch/classifier.py
  3. +7 −0 skorch/tests/test_classifier.py
@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Improve numerical stability when using `NLLLoss` in `NeuralNetClassifer` (#491)

### Fixed


@@ -103,7 +103,8 @@ def check_data(self, X, y):
# pylint: disable=arguments-differ
def get_loss(self, y_pred, y_true, *args, **kwargs):
if isinstance(self.criterion_, torch.nn.NLLLoss):
y_pred = torch.log(y_pred)
eps = torch.finfo(y_pred.dtype).eps
y_pred = torch.log(y_pred + eps)
return super().get_loss(y_pred, y_true, *args, **kwargs)

# pylint: disable=signature-differs
@@ -98,6 +98,13 @@ def test_takes_no_log_without_nllloss(self, net_cls, module_cls, data):
assert not (y_out < 0).all()
assert torch.isclose(torch.ones(len(y_out)), y_out.sum(1)).all()

# classifier-specific test
def test_high_learning_rate(self, net_cls, module_cls, data):
# regression test for nan loss with high learning rates issue #481
net = net_cls(module_cls, max_epochs=2, lr=2, optimizer=torch.optim.Adam)
net.fit(*data)
assert np.any(~np.isnan(net.history[:, 'train_loss']))


class TestNeuralNetBinaryClassifier:
@pytest.fixture(scope='module')

0 comments on commit 873b481

Please sign in to comment.
You can’t perform that action at this time.