-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Either support torch.mean on BoolTensors or fix the error message #64897
Comments
We would accept a PR updating mean participate in integer->floating point type promotion and a more short-term fix for the error message. Given PyTorch's current architecture taking the mean of a long tensor would probably require making a float copy because of build size concerns. On CUDA we might be able to fuse the copy into the kernel and not actually materialize the float copy, however. |
As dicussion for more long-termed solution, I wonder if JIT could produce these inputXoutput dtype combinations code? (including CPU?) |
Sure; if you have a CPU jit than can fuse pointwise operations into the prologue of reductions, which is a pretty common fusion scenario. |
It may prevent a vectorized read though? |
It's hard to speculate on the details of hypothetical systems |
Similar error for uint8 dtype: |
I guess this is because torch.mean first upcasts inputs to torch.int64 and then fails. But this is quite confusing + uses outdated dtype names (I propose error messages should use
torch.int64
).For int64 mean is also not supported:
It would also make sense if torch.mean was supported by torch.int64 (hopefully without copy allocation of float cast, but maybe for now this is not possible) without an explicit cast
This is also quite wasteful to copy-allocate torch.bool to torch.int64 to perform these basic operations :( Related: #55366
cc @heitorschueroff @brianjo @mruberry
The text was updated successfully, but these errors were encountered: