/
aot_autograd.py
2898 lines (2596 loc) · 126 KB
/
aot_autograd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import collections
import dataclasses
import itertools
import logging
import warnings
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from functorch import make_fx
import torch
import torch.fx.traceback as fx_traceback
import torch.nn as nn
import torch.utils._pytree as pytree
import torch.utils.dlpack
from torch import Tensor
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import dynamo_timed
from torch._subclasses import CrossRefFakeMode, FakeTensor, FakeTensorMode
from torch.fx import immutable_collections, Interpreter
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.multiprocessing.reductions import StorageWeakRef
from torch.nn.utils import stateless
from . import config
from .partitioners import default_partition
from torch._guards import TracingContext, DuplicateInputs
log = logging.getLogger(__name__)
MutationType = Enum(
"MutationType", ("none", "metadata_only", "data", "data_and_metadata")
)
OutputType = Enum(
"OutputType", (
# output is not an alias
"non_alias",
# output aliases an input
"alias_of_input",
# output **is** an input tensor
"is_input",
# output has a ._base tensor, which is a graph intermediate.
# We need to return its ._base as a graph output,
# so its requires_grad info is populated correctly.
# Instructs the runtime code to regenerate the current output
# from a base tensor, graph_intermediates[base_idx]
"alias_of_intermediate_save_as_output",
# Same as above; but we don't need to explicitly add its ._base
# as a graph output, because it already **is** a graph output.
"alias_of_intermediate",
# Same as above; but the output's ._base is **already** a user output.
# Instructs the runtime code to regenerate the current output from
# a base tensor, user_outputs[base_idx]
"alias_of_intermediate_base_is_user_output",
)
)
pytree._register_pytree_node(
immutable_collections.immutable_list,
lambda x: (list(x), None),
lambda x, c: immutable_collections.immutable_list(x),
)
pytree._register_pytree_node(
immutable_collections.immutable_dict,
lambda x: (list(x.values()), list(x.keys())),
lambda x, c: immutable_collections.immutable_dict(
{key: value for key, value in zip(c, x)}
),
)
aten = torch.ops.aten
# This global counter increments every time we compile a graph with
# AOTAutograd. You can use this to correlate runtime error messages
# with compile time (e.g., if you get an error at runtime saying
# compiled graph 3 failed, you can set a breakpoint at compile time
# for this graph number to investigate further at compile time.)
#
# NB: this is different from get_aot_compilation_context, which tracks
# each underlying graph that is compiled. In contrast, AOT_COUNTER
# corresponds to top-level invocations of aot_module/aot_function;
# one counter is allocated per entire compiled block (but this block
# may involve compiling multiple subgraphs; e.g., for forwards/backwards)
AOT_COUNTER = itertools.count()
KNOWN_TYPES = tuple(
[torch.Tensor, int, str, float, bool, type(None)] + list(py_sym_types)
)
@contextmanager
def preserve_rng_state():
rng_state = torch.clone(torch.random.get_rng_state())
if torch.cuda.is_available():
cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
try:
yield
finally:
torch.random.set_rng_state(rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
# Set up hooks so that during backward the fx's stack_trace is properly set
callback_set = False
def setup_stacktrace_preservation_hooks(roots: List):
def iter_graph(roots):
if not roots:
return
seen = set()
q = collections.deque()
for node in roots:
if node is not None:
seen.add(node)
q.append(node)
while q:
node = q.popleft()
for fn, _idx in node.next_functions:
if fn in seen or fn is None:
continue
seen.add(fn)
q.append(fn)
yield node
def get_callback(saved_stack_):
def callback():
global callback_set
fx_traceback.set_stack_trace(saved_stack_)
callback_set = False
return callback
def get_prehook(stack_):
def prehook(grad_output):
global callback_set
if not callback_set:
torch.autograd.variable.Variable._execution_engine.queue_callback(
get_callback(fx_traceback.format_stack())
)
callback_set = True
fx_traceback.set_stack_trace(stack_)
return prehook
def get_posthook(special_stack_):
def posthook(grad_input, grad_output):
fx_traceback.set_stack_trace(special_stack_)
return posthook
for node in iter_graph(roots):
forward_node_stack = node.metadata.get("traceback_", [])
node.register_prehook(get_prehook(forward_node_stack))
special_stack = forward_node_stack.copy()
special_stack.append(
"Gradient addition node due to multiple use of tensor around:"
)
node.register_hook(get_posthook(special_stack))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# AOT Autograd contains a pretty non-trivial amount of logic to handle edge cases around aliasing and mutation
# that are external to the graph (they show up as side effects in some way when you run the graph).
#
# Take a look at `test_aotdispatch.py TestAOTAutograd.test_input_mutation*` tests for some examples functions
# and what they're compiled graphs looks like.
# Below is a very long comment detailing several edge cases, and showing how AOT Autograd handles them.
#
# Note [AOT Autograd: input data mutations]
#
# If we compile a function that mutates inputs, then those input mutations are real side effects
# that a user expects to see after running the compiled graph.
# However, the graph that we want to send to a backend needs to be *entirely* functional.
# The way we reconcile this difference is that we remove the mutations completely from the graph that we compile
# but we update the graph to return (updated_inputs, user_outputs).
# In the epilogue that runs after the compiled graph is executed, we copy the updated inputs back to the originals.
#
# Example: original user code:
# def f(x):
# x.mul_(2)
# out = x.mul(3)
# return out
#
# After AOT Autograd compiles, we end up with a:
# (a) compiled graph
# (b) autograd.Function.forward() method, that executes the compiled graph
# (c) wrapper function, that calls the autograd.Function.forward() and performs the epilogue
#
# The output of (a, b, c) are all written below.
#
# def compiled_forward_graph(x):
# x_updated = x.mul(2)
# out = x_updated.mul(3)
# return x_updated, out
#
# # x_updated gets a gradient in the compiled backward
# def compiled_backward_graph(grad_x_updated, grad_out):
# grad_x = ...
# return grad_x
#
# def autograd.Function.forward(x):
# x_updated, out = compiled_forward_graph(x)
# return x_updated, out
#
# def compiled_wrapper(x):
# x_updated, out = autograd.Function.apply(x)
# x.copy_(x_updated)
# return out
#
# Another important thing to note is that updated inputs (due to data mutations) *do* participate
# in the compiled backward graph! Since the compiled forward graph gets N extra outputs
# (due to updated inputs showing up as graph outputs),
# The compiled backward gets an additional N inputs.
# That way, during the x.copy_(x_updated) bit in the epilogue, gradients will flow from the updated input
# back to the original input.
# Note [AOT Autograd: input metadata mutations]
#
# For the same reason as input mutations, we also don't put input metadata mutations in the graph.
# Instead, we return the updated version of the input (a view), and mutate the input's metadata outside of the graph
#
# Example: original user code:
# def f(x):
# x.t_()
# out = x.mul(3)
# return out
#
# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
# def compiled_forward_graph(x):
# x_updated = x.t()
# out = x_updated.mul(3)
# return x_updated, out
#
# # x_updated does *not* get a gradient in the compiled backward
# def compiled_backward_graph(grad_out):
# grad_x = ...
# return grad_x
#
# def autograd.Function.forward(x):
# x_updated, out = compiled_forward_graph(x)
# return x_updated, out
#
# def compiled_wrapper(x):
# x_updated, out = autograd.Function.apply(x)
# x.as_strided_(x_updated)
# return out
# Note [AOT Autograd: outputs aliasing inputs or intermediates!]
#
# AOT Autograd needs special handling for outputs that alias graph inputs or intermediates!
# Why?
# (1) autograd.Function.forward() has a limitation, where views that returned in the forward cannot later be mutated.
# (2) views don't need to be compiled in the graph anyway - it's cheap to generate them outside of the compiled graph,
# in an epilogue.
# For outputs that alias inputs, we do the following:
# (a) *still* return the aliased output as a graph output
# (b) In the AOT Autograd wrapper/epilogue, we don't return that aliased output. Instead, we use it to regenerate the output.
#
# For outputs that alias *intermediates*, we do the following:
# (a) Return the output in the compiled forward, **and** return it's ._base (a graph intermediates) as an output in the forward
# (b) Use (output, graph_intermediate) to regenerate the alias, and return that to the user (instead of the compiled fw output).
# You might wonder why we return the aliased output directly in the graph (and making the graph compute it),
# only to not return it and instead generate a fresh alias off of the intermediate,
# instead of (say) just storing metadata about the size/stride of the output somewhere to generate the alias. There are two reasons:
# (1) Getting the actual alias tensor allows us to use view-replay to generate the alias, instead of an as_strided() call
# (2) Inductor (and other backends) are free to change the memory format of graph outputs, if it results in better performance.
# This can result in problems if a user later tries to .view() that output expecting it to have one set of strides,
# when it has a different set of strides.
# By including the view op directly in the graph, inductor takes that into account when deciding what memory format
# the graph intermediate should be.
#
# Another important thing to note is how our traced backward() graph handles aliases.
# (this applies to outputs aliasing inputs, outputs aliasing intermediates,
# *and* updated inputs returned in the compiled forward due to metadata-only mutations).
# Any outputs that alias (either inputs or intermediates) do NOT participate in the compiled backward graph
# It would be wasteful to include them in the compiled backward(), because we regenerate them eagerly
# at the end of the forward.
#
# Example: original user code:
# def f(x):
# out1 = x.t()
# intermediate = x.mul(2)
# out2 = intermediate.view(-1)
# return out1, out2
#
# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
# def compiled_forward_graph(x):
# out1 = x.t()
# intermediate = x.mul(2)
# out2 = intermediate.view(-1)
# # the compiled graph also returns the intermediate
# return out1, out2, intermediate
#
# # intermediate gets a gradient in the compiled backward.
# # both output aliases (out1 and out2) do not.
# def compiled_backward_graph(grad_intermediate):
# grad_x = ...
# return grad_x
#
# def autograd.Function.forward(x):
# out1, out2, intermediate = compiled_forward_graph(x)
# return out1, out2, intermediate
#
# def compiled_wrapper(x):
# out1, out2, intermediate = autograd.Function.apply(x)
# # regenerate out1 from the input
# out1_regenerated = out1._view_func(x)
# # regenerate out1 from the intermediate
# out2_regenerated = out2._view_func(intermediate)
# return out1_regenerated, out2_regenerated
# Note [AOT Autograd: mutations to inputs that alias other inputs]
#
# Another edge case that is (only partially) handled today is when an input is mutated, but itself aliases another input.
# AOT Autograd needs to **ensure** that functionalization knows that the two inputs are aliased to each other.
# That way, when the aliased input is accessed later in the graph, functionalization knows to "update" the alias
# given the mutation that occurred.
#
# This is handled by updating the calling convention: we create a "synthetic base" that becomes a new input
# in the compiled function, and we regenerate the original (aliased) inputs directly off of the base
# inside of the compiled function.
#
# See merge_view_inputs() for more detailed info.
#
# Example: original user code:
# def f(x, x_view):
# x.mul_(2)
# out = x * x_view
# return out
# f(x, x.view(-1))
#
# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
# def compiled_forward_graph(base)
# x = generate_x(base)
# x_view = generate_x_view(base)
# x_updated = x.mul(2)
# x_view_updated = x_updated.view(-1)
# out = x_updated * x_view_udpated
# return x_updated, out
#
# # The calling convention change from (aliases) -> (base) happens
# # *outside* of the autograd.Function.forward().
# # That means the forward() only has 1 input (base),
# # and the backward() only has 1 output (grad_base)
# def compiled_backward_graph(grad_out):
# grad_base = ...
# return grad_base
#
# def autograd.Function.forward(base):
# x_updated, out = compiled_forward_graph(base)
# return x_updated, out
#
# # The compiled wrapper is where we create synthetic bases.
# # The info on which inputs are mutated is also tracked *before* synthetic base creation.
# def compiled_wrapper(x, x_view):
# base = merge_view_inputs(x, x_view)
# x_updated, out = autograd.Function.apply(base)
# # x and x_view are aliased in eager mode, so this mutation to x will automatically affect x_view.
# x.copy_(x_updated)
# return out
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# This class stores info about every user output.
@dataclass(frozen=True)
class OutputAliasInfo:
# Tells us if this output is:
# (1) a regular (non-aliased) output
# (2) an alias of a forward input
# (3) **is** a forward input (special case of "alias_of_input")
# (4) an alias of an intermediate (aka an alias of an output of the inner traced forward)
# (5) an alias of an intermediate, that explicitly requires returning the intermediate
# as a graph output
# (6) an alias of an intermediate, where that intermediate is also a user output
output_type: OutputType
# The raw type of the output (torch.Tensor, SymInt, etc)
raw_type: type
# If (1) above, then
# - base_idx is None
# If (2) or (3) above, then
# - Tells us that the base of this alias is user_fwd_input[base_idx]
# (This is an index into the inputs *before* we make synthetic bases)
# If (4) or (5) above, then
# - Tells us that the base of this alias is output_graph_intermediates[base_idx]
# here, this refers to the index of the *direct* traced
# If (6) above, then:
# - Tells us that the base of this alias is output_user_fwds[base_idx]
# here, this refers to the index of the *direct* traced
base_idx: Optional[int]
# This class tells us info about user inputs.
@dataclass(frozen=True)
class InputAliasInfo:
is_leaf: bool
mutates_data: bool
mutates_metadata: bool
# This class encapsulates all aliasing + mutation info we need about the forward graph
# See a more detailed overview of the edge case handling at
# https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit
@dataclass()
class ViewAndMutationMeta:
# length = # user inputs
# This gives us info about every input, and what sort of mutation happened to it (if any)
input_info: List[InputAliasInfo]
# length = # user outputs
# This gives us info about every output (mostly around whether it aliases other tensors)
output_info: List[OutputAliasInfo]
# length = # mutated inps + # user outputs
# For every output *and* mutated input returned from the forward,
# tells us whether or not the output should require gradients or not
requires_grad_info: List[bool]
# length = the number of intermediate bases appended as outputs to the end of the forward graph.
# Note: this is not necessarily the same thing as:
# len([x for x in output_info if x.output_type == OutputType.alias_of_intermediate])
# Because outputs might share a ._base, or an output's ._base might itself be
# another user output (in both cases, we won't redundantly append bases to the end of the graph)
num_intermediate_bases: int
# For inference only: instructs us to keep data-only input mutations directly in the graph
keep_input_mutations: int
# These are the FakeTensor (or potential SymInt) outputs that we traced from our
# metadata pass of the user's forward function.
# Their only use today is to pass them as a best-guess for tangents when tracing the joint.
# Stashing them as part of our "metadata" makes it simpler if we want to run our analysis
# pass once, and re-use the output throughout AOTAutograd
traced_tangents: List[Any]
def __post_init__(self):
# pre-compute the indices of the inputs that are mutated.
# When keep_input_mutations is set, we don't need to worry about our epilogue
# handling data-only mutations, because we keep them directly in the graph.
mutated_inp_indices = [
i for i, m in enumerate(self.input_info) if m.mutates_metadata or (not self.keep_input_mutations and m.mutates_data)
]
aliased_out_indices = [
i
for i, m in enumerate(self.output_info)
if m.output_type != OutputType.non_alias
]
# This is pre-computed in post_init for perf.
# It contains the index of every element
# of input_info that corresponds to a mutation (data or metadata or both)
self.mutated_inp_indices = mutated_inp_indices
# This is pre-computed for perf.
# It contains the index of every element
# of output_info that corresponds to an alias (either of an input or intermediate)
self.aliased_out_indices = aliased_out_indices
# This class exists because:
# - the autograd.Function.forward() in aot autograd returns outputs that might alias inputs
# - we only care about the metadata on those aliases, so we can regenerate them.
# We do not want them to participate in the autograd.Function.
# We do that by wrapping them in an opaque class, so the autograd.Function
# does not know to treat them as tensors.
@dataclass(frozen=True)
class TensorAlias:
alias: torch.Tensor
def has_same_metadata(t1, t2):
return (
t1.size() == t2.size()
and t1.stride() == t2.stride()
and t1.storage_offset() == t2.storage_offset()
)
def gen_alias_from_base(aliased_base_tensor, target_meta_tensor, target_requires_grad):
# Try to do view-replay if possible.
# fall back to .as_strided() if we can't.
if target_meta_tensor._base is not None:
# The base that we want to replay our view off of might have a different shape than the view's original base.
b = target_meta_tensor._base
abt = aliased_base_tensor
# Don't unnecessarily call as_strided if nothing changed; as_strided's
# backward is poorly implemented and slow
if abt is not b and (
abt.size() != b.size() or
abt.stride() != b.stride() or
abt.storage_offset() != b.storage_offset()
):
reshaped_base_tensor = aliased_base_tensor.as_strided(
b.size(), b.stride(), b.storage_offset()
)
else:
reshaped_base_tensor = aliased_base_tensor
out = target_meta_tensor._view_func(reshaped_base_tensor)
# This shape mismatch can happen due to a bug in inplace/view handling in autograd.
# Try putting a breakpoint here and running
# `test/functorch/test_aotdispatch TestAOTAutograd.test_output_all_alias_types`
# Also, https://github.com/pytorch/pytorch/issues/49825
#
# As a stopgap, we'll fall back to as_strided.
if out is not None and out.shape == target_meta_tensor.shape:
if aliased_base_tensor.requires_grad and not target_requires_grad:
out = out.detach()
elif not aliased_base_tensor.requires_grad and target_requires_grad:
out.requires_grad_(True)
return out
size = target_meta_tensor.size()
stride = target_meta_tensor.stride()
storage_offset = target_meta_tensor.storage_offset()
if aliased_base_tensor.is_complex() and not target_meta_tensor.is_complex():
aliased_out = torch.view_as_real(aliased_base_tensor).as_strided(
size, stride, storage_offset
)
elif not aliased_base_tensor.is_complex() and target_meta_tensor.is_complex():
aliased_out = torch.view_as_complex(aliased_base_tensor).as_strided(
size, stride, storage_offset
)
else:
aliased_out = aliased_base_tensor.as_strided(size, stride, storage_offset)
# For outputs aliasing inputs, we need to check if the requires-gradness has changed.
if aliased_base_tensor.requires_grad and not target_requires_grad:
aliased_out = aliased_out.detach()
elif not aliased_base_tensor.requires_grad and target_requires_grad:
aliased_out.requires_grad_(True)
return aliased_out
def to_fun(t):
if isinstance(t, Tensor):
return torch._to_functional_tensor(t, mirror_autograd_meta=True)
else:
return t
def from_fun(t):
if not isinstance(t, Tensor) or not torch._is_functional_tensor(t):
return t
torch._sync(t)
return torch._from_functional_tensor(t)
# This is a version of functionalization that is specifically designed
# for the AOTAutograd use case.
#
# Unlike functorch's variant, this doesn't use the functorch level system,
# instead it directly uses PyTorch's conventional dispatcher to hit the
# functionalization key. In particular, this means that FunctionalTensorWrapper
# can have autograd data stored directly on it.
#
# In typical AOTAutograd usage, the dispatch key order will look like:
#
# Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor
# outer tensor inner tensor
#
# Returns:
# - ViewAndMutationMeta, telling us metadata about the inputs and outputs, and
# The list of outputs from the forward, but **only** the outputs that we need
# to pass in as tangents into the backward.
# Specifically, aliased outputs from the forward get regenerated, and don't participate
# in the compiled backward function.
def run_functionalized_fw_and_collect_metadata(
f,
*,
keep_input_mutations: bool
) -> ViewAndMutationMeta:
memo = {}
def to_fun(t):
if isinstance(t, Tensor):
if t in memo:
return memo[t]
r = torch._to_functional_tensor(t, mirror_autograd_meta=True)
memo[t] = r
return r
else:
return t
def from_fun(t):
if not isinstance(t, Tensor) or not torch._is_functional_tensor(t):
return t
torch._sync(t)
return torch._from_functional_tensor(t)
@wraps(f)
def inner(*flat_args):
# This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args.
assert all(isinstance(a, KNOWN_TYPES) for a in flat_args)
input_info: List[InputAliasInfo] = []
output_info: List[OutputAliasInfo] = []
input_requires_grad_info: List[bool] = []
output_requires_grad_info: List[bool] = []
flat_f_args = pytree.tree_map(to_fun, flat_args)
torch._enable_functionalization(reapply_views=True)
try:
# precondition: The passed in function already handles unflattening inputs + flattening outputs
flat_f_outs = f(*flat_f_args)
finally:
torch._disable_functionalization()
# Inspect the state of the input tensor functional wrapper to detect input mutation info
# If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
for (i, (arg, f_arg)) in enumerate(zip(flat_args, flat_f_args)):
if not isinstance(arg, Tensor):
new_arg = arg
else:
torch._sync(f_arg)
new_arg = torch._from_functional_tensor(f_arg)
if arg is not new_arg:
if StorageWeakRef(arg.untyped_storage()) == StorageWeakRef(new_arg.untyped_storage()):
mutates_data = False
mutates_metadata = True
else:
mutates_data = True
mutates_metadata = not has_same_metadata(arg, new_arg)
# Only track requires_grad info on *mutated* inputs,
# because they show up in the autograd.Function.forward as outputs
input_requires_grad_info.append(
isinstance(f_arg, torch.Tensor) and f_arg.requires_grad
)
else:
mutates_data = False
mutates_metadata = False
input_info.append(InputAliasInfo(
is_leaf=isinstance(arg, torch.Tensor) and arg.is_leaf,
mutates_data=mutates_data,
mutates_metadata=mutates_metadata
))
# If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediiate,
# We need to make sure our graph returns the _base as a graph output, and we manually recreate the view
# to return to the user. Why? The backend compiler is free to (incorrectly) not set requires_grad
# on the base tensor, but we are obligated to properly set requires-gradness on the real output.
num_mutated_inps = len(
[x for x in input_info if x.mutates_data or x.mutates_metadata]
)
inp_storage_refs = {
StorageWeakRef(inpt.untyped_storage()): idx
for idx, inpt in enumerate(flat_f_args)
if isinstance(inpt, torch.Tensor)
}
# We need inp tensor id's to be able to tell if an outputs **are** inputs.
inp_tensor_ids = {
id(inpt) for inpt in flat_f_args if isinstance(inpt, torch.Tensor)
}
# We need output tensor id's to tell if any output._base` attributes **are** other outputs.
# (This is also a dict because we need to know that output's index, so we can regenerate
# the alias from it).
out_tensor_ids = {id(o): i for i, o in enumerate(flat_f_outs)}
# maps the id of an intermediate base to its index in the output of the compiled forward
intermediate_base_tensor_id_to_output_idx: Dict[int, int] = {}
intermediate_bases: List[torch.Tensor] = []
for o in flat_f_outs:
if (
isinstance(o, torch.Tensor)
and StorageWeakRef(o.untyped_storage()) in inp_storage_refs
):
base_idx = inp_storage_refs[StorageWeakRef(o.untyped_storage())]
is_input_tensor = id(o) in inp_tensor_ids
if is_input_tensor:
output_type = OutputType.is_input
else:
output_type = OutputType.alias_of_input
# We only need to handle the intermediate base case when both
# the intermediate base and the output require gradients.
# See Note [AOT Autograd: outputs aliasing inputs or intermediates!]
elif (
isinstance(o, torch.Tensor)
and o._base is not None
and o.requires_grad
and o._base.requires_grad
):
# First, check if o's ._base is an existing output
maybe_existing_out_idx = out_tensor_ids.get(id(o._base), None)
if maybe_existing_out_idx is not None:
# Special case where the output is an alias of a graph intermediate, but that intermediate
# is itself also a user output.
output_type = OutputType.alias_of_intermediate_base_is_user_output
base_idx = maybe_existing_out_idx
else:
# Next, check if o's ._base is an intermediate base that we already returned
maybe_existing_base_output_idx = intermediate_base_tensor_id_to_output_idx.get(
id(o._base), None
)
if maybe_existing_base_output_idx is not None:
output_type = OutputType.alias_of_intermediate
base_idx = maybe_existing_base_output_idx
else:
# Otherwise, take o._base and explicitly return it as an output in the compiled graph
new_out_idx = len(intermediate_bases)
base_idx = new_out_idx
# Indicate to the logic later on (when we trace the joint)
# that this particular output should get it's ._base appended to the forward graph outputs
output_type = OutputType.alias_of_intermediate_save_as_output
intermediate_base_tensor_id_to_output_idx[id(o._base)] = new_out_idx
intermediate_bases.append(o._base)
else:
output_type = OutputType.non_alias
base_idx = None
out_info = OutputAliasInfo(
output_type=output_type,
raw_type=type(o),
base_idx=base_idx,
)
output_info.append(out_info)
output_requires_grad_info.append(
isinstance(o, torch.Tensor) and o.requires_grad
)
# Our autograd.Function.forward returns both mutated inputs and outputs,
# so we need grad info on all of them.
requires_grad_info = input_requires_grad_info + output_requires_grad_info
assert len(requires_grad_info) == len(output_info) + len(
[x for x in input_info if x.mutates_data or x.mutates_metadata]
)
# This analysis function returns *only* the outputs that are meant to be tangents to the backwards.
# Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates)
# are *regenerated* later, and not used directly in the autograd graph
f_input_tangents = [
inp
for inp, info in zip(flat_f_args, input_info)
if info.mutates_data
]
f_output_tangents = [
o
for o, info in zip(flat_f_outs, output_info)
if info.output_type == OutputType.non_alias and issubclass(info.raw_type, torch.Tensor)
]
# intermediate bases are also included in the backward graph
f_tangents = f_input_tangents + f_output_tangents + intermediate_bases
traced_tangents = pytree.tree_map(from_fun, f_tangents)
metadata = ViewAndMutationMeta(
input_info=input_info,
requires_grad_info=requires_grad_info,
output_info=output_info,
num_intermediate_bases=len(intermediate_bases),
keep_input_mutations=keep_input_mutations,
traced_tangents=traced_tangents,
)
return metadata
return inner
def unpack_synthetic_bases(
primals: List[Any],
synthetic_base_info: Optional[List[Union[int, Tuple[int, torch.Tensor]]]],
) -> List[Any]:
# This is only not None if our graph mutates a graph input that aliases another graph input.
if synthetic_base_info is None:
return primals
f_args_inner = []
for outer_idx_or_tuple in synthetic_base_info:
if isinstance(outer_idx_or_tuple, int):
f_args_inner.append(primals[outer_idx_or_tuple])
else:
outer_base_idx, view_tensor = outer_idx_or_tuple
outer_base = primals[outer_base_idx]
view_arg = gen_alias_from_base(
outer_base, view_tensor, view_tensor.requires_grad
)
f_args_inner.append(view_arg)
return f_args_inner
# This class contains all the metadata we care about for the current function we're compiling.
# This data is needed both at trace time and at runtime.
@dataclass
class CompiledRuntimeMetadata:
# This type / object should be cleaned up
# See Note [Synthetic Base Info Metadata]
synthetic_base_info: Optional[List[Union[int, Tuple[int, torch.Tensor]]]]
fw_metadata: ViewAndMutationMeta
def __post_init__(self):
self.num_outputs = len(self.fw_metadata.output_info)
self.num_outputs_non_aliased = len(
[x for x in self.fw_metadata.output_info if x.output_type == OutputType.non_alias]
)
self.num_outputs_aliased_to_inputs = len(
[
x
for x in self.fw_metadata.output_info
if x.output_type in [
OutputType.alias_of_input,
OutputType.is_input,
]
]
)
self.num_outputs_aliased_to_intermediates = len(
[
x
for x in self.fw_metadata.output_info
if x.output_type in [
OutputType.alias_of_intermediate,
OutputType.alias_of_intermediate_save_as_output,
OutputType.alias_of_intermediate_base_is_user_output,
]
]
)
self.num_outputs_aliased = (
self.num_outputs_aliased_to_inputs + self.num_outputs_aliased_to_intermediates
)
self.num_mutated_data_inputs = len(
[x for x in self.fw_metadata.input_info if x.mutates_data]
)
self.num_mutated_metadata_inputs = len(
[
x
for x in self.fw_metadata.input_info
if x.mutates_metadata
]
)
self.num_mutated_metadata_only_inputs = len(
[
x
for x in self.fw_metadata.input_info
if not x.mutates_data and x.mutates_metadata
]
)
self.num_mutated_inputs = self.num_mutated_data_inputs + self.num_mutated_metadata_only_inputs
# This function takes in a tensor t, and returns one of t, t.view(), or t.clone().
# When tracing the joint forward + backward, for any inputs in the graph that are mutated,
# we need to clone them first (and similarly for metadata-only mutations, we need to view them first).
# The idea is that when we trace the backward, we need to pass in the *original* primals
# to autograd.grad(), before they were mutated.
# Note: when we have synthetic base inputs, we need to clone them *before* creating views off of them.
# This means that "idx" here represents the index of the (potentially) synthetic base.
# What we need to do is:
# (1) map the current (post-synthetic-base calling convention) input argument index
# to int index pre-synthetic-base-calling-convention.
# (2) There could be multiple, if this index corresponds to a synthetic base
# that has multiple input aliases.
# (3) If any of those corresponding inputs get metadata mutations, then we clone the base.
def maybe_to_fresh_input(idx, t, meta):
if not isinstance(t, Tensor):
return t
if meta.synthetic_base_info is None:
outer_aliased_indices_of_current_base_arg = [idx]
else:
outer_aliased_indices_of_current_base_arg = [
# For every argument index in the outer calling convention (before synthetic bases)
# find its index in the inner calling convention.
# if it matches the index of our current arg (idx), track the outer argument's index (i)
i
for i, outer_idx_or_tuple in enumerate(meta.synthetic_base_info)
if (isinstance(outer_idx_or_tuple, int) and outer_idx_or_tuple == idx)
or (
isinstance(outer_idx_or_tuple, tuple)
and outer_idx_or_tuple[0] == idx
)
]
if any(
meta.fw_metadata.input_info[i].mutates_data
for i in outer_aliased_indices_of_current_base_arg
):
# Make sure the primal we pass to autograd.grad()
# sees the tensor before the mutation
return t.clone()
if any(
meta.fw_metadata.input_info[i].mutates_metadata and not meta.fw_metadata.input_info[i].mutates_data
for i in outer_aliased_indices_of_current_base_arg
):
# Make sure the primal we pass to autograd.grad()
# sees the tensor before the metadata mutation
return t.view(t.shape)
return t
# This function takes in a forward fn, runs it, and (optionally) runs autograd to compute the joint.
# When maybe_tangents is None, we only run the forward. Otherwise we run the "joint" forward + backward.
# Preconditions:
# - fn corresponds to the flattened user fw function, with duplicate inputs removed
# - functionalization is turned on (and inputs are wrapped in functional tensors)
# - Synthetic bases have been *removed* (we've taken views on them corresponding to the user argument views).
# - primals_after_cloning are what we run our forward function on. It is identical to primals_before_cloning,
# except that every input we know will be mutated in the forward has been cloned.
# We run our forward on primals_after_cloning (potentially mutating some inputs), and then compute our gradients
# w.r.t. primals_before_cloning (so we properly capture the mutation in our gradient computation).
# Importantly, due functionalization + some autograd.Function constraints, this function can return EXTRA outputs
# compared to what the original user forward returns.
#
# If we are only running the forward (and not computing the joint):
# - Our function will return (updated_inputs, fw_outs)
#
# If we are running the forward + backward (computing the joint):
# - Our function will return (updated_inputs, fw_outs, intermediate_bases), (gradients)
#
# Finally, if keep_input_mutations is set, then we will explicitly *not* return updated inputs, for any inputs
# that experienced data-only mutations.
# Instead, we are relying on the logic in create_forward_or_joint_functionalized to manually perform the input mutations,
# keeping them directly in the traced graph.
def forward_or_joint(
fn: Callable,
primals_before_cloning: List[Any],
primals_after_cloning: List[Any],
maybe_tangents: Optional[List[Any]],
meta: CompiledRuntimeMetadata,
keep_input_mutations: bool,
) -> Any:
outs = fn(*primals_after_cloning)
assert len(meta.fw_metadata.output_info) == len(outs)
# The compiled fw will return mutated input tensors, *including* metadata-only mutation.
# However, if keep_input_mutations is set, the compiled fw only needs to return metadata-mutated inputs.
# (because data-only input mutations are handled directly in the compiled graph)
if keep_input_mutations:
mutated_inputs_to_return = [
x
for (i, x) in enumerate(primals_after_cloning)
if meta.fw_metadata.input_info[i].mutates_metadata
]
else:
mutated_inputs_to_return = [
x
for (i, x) in enumerate(primals_after_cloning)
if meta.fw_metadata.input_info[i].mutates_data or meta.fw_metadata.input_info[i].mutates_metadata
]
# Case 1: We are just tracing the forward; not the joint forward + backward.
if maybe_tangents is None:
return *mutated_inputs_to_return, *outs
else:
tangents = maybe_tangents
# Case 2: We are tracing the joint forward backward.
# This also requires us to:
# - update the graph to return intermediate bases
# - Figure out what grad_outputs to pass into the backward
# - (this includes intermediate bases in the forward, and forward inputs that had data mutations)
# - actually call autograd.grad to trace the backward.
intermediate_bases = []
for o, info in zip(outs, meta.fw_metadata.output_info):
if info.output_type == OutputType.alias_of_intermediate_save_as_output:
intermediate_bases.append(o._base)
assert meta.fw_metadata.num_intermediate_bases == len(intermediate_bases)
# Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw
# For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead,
# which we *should* send to grad()
outputs_for_grad = [
x
for (i, x) in enumerate(outs)
if meta.fw_metadata.output_info[i].output_type == OutputType.non_alias
# Also, only tensor outputs should participate in the backward
# (in particular, Symint outputs in the forward graph shouldn't get tangents)
and issubclass(meta.fw_metadata.output_info[i].raw_type, torch.Tensor)
]
# Pass any (non-aliased) mutated inputs in as tangents, since they'll be returned as outputs in the fw
# Important: the traced joint fw/bw will return updated inputs with data mutations,
# but *not* with metadata mutations.
# Instead, we shunt the updated metadata around externally
# and update the input's metadata outside of the autograd.Function
mutated_inputs_for_grad = [
x
for (i, x) in enumerate(primals_after_cloning)
if meta.fw_metadata.input_info[i].mutates_data
]
# The tensors that we include in the backward graph are:
# - inputs that recieve *data* mutations (not metadata-only; those are recomputed later)
# - outputs that are not aliased (aliased outputs are recomputed later)
# - intermediate ._base tensors of aliased outputs (we use those later to recompute the aliased outputs)
fw_outs_to_grad = mutated_inputs_for_grad + outputs_for_grad + intermediate_bases
assert len(tangents) == len(fw_outs_to_grad)
# the compiled forward should return (mutated_inputs, user_outs, intermediate_bases)
fw_outs_to_return = *mutated_inputs_to_return, *outs, *intermediate_bases
# Take care to grab and sync the updated inputs from primals_after_cloning (the inputs we actually mutate!)
# and not primals_before_cloning (the preserved inputs, pre-mutation, that we pass to grad())
for i, arg in enumerate(primals_after_cloning):
if not isinstance(arg, Tensor):