-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
symbolic_shapes.py
5468 lines (4812 loc) · 240 KB
/
symbolic_shapes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# mypy: ignore-errors
"""
``torch.fx.experimental.symbolic_shapes`` provides interfaces for interacting with
our symbolic shapes reasoning system that is used heavily in torch.compile. Although
this is not generally considered public API, when writing framework code in PyTorch
as well as extensions to PyTorch (e.g., in custom operator implementations), you may
need to make use of these APIs to setup dynamic shapes support appropriately.
"""
import builtins
import collections
import functools
import inspect
import itertools
import logging
import math
import operator
import re
import sys
import threading
import traceback
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
import atexit
from typing import (
Any,
cast,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
TYPE_CHECKING
)
from typing_extensions import TypeAlias
import torch
import torch.fx
import torch.fx.traceback as fx_traceback
from torch.fx.experimental import _config as config
from torch.fx.experimental.recording import (
FakeTensorMeta,
ShapeEnvEvent,
record_shapeenv_event,
replay_shape_env_events,
shape_env_check_state_equal
)
from torch.fx.experimental.sym_node import SymNode, SymTypes
from torch._logging import trace_structured, structured
# NB: The sym_* functions are used via getattr() and must be imported here.
from torch import SymBool, SymFloat, SymInt
from torch._guards import ShapeGuard, Source, TracingContext
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils._sympy.functions import (
FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt
)
from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._traceback import format_frame, CapturedTraceback
from torch._utils_internal import signpost_event
from torch._subclasses.meta_utils import is_sparse_any
import torch.utils._pytree as pytree
from torch.utils._sympy.symbol import SymT, make_symbol, symbol_is_type
from torch._logging import LazyString
if TYPE_CHECKING:
from torch._dynamo.source import TensorPropertySource
InputList = List
DimList = List
log = logging.getLogger(__name__)
class GuardOnDataDependentSymNode(RuntimeError):
pass
class PendingUnbackedSymbolNotFound(RuntimeError):
pass
import sympy
from sympy.printing.str import StrPrinter
from sympy.printing.precedence import precedence, PRECEDENCE
aten = torch._ops.ops.aten # type: ignore[has-type]
__all__ = [
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int",
"guard_int", "guard_float", "guard_scalar", "canonicalize_bool_expr",
"hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
"is_concrete_bool", "is_nested_int", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
"has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext",
"StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true",
"guard_size_oblivious", "check_consistent",
"compute_unbacked_bindings", "ConvertIntKey",
"rebind_unbacked", "resolve_unbacked_bindings",
]
# FX node metadata keys for symbolic shape FX graph.
SHAPEENV_EVENT_KEY = "shapeenv_event"
CURRENT_NODE_KEY = "current_node"
def log_lru_cache_stats(wrapped_f):
log.debug("lru_cache_stats %s: %s", wrapped_f.__name__, wrapped_f.cumulative_cache_info())
# Wrapper on lru_cache that reports statistics at process end
def lru_cache(maxsize):
def inner(f):
wrapped_f = functools.lru_cache(maxsize)(f)
old_cache_clear = wrapped_f.cache_clear
prev_hits = 0
prev_misses = 0
# TODO: There's a ref-cycle here (wrapped_f -> cumulative_cache_info
# -> wrapped_f) but cannot be solved with weakref as wrapped_f is not
# weakref'able on some versions of Python
def cumulative_cache_info():
cur = wrapped_f.cache_info()
return functools._CacheInfo(
prev_hits + cur.hits,
prev_misses + cur.misses,
cur.maxsize,
cur.currsize,
)
def new_cache_clear():
nonlocal prev_hits, prev_misses
cur = wrapped_f.cache_info()
prev_hits += cur.hits
prev_misses += cur.misses
old_cache_clear()
wrapped_f.cache_clear = new_cache_clear
wrapped_f.cumulative_cache_info = cumulative_cache_info
if log.isEnabledFor(logging.DEBUG):
atexit.register(log_lru_cache_stats, wrapped_f)
return wrapped_f
return inner
# These are modules that contain generic code for interacting with ShapeEnv
# which are unlikely to identify a particular interesting guard statement
@lru_cache(None)
def uninteresting_files() -> Set[str]:
import torch._inductor.sizevars
import torch._library.abstract_impl
import torch._subclasses.meta_utils
import torch._subclasses.fake_tensor
mods = [
sys.modules[__name__],
torch.fx.experimental.recording,
torch.fx.experimental.sym_node,
torch.fx.interpreter,
torch,
torch._inductor.sizevars,
torch._library.abstract_impl,
torch._subclasses.meta_utils,
torch._subclasses.fake_tensor,
]
return {inspect.getfile(m) for m in mods}
# We don't bother with the metaclass as all of the dispatching logic happens
# entirely from Python
#
# Didn't bother with ancestors for now, unlikely to have multiple modes for
# symints right now
class ConstraintViolationError(RuntimeError):
pass
def has_symbolic_sizes_strides(elem) -> bool:
return elem._has_symbolic_sizes_strides
Int = Union[torch.SymInt, int]
def create_contiguous(shape: Sequence[Int]) -> List[Int]:
strides: List[Int] = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int:
"""
Retrieve the hint for an int (based on the underlying real values as observed
at runtime). If no hint is available (e.g., because data dependent shapes),
if fallback is not None, use that instead (otherwise raise an error).
"""
if isinstance(a, torch.SymInt):
return a.node.require_hint(fallback)
assert type(a) is int, a
return a
Scalar = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool]
def has_hint(a: Scalar) -> bool:
if isinstance(a, SymTypes):
return a.node.has_hint()
return True
def is_concrete_int(a: Union[int, SymInt]) -> bool:
r""" Utility to check if underlying object
in SymInt is concrete value. Also returns
true if integer is passed in.
Args:
a (SymInt or int): Object to test if it int
"""
assert isinstance(a, (SymInt, int))
if isinstance(a, int):
return True
if isinstance(a.node.expr, sympy.core.numbers.Integer):
return True
return False
# In obscure Meta only situations, sympy.logic.boolalg doesn't exist at runtime.
# So make sure only type checker evaluates this alias.
# Xref: https://www.internalfb.com/diff/D53324783
SympyBoolean: TypeAlias = "sympy.logic.boolalg.Boolean"
def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool:
"""
Perform a guard on a symbolic boolean expression in a size oblivious way.
This is typically used when a non-oblivious test would result in a guard
on a data dependent value of which we don't know the value of at compile time.
When a guard is tested this way, we may diverge in behavior from how regular
PyTorch semantics would treat it. For more information, see
https://github.com/pytorch/pytorch/pull/118579
"""
if isinstance(expr, torch.SymBool):
return expr.node.guard_size_oblivious("", 0)
else:
assert isinstance(expr, bool)
return expr
def check_consistent(new, old) -> None:
"""
Test that two "meta" values (typically either Tensor or SymInt) have
the same values, e.g., after retracing. If we don't understand the
quantities in question, we'll just skip the consistency check.
"""
# TODO: do boolean equality test too, see
# https://github.com/pytorch/pytorch/issues/124110
scalar_types = (torch.SymInt, torch.SymFloat, int, float)
if isinstance(new, torch.Tensor):
assert isinstance(old, torch.Tensor)
torch._check(old.dim() == new.dim(), lambda: f"{old.shape} != {new.shape} (old != new)")
# Do this manually so that each individual test is irrefutable
# (TODO: should be a helper for this, maybe sym_eq? That
# gives us a compound expression and I'm not sure it
# simplifies right now)
for i, j in zip(old.shape, new.shape):
torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)")
# NB: bool is subclass of int
elif isinstance(new, scalar_types) and not isinstance(new, bool):
assert isinstance(old, scalar_types) and not isinstance(old, bool), f"{old} != {new}"
torch._check(old == new, lambda: f"{old} != {new} (old != new)")
def resolve_unbacked_bindings(shape_env, bindings):
if bindings is None:
return None
return {
shape_env.unbacked_renamings.get(k, k): v
for k, v in bindings.items()
}
def rebind_unbacked(shape_env, n: torch.fx.Node, result):
"""
Suppose we are retracing a pre-existing FX graph that previously had
fake tensor propagation (and therefore unbacked SymInts). When we retrace,
we re-propagate fake tensors, which results in new unbacked SymInts.
When this happens, we need to tell the shape environment about the equivalence
of the old and new unbacked SymInts. Pass us the old torch.fx.Node (which
has the old binding information) and the new result (which we can extract the
new unbacked SymInts out from).
"""
from torch._dynamo.tensor_version_op import _tensor_version
# Inputs never need rebinding
if n.op == "placeholder":
return
if bindings := resolve_unbacked_bindings(shape_env, n.meta.get("unbacked_bindings")):
for raw_u0, path in bindings.items():
u1 = pytree.key_get(result, path)
# tensor_version ops get specialized after AOTAutograd, it's OK,
# we don't actually want to do asserts on them. This is all a bit
# questionable though
if isinstance(u1, int) and n.target is _tensor_version:
log.info("rebind_unbacked: discard _tensor_version %s %s -> %s", raw_u0, path, u1)
continue
raw_u1 = u1.node.expr
# Simplify SymBool binding
if (
isinstance(raw_u1, sympy.Piecewise) and
len(raw_u1.args) == 2 and
raw_u1.args[0][0] == 1 and
isinstance(eq := raw_u1.args[0][1], sympy.Eq) and
isinstance(new_raw_u1 := eq.lhs, sympy.Symbol) and
shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1)) and
eq.rhs == 1 and
raw_u1.args[1] == (0, True)
):
# This is what the pattern match above is testing
repacked = _sympy_cast_symbool_to_symint_guardless(sympy.Eq(new_raw_u1, 1))
assert repacked == raw_u1, f"{repacked} != {raw_u1}"
# Cancel the to_int(to_bool(x)). This is sound because x in
# [0, 1]
raw_u1 = new_raw_u1
assert isinstance(raw_u1, sympy.Symbol)
# The old and new could be the same if you improperly hit the memo
# while retracing. Make sure you updated FakeTensorMode.epoch
assert raw_u0 != raw_u1, f"{raw_u0} possible memo disaster"
# Reuse the OLD symbol name
shape_env._rename_unbacked_to(raw_u1, raw_u0)
def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean:
r""" Canonicalize a boolean expression by transforming it into a lt / le
inequality and moving all the non-constant terms to the rhs.
We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr
recursively
nb. sympy.Rel.canonical is not good enough https://github.com/sympy/sympy/issues/25924
Args:
expr (sympy.Expr): Expression to canonicalize
"""
# Canonicalise an inequality by transforming it into a lt / le
# inequality and moving all the non-constant terms to the rhs
# We canonicalise And / Ors / Not via cnf
# nb. Relational.canonical in sympy is broken
# https://github.com/sympy/sympy/issues/25924
if not isinstance(expr, (sympy.Rel, sympy.And, sympy.Or, sympy.Not, sympy.Eq, sympy.Ne)):
return expr
if isinstance(expr, (sympy.And, sympy.Or, sympy.Not)):
expr = sympy.logic.boolalg.to_cnf(expr)
return _canonicalize_bool_expr_impl(expr)
def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean:
"""
After canonicalization, we are guaranteed to have eliminated Ge/Gt relations
(rewriting them to Le/Lt, respectively).
"""
if isinstance(expr, (sympy.And, sympy.Or)):
return type(expr)(*map(canonicalize_bool_expr, expr.args))
opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le}
if isinstance(expr, tuple(opposite.keys())):
rhs = expr.lhs - expr.rhs
t = opposite[type(expr)]
else:
assert isinstance(expr, (sympy.Lt, sympy.Le, sympy.Eq, sympy.Ne))
rhs = expr.rhs - expr.lhs
t = type(expr)
def is_neg(t):
return t.is_negative or (isinstance(t, sympy.Mul) and t.args[0].is_negative)
lhs = 0
if isinstance(rhs, sympy.Add):
pos = []
neg = []
for term in rhs.args:
if is_neg(term):
neg.append(-term)
else:
pos.append(term)
lhs = sympy.Add(*neg)
rhs = sympy.Add(*pos)
elif is_neg(rhs):
# lhs == 0
lhs, rhs = -rhs, 0
return t(lhs, rhs)
def is_concrete_bool(a: Union[bool, SymBool]) -> bool:
r""" Utility to check if underlying object
in SymBool is concrete value. Also returns
true if integer is passed in.
Args:
a (SymBool or bool): Object to test if it bool
"""
assert isinstance(a, (SymBool, bool))
if isinstance(a, bool):
return True
if isinstance(a.node.expr, (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse)):
return True
return False
def is_nested_int(s):
return isinstance(s, torch.SymInt) and s.node.is_nested_int()
def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
if isinstance(val, SymTypes):
# This allow applies to the jagged layout NestedTensor case as
# nested ints are not symbolic
if is_symbolic(val):
yield val.node.expr
elif isinstance(val, sympy.Basic):
yield val
elif isinstance(val, (int, float, bool)):
pass
elif isinstance(val, (tuple, list)):
for s in val:
yield from _iterate_exprs(s)
elif is_sparse_any(val):
yield from _iterate_exprs(val.size())
elif isinstance(val, torch.Tensor):
yield from _iterate_exprs(val.size())
yield from _iterate_exprs(val.stride())
yield from _iterate_exprs(val.storage_offset())
elif val is None:
pass
else:
raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
def free_symbols(val: Union[SymInt, sympy.Expr, torch.Tensor]) -> Set[sympy.Symbol]:
if val is None:
return set()
itr = _iterate_exprs(val)
# we need at least 1 to call union, so we hand code the identity
try:
first_expr = next(itr)
except StopIteration:
return set()
return first_expr.free_symbols.union(*(e.free_symbols for e in itr))
def has_free_symbols(val: Union[SymInt, torch.Tensor]) -> bool:
"""Faster version of bool(free_symbols(val))"""
return not all(e.is_number for e in _iterate_exprs(val))
# Like free_symbols, but filtered to only report unbacked symbols
def free_unbacked_symbols(x):
# NB: keep synced with is_unbacked_symint
return {s for s in free_symbols(x) if symbol_is_type(s, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT))}
# WARNING: Don't use this on Dynamo produced graphs, they don't have meta
# setup!
def is_symbol_binding_fx_node(node) -> Optional[sympy.Symbol]:
if (
"val" in node.meta and
isinstance(node.meta["val"], torch.SymInt) and
isinstance(node.meta["val"].node.expr, sympy.Symbol) and
(node.op == "placeholder" or free_unbacked_symbols(node.meta["val"].node.expr))
):
return node.meta["val"].node.expr
return None
def find_symbol_binding_fx_nodes(graph):
r = {}
# NB: Prefer first occurrence of symbol
for node in graph.nodes:
if is_symbol_binding_fx_node(node) and node.meta["val"].node.expr not in r:
r[node.meta["val"].node.expr] = node
return r
# Analogous to ConvertIntSource
@dataclass(frozen=True)
class ConvertIntKey:
def __str__(self) -> str:
return ".cast_symbool_to_symint_guardless()"
def get(self, b: bool) -> int:
"""Get the int value from bool"""
return cast_symbool_to_symint_guardless(b)
@dataclass(frozen=True)
class CallMethodKey:
name: str
def __str__(self) -> str:
return f".{self.name}()"
def get(self, o: Any) -> Any:
"""Call the method on object"""
return getattr(o, self.name)()
@dataclass(frozen=True)
class InnerTensorKey:
inner_name: str
def __str__(self) -> str:
return f".{self.inner_name}"
def get(self, o: Any) -> Any:
"""Get the inner tensor attribute"""
return getattr(o, self.inner_name)
@dataclass(frozen=True)
class DivideByKey:
divisor: int
def __str__(self) -> str:
return f".__floordiv__({self.divisor})"
def get(self, o: int) -> int:
"""Divide object by divisor"""
return o // self.divisor
def compute_unbacked_bindings(shape_env, example_value, old_example_value=None, peek=False):
"""
After having run fake tensor propagation and producing example_value
result, traverse example_value looking for freshly bound unbacked
symbols and record their paths for later. It is an error if
we have allocated an unbacked SymInt but it cannot be found in
example_value. (NB: this means if you have a multi-output
function, you must call this on the tuple of tensor output, you
cannot wait!)
The peek parameter lets you check out what the bindings are without
changing the affected list. This is primarily useful for ensuring
unbacked_var_to_val is promptly populated when propagate_real_tensors is on.
"""
if shape_env is None:
return
if shape_env._ignore_fresh_unbacked_symbols_tls():
return
fs = shape_env.pending_fresh_unbacked_symbols
pending = set(fs)
if pending:
if not peek:
log.info("compute_unbacked_bindings %s", fs)
fs.clear()
def free_unbacked_symbols_with_path(
a, path, real=None
) -> Dict[sympy.Symbol, pytree.KeyPath]:
r = {}
if isinstance(a, (tuple, list)):
for i in range(len(a)):
r.update(
free_unbacked_symbols_with_path(
a[i], path + (pytree.SequenceKey(i),),
real=real[i] if real is not None else None
)
)
elif is_traceable_wrapper_subclass(a):
# TODO: Determine if this is correct
attrs, _ = a.__tensor_flatten__()
for attr in attrs:
sub = getattr(a, attr)
r.update(
free_unbacked_symbols_with_path(sub, path + (InnerTensorKey(attr),))
)
elif isinstance(a, torch.Tensor):
r.update(
free_unbacked_symbols_with_path(
a.size(), path + (CallMethodKey("size"),),
real=a.real_tensor.size() if a.real_tensor is not None else None
)
)
r.update(
free_unbacked_symbols_with_path(
a.stride(), path + (CallMethodKey("stride"),),
real=a.real_tensor.stride() if a.real_tensor is not None else None
)
)
r.update(
free_unbacked_symbols_with_path(
a.storage_offset(), path + (CallMethodKey("storage_offset"),),
real=a.real_tensor.storage_offset() if a.real_tensor is not None else None
)
)
# NB: Intentionally access _expr, not expr, do not want
# simplification!
elif (
isinstance(a, (torch.SymInt, torch.SymFloat))
and isinstance(s := a.node._expr, sympy.Symbol)
and s in pending
):
r[s] = path
if real is not None:
shape_env.set_unbacked_var_to_val(s, real)
pending.remove(s)
# When an unbacked SymInt is perfectly divisible by an integer
# constant, we replace it with the integer constant to improve
# reasoning capabilities. However, in synthetic examples, it is
# then possible that the factor never is explicitly allocated.
# Fortunately, we can compute it by division.
elif (
isinstance(a, torch.SymInt)
and isinstance(s := a.node._expr, sympy.Mul)
and len(s.args) == 2
and isinstance(lhs := s.args[0], sympy.Integer)
and isinstance(rhs := s.args[1], sympy.Symbol)
and rhs in pending
):
# TODO: DivideByKey needs to test divisibility at runtime!
r[s] = path + (DivideByKey(int(lhs)),)
if real is not None:
shape_env.set_unbacked_var_to_val(s, real // int(lhs))
pending.remove(rhs)
# The annoyance here arises from the fact that SymBool is
# allocated by allocating a SymInt and then testing if it's equal
# to one. So you have a complicated binding site logic for this.
elif (
isinstance(a, torch.SymBool)
and isinstance(s := a.node._expr, sympy.Eq)
# This must match create_unbacked_symbool EXACTLY
and isinstance(s.lhs, sympy.Symbol)
and s.rhs == 1
and s.lhs in pending
):
r[s.lhs] = path + (ConvertIntKey(),)
if real is not None:
shape_env.set_unbacked_var_to_val(s, int(real))
pending.remove(s.lhs)
return r
symbol_to_path = free_unbacked_symbols_with_path(example_value, ())
if not peek and pending:
extra = (
repr((example_value.stride(), example_value.storage_offset()))
if isinstance(example_value, torch.Tensor)
else ""
)
raise PendingUnbackedSymbolNotFound(
f"Pending unbacked symbols {pending} not in returned outputs {example_value} {extra}.\n"
"Did you accidentally call new_dynamic_size() or item() more times "
"than you needed to in your fake implementation?\n"
"For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit"
)
# Why do we have to do some rebinding here? If the original FX node
# wasn't a binding site because you had a memo hit, but post
# translation you aren't a memo hit anymore, there's now a new binding
# site... but we know (because it's the same FX node) that the value
# is actually the same, they're just not obviously equal anymore.
#
# The logic here is written carefully, because unlike the
# bind_unbacked case, we are not guaranteed to have a symbol for
# old_sym. If we have a symbol, do regular rename unbacked to; but if
# we don't, we need to specially eliminate the fresh unbacked symbol
# (NB: we are /trusting/ that the memoization is correct, and that we
# don't need to generate a new runtime assert. This is load bearing,
# as repropagation can happen after we've frozen runtime asserts.)
if old_example_value is not None:
for keypath in symbol_to_path.values():
old_sym = pytree.key_get(old_example_value, keypath)
new_sym = pytree.key_get(example_value, keypath)
if (
isinstance(new_sym, SymTypes) and
isinstance(new_s := new_sym.node.expr, sympy.Symbol)
):
if isinstance(old_sym, SymTypes) and (old_s := old_sym.node.expr) != new_s:
if isinstance(old_s, sympy.Symbol):
shape_env._rename_unbacked_to(new_s, old_s)
else:
shape_env._eliminate_unbacked(new_s, old_s)
elif not isinstance(old_sym, SymTypes):
shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym))
return symbol_to_path
def definitely_true(a):
"""
Returns True only if we can tell that a is True, possibly introducing
a guard in the process. If a depends on some unbacked SymInt, we may
return False even though there may exist a possible value of the SymInt
that would cause the expression to return True.
When is it appropriate to use definitely_true? First, if you can use
a higher level combinator like parallel_or/parallel_and, prefer using
those instead, they are definitely safe (modulo short-circuiting).
Second, it can be used if the program would behave equivalently if
definitely_true always returned False (parallel_or/parallel_and are
examples of this pattern, modulo short-circuiting). Finally, it even
be OK if the program wouldn't behave equivalently, so long as the
change is semantics preserving. It can be semantics preserving if
the program errors in more cases than it did previously (but otherwise
behaves identically), or if it changes some quantity in a way that
doesn't matter (e.g., strides often fall in this bucket.)
"""
if isinstance(a, SymBool):
if a.node.has_hint():
return guard_bool(a)
else:
return False
return bool(a)
def definitely_false(a):
"""
Returns True only if we can tell that a is False, possibly introducing
a guard in the process. If a depends on some unbacked SymInt, we may
return False even though there may exist a possible value of the SymInt
that would cause the expression a to be False. See definitely_true
for more usage guidance.
"""
if isinstance(a, SymBool):
if a.node.has_hint():
return not guard_bool(a)
else:
return False
return not bool(a)
def statically_known_true(x: Union[bool, SymBool]) -> bool:
"""Returns True if x can be simplified to a constant and is true.
.. note::
This function doesn't introduce new guards, so the expression may end
up evaluating to true at runtime even if this function returns False.
Args:
x (bool, SymBool): The expression to try statically evaluating
"""
if isinstance(x, SymBool):
expr = x.node.expr
shape_env = x.node.shape_env
try:
simplified = shape_env._maybe_evaluate_static(expr)
if simplified is not None:
return bool(simplified)
except Exception:
log.debug("Could not simplify %s", expr)
return False
assert isinstance(x, bool)
return x
def parallel_or(*args):
"""
Evaluate the logical OR of several arguments, avoiding guarding on
unbacked SymInts if another argument is definitely True.
"""
if any(statically_known_true(a) for a in args):
return True
if any(definitely_true(a) for a in args):
return True
return any(args)
def parallel_and(*args):
"""
Evaluate the logical FALSE of several arguments, avoiding guarding on
unbacked SymInts if another argument is definitely False.
"""
if any(statically_known_true(torch.sym_not(a)) for a in args):
return False
if any(definitely_false(a) for a in args):
return False
return all(args)
def sym_eq(x, y):
"""
Like ==, but when run on list/tuple, it will recursively test equality
and use sym_and to join the results together, without guarding.
"""
if (isinstance(x, tuple) and isinstance(y, tuple)) or (isinstance(x, list) and isinstance(y, list)):
if len(x) != len(y):
return False
return functools.reduce(operator.and_, map(sym_eq, x, y), True)
elif isinstance(x, (int, torch.SymInt)) and isinstance(y, (int, torch.SymInt)):
return x == y
else:
raise AssertionError(f"unexpected sym_eq between {type(x)} {type(y)}")
def guard_scalar(a):
if isinstance(a, (SymBool, bool)):
return guard_bool(a)
elif isinstance(a, (SymInt, int)):
return guard_int(a)
elif isinstance(a, (SymFloat, float)):
return guard_float(a)
else:
raise AssertionError(f"unrecognized scalar {a}")
def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int):
shape_env.constrain_symbol_range(s, compiler_min, compiler_max)
def _advise_is_size(a):
"""
Don't use this directly; use torch._check_is_size instead.
This is a softer version of _constrain_range_for_size (with min=0,
max=Inf). Instead of forcibly constraining a variable (and erroring if we
failed to constrain it), it will simply advise us that a size is
constrained in some way. We will always defer a runtime assert for this
constraint if we cannot prove it at compile-time, but we we only
*sometimes* learn useful extra information at compile-time with this
information. This is in contrast to constrain_range_for_size, where if
you don't call that on a fresh unbacked symint, chances are we will choke.
TODO: Make Dynamo handle this appropriately if this is seen in Dynamo-ed
code. Right now this is only really used in code with AOTAutograd trace
through, so it is not a big problem that this isn't supported, but in
principle all of this code should be Dynamo'able too.
TODO: I didn't support min/max because I didn't have a use case where this
actually helped. In principle we can support it, it just makes the
implementation below more complicated.
"""
# This must always succeed, because the sole allowed caller _check_is_size
# was responsible for expect_true'ing this
# This assert triggers expensive sym compute, do not do it until its cheap.
# assert a >= 0
# NB: it's important not to constrain range for size for *hinted* SymInts,
# because it is not only unsound, it will immediately trip our asserts
# that hints have to be consistent with static analysis! If you somehow
# have an unbounded SymInt that later constrains to 1, this will be
# inconsistent with the range
if (
isinstance(a, SymInt)
and isinstance(a.node, SymNode)
and isinstance(a.node.expr, sympy.Symbol)
and a.node.shape_env.is_unbacked_symint(a.node.expr)
):
_constrain_range_for_size(a)
def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = None):
"""
This function is NOT INTENDED to be used by itself.
"""
if isinstance(a, (SymFloat, SymBool)):
raise ValueError("Constraining SymFloat/SymBool is nyi")
assert isinstance(a, SymInt), "can only constrain range for SymInt"
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
a.node.shape_env._constrain_range_for_size(a.node.expr, min, max)
# inclusive both ways
def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
"""
Applies a constraint that the passed in SymInt must lie between min-max
inclusive-inclusive, WITHOUT introducing a guard on the SymInt (meaning
that it can be used on unbacked SymInts). If min/max are None, we assume
that the dimension is unbounded in that direction. Repeated application
of constrain_range intersects the ranges. This is a fairly low level API
that doesn't have a lot of safety guarantees (TODO: provide higher level
APIs).
Currently, we use this API in the following circumstance: when we allocate
an unbacked SymInt, denoting an integer quantity which is data dependent,
we ordinarily do not know anything about what values it may take. This
means that any sort of guard on it will immediately fail. However, in
many cases, we know something about the unbacked SymInt: for example, we
know that nonzero(x).size(0) must be >= 0. We use constrain_range to
narrow the possible range, declaring that negative symbols are impossible.
This permits to definitely answer True to queries like 'nnz >= 0', even if
we don't know what the actual (hinted) value of 'nnz' is. In fact, we
actually use constrain_range to unsoundly discharge common guards: for an
unbacked SymInt produced by nonzero, we will also assume that it is not
equal to 0/1 (even though these are perfectly possible values at runtime),
because we generally expect graphs that are valid for N=2 to also be valid
for N=1.
"""
if min is None:
min = -int_oo
if max is None:
max = int_oo
if max < min:
raise ValueError(
"Maximum value to constrain_as_size can't be less than the specified min value, "
"received min={min} and max={max}"
)
if isinstance(a, int):
if not (min <= a <= max):
raise ValueError(f"Invalid value {a} for range [{min}:{max}]")
return
a.node.shape_env._constrain_range(a.node.expr, min, max)
def constrain_unify(a: torch.SymInt, b: torch.SymInt) -> None:
"""
Given two SymInts, constrain them so that they must be equal. NB:
this will not work with SymInts that represent nontrivial expressions
(yet!)
"""
if not isinstance(a, SymInt):
if not isinstance(b, SymInt):
assert a == b
return
else:
shape_env = b.node.shape_env
else:
shape_env = a.node.shape_env
shape_env._constrain_unify(a, b)
# Assume that a boolean is true for the purposes of subsequent symbolic
# reasoning. This will keep track of corresponding runtime checks to verify
# that the result is upheld: either as a regular guard, or as a special set
# of asserts which are triggered when an unbacked SymInt is allocated.
#
# DO NOT use this function for these cases:
#
# - This is inappropriate for "branching" conditions (where both
# true and false result in valid programs). We will always assume
# the condition evaluates true, and so it will never be possible
# to trace the false condition when you use it. For true branching
# on unbacked SymInts, you must use torch.cond; if you incorrectly
# use expect_true in this case, you will make the false branch
# unreachable (as we will simply assume that only the true branch
# is ever exercised).
#
# - This is inappropriate for situations where you know some other system
# invariant guarantees that this property holds, since you don't
# really need to insert a runtime check in that case. Use something
# like constrain_range in that case.
#
# This API has a hitch. To avoid having to reimplement error reporting
# capabilities, this function CAN return False. The invariant is that
# the surrounding code must raise an error when this function returns
# False. This is quite low level, so we recommend using other functions
# like check() which enforce this in a more intuitive way.
#
# By the way, this name is a nod to the __builtin_expect macro,
# which is used similarly (but unlike __builtin_expect, you MUST fail
# in the unlikely branch.) (I think expect is a good name; in recent
# versions of C++, this is replaced with [[likely]], which is weaker
# and not accurate for this function!)
def expect_true(a, skip: int = 0):
if isinstance(a, SymBool):
# TODO: check perf implications of this
frame = inspect.currentframe()
for _ in range(skip + 1): # always run this loop at least once
frame = frame.f_back
return a.node.expect_true(frame.f_code.co_filename, frame.f_lineno)
assert type(a) is bool, a
return a
def guard_bool(a):
if isinstance(a, SymBool):
return a.node.guard_bool("", 0) # NB: uses Python backtrace
assert type(a) is bool, a
return a
def guard_int(a):
if isinstance(a, SymInt):
return a.node.guard_int("", 0) # NB: uses Python backtrace
assert type(a) is int, a
return a
def guard_float(a):
if isinstance(a, SymFloat):
return a.node.guard_float("", 0) # NB: uses Python backtrace
assert isinstance(a, float), a
return a
# Given a GraphModule, return all the FakeTensors for all the placeholders
def fx_placeholder_vals(gm):
return [n.meta['val'] for n in gm.graph.nodes if n.op == "placeholder"]
def fx_placeholder_targets(gm):
return [n.target for n in gm.graph.nodes if n.op == "placeholder"]
# Given a GraphModule and arguments to run it with, evaluate that the guards
# for its associated ShapeEnv are satisfied by the passed arguments. This
# WILL check for duck sizing.
def eval_guards(gm, *args, ignore_static=True):
return gm.shape_env.evaluate_guards_for_args(fx_placeholder_vals(gm), args, ignore_static=ignore_static)
def bind_symbols(gm, *args):
return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args)
class DimDynamic(Enum):
"""
Controls how to perform symbol allocation for a dimension. It is always
sound to default this to DYNAMIC, but the policies DUCK and STATIC can
result in better trace-time and compile-time performance, as they reduce
the number of allocated symbols and generally make your graph more static.
NB: If we notice you've applied a constraint to the dimension, we will