Skip to content

Commit

Permalink
Allow regression with 1d targets (#974)
Browse files Browse the repository at this point in the history
* Allow regression with 1d targets

This change makes it possible to pass a 1-dimensional y to
`NeuralNetRegressor`.

Problem description

Right now, skorch requires the `y` passed to `NeuralNetRegressor.fit` to
be 2-dimensional, even if there is only one target, as is the most
common case. This problem has come up a few times in the past, but
mostly it's just an annoyance - just do `y.reshape(-1, 1)` and you're
good (the error message says as much).

There are, however, also cases where it's not so easy to solve. For
instance, in #972, a user reports that they cannot use skorch with
sklearn's `BaggingRegressor`. The problem is that even if `y` is
reshaped, once it is passed to the net from `BaggingRegressor`, it is 1d
again. I assume that `BaggingRegressor` internally squeezes `y` at some
point.

This PR lifts the 2d restriction check.

Initial motivation

Why does skorch require `y` to be 2d? I couldn't remember the initial
reasoning and did some archeology.

I found this comment:

(2f00e25#diff-66ed08bca4d171889565d0285a36b9b47e0e91e3b33d85c51352d8eb00faefac):

>         # The problem with 1-dim float y is that the pytorch DataLoader will
>         # somehow upcast it to DoubleTensor

This strange behavior should not be an issue anymore, so if that was the
only problem, we should be able to just remove the constraint, right?

Problems with removing the constraint

Unfortunately, it's not that easy. The issue comes down to the
following: When we remove the constraint and allow the target `y` to be
1d, but the prediction `y_pred` is still 2d, the criterion `nn.MSELoss`
will probably do the wrong thing. What exactly is wrong? Instead of
calculating the squared error for each sample pair, the criterion
will broadcast the vector and calculate _all squared errors_ between
each sample, then return the mean of that. To demonstrate, let's remove
the reduction step and look at the shape:

```python
>>> import torch
>>> criterion = torch.nn.MSELoss(reduction='none')
>>> y = torch.rand(100)
>>> y_pred = torch.rand((100, 1))
>>> y.shape, y_pred.shape
(torch.Size([100]), torch.Size([100, 1]))
>>> se = criterion(y_pred, y)
/home/vinh/anaconda3/envs/skorch/lib/python3.10/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([100])) that is different to the input size (torch.Size([100, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
>>> se.shape
torch.Size([100, 100])
```

As can be seen, PyTorch broadcasts the two arrays, leading to 100x100
errors being calculated. Thankfully, PyTorch warns about potential
issues with that.

The current solution is to accept this behavior and hope that the users
will indeed see the warning. If they don't see it or ignore it, it could
be a huge issue, because they still get a loss scalar and might even see
a small improvement in the loss during training. But the model will not
converge and it's going to be a huge pain to debug the bug, if it's even
identified as such.

Just to be clear, existing code, which uses 2d targets, will not be
affected by the change introduced in this PR and is still the preferred
way (IMO) to use regression in skorch.

Rejected solutions

I did consider the following solutions but rejected them.

Raising an error when shapes mismatch

This would remove the risk of users missing the warning. The problem
with this is that mismatching shapes can be okay in certain
circumstances. Some criteria don't expect target and prediction to have
the same shape, so we would need to check based on criterion. Moreover,
theoretically, users may indeed want to broadcast. Raising an error
would prevent that and users may have to resort to subclassing to
circumvent the error.

Automatic reshaping

We could automatically add/remove dimensions if we see that they
mismatch. This has the same problems as the previous solution regarding
the dependence on the type of criterion. Furthermore, automatic
adjustment of the user's output is prone to run into issues in some edge
cases (e.g. when the broadcasting is actually desired).

* Fix error when initializing BaggingRegressor

For Python 3.7, CI got:

TypeError: __init__() got an unexpected keyword argument 'estimator'

for BaggingRegressor. Probably it installs an older version of sklearn,
which uses a different argument name. Passing as positional arg should
fix it.

* Reviewer comment: typo

Co-authored-by: ottonemo <marian.tietz@ottogroup.com>

* Reviewer comment: typo

Co-authored-by: ottonemo <marian.tietz@ottogroup.com>

---------

Co-authored-by: ottonemo <marian.tietz@ottogroup.com>
  • Loading branch information
BenjaminBossan and ottonemo committed Jun 26, 2023
1 parent fe8286a commit df92d4d
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 23 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Moved from `pkg_resources` to `importlib` and subsequently dropping support for Python 3.7
as PyTorch moved dropped support and the version itself hit EOL (#928 and #983)

- `NeuralNetRegressor` can now be fitted with 1-dimensional `y`, which is necessary in some specific circumstances (e.g. in conjunction with sklearn's `BaggingRegressor`, see #972); for this to work correctly, the output of the of the PyTorch module should also be 1-dimensional; the existing default, i.e. having `y` and `y_pred` be 2-dimensional, remains the recommended way of using `NeuralNetRegressor`

### Fixed

## [0.13.0] - 2023-05-17
Expand Down
9 changes: 0 additions & 9 deletions skorch/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,6 @@ def check_data(self, X, y):
# The user implements its own mechanism for generating y.
return

if get_dim(y) == 1:
msg = (
"The target data shouldn't be 1-dimensional but instead have "
"2 dimensions, with the second dimension having the same size "
"as the number of regression targets (usually 1). Please "
"reshape your target data to be 2-dimensional "
"(e.g. y = y.reshape(-1, 1).")
raise ValueError(msg)

# pylint: disable=signature-differs
def fit(self, X, y, **fit_params):
"""See ``NeuralNet.fit``.
Expand Down
79 changes: 65 additions & 14 deletions skorch/tests/test_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""

from functools import partial

import numpy as np
import pytest
from sklearn.base import clone
Expand All @@ -21,6 +23,12 @@ def module_cls(self):
from skorch.toy import make_regressor
return make_regressor(dropout=0.5)

@pytest.fixture(scope='module')
def module_pred_1d_cls(self):
from skorch.toy import MLPModule
# Module that returns 1d predictions
return partial(MLPModule, output_units=1, squeeze_output=True)

@pytest.fixture(scope='module')
def net_cls(self):
from skorch import NeuralNetRegressor
Expand Down Expand Up @@ -57,9 +65,9 @@ def net_fit(self, net, data):
def test_clone(self, net_fit):
clone(net_fit)

def test_fit(self, net_fit):
# fitting does not raise anything
pass
def test_fit(self, net_fit, recwarn):
# fitting does not raise anything and does not warn
assert not recwarn.list

@pytest.mark.parametrize('method', INFERENCE_METHODS)
def test_not_fitted_raises(self, net_cls, module_cls, data, method):
Expand Down Expand Up @@ -91,17 +99,6 @@ def test_history_default_keys(self, net_fit):
for row in net_fit.history:
assert expected_keys.issubset(row)

def test_target_1d_raises(self, net, data):
X, y = data
with pytest.raises(ValueError) as exc:
net.fit(X, y.flatten())
assert exc.value.args[0] == (
"The target data shouldn't be 1-dimensional but instead have "
"2 dimensions, with the second dimension having the same size "
"as the number of regression targets (usually 1). Please "
"reshape your target data to be 2-dimensional "
"(e.g. y = y.reshape(-1, 1).")

def test_predict_predict_proba(self, net_fit, data):
X = data[0]
y_pred = net_fit.predict(X)
Expand All @@ -123,3 +120,57 @@ def test_multioutput_score(self, multioutput_net, multioutput_regression_data):
multioutput_net.fit(X, y)
r2_score = multioutput_net.score(X, y)
assert r2_score <= 1.

def test_dimension_mismatch_warning(self, net_cls, module_cls, data, recwarn):
# When the target and the prediction have different dimensionality, mse
# loss will broadcast them, calculating all pairwise errors instead of
# only sample-wise. Since the errors are averaged at the end, there is
# still a valid loss, which makes the error hard to spot. Thankfully,
# torch gives a warning in that case. We test that this warning exists,
# otherwise, skorch users could run into very hard to debug issues
# during training.
net = net_cls(module_cls)
X, y = data
X, y = X[:100], y[:100].flatten() # make y 1d
net.fit(X, y)

w0, w1 = recwarn.list # one warning for train, one for valid
# The warning comes from PyTorch, so checking the exact wording is prone to
# error in future PyTorch versions. We thus check a substring of the
# whole message and cross our fingers that it's not changed.
msg_substr = (
"This will likely lead to incorrect results due to broadcasting. "
"Please ensure they have the same size"
)
assert msg_substr in str(w0.message)
assert msg_substr in str(w1.message)

def test_fitting_with_1d_target_and_pred(
self, net_cls, module_cls, data, module_pred_1d_cls, recwarn
):
# This test relates to the previous one. In general, users should fit
# with target and prediction being 2d, even if the 2nd dimension is just
# 1. However, in some circumstances (like when using BaggingRegressor,
# see next test), having the ability to fit with 1d is required. In that
# case, the module output also needs to be 1d for correctness.
X, y = data
X, y = X[:100], y[:100] # less data to run faster
y = y.flatten()

net = net_cls(module_pred_1d_cls)
net.fit(X, y)
assert not recwarn.list

def test_bagging_regressor(
self, net_cls, module_cls, data, module_pred_1d_cls, recwarn
):
# https://github.com/skorch-dev/skorch/issues/972
from sklearn.ensemble import BaggingRegressor

net = net_cls(module_pred_1d_cls) # module output should be 1d too
X, y = data
X, y = X[:100], y[:100] # less data to run faster
y = y.flatten() # make y 1d or else sklearn will complain
regr = BaggingRegressor(net, n_estimators=2, random_state=0)
regr.fit(X, y) # does not raise
assert not recwarn.list # ensure there is no broadcast warning from torch

0 comments on commit df92d4d

Please sign in to comment.