Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix scalar ops again on "[te] Ban uint8 tensors from fusion groups"
uint8's expose all kind of corner cases in type promotion. As an example, consider: ``` >>> torch.tensor([1], dtype=torch.uint8).lt(-1) tensor([True]) >>> torch.tensor([1], dtype=torch.uint8).lt(torch.tensor(-1)) tensor([True]) >>> torch.tensor([1], dtype=torch.uint8).lt(torch.tensor([-1])) tensor([False]) ``` the difference is how promotions involving scalars (or 0-dim tensors, which are treated like scalars) are prioritized compared to tensor dtypes. Per @eellison, the order is something like: 1. Tensor FP types 2. Scalar FP types 3. Tensor Int types 4. Scalar Int types The logic for this is here: https://github.com/pytorch/pytorch/blob/c73e97033a3aef97a5685588ea014d54a5cc11cc/aten/src/ATen/native/TypeProperties.cpp#L93 AFAICT the effects are mainly visible for the unsigned byte type (the only unsigned type, besides bool) since the others degrade more or less gracefully. It's hard to re-use this logic as is in TensorIterator/TypeProperties, and it's complicated enough that it's not worth re-implementing in TE unless there's evidence that it matters for real models. Differential Revision: [D25489035](https://our.internmc.facebook.com/intern/diff/D25489035/) [ghstack-poisoned]
- Loading branch information