Skip to content

Commit d885d4f

Browse files
committed
enable IntLikeType
1 parent 1a6d50d commit d885d4f

File tree

3 files changed

+33
-39
lines changed

3 files changed

+33
-39
lines changed

torch/_functorch/_aot_autograd/subclass_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch.utils._pytree as pytree
1515
from torch import SymInt, Tensor
1616
from torch._subclasses.fake_tensor import get_plain_tensors
17+
from torch.types import IntLikeType
1718
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
1819

1920
from .schemas import (
@@ -166,9 +167,9 @@ def create_subclass_meta(
166167
return infos
167168

168169

169-
def filter_symints(lst: Iterable[Union[int, SymInt]]):
170+
def filter_symints(lst: Iterable[IntLikeType]):
170171
# Capture all SymInts from the iterable.
171-
def symint_check(s: Union[int, SymInt]) -> bool:
172+
def symint_check(s: IntLikeType) -> bool:
172173
return isinstance(s, SymInt) and not s.node.is_nested_int()
173174

174175
return [s for s in lst if symint_check(s)]

torch/_higher_order_ops/triton_kernel_wrap.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
track_tensor_tree,
2626
)
2727
from torch.fx.experimental.symbolic_shapes import guard_scalar
28+
from torch.types import IntLikeType
2829

2930

3031
if TYPE_CHECKING:
@@ -71,9 +72,9 @@ class JITFunction: # type: ignore[no-redef]
7172
TMADescriptorMetadata = dict[
7273
str, # kernel parameter name
7374
tuple[
74-
list[Union[int, SymInt]], # dims
75-
list[Union[int, SymInt]], # block_dims
76-
Union[int, SymInt], # element_size
75+
list[IntLikeType], # dims
76+
list[IntLikeType], # block_dims
77+
IntLikeType, # element_size
7778
],
7879
]
7980

torch/fx/experimental/symbolic_shapes.py

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102

103103
from torch import Tensor
104104
from torch._subclasses.fake_tensor import FakeTensor
105-
from torch.types import BoolLikeType
105+
from torch.types import BoolLikeType, IntLikeType
106106

107107

108108
InputList = list
@@ -240,8 +240,8 @@ def __hash__(self) -> int:
240240

241241

242242
def _nested_int_aware_sort(
243-
tup: tuple[Union[SymInt, int], int]
244-
) -> tuple[int, Union[SymInt, int], int]:
243+
tup: tuple[IntLikeType, int]
244+
) -> tuple[int, IntLikeType, int]:
245245
return (
246246
# Order nested ints by their coefficients.
247247
# 1 here to order nested ints after non-nested-ints.
@@ -369,7 +369,7 @@ def has_hint(a: Scalar) -> bool:
369369
return True
370370

371371

372-
def is_concrete_int(a: Union[int, SymInt]) -> bool:
372+
def is_concrete_int(a: IntLikeType) -> bool:
373373
"""
374374
Utility to check if underlying object
375375
in SymInt is concrete value. Also returns
@@ -824,7 +824,7 @@ def div_by_factor(x: sympy.Expr, factor: int) -> sympy.Expr:
824824
return expr
825825

826826

827-
def is_nested_int(s: Union[int, SymInt]) -> TypeGuard[SymInt]:
827+
def is_nested_int(s: IntLikeType) -> TypeGuard[SymInt]:
828828
return isinstance(s, torch.SymInt) and s.node.is_nested_int()
829829

830830

@@ -938,7 +938,7 @@ class ConvertIntKey:
938938
def __str__(self) -> str:
939939
return ".cast_symbool_to_symint_guardless()"
940940

941-
def get(self, b: bool) -> Union[int, SymInt]:
941+
def get(self, b: bool) -> IntLikeType:
942942
"""Get the int value from bool"""
943943
return cast_symbool_to_symint_guardless(b)
944944

@@ -969,7 +969,7 @@ def get(self, o: Any) -> Any:
969969

970970
@dataclass(frozen=True)
971971
class DivideByKey:
972-
divisor: Union[int, SymInt]
972+
divisor: IntLikeType
973973

974974
def __str__(self) -> str:
975975
return f".__floordiv__({self.divisor})"
@@ -1097,7 +1097,7 @@ def _symint_wrap(s: sympy.Symbol) -> SymInt:
10971097
)
10981098

10991099
unbacked = lhs if lhs in pending else rhs
1100-
divisor: Union[int, SymInt] = (
1100+
divisor: IntLikeType = (
11011101
int(coeff)
11021102
if shape_env and isinstance(coeff, sympy.Integer)
11031103
else _symint_wrap(coeff)
@@ -1439,7 +1439,7 @@ def _advise_is_size(a: SymInt) -> None:
14391439
_constrain_range_for_size(a)
14401440

14411441

1442-
def _advise_is_bounded(a: SymInt, upper_bound: Union[int, SymInt]) -> None:
1442+
def _advise_is_bounded(a: SymInt, upper_bound: IntLikeType) -> None:
14431443
if (
14441444
isinstance(a, SymInt)
14451445
and isinstance(a.node, SymNode)
@@ -1585,7 +1585,7 @@ def guard_bool(a: Union[SymBool, bool]) -> bool:
15851585
return a
15861586

15871587

1588-
def guard_int(a: Union[SymInt, int]) -> int:
1588+
def guard_int(a: IntLikeType) -> int:
15891589
if isinstance(a, SymInt):
15901590
return a.node.guard_int("", 0) # NB: uses Python backtrace
15911591
assert type(a) is int, a
@@ -3987,7 +3987,7 @@ def _update_version_counter(self) -> None:
39873987

39883988
def _produce_dyn_sizes(
39893989
self,
3990-
ex_size: Sequence[Union[int, SymInt]],
3990+
ex_size: Sequence[IntLikeType],
39913991
source: Source,
39923992
symbolic_context: SymbolicContext,
39933993
) -> list[sympy.Expr]:
@@ -3997,7 +3997,7 @@ def _produce_dyn_sizes(
39973997

39983998
def _produce_dyn_sizes_from_int_tuple(
39993999
self,
4000-
tensor_size: Sequence[Union[int, SymInt]],
4000+
tensor_size: Sequence[IntLikeType],
40014001
source: Source,
40024002
symbolic_context: SymbolicContext,
40034003
) -> list[sympy.Expr]:
@@ -4034,11 +4034,7 @@ def create_symbolic_sizes_strides_storage_offset(
40344034
source: Source,
40354035
*,
40364036
symbolic_context: Optional[SymbolicContext] = None,
4037-
) -> tuple[
4038-
tuple[Union[int, SymInt], ...],
4039-
tuple[Union[int, SymInt], ...],
4040-
Union[int, SymInt],
4041-
]:
4037+
) -> tuple[tuple[IntLikeType, ...], tuple[IntLikeType, ...], IntLikeType,]:
40424038
"""
40434039
Returns a list of symbolic sizes and strides for the given tensor.
40444040
We try our best to express stride in terms of the sizes, so as to not
@@ -4099,8 +4095,8 @@ def create_symbolic_sizes_strides_storage_offset(
40994095
# If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
41004096
# we may have an unnessary shape speciliazation for y.
41014097
def _maybe_specialize_sym_int_with_hint(
4102-
self, maybe_sym: Union[int, SymInt]
4103-
) -> Union[int, SymInt]:
4098+
self, maybe_sym: IntLikeType
4099+
) -> IntLikeType:
41044100
assert isinstance(maybe_sym, (int, torch.SymInt))
41054101
if is_symbolic(maybe_sym):
41064102
assert (
@@ -4114,18 +4110,14 @@ def _create_symbolic_sizes_strides_storage_offset(
41144110
self,
41154111
# NB: SymInt is allowed here due to nested int, normally you don't
41164112
# actually pass true symbolic sizes to this function
4117-
ex_size: Sequence[Union[int, SymInt]],
4118-
ex_stride: Sequence[Union[int, SymInt]],
4119-
ex_storage_offset: Union[int, SymInt],
4113+
ex_size: Sequence[IntLikeType],
4114+
ex_stride: Sequence[IntLikeType],
4115+
ex_storage_offset: IntLikeType,
41204116
is_dim_dynamic: Sequence[bool],
41214117
source: Source,
41224118
*,
41234119
symbolic_context: Optional[SymbolicContext] = None,
4124-
) -> tuple[
4125-
tuple[Union[int, SymInt], ...],
4126-
tuple[Union[int, SymInt], ...],
4127-
Union[int, SymInt],
4128-
]:
4120+
) -> tuple[tuple[IntLikeType, ...], tuple[IntLikeType, ...], IntLikeType,]:
41294121
dim = len(ex_size)
41304122

41314123
# Reimplement the legacy behavior
@@ -4232,8 +4224,8 @@ def _compute_symbolic_stride(
42324224
self,
42334225
source: Source,
42344226
size: Sequence[sympy.Expr],
4235-
ex_size: Sequence[Union[int, SymInt]],
4236-
ex_stride: Sequence[Union[int, SymInt]],
4227+
ex_size: Sequence[IntLikeType],
4228+
ex_stride: Sequence[IntLikeType],
42374229
dynamic_strides: Sequence[DimDynamic],
42384230
constraint_strides: Sequence[
42394231
Optional[Union[StrictMinMaxConstraint, RelaxedUnspecConstraint]]
@@ -4244,7 +4236,7 @@ def _compute_symbolic_stride(
42444236
from torch._dynamo.source import TensorProperty, TensorPropertySource
42454237

42464238
stride: list[Optional[sympy.Expr]] = [None] * len(size)
4247-
candidates: dict[Union[int, SymInt], sympy.Expr] = {}
4239+
candidates: dict[IntLikeType, sympy.Expr] = {}
42484240

42494241
# iterate over unbound strides in val ascending order with
42504242
# index descending as a tie breaker since for cases like
@@ -4293,7 +4285,7 @@ def create_symintnode(
42934285
*,
42944286
hint: Optional[int],
42954287
source: Optional[Source] = None,
4296-
) -> Union[int, SymInt]:
4288+
) -> IntLikeType:
42974289
"""Create a SymInt value from a symbolic expression
42984290
42994291
If you know what the current hint value of the SymInt to be created
@@ -4314,7 +4306,7 @@ def create_symintnode(
43144306
else:
43154307
fx_node = None
43164308

4317-
out: Union[int, SymInt]
4309+
out: IntLikeType
43184310
if isinstance(sym, sympy.Integer):
43194311
if hint is not None:
43204312
assert int(sym) == hint
@@ -4370,7 +4362,7 @@ def create_symfloatnode(
43704362
@record_shapeenv_event()
43714363
def create_unspecified_symint_and_symbol(
43724364
self, value: int, source: Source, dynamic_dim: DimDynamic
4373-
) -> Union[int, SymInt]:
4365+
) -> IntLikeType:
43744366
"""Create a SymInt wrapping a new unspecified symbol"""
43754367
return self.create_symintnode(
43764368
self.create_unspecified_symbol(
@@ -5084,7 +5076,7 @@ def get_expression(tensor_dim_src: Source) -> sympy.Expr:
50845076
# tensors that never actually become graph arguments (they are
50855077
# pruned). In this case, only Dynamo knows about these arguments.
50865078
def track_symint(
5087-
source: Source, val: Union[SymInt, int], constraint: DimConstraint = None
5079+
source: Source, val: IntLikeType, constraint: DimConstraint = None
50885080
) -> None:
50895081
log.debug("track_symint %s %s %s", LazyString(source.name), val, constraint)
50905082
assert not isinstance(val, SymInt) or is_symbolic(val)

0 commit comments

Comments
 (0)