-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[te] Ban uint8 tensors from fusion groups
Pull Request resolved: #49247 uint8's expose all kind of corner cases in type promotion. As an example, consider: ``` >>> torch.tensor([1], dtype=torch.uint8).lt(-1) tensor([True]) >>> torch.tensor([1], dtype=torch.uint8).lt(torch.tensor(-1)) tensor([True]) >>> torch.tensor([1], dtype=torch.uint8).lt(torch.tensor([-1])) tensor([False]) ``` the difference is how promotions involving scalars (or 0-dim tensors, which are treated like scalars) are prioritized compared to tensor dtypes. Per @eellison, the order is something like: 1. Tensor FP types 2. Scalar FP types 3. Tensor Int types 4. Scalar Int types The logic for this is here: https://github.com/pytorch/pytorch/blob/c73e97033a3aef97a5685588ea014d54a5cc11cc/aten/src/ATen/native/TypeProperties.cpp#L93 AFAICT the effects are mainly visible for the unsigned byte type (the only unsigned type, besides bool) since the others degrade more or less gracefully. It's hard to re-use this logic as is in TensorIterator/TypeProperties, and it's complicated enough that it's not worth re-implementing in TE unless there's evidence that it matters for real models. ghstack-source-id: 118430556 Differential Revision: [D25489035](https://our.internmc.facebook.com/intern/diff/D25489035/)
- Loading branch information
Showing
2 changed files
with
46 additions
and
118 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters