-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
symbolic_convert.py
2547 lines (2234 loc) · 91.5 KB
/
symbolic_convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import collections
import contextlib
import copy
import dataclasses
import dis
import functools
import importlib
import inspect
import itertools
import linecache
import logging
import operator
import sys
import textwrap
import threading
import traceback
import types
import typing
import weakref
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type
from unittest.mock import patch
import torch
import torch._logging
from torch._guards import Checkpointable, tracing, TracingContext
from . import config, exc, logging as torchdynamo_logging, skipfiles, variables
from .allowed_functions import is_builtin_constant, is_forbidden
from .bytecode_analysis import (
get_indexof,
JUMP_OPNAMES,
livevars_analysis,
propagate_line_nums,
)
from .bytecode_transformation import (
cleaned_instructions,
create_call_function,
create_instruction,
create_jump_absolute,
Instruction,
is_generator,
unique_id,
)
from .code_context import code_context
from .codegen import PyCodegen
from .current_scope_id import current_scope_id
from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported
from .funcname_cache import get_funcname
from .guards import GuardBuilder, install_guard
from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState
from .replay_record import DummyModule, ExecutionRecorder
from .resume_execution import ContinueExecutionCache, ReenterWith
from .source import (
AttrSource,
GetItemSource,
GlobalSource,
GlobalWeakRefSource,
LocalSource,
Source,
)
from .utils import (
counters,
get_fake_value,
get_instruction_source_311,
graph_break_dup_warning_checker,
istype,
LazyString,
proxy_args_kwargs,
)
from .variables.base import (
_is_top_level_scope,
is_side_effect_safe,
MutableLocal,
typestr,
VariableTracker,
)
from .variables.builder import VariableBuilder, wrap_fx_proxy
from .variables.builtin import BuiltinVariable
from .variables.constant import ConstantVariable, EnumVariable
from .variables.ctx_manager import (
ContextWrappingVariable,
GenericContextWrappingVariable,
WithExitFunctionVariable,
)
from .variables.dicts import ConstDictVariable, SetVariable
from .variables.functions import (
BaseUserFunctionVariable,
NestedUserFunctionVariable,
UserFunctionVariable,
UserMethodVariable,
)
from .variables.lists import (
BaseListVariable,
ListIteratorVariable,
ListVariable,
SliceVariable,
TupleVariable,
)
from .variables.misc import (
ClosureVariable,
GetAttrVariable,
InlinedClosureVariable,
NullVariable,
PythonModuleVariable,
UnknownVariable,
)
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
supported_const_comparison_ops,
supported_tensor_comparison_ops,
SymNodeVariable,
TensorVariable,
)
from .variables.user_defined import (
RemovableHandleVariable,
UserDefinedClassVariable,
UserDefinedObjectVariable,
UserDefinedVariable,
)
log = logging.getLogger(__name__)
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source")
tls = threading.local()
@dataclasses.dataclass
class SpeculationEntry:
filename: str
lineno: int
instruction_pointer: int
failed: bool = False
reason: Optional[GraphCompileReason] = None
def fail_and_restart_analysis(self):
"""
Start tracing of the current frame over again, and don't take this branch.
"""
self.failed = True
raise exc.SpeculationRestartAnalysis()
@dataclasses.dataclass
class SpeculationLog:
"""
SpeculationLog replaces the prior copy_graphstate/restore_graphstate
checkpointing. Rather than saving/restoring state, we restart the
dynamo conversion process over from the beginning -- but when we
hit the start of the speculation that failed, we instead generate
a graph break.
"""
entries: List[SpeculationEntry] = dataclasses.field(default_factory=list)
index: int = 0
def restart(self):
self.index = 0
def clear(self):
self.entries.clear()
self.index = 0
def next(self, filename: str, lineno: int, instruction_pointer) -> SpeculationEntry:
"""
Lookup or create a SpeculationEntry() that is shared across
RestartAnalysis calls. Args are used only for debug checks.
"""
if len(self.entries) == self.index:
self.entries.append(SpeculationEntry(filename, lineno, instruction_pointer))
entry = self.entries[self.index]
self.index += 1
assert (
entry.instruction_pointer == instruction_pointer
and entry.filename == filename
and entry.lineno == lineno
), textwrap.dedent(
f"""
SpecuationLog diverged at {self.index} of {len(self.entries)}:
- Run1: {entry.filename}:{entry.lineno} (ip={entry.instruction_pointer})
- Run2: {filename}:{lineno} (ip={instruction_pointer})
Please submit a bug report.
"""
)
return entry
@functools.lru_cache(None)
def _step_logger():
return torchdynamo_logging.get_step_logger(log)
@dataclasses.dataclass
class BlockStackEntry:
target: Instruction
stack_index: Optional[int] = None
with_context: Optional[ContextWrappingVariable] = None
def can_restore(self):
return self.with_context is not None
def resume_fn(self):
assert self.stack_index is not None
if self.with_context and self.with_context.target_values:
return ReenterWith(self.stack_index, tuple(self.with_context.target_values))
else:
return ReenterWith(self.stack_index)
def exit(self, tx):
assert self.with_context is not None
return self.with_context.exit(tx)
class InstructionTranslatorGraphState(NamedTuple):
output: OutputGraphState
symbolic_locals: Dict[str, VariableTracker]
stack: List[VariableTracker]
block_stack: List[BlockStackEntry]
instruction_pointer: Optional[int]
current_instruction: Instruction
next_instruction: Optional[Instruction]
lineno: int
def diff(self, other: "InstructionTranslatorGraphState") -> Optional[str]:
for k in self._fields:
if k == "output":
return self.output.diff(other.output, prefix=f"{k}.")
sv = getattr(self, k)
ov = getattr(other, k)
if sv != ov:
return f"{k} mismatch: {sv} != {ov}"
return None
def stack_op(fn: typing.Callable[..., object]):
nargs = len(inspect.signature(fn).parameters)
fn_var = BuiltinVariable(fn)
@functools.wraps(fn)
def impl(self: "InstructionTranslatorBase", inst: Instruction):
self.push(fn_var.call_function(self, self.popn(nargs), {}))
return impl
def _detect_and_normalize_assert_statement(
self: "InstructionTranslatorBase",
truth_fn: typing.Callable[[object], bool],
push: bool,
):
# Detect if this jump instruction is assert and normalize the assert
# by pushing dummy error message when nothing is given.
#
# Python 3.9 assertion is in following format:
# 18 POP_JUMP_IF_TRUE 28
# 20 LOAD_ASSERTION_ERROR
# 22 LOAD_CONST 3 ('Assert message') -> optional instruction
# 24 CALL_FUNCTION 1 -> optional instruction
# 26 RAISE_VARARGS
#
# Python 3.8 assertion is in following format:
# 18 POP_JUMP_IF_TRUE 28
# 20 LOAD_GLOBAL 0 (Assertion type)
# 22 LOAD_CONST 3 ('Assert message') -> optional instruction
# 24 CALL_FUNCTION 1 -> optional instruction
# 26 RAISE_VARARGS 1
if (truth_fn is not operator.truth) or push:
return False
assert isinstance(self.instruction_pointer, int)
current_instruction_pointer = self.instruction_pointer
inst = self.instructions[current_instruction_pointer]
# Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0
if sys.version_info < (3, 9):
if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError":
return False
else:
if inst.opname != "LOAD_ASSERTION_ERROR":
return False
current_instruction_pointer += 1
# Use dummy error message if its hard to extract
error_msg = "assertion error"
inst = self.instructions[current_instruction_pointer]
# DETECT RAISE_VARARGS or LOAD CONST
if inst.opname == "LOAD_CONST":
if not isinstance(inst.argval, str):
return False
error_msg = inst.argval
# if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION
# (PRECALL for Python 3.11+)
current_instruction_pointer += 1
inst = self.instructions[current_instruction_pointer]
if inst.opname not in ("CALL_FUNCTION", "PRECALL"):
return False
# for Python 3.11+, PRECALL should be followed by CALL, then RAISE_VARARGS
# for Python < 3.11, CALL_FUNCTION should be followed by RAISE_VARARGS
current_instruction_pointer += 1
if inst.opname == "PRECALL":
current_instruction_pointer += 1
inst = self.instructions[current_instruction_pointer]
if inst.opname != "RAISE_VARARGS":
return False
self.push(ConstantVariable.create(error_msg))
return True
def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
def inner(self: "InstructionTranslatorBase", inst: Instruction):
value: VariableTracker = self.pop()
if (
config.rewrite_assert_with_torch_assert
and _detect_and_normalize_assert_statement(self, truth_fn, push)
):
error_msg: VariableTracker = self.pop()
# Skip over things like `assert True`
if value.is_python_constant() and bool(value.as_python_constant()):
self.jump(inst)
return
# TODO maybe should respect DtoH sync intention of users later??
# Manually insert torch._assert_async instead of python assert and jump over
# assert related instructions as we don't need them anymore.
# if we see Tensor as assert statement, no need to call scalar_tensor
if isinstance(value, TensorVariable):
self.output.create_proxy(
"call_function",
torch._assert_async,
*proxy_args_kwargs((value, error_msg), {}),
)
self.jump(inst)
return
scalar_to_tensor_proxy = self.output.create_proxy(
"call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {})
)
scalar_to_tensor = wrap_fx_proxy(
self,
scalar_to_tensor_proxy,
example_value=get_fake_value(scalar_to_tensor_proxy.node, self),
)
self.output.create_proxy(
"call_function",
torch._assert_async,
*proxy_args_kwargs((scalar_to_tensor, error_msg), {}),
)
self.jump(inst)
return
if value.is_python_constant():
if truth_fn(value.as_python_constant()):
push and self.push(value)
self.jump(inst)
elif (
isinstance(value, (TensorVariable)) and self.should_compile_partial_graph()
):
# compile a partial subgraph prefix then jump into user code
if self.has_backedge():
msg = (
"Skipping frame because there is a graph break in a for/while loop\n"
f"{self.frame_summary()}"
)
log.info(msg)
raise exc.SkipFrame(msg)
self.push(value)
log.debug("generic_jump triggered compile")
self.output.compile_subgraph(
self,
reason=GraphCompileReason(
f"generic_jump {typestr(value)}", [self.frame_summary()]
),
)
self.pop()
if_next = self.create_call_resume_at(self.next_instruction)
push and self.push(value)
if_jump = self.create_call_resume_at(inst.target)
self.output.add_output_instructions(
[create_instruction(inst.opname, target=if_jump[0])] + if_next + if_jump
)
elif isinstance(value, NNModuleVariable):
# Equivalent of "self.nn_module is not None"
mod = self.output.get_submodule(value.module_key)
if truth_fn(mod):
push and self.push(value)
self.jump(inst)
elif isinstance(value, UserDefinedObjectVariable):
x = value.var_getattr(self, "__bool__")
# if __bool__ is missing, trying __len__ to infer a truth value.
if isinstance(x, GetAttrVariable):
x = value.var_getattr(self, "__len__")
# __bool__ or __len__ is function
if isinstance(x, UserMethodVariable):
result = x.call_function(self, [], {})
if isinstance(result, ConstantVariable) and isinstance(
result.value, (bool, int)
):
if truth_fn(result.value):
push and self.push(value)
self.jump(inst)
else:
unimplemented(
"generic_jump on UserDefined with __bool__ returning non-constant"
)
# __bool__ or __len__ is non-function or not existed in the user defined object
else:
if truth_fn(True):
push and self.push(value)
self.jump(inst)
elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence(
self
):
if truth_fn(len(value.unpack_var_sequence(self))):
push and self.push(value)
self.jump(inst)
elif isinstance(value, SymNodeVariable):
eval_result = value.evaluate_expr(self.output)
if truth_fn(eval_result):
push and self.push(value)
self.jump(inst)
else:
# TODO link the torch.cond doc later
raise exc.UserError(
exc.UserErrorType.DYNAMIC_CONTROL_FLOW,
"Dynamic control flow is not supported at the moment. Please use "
"functorch.experimental.control_flow.cond to explicitly capture the control flow.",
case_name="cond_operands",
)
return inner
explain = False
def break_graph_if_unsupported(*, push):
def decorator(inner_fn):
@functools.wraps(inner_fn)
def wrapper(self: "InstructionTranslatorBase", inst: Instruction):
speculation = self.speculate()
if speculation.failed:
assert speculation.reason is not None
return handle_graph_break(self, inst, speculation.reason)
try:
TracingContext.set_current_loc(
self.f_code.co_filename, self.lineno, self.f_code.co_name
)
return inner_fn(self, inst)
except Unsupported as excp:
if self.should_compile_partial_graph() and self.has_backedge():
msg = (
"Skipping frame because there is a graph break in a for/while loop\n"
f"{self.frame_summary()}"
)
log.info(msg)
raise exc.SkipFrame(msg) from excp
if self.generic_context_manager_depth > 0:
# We don't support graph break under GenericContextWrappingVariable,
# If there is, we roll back to the checkpoint and fall back.
excp.remove_from_stats()
unimplemented("Graph break under GenericContextWrappingVariable")
if isinstance(excp, exc.UncapturedHigherOrderOpError):
raise
if not self.should_compile_partial_graph():
raise
log.debug("break_graph_if_unsupported triggered compile", exc_info=True)
user_stack = excp.real_stack
# TODO: Also report the traceback from the parent frame
user_stack_formatted = "".join(traceback.format_list(user_stack))
frame_loc = (user_stack[-1].filename, user_stack[-1].lineno)
# torch._dynamo.explain() formats this a little nicer, and presents a slightly
# more actionable user code pointer
if (
graph_break_log.isEnabledFor(logging.DEBUG)
and not explain
and graph_break_dup_warning_checker.add(frame_loc)
):
graph_break_log.debug(
"Graph break: %s from user code at:\n%s",
excp,
user_stack_formatted,
)
excp.remove_from_stats()
excp.add_to_stats("graph_break")
speculation.reason = GraphCompileReason(excp.msg, user_stack)
speculation.fail_and_restart_analysis()
def handle_graph_break(
self: "InstructionTranslatorBase",
inst: Instruction,
reason: GraphCompileReason,
):
self.output.compile_subgraph(self, reason=reason)
cg = PyCodegen(self)
cleanup: List[Instruction] = []
# Reconstruct the context variables in the block stack
for b in self.block_stack:
assert b.with_context is not None
self.output.add_output_instructions(
[
*b.with_context.reconstruct(cg),
*b.resume_fn().try_except(cg.code_options, cleanup),
]
)
if sys.version_info >= (3, 11) and inst.opname == "CALL":
kw_names = (
self.kw_names.as_python_constant()
if self.kw_names is not None
else ()
)
if len(kw_names) > 0:
self.output.add_output_instructions(
[create_instruction("KW_NAMES", argval=kw_names)]
)
self.output.add_output_instructions(
create_call_function(inst.arg, False)
)
else:
# copy instruction, but without exception table data
assert inst.target is None
inst_copy = copy.copy(inst)
inst_copy.exn_tab_entry = None
self.output.add_output_instructions([inst_copy])
self.output.add_output_instructions(cleanup)
if sys.version_info >= (3, 11) and inst.opname == "CALL":
# stack effect for PRECALL + CALL is split between the two instructions
stack_effect = dis.stack_effect(
dis.opmap["PRECALL"], inst.arg
) + dis.stack_effect(dis.opmap["CALL"], inst.arg)
else:
stack_effect = dis.stack_effect(inst.opcode, inst.arg)
self.popn(push - stack_effect)
for _ in range(push):
self.push(UnknownVariable())
self.output.add_output_instructions(
self.create_call_resume_at(self.next_instruction)
)
return wrapper
return decorator
class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState]):
output: OutputGraph
symbolic_locals: Dict[str, VariableTracker]
symbolic_globals: Dict[str, VariableTracker]
stack: List[VariableTracker]
instruction_pointer: Optional[int]
current_instruction: Instruction
next_instruction: Optional[Instruction]
block_stack: List[BlockStackEntry]
lineno: int
kw_names: Optional[ConstantVariable]
accept_prefix_inst: bool
prefix_insts: List[Instruction]
inline_depth: int
inconsistent_side_effects: bool
current_speculation: Optional[SpeculationEntry]
random_calls: List[
Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
]
def mark_inconsistent_side_effects(self):
"""
InstructionTranslator has encountered instructions which may cause
dynamo to see a different version of history from eager
See: https://github.com/pytorch/pytorch/issues/110765
"""
self.inconsistent_side_effects = True
def has_backedge(self):
cur_offset = self.current_instruction.offset
assert self.instruction_pointer is not None
for inst in self.instructions[self.instruction_pointer :]:
if inst.opname in JUMP_OPNAMES:
jump_offset = inst.argval
if jump_offset < cur_offset:
return True
return False
def cell_and_freevars(self):
if not hasattr(self, "_cell_and_freevars"):
self._cell_and_freevars = tuple(
self.code_options["co_cellvars"] or []
) + tuple(self.code_options["co_freevars"] or [])
return self._cell_and_freevars
def prune_dead_locals(self):
reads = livevars_analysis(self.instructions, self.current_instruction)
# implicit use by super()
# reads = reads | {"__class__"}
# output variables?
reads = reads | set(self.cell_and_freevars())
self.symbolic_locals = {
k: v for k, v in self.symbolic_locals.items() if k in reads
}
self.output.side_effects.prune_dead_object_new(self)
def call_function(
self,
fn: VariableTracker,
args: List[VariableTracker],
kwargs: Dict[str, VariableTracker],
):
assert isinstance(fn, VariableTracker)
assert isinstance(args, list)
assert isinstance(kwargs, dict)
assert all(
isinstance(x, VariableTracker)
for x in itertools.chain(args, kwargs.values())
)
inner_fn = None
if hasattr(fn, "value"):
inner_fn = fn.value
if hasattr(fn, "fn"):
inner_fn = fn.fn
if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
self.push(fn.call_function(self, args, kwargs))
def inline_user_function_return(self, fn, args, kwargs):
"""
A call to some user defined function by inlining it.
"""
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
def get_line_of_code_header(self, lineno=None):
if lineno is None:
lineno = self.lineno
inline_depth_str = (
f" (inline depth: {self.inline_depth})" if self.inline_depth > 0 else ""
)
funcname = get_funcname(self.f_code.co_filename, lineno)
funcname_str = "" if funcname is None else f" ({funcname})"
return f"{self.f_code.co_filename}:{lineno} in {self.f_code.co_name}{funcname_str}{inline_depth_str}"
def get_log_starts_line_log_str(self):
log_str = f"TRACE starts_line {self.get_line_of_code_header()}\n"
line = linecache.getline(self.f_code.co_filename, self.lineno).rstrip()
log_str += f" {line}"
return log_str
def log_starts_line(self):
trace_source_log.debug("%s", LazyString(self.get_log_starts_line_log_str))
def step(self):
"""Process exactly one instruction, return False we should exit"""
assert isinstance(self.instruction_pointer, int)
inst = self.instructions[self.instruction_pointer]
self.current_instruction = inst
self.instruction_pointer += 1
if self.instruction_pointer < len(self.instructions):
self.next_instruction = self.instructions[self.instruction_pointer]
else:
self.instruction_pointer = None
self.next_instruction = None
if inst.starts_line and self.lineno != inst.starts_line:
self.lineno = inst.starts_line
self.log_starts_line()
if (
len(self.stack) == 0
and self.should_compile_partial_graph()
and self.is_non_empty_graph()
):
self.current_speculation = self.speculate()
if self.current_speculation.failed:
return self.step_graph_break(inst)
log.debug("TRACE %s %s %s", inst.opname, inst.argval, self.stack)
# 3.11 no longer uses a block stack, but we still keep track of one
# so that we know which contexts are currently active.
# For our purposes, all exception table entries with the same target
# are considered to be part of the same "block".
if sys.version_info >= (3, 11):
entry = inst.exn_tab_entry
if not (
# still in the same block
self.block_stack
and entry
and self.block_stack[-1].target is entry.target
):
if not entry:
# no longer in any block
# It is possible for NOPs to be between two instructions
# in the same block, but the NOPs are not covered by an
# exception table entry. In this case, assume that we
# are still in the same block.
if self.block_stack and inst.opname != "NOP":
# If we really escape from a block and the current
# instruction is not in another block, then there
# should be no other nested blocks that we are in.
assert len(self.block_stack) == 1
self.block_stack.pop()
elif (
# current instruction is in the previous block
len(self.block_stack) > 1
and self.block_stack[-2].target is entry.target
):
# exit the current block
self.block_stack.pop()
else:
# current instruction is in a new block
# push block to stack - note, BEFORE_WITH blocks won't
# be pushed here since BEFORE_WITH pushes the block, and
# the current instruction would be counted as being in that block.
self.block_stack.append(
BlockStackEntry(entry.target, len(self.stack))
)
try:
if not hasattr(self, inst.opname):
unimplemented(f"missing: {inst.opname}")
TracingContext.set_current_loc(
self.f_code.co_filename, self.lineno, self.f_code.co_name
)
getattr(self, inst.opname)(inst)
return inst.opname != "RETURN_VALUE"
except Unsupported:
if self.current_speculation is None:
log.debug("empty checkpoint")
raise
log.debug("step triggered compile", exc_info=True)
self.current_speculation.fail_and_restart_analysis()
def step_graph_break(self, continue_inst):
# generate code from checkpoint
assert not self.output.output_instructions
assert self.current_speculation is not None
self.output.compile_subgraph(
self,
partial_convert=True,
reason=GraphCompileReason("step_unsupported", [self.frame_summary()]),
)
self.output.add_output_instructions(
[create_jump_absolute(continue_inst)] + self.instructions
)
def run_ctx_mgr(self):
# NB: Don't push the top level frame summary; set_current_loc will
# take care of it. However, DO make sure we attach real_stack to
# exceptions
return TracingContext.current_frame(None)
def run(self):
with self.run_ctx_mgr():
try:
self.output.push_tx(self)
while (
self.instruction_pointer is not None
and not self.output.should_exit
and self.step()
):
pass
except BackendCompilerFailed:
raise
except Exception as e:
if config.replay_record_enabled:
e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined]
raise
finally:
self.output.pop_tx()
# Cleanup the outputGraph to delete the held tensors. We perform the
# cleanup only for InstructionTranslator and not
# InliningInstructionTranslator. The InliningInstructionTranslator
# mutates the output object and is restored to original state if
# there was an exception.
if isinstance(self, InstructionTranslator):
self.output.cleanup()
def push(self, val: Optional[VariableTracker]):
assert val is None or isinstance(
val, VariableTracker
), f"push expects VariableTracker, got {typestr(val)}"
self.stack.append(val)
def push_many(self, vals: List[VariableTracker]):
for val in vals:
self.push(val)
def pop(self) -> VariableTracker:
return self.stack.pop()
def popn(self, n: int) -> List[VariableTracker]:
assert n >= 0
return list(reversed([self.pop() for _ in range(n)]))
def LOAD_FAST(self, inst):
name = inst.argval
if name in self.f_locals and config.replay_record_enabled:
self.exec_recorder.add_local_var(name, self.f_locals[name])
if name.startswith(".") and name not in self.symbolic_locals:
# This happens in dict/list comprehensions
name = name.replace(".", "implicit")
assert name not in self.cell_and_freevars()
if name not in self.symbolic_locals:
unimplemented("undefined LOAD_FAST")
self.push(self.symbolic_locals[name])
if name.startswith("___stack"):
self.symbolic_locals.pop(name)
def LOAD_DEREF(self, inst):
assert inst.argval in self.cell_and_freevars()
if inst.argval in self.f_locals and config.replay_record_enabled:
self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval])
if inst.argval not in self.symbolic_locals:
unimplemented(f"undefined LOAD_DEREF {inst.argval}")
self.push(self.symbolic_locals[inst.argval])
def STORE_FAST(self, inst):
loaded_vt = self.pop()
name = inst.argval
# Only rename at the top-level scope, this is to avoid the confusion between
# mutating a variable vs renaming it (e.g. a = b) during speculating a higher order op,
# where mutation is prohibited and it's difficult to differentiate it with renaming.
if _is_top_level_scope(current_scope_id()):
loaded_vt = loaded_vt.rename(self, name)
self.symbolic_locals[name] = loaded_vt
def DELETE_FAST(self, inst):
del self.symbolic_locals[inst.argval]
STORE_DEREF = STORE_FAST
def LOAD_CLOSURE(self, inst):
self.push(ClosureVariable(name=inst.argval))
def LOAD_CONST(self, inst):
# For empty tuples, create empty TupleVariable
if isinstance(inst.argval, tuple) and not inst.argval:
self.push(TupleVariable([]))
else:
self.push(ConstantVariable.create(value=inst.argval))
def get_global_source(self, name):
source: Source
if self.output.global_scope is self.f_globals:
source = GlobalSource(name)
else:
if "__name__" in self.f_globals:
source = AttrSource(
self.import_source(self.f_globals["__name__"]), name
)
else:
mangled_name = f"___unnamed_scope_{id(self.f_globals)}"
if mangled_name not in self.output.global_scope:
self.output.install_global(mangled_name, self.f_globals)
source = GetItemSource(GlobalSource(mangled_name), name)
return source
def LOAD_GLOBAL(self, inst):
if sys.version_info >= (3, 11):
if inst.arg % 2:
self.PUSH_NULL(inst)
name = inst.argval
if config.replay_record_enabled:
if name in self.f_globals:
self.exec_recorder.add_global_var(name, self.f_globals[name])
else:
assert name in self.f_builtins
self.exec_recorder.builtins[name] = self.f_builtins[name]
if inst.argval == "AssertionError":
unimplemented("assert with non-string message")
if name in self.symbolic_globals:
variable = self.output.side_effects[self.symbolic_globals[name]]
self.push(self.output.side_effects.load_global(variable, name))
return
try:
value = self.f_globals[name]
except KeyError:
return self.load_builtin(inst)
source = self.get_global_source(name)
self.push(VariableBuilder(self, source)(value))
def STORE_GLOBAL(self, inst):
value = self.pop()
name = inst.argval
source = self.get_global_source(name)
if name not in self.symbolic_globals:
self.symbolic_globals[name] = object() # sentinel object
variable = self.output.side_effects.track_global_existing(
source, self.symbolic_globals[name]
)
if isinstance(value, RemovableHandleVariable):
unimplemented("Storing handles in globals - NYI")
self.output.side_effects.store_global(variable, name, value)
def import_source(self, module_name):
"""Create an alias to a module for use in guards"""
if "torch_package" in module_name:
value = torch.package.package_importer._package_imported_modules[
module_name
]
alias = (
module_name.replace(">", "_").replace("<", "_").replace(".", "_dot_")
)
else:
value = importlib.import_module(module_name)
alias = f"__import_{module_name.replace('.', '_dot_')}"
f_globals = self.output.global_scope
assert alias not in f_globals or f_globals[alias] is value
f_globals[alias] = value
self.output.update_co_names(alias)
return GlobalSource(alias)
def resolve_name(self, name, package, level):
"""
Copied from the Cpython implementation of __import__
Resolve a relative module name to an absolute one.
https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902
"""
bits = package.rsplit(".", level - 1)
if len(bits) < level:
raise ImportError("attempted relative import beyond top-level package")
base = bits[0]
return f"{base}.{name}" if name else base
def calc_package(self):
"""
Copied from the Cpython implementation of __import__
https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090
"""
package = self.f_globals.get("__package__")
spec = self.f_globals.get("__spec__")
if package is not None:
if spec is not None and package != spec.parent:
log.warning(
"__package__ != __spec__.parent (%r != %r)",
package,
spec.parent,
stacklevel=3,
)
return package
elif spec is not None:
return spec.parent
else:
log.warning(
"can't resolve package from __spec__ or __package__, "
"falling back on __name__ and __path__",
stacklevel=3,
)
package = self.f_globals["__name__"]
if "__path__" not in self.f_globals:
package = package.rpartition(".")[0]
return package
def IMPORT_NAME(self, inst):
level, fromlist = self.popn(2)
level = level.as_python_constant()
fromlist = fromlist.as_python_constant()
module_name = inst.argval
# Are we replaying? if so, load recorded module
recorded_name = (
f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}"
)
if recorded_name in self.f_globals:
value = self.f_globals[recorded_name]
source = GlobalSource(recorded_name)
else:
value = __import__(
module_name,