Skip to content

Commit

Permalink
Better error message when data conversion fails.
Browse files Browse the repository at this point in the history
Specifically, when conversion to torch tensor or to numpy array fails,
users would often only get "AttributeError: bla has no attribute
.cuda". Now a more specific TypeError is shown.
  • Loading branch information
benjamin-work authored and ottonemo committed Dec 4, 2017
1 parent 2bbb638 commit 9eb1a6c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion skorch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def check_cv(self, y):
# doesn't work, still try.
try:
y_arr = to_numpy(y)
except AttributeError:
except (AttributeError, TypeError):
y_arr = y

if self._is_float(self.cv):
Expand Down
11 changes: 11 additions & 0 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ class Ansi(Enum):
ENDC = '\033[0m'


def is_torch_data_type(x):
# pylint: disable=protected-access
return isinstance(x, (torch.tensor._TensorBase, Variable))


def to_var(X, use_cuda):
"""Generic function to convert a input data to pytorch Variables.
Expand Down Expand Up @@ -73,6 +78,9 @@ def to_tensor(X, use_cuda):
elif np.isscalar(X):
X = torch.from_numpy(np.array([X]))

if not is_torch_data_type(X):
raise TypeError("Cannot convert this data type to a torch tensor.")

if use_cuda:
X = X.cuda()
return X
Expand All @@ -91,6 +99,9 @@ def to_numpy(X):
if is_pandas_ndframe(X):
return X.values

if not is_torch_data_type(X):
raise TypeError("Cannot convert this data type to a numpy array.")

if X.is_cuda:
X = X.cpu()

Expand Down

0 comments on commit 9eb1a6c

Please sign in to comment.