102102
103103 from torch import Tensor
104104 from torch ._subclasses .fake_tensor import FakeTensor
105- from torch .types import BoolLikeType
105+ from torch .types import BoolLikeType , IntLikeType
106106
107107
108108InputList = list
@@ -240,8 +240,8 @@ def __hash__(self) -> int:
240240
241241
242242def _nested_int_aware_sort (
243- tup : tuple [Union [ SymInt , int ] , int ]
244- ) -> tuple [int , Union [ SymInt , int ] , int ]:
243+ tup : tuple [IntLikeType , int ]
244+ ) -> tuple [int , IntLikeType , int ]:
245245 return (
246246 # Order nested ints by their coefficients.
247247 # 1 here to order nested ints after non-nested-ints.
@@ -369,7 +369,7 @@ def has_hint(a: Scalar) -> bool:
369369 return True
370370
371371
372- def is_concrete_int (a : Union [ int , SymInt ] ) -> bool :
372+ def is_concrete_int (a : IntLikeType ) -> bool :
373373 """
374374 Utility to check if underlying object
375375 in SymInt is concrete value. Also returns
@@ -824,7 +824,7 @@ def div_by_factor(x: sympy.Expr, factor: int) -> sympy.Expr:
824824 return expr
825825
826826
827- def is_nested_int (s : Union [ int , SymInt ] ) -> TypeGuard [SymInt ]:
827+ def is_nested_int (s : IntLikeType ) -> TypeGuard [SymInt ]:
828828 return isinstance (s , torch .SymInt ) and s .node .is_nested_int ()
829829
830830
@@ -938,7 +938,7 @@ class ConvertIntKey:
938938 def __str__ (self ) -> str :
939939 return ".cast_symbool_to_symint_guardless()"
940940
941- def get (self , b : bool ) -> Union [ int , SymInt ] :
941+ def get (self , b : bool ) -> IntLikeType :
942942 """Get the int value from bool"""
943943 return cast_symbool_to_symint_guardless (b )
944944
@@ -969,7 +969,7 @@ def get(self, o: Any) -> Any:
969969
970970@dataclass (frozen = True )
971971class DivideByKey :
972- divisor : Union [ int , SymInt ]
972+ divisor : IntLikeType
973973
974974 def __str__ (self ) -> str :
975975 return f".__floordiv__({ self .divisor } )"
@@ -1097,7 +1097,7 @@ def _symint_wrap(s: sympy.Symbol) -> SymInt:
10971097 )
10981098
10991099 unbacked = lhs if lhs in pending else rhs
1100- divisor : Union [ int , SymInt ] = (
1100+ divisor : IntLikeType = (
11011101 int (coeff )
11021102 if shape_env and isinstance (coeff , sympy .Integer )
11031103 else _symint_wrap (coeff )
@@ -1439,7 +1439,7 @@ def _advise_is_size(a: SymInt) -> None:
14391439 _constrain_range_for_size (a )
14401440
14411441
1442- def _advise_is_bounded (a : SymInt , upper_bound : Union [ int , SymInt ] ) -> None :
1442+ def _advise_is_bounded (a : SymInt , upper_bound : IntLikeType ) -> None :
14431443 if (
14441444 isinstance (a , SymInt )
14451445 and isinstance (a .node , SymNode )
@@ -1585,7 +1585,7 @@ def guard_bool(a: Union[SymBool, bool]) -> bool:
15851585 return a
15861586
15871587
1588- def guard_int (a : Union [ SymInt , int ] ) -> int :
1588+ def guard_int (a : IntLikeType ) -> int :
15891589 if isinstance (a , SymInt ):
15901590 return a .node .guard_int ("" , 0 ) # NB: uses Python backtrace
15911591 assert type (a ) is int , a
@@ -3987,7 +3987,7 @@ def _update_version_counter(self) -> None:
39873987
39883988 def _produce_dyn_sizes (
39893989 self ,
3990- ex_size : Sequence [Union [ int , SymInt ] ],
3990+ ex_size : Sequence [IntLikeType ],
39913991 source : Source ,
39923992 symbolic_context : SymbolicContext ,
39933993 ) -> list [sympy .Expr ]:
@@ -3997,7 +3997,7 @@ def _produce_dyn_sizes(
39973997
39983998 def _produce_dyn_sizes_from_int_tuple (
39993999 self ,
4000- tensor_size : Sequence [Union [ int , SymInt ] ],
4000+ tensor_size : Sequence [IntLikeType ],
40014001 source : Source ,
40024002 symbolic_context : SymbolicContext ,
40034003 ) -> list [sympy .Expr ]:
@@ -4034,11 +4034,7 @@ def create_symbolic_sizes_strides_storage_offset(
40344034 source : Source ,
40354035 * ,
40364036 symbolic_context : Optional [SymbolicContext ] = None ,
4037- ) -> tuple [
4038- tuple [Union [int , SymInt ], ...],
4039- tuple [Union [int , SymInt ], ...],
4040- Union [int , SymInt ],
4041- ]:
4037+ ) -> tuple [tuple [IntLikeType , ...], tuple [IntLikeType , ...], IntLikeType ,]:
40424038 """
40434039 Returns a list of symbolic sizes and strides for the given tensor.
40444040 We try our best to express stride in terms of the sizes, so as to not
@@ -4099,8 +4095,8 @@ def create_symbolic_sizes_strides_storage_offset(
40994095 # If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
41004096 # we may have an unnessary shape speciliazation for y.
41014097 def _maybe_specialize_sym_int_with_hint (
4102- self , maybe_sym : Union [ int , SymInt ]
4103- ) -> Union [ int , SymInt ] :
4098+ self , maybe_sym : IntLikeType
4099+ ) -> IntLikeType :
41044100 assert isinstance (maybe_sym , (int , torch .SymInt ))
41054101 if is_symbolic (maybe_sym ):
41064102 assert (
@@ -4114,18 +4110,14 @@ def _create_symbolic_sizes_strides_storage_offset(
41144110 self ,
41154111 # NB: SymInt is allowed here due to nested int, normally you don't
41164112 # actually pass true symbolic sizes to this function
4117- ex_size : Sequence [Union [ int , SymInt ] ],
4118- ex_stride : Sequence [Union [ int , SymInt ] ],
4119- ex_storage_offset : Union [ int , SymInt ] ,
4113+ ex_size : Sequence [IntLikeType ],
4114+ ex_stride : Sequence [IntLikeType ],
4115+ ex_storage_offset : IntLikeType ,
41204116 is_dim_dynamic : Sequence [bool ],
41214117 source : Source ,
41224118 * ,
41234119 symbolic_context : Optional [SymbolicContext ] = None ,
4124- ) -> tuple [
4125- tuple [Union [int , SymInt ], ...],
4126- tuple [Union [int , SymInt ], ...],
4127- Union [int , SymInt ],
4128- ]:
4120+ ) -> tuple [tuple [IntLikeType , ...], tuple [IntLikeType , ...], IntLikeType ,]:
41294121 dim = len (ex_size )
41304122
41314123 # Reimplement the legacy behavior
@@ -4232,8 +4224,8 @@ def _compute_symbolic_stride(
42324224 self ,
42334225 source : Source ,
42344226 size : Sequence [sympy .Expr ],
4235- ex_size : Sequence [Union [ int , SymInt ] ],
4236- ex_stride : Sequence [Union [ int , SymInt ] ],
4227+ ex_size : Sequence [IntLikeType ],
4228+ ex_stride : Sequence [IntLikeType ],
42374229 dynamic_strides : Sequence [DimDynamic ],
42384230 constraint_strides : Sequence [
42394231 Optional [Union [StrictMinMaxConstraint , RelaxedUnspecConstraint ]]
@@ -4244,7 +4236,7 @@ def _compute_symbolic_stride(
42444236 from torch ._dynamo .source import TensorProperty , TensorPropertySource
42454237
42464238 stride : list [Optional [sympy .Expr ]] = [None ] * len (size )
4247- candidates : dict [Union [ int , SymInt ] , sympy .Expr ] = {}
4239+ candidates : dict [IntLikeType , sympy .Expr ] = {}
42484240
42494241 # iterate over unbound strides in val ascending order with
42504242 # index descending as a tie breaker since for cases like
@@ -4293,7 +4285,7 @@ def create_symintnode(
42934285 * ,
42944286 hint : Optional [int ],
42954287 source : Optional [Source ] = None ,
4296- ) -> Union [ int , SymInt ] :
4288+ ) -> IntLikeType :
42974289 """Create a SymInt value from a symbolic expression
42984290
42994291 If you know what the current hint value of the SymInt to be created
@@ -4314,7 +4306,7 @@ def create_symintnode(
43144306 else :
43154307 fx_node = None
43164308
4317- out : Union [ int , SymInt ]
4309+ out : IntLikeType
43184310 if isinstance (sym , sympy .Integer ):
43194311 if hint is not None :
43204312 assert int (sym ) == hint
@@ -4370,7 +4362,7 @@ def create_symfloatnode(
43704362 @record_shapeenv_event ()
43714363 def create_unspecified_symint_and_symbol (
43724364 self , value : int , source : Source , dynamic_dim : DimDynamic
4373- ) -> Union [ int , SymInt ] :
4365+ ) -> IntLikeType :
43744366 """Create a SymInt wrapping a new unspecified symbol"""
43754367 return self .create_symintnode (
43764368 self .create_unspecified_symbol (
@@ -5084,7 +5076,7 @@ def get_expression(tensor_dim_src: Source) -> sympy.Expr:
50845076 # tensors that never actually become graph arguments (they are
50855077 # pruned). In this case, only Dynamo knows about these arguments.
50865078 def track_symint (
5087- source : Source , val : Union [ SymInt , int ] , constraint : DimConstraint = None
5079+ source : Source , val : IntLikeType , constraint : DimConstraint = None
50885080 ) -> None :
50895081 log .debug ("track_symint %s %s %s" , LazyString (source .name ), val , constraint )
50905082 assert not isinstance (val , SymInt ) or is_symbolic (val )
0 commit comments