Skip to content
Browse files
Bugfix: sklearn 0.22 (#571)
* Remove import-outside-toplevel from pylint

Very unnecessary :/

* Make check_is_fitted work again

`check_is_fitted` was changed (see #570) in a way that we won't raise
a `NotInitializedError` anymore despite the net not being
initialized. This is now solved by basically porting the old
`check_is_fitted` behavior to skorch.

* Fix multi_indexing with sklearn 0.22

sklearn's (_)safe_indexing used to work with tuples as indices, now

* Remove unnecessary import

* Make convert_sklearn_metric_function work again

In sklearn 0.22, make_scorer returns a special object that must be
treated differently in skorch.

  • Loading branch information
BenjaminBossan authored and ottonemo committed Dec 16, 2019
1 parent f7e74af commit ab8e6edc487f0bef8109bfeacaae8e4cb070dac8
Showing 6 changed files with 40 additions and 19 deletions.
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](

### Fixed

- Make skorch compatible with sklearn 0.22

## [0.7.0] - 2019-11-29

@@ -56,6 +56,7 @@ disable =

@@ -1,7 +1,6 @@
"""Contains learning rate scheduler callbacks"""

import sys
import warnings

# pylint: disable=unused-import
import numpy as np
@@ -46,11 +46,16 @@ def convert_sklearn_metric_function(scoring):
sklearn scorer and return it. Otherwise, return ``scoring`` unchanged."""
if callable(scoring):
module = getattr(scoring, '__module__', None)

# those are scoring objects returned by make_scorer starting
# from sklearn 0.22
scorer_names = ('_PredictScorer', '_ProbaScorer', '_ThresholdScorer')
if (
hasattr(module, 'startswith') and
module.startswith('sklearn.metrics.') and
not module.startswith('sklearn.metrics.scorer') and
not module.startswith('sklearn.metrics.tests.')
not module.startswith('sklearn.metrics.tests.') and
not scoring.__class__.__name__ in scorer_names
return make_scorer(scoring)
return scoring
@@ -89,7 +94,12 @@ def _get_name(self):
if isinstance(self.scoring_, partial):
return self.scoring_.func.__name__
if isinstance(self.scoring_, _BaseScorer):
return self.scoring_._score_func.__name__
if hasattr(self.scoring_._score_func, '__name__'):
# sklearn < 0.22
return self.scoring_._score_func.__name__
# sklearn >= 0.22
return self.scoring_._score_func._score_func.__name__
if isinstance(self.scoring_, dict):
raise ValueError("Dict not supported as scorer for multi-metric scoring."
" Register multiple scoring callbacks instead.")
@@ -463,13 +463,20 @@ def test_index_torch_tensor_with_numpy_bool_array(self, multi_indexing):
assert (result == X[:100]).all()

def test_index_with_float_array_raises(self, multi_indexing):
# sklearn < 0.22 raises IndexError with msg0
# sklearn >= 0.22 raises ValueError with msg1
X = np.zeros(10)
i = np.arange(3, 0.5)
with pytest.raises(IndexError) as exc:

with pytest.raises((IndexError, ValueError)) as exc:
multi_indexing(X, i)

assert exc.value.args[0] == (
"arrays used as indices must be of integer (or boolean) type")
msg0 = "arrays used as indices must be of integer (or boolean) type"
msg1 = ("No valid specification of the columns. Only a scalar, list or "
"slice of all integers or all strings, or boolean mask is allowed")

result = exc.value.args[0]
assert result in (msg0, msg1)

def test_boolean_index_2d(self, multi_indexing):
X = np.arange(9).reshape(3, 3)
@@ -9,21 +9,25 @@
from enum import Enum
from functools import partial
from itertools import tee
from packaging import version
import pathlib
import warnings

import numpy as np
from scipy import sparse
from sklearn.utils import safe_indexing
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted as sklearn_check_is_fitted
import sklearn
import torch
from torch.nn.utils.rnn import PackedSequence
from import Subset

from skorch.exceptions import DeviceWarning
from skorch.exceptions import NotInitializedError

if version.parse(sklearn.__version__) >= version.parse('0.22.0'):
from sklearn.utils import _safe_indexing as safe_indexing
from sklearn.utils import safe_indexing

class Ansi(Enum):
BLUE = '\033[94m'
@@ -178,7 +182,8 @@ def _indexing_ndframe(data, i):

def _indexing_other(data, i):
if isinstance(i, (int, np.integer, slice)):
# sklearn's safe_indexing doesn't work with tuples since 0.22
if isinstance(i, (int, np.integer, slice, tuple)):
return data[i]
return safe_indexing(data, i)

@@ -494,15 +499,13 @@ def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
msg = ("This %(name)s instance is not initialized yet. Call "
"'initialize' or 'fit' with appropriate arguments "
"before using this method.")
except NotFittedError as e:
raise NotInitializedError(str(e))

if not isinstance(attributes, (list, tuple)):
attributes = [attributes]

if not all_or_any([hasattr(estimator, attr) for attr in attributes]):
raise NotInitializedError(msg % {'name': type(estimator).__name__})

class TeeGenerator:

0 comments on commit ab8e6ed

Please sign in to comment.