Skip to content

Commit

Permalink
Allow NeuralNetClassifier to fit tensor labels #802 (#803)
Browse files Browse the repository at this point in the history
Apply to_numpy on y in NeuralNetClassifier.check_data
before applying np.unique in case that y is a torch tensor.

Co-authored-by: Autumnii <nuaazhouyi@nuaa.edu.cn>
  • Loading branch information
TheAutumnOfRice and Autumnii committed Oct 1, 2021
1 parent 32577c5 commit fe4fcc4
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions skorch/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
from skorch.callbacks import EpochScoring
from skorch.callbacks import PassthroughScoring
from skorch.dataset import CVSplit
from skorch.utils import get_dim
from skorch.utils import get_dim, to_numpy
from skorch.utils import is_dataset


neural_net_clf_doc_start = """NeuralNet for classification tasks
Use this specifically if you have a standard classification task,
Expand Down Expand Up @@ -117,7 +116,7 @@ def check_data(self, X, y):
raise ValueError(msg)
if y is not None:
# pylint: disable=attribute-defined-outside-init
self.classes_inferred_ = np.unique(y)
self.classes_inferred_ = np.unique(to_numpy(y))

# pylint: disable=arguments-differ
def get_loss(self, y_pred, y_true, *args, **kwargs):
Expand Down

0 comments on commit fe4fcc4

Please sign in to comment.