Skip to content

Commit

Permalink
Make binary classifier work with BCELoss (#868)
Browse files Browse the repository at this point in the history
Make binary classifier work with BCELoss

Until now, it was assumed that BCEWithLogitLoss is used.
  • Loading branch information
BenjaminBossan committed Jul 21, 2022
1 parent 4b9e765 commit 91a5876
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 3 deletions.
58 changes: 57 additions & 1 deletion skorch/tests/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import pytest
import torch
from scipy.special import expit
from sklearn.base import clone
from torch import nn

Expand Down Expand Up @@ -88,7 +89,7 @@ def test_takes_log_with_nllloss(self, net_cls, module_cls, data):

# classifier-specific test
def test_takes_no_log_without_nllloss(self, net_cls, module_cls, data):
net = net_cls(module_cls, criterion=nn.BCELoss, max_epochs=1)
net = net_cls(module_cls, criterion=nn.CrossEntropyLoss, max_epochs=1)
net.initialize()

mock_loss = Mock(side_effect=nn.NLLLoss())
Expand Down Expand Up @@ -255,6 +256,17 @@ def test_predict_predict_proba(self, net, data, threshold):
y_pred_proba = net.predict_proba(X)
assert y_pred_proba.shape == (X.shape[0], 2)

# The tests below check that we don't accidentally apply sigmoid twice,
# which would result in probabilities constrained to [expit(-1),
# expit(1)]. The lower bound is not expit(0), as one may think at first,
# because we create the probabilities as:
# torch.stack((1 - prob, prob), 1)
# So the lowest value that could be achieved by applying sigmoid twice
# is 1 - expit(1), which is equal to expit(-1).
prob_min, prob_max = expit(-1), expit(1)
assert (y_pred_proba < prob_min).any()
assert (y_pred_proba > prob_max).any()

y_pred_exp = (y_pred_proba[:, 1] > threshold).astype('uint8')

y_pred_actual = net.predict(X)
Expand Down Expand Up @@ -352,3 +364,47 @@ def test_module_output_2d_raises(self, net_cls, data):
expected = ("Expected module output to have shape (n,) or "
"(n, 1), got (128, 2) instead")
assert msg == expected

@pytest.fixture(scope='module')
def net_with_bceloss(self, net_cls, module_cls, data):
# binary classification should also work with BCELoss
net = net_cls(
module_cls,
module__output_nonlin=torch.nn.Sigmoid(),
criterion=torch.nn.BCELoss,
lr=1,
)
X, y = data
net.fit(X, y)
return net

def test_net_with_bceloss_learns(self, net_with_bceloss):
train_losses = net_with_bceloss.history[:, 'train_loss']
assert train_losses[0] > 1.3 * train_losses[-1]

def test_predict_proba_with_bceloss(self, net_with_bceloss, data):
X, _ = data
y_proba = net_with_bceloss.predict_proba(X)

assert y_proba.shape == (X.shape[0], 2)
assert (y_proba >= 0).all()
assert (y_proba <= 1).all()

# The tests below check that we don't accidentally apply sigmoid twice,
# which would result in probabilities constrained to [expit(-1),
# expit(1)]. The lower bound is not expit(0), as one may think at first,
# because we create the probabilities as:
# torch.stack((1 - prob, prob), 1)
# So the lowest value that could be achieved by applying sigmoid twice
# is 1 - expit(1), which is equal to expit(-1).
prob_min, prob_max = expit(-1), expit(1)
assert (y_proba < prob_min).any()
assert (y_proba > prob_max).any()

def test_predict_with_bceloss(self, net_with_bceloss, data):
X, _ = data

y_pred_proba = net_with_bceloss.predict_proba(X)
y_pred_exp = (y_pred_proba[:, 1] > net_with_bceloss.threshold).astype('uint8')
y_pred_actual = net_with_bceloss.predict(X)
assert np.allclose(y_pred_exp, y_pred_actual)
19 changes: 17 additions & 2 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted as sk_check_is_fitted
import torch
from torch.nn import BCELoss
from torch.nn import BCEWithLogitsLoss
from torch.nn import CrossEntropyLoss
from torch.nn.utils.rnn import PackedSequence
Expand Down Expand Up @@ -585,6 +586,18 @@ def _identity(x):
return x


def _make_2d_probs(prob):
"""Create a 2d probability array from a 1d vector
This is needed because by convention, even for binary classification
problems, sklearn expects 2 probabilities to be returned per row, one for
class 0 and one for class 1.
"""
y_proba = torch.stack((1 - prob, prob), 1)
return y_proba


def _sigmoid_then_2d(x):
"""Transform 1-dim logits to valid y_proba
Expand All @@ -607,8 +620,7 @@ def _sigmoid_then_2d(x):
"""
prob = torch.sigmoid(x)
y_proba = torch.stack((1 - prob, prob), 1)
return y_proba
return _make_2d_probs(prob)


# TODO only needed if multiclass GP classfication is added
Expand Down Expand Up @@ -636,6 +648,9 @@ def _infer_predict_nonlinearity(net):
if isinstance(criterion, BCEWithLogitsLoss):
return _sigmoid_then_2d

if isinstance(criterion, BCELoss):
return _make_2d_probs

# TODO only needed if multiclass GP classfication is added
# likelihood = getattr(net, 'likelihood_', None)
# if (
Expand Down

0 comments on commit 91a5876

Please sign in to comment.