Skip to content
Permalink
Browse files

Add classes_ attribute NeuralNetClassifier (#546)

* Add a classes_ attribute on NeuralNetClassifier

This is inferred from y default but can be overridden by passing
classes explicitly during initialization.

* Add more rigorous tests to classifiers and regressor

That is to catch errors with specific attributes being set on those
that are not set on NeuralNet itself. This can potentially mess with
cloning, which is why it should be tested explicitly.
  • Loading branch information...
BenjaminBossan authored and ottonemo committed Nov 6, 2019
1 parent 4dcc8fe commit 8dd58ef996a0529dd2f407d59609c966a8cf5469
Showing with 86 additions and 3 deletions.
  1. +1 −0 CHANGES.md
  2. +29 −3 skorch/classifier.py
  3. +52 −0 skorch/tests/test_classifier.py
  4. +4 −0 skorch/tests/test_regressor.py
@@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Make NeuralNetBinaryClassifier work with sklearn.calibration.CalibratedClassifierCV
- Improve NeuralNetBinaryClassifier compatibility with certain sklearn metrics (#515)
- NeuralNetBinaryClassifier automatically squeezes module output if necessary (#515)
- NeuralNetClassifier now has a classes_ attribute after fit is called, which is inferred from y by default (#465, #486)

### Changed

@@ -28,18 +28,31 @@
"""

neural_net_clf_criterion_text = """
neural_net_clf_additional_text = """
criterion : torch criterion (class, default=torch.nn.NLLLoss)
Negative log likelihood loss. Note that the module should return
probabilities, the log is applied during ``get_loss``."""
probabilities, the log is applied during ``get_loss``.
classes : None or list (default=None)
If None, the ``classes_`` attribute will be inferred from the
``y`` data passed to ``fit``. If a non-empty list is passed,
that list will be returned as ``classes_``. If the initial
skorch behavior should be restored, i.e. raising an
``AttributeError``, pass an empty list."""

neural_net_clf_additional_attribute = """classes_ : array, shape (n_classes, )
A list of class labels known to the classifier.
"""


def get_neural_net_clf_doc(doc):
doc = neural_net_clf_doc_start + " " + doc.split("\n ", 4)[-1]
pattern = re.compile(r'(\n\s+)(criterion .*\n)(\s.+){1,99}')
start, end = pattern.search(doc).span()
doc = doc[:start] + neural_net_clf_criterion_text + doc[end:]
doc = doc[:start] + neural_net_clf_additional_text + doc[end:]
doc = doc + neural_net_clf_additional_attribute
return doc


@@ -53,6 +66,7 @@ def __init__(
*args,
criterion=torch.nn.NLLLoss,
train_split=CVSplit(5, stratified=True),
classes=None,
**kwargs
):
super(NeuralNetClassifier, self).__init__(
@@ -62,6 +76,7 @@ def __init__(
train_split=train_split,
**kwargs
)
self.classes = classes

@property
def _default_callbacks(self):
@@ -86,6 +101,15 @@ def _default_callbacks(self):
('print_log', PrintLog()),
]

@property
def classes_(self):
if self.classes is not None:
if not len(self.classes):
raise AttributeError("{} has no attribute 'classes_'".format(
self.__class__.__name__))
return self.classes
return self.classes_inferred_

# pylint: disable=signature-differs
def check_data(self, X, y):
if (
@@ -99,6 +123,8 @@ def check_data(self, X, y):
"``iterator_train`` and ``iterator_valid`` parameters "
"respectively.")
raise ValueError(msg)
if y is not None:
self.classes_inferred_ = np.unique(y)

# pylint: disable=arguments-differ
def get_loss(self, y_pred, y_true, *args, **kwargs):
@@ -10,6 +10,7 @@
import numpy as np
import pytest
import torch
from sklearn.base import clone
from torch import nn

from skorch.tests.conftest import INFERENCE_METHODS
@@ -57,6 +58,9 @@ def net_fit(self, net, data):
X, y = data
return net.fit(X, y)

def test_clone(self, net_fit):
clone(net_fit)

def test_predict_and_predict_proba(self, net_fit, data):
X = data[0]

@@ -107,6 +111,51 @@ def test_high_learning_rate(self, net_cls, module_cls, data):
net.fit(*data)
assert np.any(~np.isnan(net.history[:, 'train_loss']))

def test_binary_classes_set_by_default(self, net_cls, module_cls, data):
net = net_cls(module_cls).fit(*data)
assert (net.classes_ == [0, 1]).all()

def test_non_binary_classes_set_by_default(self, net_cls, module_cls, data):
X = data[0]
y = np.arange(len(X)) % 10
net = net_cls(module_cls, max_epochs=0).fit(X, y)
assert (net.classes_ == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).all()

def test_classes_data_torch_tensor(self, net_cls, module_cls, data):
X = torch.as_tensor(data[0])
y = torch.as_tensor(np.arange(len(X)) % 10)

net = net_cls(module_cls, max_epochs=0).fit(X, y)
assert (net.classes_ == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).all()

def test_classes_with_gaps(self, net_cls, module_cls, data):
X = data[0]
y = np.arange(len(X)) % 10
y[(y == 0) | (y == 5)] = 4 # remove classes 0 and 5
net = net_cls(module_cls, max_epochs=0).fit(X, y)
assert (net.classes_ == [1, 2, 3, 4, 6, 7, 8, 9]).all()

def test_pass_classes_explicitly_overrides(self, net_cls, module_cls, data):
net = net_cls(module_cls, max_epochs=0, classes=['foo', 'bar']).fit(*data)
assert net.classes_ == ['foo', 'bar']

@pytest.mark.parametrize('classes', [[], np.array([])])
def test_pass_empty_classes_raises(
self, net_cls, module_cls, data, classes):
net = net_cls(
module_cls, max_epochs=0, classes=classes).fit(*data).fit(*data)
with pytest.raises(AttributeError) as exc:
net.classes_

msg = exc.value.args[0]
expected = "NeuralNetClassifier has no attribute 'classes_'"
assert msg == expected

def test_with_calibrated_classifier_cv(self, net_fit, data):
from sklearn.calibration import CalibratedClassifierCV
cccv = CalibratedClassifierCV(net_fit, cv=2)
cccv.fit(*data)


class TestNeuralNetBinaryClassifier:
@pytest.fixture(scope='module')
@@ -152,6 +201,9 @@ def test_fit(self, net_fit):
# fitting does not raise anything
pass

def test_clone(self, net_fit):
clone(net_fit)

@pytest.mark.parametrize('method', INFERENCE_METHODS)
def test_not_fitted_raises(self, net_cls, module_cls, data, method):
from skorch.exceptions import NotInitializedError
@@ -7,6 +7,7 @@
from flaky import flaky
import numpy as np
import pytest
from sklearn.base import clone
import torch

from skorch.tests.conftest import INFERENCE_METHODS
@@ -58,6 +59,9 @@ def net_fit(self, net, data):
X, y = data
return net.fit(X, y)

def test_clone(self, net_fit):
clone(net_fit)

def test_fit(self, net_fit):
# fitting does not raise anything
pass

0 comments on commit 8dd58ef

Please sign in to comment.
You can’t perform that action at this time.