-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Description
🐛 Describe the bug
What happens
torch.rand((3,), dtype=torch.int32)
raises the rather enigmatic NotImplementedError: "check_uniform_bounds" not implemented for 'Long'
.
A different but similarly uninformative exception is raised for torch.randn
: NotImplementedError: "normal_kernel_cpu" not implemented for 'Long'
. Later work found other examples like &
so the problem is fairly wide.
What should actually happen
Those code snippets should raise the same exception with a better message.
NotImplementedError: torch.rand only returns floating point Tensors. Did you mean torch.randint?
and NotImplementedError: torch.randn only returns floating point Tensors
ValueError
would technically be more accurate because, for example, torch.rand
is implemented - it's just that the dtype
parameter has the right type, but the wrong value (even though that value is a type 😁).
But other code downstream might rely on NotImplementedError
, though this seems undocumented. Having a suboptimal exception type is not a grave sin compared with possible breakage. (If people are relying on that exact error text, well, they should expect breakage.)
Scope
I examined just the torch.*
functions whose name contained the string rand
, and I identified only rand
, randn
, rand_like
and randn_like
as having this problem. The torch.randint
and torch.randint_like
functions were well-behaved, with seemingly correct results for many dtypes.
Once I get to the root of it, I might identify other ops which might benefit from this change.
The hope is that there is one fairly logical spot where we can catch NotImplementedError e
for all the ops, ascertain completely that the wrong dtype is the correct cause, and if so, overwrite e.args
with our better message, then reraise. There will be a little extra work to find a possible "Did you mean?"
Most of the work is finding that one spot. I would expect this to be small but non-trivial.
cc @malfet