-
-
Notifications
You must be signed in to change notification settings - Fork 655
Closed
Description
🐛 Bug description
When using Accuracy.update() with both inputs having the second dimension 1, e.g. in my case torch.Size([256,1])
the raised error message is misleading.
To reproduce
from ignite.metrics import Accuracy
import torch
acc = Accuracy(is_multilabel=True)
acc.update((torch.zeros((256,1)), torch.zeros((256,1))))
ValueError: y and y_pred must have same shape of (batch_size, num_categories, ...).
In this case the y and y_pred do have the same shape but the issue is that it's not an accepted multilabel input (the and y.shape[1] != 1
in the following code block from _check_shape
in _BaseClassification
). This should be indicated in the error message (or the if statement changed).
What is the argument to not allow a y.shape[1]
of 1?
if self._is_multilabel and not (y.shape == y_pred.shape and y.ndimension() > 1 and y.shape[1] != 1):
raise ValueError("y and y_pred must have same shape of (batch_size, num_categories, ...).")
Environment
- PyTorch Version (e.g., 1.4):
- Ignite Version (e.g., 0.3.0): 0.3.0
- OS (e.g., Linux): Linux
- How you installed Ignite (
conda
,pip
, source): conda - Python version:
- Any other relevant information: