Skip to content

Commit

Permalink
fix scalar ops again on "[te] Ban uint8 tensors from fusion groups"
Browse files Browse the repository at this point in the history
uint8's expose all kind of corner cases in type promotion.  As an example, consider:
```
>>> torch.tensor([1], dtype=torch.uint8).lt(-1)
tensor([True])
>>> torch.tensor([1], dtype=torch.uint8).lt(torch.tensor(-1))
tensor([True])
>>> torch.tensor([1], dtype=torch.uint8).lt(torch.tensor([-1]))
tensor([False])
```
the difference is how promotions involving scalars (or 0-dim tensors, which are treated like scalars) are prioritized compared to tensor dtypes.
Per @eellison, the order is something like:
1. Tensor FP types
2. Scalar FP types
3. Tensor Int types
4. Scalar Int types

The logic for this is here: https://github.com/pytorch/pytorch/blob/c73e97033a3aef97a5685588ea014d54a5cc11cc/aten/src/ATen/native/TypeProperties.cpp#L93

AFAICT the effects are mainly visible for the unsigned byte type (the only unsigned type, besides bool) since the others degrade more or less gracefully.

It's hard to re-use this logic as is in TensorIterator/TypeProperties, and it's complicated enough that it's not worth re-implementing in TE unless there's evidence that it matters for real models.

Differential Revision: [D25489035](https://our.internmc.facebook.com/intern/diff/D25489035/)

[ghstack-poisoned]
  • Loading branch information
bertmaher committed Dec 14, 2020
2 parents 52899c3 + d807908 commit f5bdf85
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions torch/csrc/jit/passes/tensorexpr_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,11 @@ class TensorExprFuser {
if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) {
return false;
}
} else if (node->isMemberOf(float_only_operator_set)) {
// Check scalar operands of float-only ops.
if (!v->type()->cast<FloatType>()) {
return false;
}
}
}

Expand Down

0 comments on commit f5bdf85

Please sign in to comment.