-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
higher_order_ops.py
1431 lines (1240 loc) · 52.9 KB
/
higher_order_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import contextlib
import functools
import itertools
import logging
from typing import Dict, List, Optional
import torch._C
import torch.fx
import torch.nn
import torch.onnx.operators
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import deepcopy_to_fake_tensor, get_fake_value, get_real_value
from torch._dynamo.variables.base import VariableTracker
from torch._dynamo.variables.builtin import BuiltinVariable
from torch._dynamo.variables.functions import UserFunctionVariable
from torch._dynamo.variables.tensor import SymNodeVariable
from torch._guards import Source
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils import _pytree as pytree
from ..exc import (
UncapturedHigherOrderOpError,
unimplemented,
Unsupported,
UserError,
UserErrorType,
)
from ..source import FSDPNNModuleSource, GetItemSource, NNModuleSource
from ..utils import proxy_args_kwargs
from .dicts import ConstDictVariable
from .lists import ListVariable, TupleVariable
from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable
log = logging.getLogger(__name__)
def raise_hard_error_if_graph_break(reason):
def deco(fn):
@functools.wraps(fn)
def graph_break_as_hard_error(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Unsupported as e:
msg = " Scroll up to find out what causes the graph break."
raise UncapturedHigherOrderOpError(reason + msg) from e
return graph_break_as_hard_error
return deco
@contextlib.contextmanager
def dynamo_enable_grad(tx, enable=True):
from . import GradModeVariable
org_value = torch.is_grad_enabled()
try:
GradModeVariable.create(tx, enable, initialized=True)
yield
finally:
GradModeVariable.create(tx, org_value, initialized=True)
def only_consist_of(var, types):
if isinstance(var, types):
return True
if isinstance(var, (TupleVariable, ListVariable)):
return all(only_consist_of(item, types) for item in var.items)
if isinstance(var, ConstDictVariable):
return all(only_consist_of(item, types) for item in var.items.values())
return False
# A more read-able syntax sugar for creating a UserFunctionVariable for f
# and run call_function on it. Make it return a function to preserve the calling
# convention of the original f.
def _make_inlined(tx, f):
assert callable(f), "Expect f to be a python callable."
def inline_call(*args, **kwargs):
return UserFunctionVariable(f).call_function(tx, args, kwargs)
return inline_call
def _call_function_and_unflatten_output(tx, fn, args, kwargs, ret_vt, ret_treespec):
from .builder import wrap_fx_proxy
flat_example_value = pytree.tree_map_only(
torch.fx.Proxy,
lambda a: a.node.meta["example_value"],
ret_vt.as_proxy(),
)
# Store the invocation as a call
flat_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
fn,
args=args,
kwargs=kwargs,
),
example_value=flat_example_value,
)
# Transform variable back into a list (previously made into a tuple by
# speculate_subgraph function) so as to respect the pytree API typing.
flat_list_variable = BuiltinVariable(list).call_function(tx, [flat_variable], {})
return (
_make_inlined(tx, pytree.tree_unflatten)(flat_list_variable, ret_treespec)
if ret_treespec
else flat_variable
)
def _assert_tensors_nonaliasing(inputs, outputs):
input_tensor_ids = {
id(t) for t in pytree.tree_leaves(inputs) if isinstance(t, torch.Tensor)
}
output_tensor_ids = {
id(t) for t in pytree.tree_leaves(outputs) if isinstance(t, torch.Tensor)
}
assert input_tensor_ids.isdisjoint(
output_tensor_ids
), "inputs to function body cannot alias outputs"
def validate_args_and_maybe_create_graph_inputs(
sub_args,
tracer,
tx,
manually_set_subgraph_inputs,
description,
):
from . import AutogradFunctionContextVariable, ConstantVariable, EnumVariable
from .builder import wrap_fx_proxy_cls
assert tracer.parent is not None
args = []
for a in sub_args:
assert isinstance(a, VariableTracker)
if not manually_set_subgraph_inputs:
args.append(a)
continue
if isinstance(a, (ConstantVariable, EnumVariable)):
# This arg is not used in the body of the higher order op.
# Currently, this new input is added to make the calls
# happy, which expect a fixed number of arguments. In
# future, we can clean this up.
tracer.create_graph_input("const")
new_arg = a
# Weird special case, we probably want to delete it or fold it
# into the next case (of `a` being placeable into a graph)
elif isinstance(a, AutogradFunctionContextVariable):
tracer.create_graph_input(a.as_proxy().node.name)
new_arg = a
# If `a` can be put into a graph
elif a.maybe_fx_node() is not None:
node = a.maybe_fx_node()
new_proxy = tracer.create_graph_input(node.name)
example_value = (
node.meta["example_value"] if "example_value" in node.meta else None
)
new_arg = wrap_fx_proxy_cls(
target_cls=type(a),
tx=tx,
proxy=new_proxy,
example_value=example_value,
)
# If `a` cannot be put into a graph
else:
# HOPs work much better if they use speculate_subgraph(manually_set_subgraph_inputs=False).
raise unimplemented(
f"{description} with body that accepts non-Tensors as input. "
f"Got: {a.python_type()}"
)
args.append(new_arg)
return args
# See NOTE [HigherOrderOperator tracing design] for details of the design
def speculate_subgraph(
tx,
f,
sub_args,
sub_kwargs,
description,
*,
# source_target is the .value of HigherOrderOpVariable and is the
# target of the proxy that we created for the higherOrderOperator.
source_target=None,
always_restore=False,
enable_grad=None,
# NOTE [Temporary argument `manually_set_subgraph_inputs`]
# If manually_set_subgraph_inputs=True, then we manually add
# the `sub_args` to `subgraph`, if False then we rely
# on tracer's lifting mechanism to lift these args.
# NOTE: Default `True` is temporary and plan is
# to always lift args in future and remove this
# argument.
manually_set_subgraph_inputs=True,
restore_side_effects=True,
should_flatten_outputs=False,
# Pass in an originating tracer - this is needed for preserving context
# across fwd-bwd for autograd.Function
tracer=None,
):
if sub_kwargs is None:
sub_kwargs = {}
# See NOTE [Temporary argument `manually_set_subgraph_inputs`]
if sub_kwargs and manually_set_subgraph_inputs:
unimplemented(
"Use `manually_set_subgraph_inputs=False` when passing `sub_kwargs`."
)
try:
f, sub_args, sub_kwargs = VariableTracker.apply(
# ensure guards on args get installed in parent subgraph
lambda x: x.realize(),
(f, sub_args, sub_kwargs),
)
with tx.output.subtracer(source_target, tracer) as subtracer:
args = validate_args_and_maybe_create_graph_inputs(
sub_args, subtracer, tx, manually_set_subgraph_inputs, description
)
validate_args_and_maybe_create_graph_inputs(
sub_kwargs.values(),
subtracer,
tx,
manually_set_subgraph_inputs=False,
description=description,
)
autograd_ctx = (
dynamo_enable_grad(tx, enable_grad)
if enable_grad is not None
else contextlib.nullcontext()
)
if restore_side_effects:
prev_side_effects = tx.output.side_effects.clone()
with autograd_ctx:
output = f.call_function(tx, args, sub_kwargs)
if restore_side_effects:
# Captured variables are tracked in side-effects
# and they show up in output graph incorrectly.
# It is ok to undo this side-effect tracking
# as speculate_subgraph will allow only
# pure functions.
tx.output.side_effects = prev_side_effects
treespec = None
if should_flatten_outputs:
# Flatten the speculated subgraph output.
output, treespec = _make_inlined(tx, pytree.tree_flatten)(
output
).unpack_var_sequence(tx)
# Actually, transform the list (returned by flatten) into a tuple
# for dynamo consistency.
output = BuiltinVariable(tuple).call_function(tx, [output], {})
# Register output to graph
# Modeled off of compile_and_call_fx_graph
# TODO: support pytree output
# We check always_restore because we dont use the output or side effects of always_restore code,
# like bwd.
if always_restore:
# Nothing left to do here
return (output, treespec), tx.output.graph, subtracer.lifted_freevars
else:
from . import TensorVariable
if not only_consist_of(output, TensorVariable):
unimplemented(
"HigherOrderOperator body's output must consist of tensors only"
)
# The output proxies might not belong to this SubgraphTracer
# (if they are free variables that were never lifted)
# so lift them here.
output_proxies = output.as_proxy()
output_proxies = pytree.tree_map(
subtracer.maybe_lift_tracked_freevar_to_input, output_proxies
)
tx.output.create_node(
"output",
"output",
(subtracer.create_arg((output_proxies,))),
{},
)
graph = tx.output.graph
graph.lint()
lifted_freevars = subtracer.lifted_freevars
return (
(output, treespec),
graph,
lifted_freevars,
)
except Unsupported as ex:
f_name = f"{type(f).__name__}"
if isinstance(f, UserFunctionVariable):
f_name = f.get_name()
msg = (
f"speculate_subgraph: while introspecting {description}, we were unable "
f"to trace function `{f_name}` into a single graph. This means "
f"that Dynamo was unable to prove safety for this API and will "
f"fall back to eager-mode PyTorch, which could lead to a slowdown."
)
log.warning(msg)
log.exception(ex)
raise Unsupported(
f"{msg} Scroll up for the stack trace "
f"of the initial exception. The reason was: {ex.msg}"
) from ex
def make_attr(tx, name):
node = tx.output.create_proxy(
"get_attr",
name,
(),
{},
)
return node
def add_subgraph(tx, source, name, gm):
next_name = None
i = 0
while not next_name:
candidate = f"{name}_{i}"
if candidate in tx.output.nn_modules:
i += 1
else:
next_name = candidate
gm.__name__ = next_name
if source.guard_source().is_fsdp_module():
src = FSDPNNModuleSource(GetItemSource(source, next_name))
else:
src = NNModuleSource(GetItemSource(source, next_name))
gm.torchdynamo_force_dynamic = False
tx.output.register_attr_or_module(gm, next_name, source=src)
return next_name
class TorchHigherOrderOperatorVariable(VariableTracker):
def __init__(self, value, source: Optional[Source] = None, **kwargs):
super().__init__(**kwargs)
self.value = value
self.source = source
@staticmethod
def make(value, source=None, **kwargs):
if value.__name__ == "cond":
return CondHigherOrderVariable(value, source, **kwargs)
elif value.__name__ in ("map", "map_impl"):
return MapHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "executorch_call_delegate":
return ExecutorchCallDelegateHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "out_dtype":
return OutDtypeHigherOrderVariable(value, source, **kwargs)
elif value is torch._functorch.eager_transforms.grad_impl:
return FunctorchGradHigherOrderVariable(value, source, **kwargs)
elif value is torch._functorch.vmap.vmap_impl:
return FunctorchVmapHigherOrderVariable(value, source, **kwargs)
elif value.__name__ in (
"trampoline_autograd_fwd",
"trampoline_autograd_bwd",
"trampoline_autograd_apply",
):
return AutogradFunctionMethodHigherOrderVariable(
value=value, source=source, **kwargs
)
elif value.__name__ == "wrap":
return WrapHigherOrderVariable(value, source, **kwargs)
elif value.__name__ in (
"wrap_activation_checkpoint",
"tag_activation_checkpoint",
):
return CheckpointHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "_export_tracepoint":
return ExportTracepointHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "trace_wrapped":
return TraceWrappedHigherOrderOperatorVariable(value, source, **kwargs)
else:
unimplemented(f"HigherOrderOperator {value.__name__}")
def call_function(
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
unimplemented(f"HigherOrderOperator {self.value.__name__}")
class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
@raise_hard_error_if_graph_break(
reason="Cond doesn't work unless it is captured completely with torch.compile."
)
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import (
ConstantVariable,
ListVariable,
NestedUserFunctionVariable,
TensorVariable,
UserFunctionVariable,
)
args, kwargs = VariableTracker.apply(lambda x: x.realize(), (args, kwargs))
for i, k in enumerate(["pred", "true_fn", "false_fn", "operands"]):
if v := kwargs.pop(k, None):
assert i == len(
args
), "did not provide the right number of non-keyword args"
args.append(v)
if kwargs:
unimplemented(f"torch.cond: Got unexpected kwargs: {list(kwargs.keys())}")
# TODO(voz): Support fake tensor dispatch for recursive
# ops - see torch/dispatch/_dispatcher.py
if len(args) != 4:
unimplemented(
f"Expected 4 arguments but got {len(args)}.\n"
f"Usage: cond(pred, true_fn, false_fn, operands)",
)
# predicate
if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable):
unimplemented(
f"Expected pred to be bool or a boolean tensor with single "
f"item but got {str(type(args[0]))} "
f"with original python type {str(args[0].python_type())}.",
)
# operands
if not isinstance(args[3], (ListVariable, TupleVariable)):
unimplemented(
f"Expected a tuple but got {args[3].python_type()}",
)
operands = args[3].unpack_var_sequence(tx)
if not only_consist_of(args[3], (TensorVariable,)):
unimplemented(
"Expect operands to be a tuple of pytrees that only consists of tensor leaves."
)
# branches
assert isinstance(
args[1],
(
UserFunctionVariable,
NestedUserFunctionVariable,
NNModuleVariable,
UnspecializedNNModuleVariable,
),
), str(
type(args[1])
) # true_fn
assert isinstance(
args[2],
(
UserFunctionVariable,
NestedUserFunctionVariable,
NNModuleVariable,
UnspecializedNNModuleVariable,
),
), str(
type(args[2])
) # false_fn
# Our strategy for tracing the true/false branches of cond
# are to checkpoint our graphstate, run the true branch,
# roll it back to the checkpoint, and run the false
# branch, and then merge the graphstates. Well, perhaps
# "merge" is too strong a word: we mostly assert that
# the resulting graphstates have to be the same.
#
# We only permit guards to diverge (we union the guards from
# both branches). In particular, this means that side
# effects are NOT permitted inside true/false branches; this
# would be difficult to implement, because of the path
# explosion problem.
def speculate_branch(branch):
# NB: 0 is predicate
ix = 1 if branch else 2
# TODO: Support kwargs
(
(ret_val, ret_treespec),
ret_graph,
ret_lifted_freevars,
) = speculate_subgraph(
tx,
args[ix],
operands,
{},
"cond",
source_target=self.value,
manually_set_subgraph_inputs=False,
should_flatten_outputs=True,
)
if not only_consist_of(ret_val, (TensorVariable,)):
unimplemented(
"Expected branches to return a possibly nested list/tuple/dict of tensors but it consists of non tensors.",
)
return ret_val, ret_treespec, ret_graph, ret_lifted_freevars
(true_r, true_treespec, true_graph, true_lifted_freevars) = speculate_branch(
True
)
true_nn_modules = dict(tx.output.nn_modules)
(
false_r,
false_treespec,
false_graph,
false_lifted_freevars,
) = speculate_branch(False)
false_nn_modules = dict(tx.output.nn_modules)
same_treespec = _make_inlined(tx, pytree.TreeSpec.__eq__)(
true_treespec, false_treespec
)
if not same_treespec.as_python_constant():
unimplemented("Expected branches to return the same pytree structure.")
def diff_meta(tensor_vars1, tensor_vars2):
assert all(
isinstance(var, TensorVariable) for var in tensor_vars1 + tensor_vars2
)
all_diffs = []
for i, (var1, var2) in enumerate(zip(tensor_vars1, tensor_vars2)):
# We check the meta data associated with meta["example_value"]
meta1 = _extract_tensor_metadata(
var1.proxy.node.meta["example_value"], include_contiguity=False
)
meta2 = _extract_tensor_metadata(
var2.proxy.node.meta["example_value"], include_contiguity=False
)
if meta1 != meta2:
all_diffs.append((f"pair{i}:", meta1, meta2))
return all_diffs
if diffs := diff_meta(
true_r.unpack_var_sequence(tx), false_r.unpack_var_sequence(tx)
):
unimplemented(
f"Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:{diffs}"
)
def dedup_and_sort_lifted_freevars(true_lifted_freevars, false_lifted_freevars):
# The nn module attributes are guaranteed to be registered into the top-level graph module during
# higher order op speculation. Therefore, get_attr nodes in two branches with the same
# target refer to the same attribute and we can safely deduplicate them with their target.
#
# Note: ideally, dynamo should just create a single proxy for the same attribute of a nn module. But
# true_branch and false_branch belong to two separate tracing contexts, they may register the same
# attribute to top level seperately. This creates two get_attr proxies for the same attribute
# that have different meta data such as stack_trace (one stack trace for the true_branch,
# and the other for false_branch). It seems better to discard the proxy explicitly in cond
# than make dynamo create a single proxy for the same get_attr target.
def shared_getattrs(true_lifted_proxies, false_lifted_proxies):
true_targets = {
proxy.node.target: proxy
for proxy in true_lifted_proxies
if proxy.node.op == "get_attr"
}
true_fn_shared_getattrs = {}
false_fn_shared_getattrs = {}
for false_proxy in false_lifted_proxies:
if (
false_proxy.node.op == "get_attr"
and false_proxy.node.target in true_targets
):
true_proxy = true_targets[false_proxy.node.target]
true_fn_shared_getattrs[true_proxy] = true_proxy
false_fn_shared_getattrs[false_proxy] = true_proxy
return true_fn_shared_getattrs, false_fn_shared_getattrs
true_fn_shared_getattrs, false_fn_shared_getattrs = shared_getattrs(
true_lifted_freevars.keys(), false_lifted_freevars.keys()
)
true_shared_freevars = (
true_lifted_freevars.keys() & false_lifted_freevars.keys()
).union(true_fn_shared_getattrs.keys())
false_shared_freevars = (
true_lifted_freevars.keys() & false_lifted_freevars.keys()
).union(false_fn_shared_getattrs.keys())
unique_true_freevars = true_lifted_freevars.keys() - true_shared_freevars
unique_false_freevars = false_lifted_freevars.keys() - false_shared_freevars
def _sort_by_name(vars):
return sorted(vars, key=lambda var: var.node.name)
return (
list(_sort_by_name(list(true_shared_freevars))),
list(_sort_by_name(list(false_shared_freevars))),
list(_sort_by_name(list(unique_true_freevars))),
list(_sort_by_name(list(unique_false_freevars))),
)
(
true_shared,
false_shared,
unique_true,
unique_false,
) = dedup_and_sort_lifted_freevars(true_lifted_freevars, false_lifted_freevars)
# Let's say we capture cond(pred, true_fn, false_fn, (x,))
# With mannually_set_graph_input set to False,
# true_fn has lifted variables x, a, b, c
# false_fn has lifted variables x, a, b, d
# Then fixup_branch_inps make sure both branches have the same signature, i.e.:
# - true_fn(x, a, b, c_true_branch, d_false_branch)
# - false_fn(x, a, b, c_true_branch, d_false_branch)
#
# More formally, the signature has three parts in the following order:
# 1. used in both branches: x, a, b
# 2. only used in true branches: c, suffixed with _true_branch
# 3. only used in false branches: d, suffixed with _false_branch
# Within each part, we re-order the nodes by name to have a derterministic ordering for testing.
def fixup_branch_inps(
graph, lifted_freevars, shared, unique_true, unique_false
):
def _insert_or_replace_phs(new_args, name_suffix):
for arg in new_args:
new_ph = graph.placeholder(arg.node.name + name_suffix)
# Override with new_ph if there exists a old placeholder.
if arg in lifted_freevars:
old_ph = lifted_freevars[arg].node
old_ph.replace_all_uses_with(new_ph)
# replace_all_uses_with doesn't clean users. Clean it mannually so that we could erase it.
old_ph.users = {}
graph.erase_node(old_ph)
first_not_ph_node = next(
node for node in graph.nodes if node.op != "placeholder"
)
with graph.inserting_before(first_not_ph_node):
_insert_or_replace_phs(shared, "")
_insert_or_replace_phs(unique_true, "_true_branch")
_insert_or_replace_phs(unique_false, "_false_branch")
fixup_branch_inps(
true_graph, true_lifted_freevars, true_shared, unique_true, unique_false
)
fixup_branch_inps(
false_graph, false_lifted_freevars, false_shared, unique_true, unique_false
)
true_name = add_subgraph(
tx,
self.source,
"cond_true",
torch.fx.GraphModule(true_nn_modules, true_graph),
)
false_name = add_subgraph(
tx,
self.source,
"cond_false",
torch.fx.GraphModule(false_nn_modules, false_graph),
)
true_node = make_attr(tx, true_name)
false_node = make_attr(tx, false_name)
p_args = (
args[0].as_proxy(),
true_node,
false_node,
# We pick true_shared but it shouldn't matter
true_shared + unique_true + unique_false,
)
return _call_function_and_unflatten_output(
tx, torch.ops.higher_order.cond, p_args, {}, true_r, true_treespec
)
def non_single_tensor_return_unsupported(api, ret):
from . import TensorVariable
if not isinstance(ret, TensorVariable):
raise Unsupported(
f"{api} over function that returns something " f"other than one Tensor"
)
class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
from . import NestedUserFunctionVariable, TensorVariable, UserFunctionVariable
from .builder import wrap_fx_proxy_cls
if len(kwargs) > 0:
unimplemented(
"torch.ops.higher_order.map: kwargs are not supported in the map operator."
)
assert type(args[0].realize()) in (
UserFunctionVariable,
NestedUserFunctionVariable,
)
assert type(args[1].realize()) is TensorVariable
sample_shape = get_fake_value(args[1].as_proxy().node, tx).size()
if len(sample_shape) < 1 or sample_shape[0] == 0:
unimplemented(
"map() operator doesn't support scalar or zero-sized tensors during tracing."
)
# To get the example output from map() we will need to provide at least one sample to
# the loop body. In our case we will always use xs[0], and our map() won't support zero
# sized tensor during tracing.
first_dim = wrap_fx_proxy_cls(
target_cls=TensorVariable, tx=tx, proxy=args[1].as_proxy()[0]
)
# TODO: Support kwargs
(
(body_r, body_spec),
body_graph,
body_lifted_freevars,
) = speculate_subgraph(
tx,
args[0],
[
first_dim,
*args[2:],
],
{},
"torch.ops.higher_order.map",
source_target=self.value,
should_flatten_outputs=True,
)
body_nn_modules = dict(tx.output.nn_modules)
body_name = add_subgraph(
tx,
self.source,
"map_body",
torch.fx.GraphModule(body_nn_modules, body_graph),
)
body_node = make_attr(tx, body_name)
p_args = (
body_node,
1, # right now we only supports num_mapped = 1
*(arg.as_proxy() for arg in args[1:]),
*(arg for arg in body_lifted_freevars.keys()),
)
return _call_function_and_unflatten_output(
tx, torch.ops.higher_order.map_impl, p_args, {}, body_r, body_spec
)
class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .builder import wrap_fx_proxy
# This is operator for delegation within Executorch which calls a
# specific function in the given lowered module with the given
# operators. The actual operator is defined in the Executorch codebase.
# This is a bad hierarchical violation since
# executorch_call_delegate sits at a higher level than dynamo, but
# there's no real solution to this issue yet.
if len(kwargs) > 0:
unimplemented(
"executorch_call_delegate: kwargs arguments were not enabled."
)
lowered_module = tx.output.get_submodule(args[0].module_key)
lowered_node = make_attr(tx, args[0].module_key)
p_args = tuple(arg.as_proxy() for arg in args[1:])
real_sub_args = pytree.tree_map_only(
torch.fx.Proxy, lambda a: get_real_value(a.node, tx.output), p_args
)
example_res = lowered_module.original_module(*real_sub_args)
# NOTE [Guaranteeing the 1-1 correspondence of FakeTensors and real tensors]:
# executorch modules promise not to alias inputs and outputs.
# Thus, output FakeTensors will correctly not alias input FakeTensors.
_assert_tensors_nonaliasing(real_sub_args, example_res)
example_value = deepcopy_to_fake_tensor(example_res, tx.fake_mode)
p_args = (lowered_node,) + p_args
# Store the invocation as a call
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=tuple(p_args),
kwargs={},
),
example_value=example_value,
)
class FunctorchGradHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import ConstantVariable
from .builder import wrap_fx_proxy
# TODO: Support `fn` with kwargs.
if not torch._dynamo.config.capture_func_transforms:
unimplemented(
"torch.func.grad capture is disabled, "
"it can be turned on by setting "
"`torch._dynamo.config.capture_func_transforms=True`"
)
# [NOTE] Here we are (roughly) modelling the following
#
# grad_fn = torch.func.grad(fn, argnums=.., has_aux=..)
# grad_output = grad_fn(x)
grad_args = (args[0], args[1], args[2])
# get arguments
func, argnums, has_aux = grad_args
kwargs = args[4].items
if len(kwargs) > 0:
# Since speculate_subgraph doesn't support kwargs, we can't handle this for now.
unimplemented(
"torch.func.grad: kwargs arguments are currently unsupported."
)
# Trace through the `func`
# NOTE [HACK: Enable autograd while tracing function]
# `torch.func.grad` should not be affected by `no_grad` outside of `grad`.
# So, we enable_grad right before the function to which `grad` is applied
# (the parts explicitly disabled with `no_grad` inside the function are still disabled).
# Eg.
# def f(x):
# with no_grad(): # This will disable grad tracking under it.
# y = x * 2
#
# return x ** 2 - y # grad tracking should be enabled irrespective of outside `no_grad`.
#
# with no_grad(): # This will not disable grad tracking inside of grad(f).
# grad_o = torch.func.grad(f)(x)
# TODO: Support kwargs
(body_r, _), body_graph, body_lifted_freevars = speculate_subgraph(
tx,
func,
args[3].items,
{},
"torch.func.grad",
source_target=self.value,
# See NOTE [HACK: Enable autograd while tracing function]
enable_grad=True,
)
body_name = add_subgraph(
tx,
self.source,
"grad_body",
torch.fx.GraphModule(tx.output.nn_modules, body_graph),
)
body_node = make_attr(tx, body_name)
grad_proxy_args = (
body_node,
*(arg.as_proxy() for arg in grad_args[1:]),
)
# Model `grad_fn = grad(fn, *grad_args, **grad_kwargs)`
grad_fn = tx.output.create_proxy(
"call_function",
torch.func.grad,
args=tuple(grad_proxy_args),
kwargs={},
name="grad_proxy",
)
# Pass lifted freevars to the call to `grad_fn`
args = args[3].items
grad_fn_args = tuple(arg.as_proxy() for arg in args) + tuple(
body_lifted_freevars
)
# Call grad_fn with inputs.
# grad_output = grad_fn(*grad_fn_args, **grad_fn_kwargs)
grad_output = grad_fn(*grad_fn_args)
# `grad_fn(*grad_fn_args, **grad_fn_kwargs)`
# Output of grad_fn is
# For has_aux=False, Tuple[gradients of inputs indicated by argnums].
# For has_aux=True, Tuple[Tuple[gradients of inputs indicated by argnums], aux values]
# NOTE: example_value should match `grad_output`.
def _from_args(idx):
return args[idx].as_proxy().node.meta["example_value"].contiguous()
def to_python_ints(argnums):
if not isinstance(argnums, (ConstantVariable, TupleVariable)):
raise UserError(
UserErrorType.INVALID_INPUT,
f"argnums is expected to be int or tuple of ints. Got {argnums}.",
)
if isinstance(argnums, ConstantVariable):
if not isinstance(argnums.value, (int, tuple)):
raise UserError(
UserErrorType.INVALID_INPUT,
f"argnums is expected to be int or tuple of ints. Got {argnums}.",
)
return argnums.value
else:
const_vars = argnums.unpack_var_sequence(tx)
if not all(
isinstance(var, ConstantVariable) and isinstance(var.value, int)
for var in const_vars
):
raise UserError(
UserErrorType.INVALID_INPUT,
f"argnums is expected to contain int only. Got {const_vars}.",
)
return tuple(var.value for var in const_vars)
argnums_v = to_python_ints(argnums)
example_value = pytree.tree_map(_from_args, argnums_v)
if has_aux.value:
# case : has_aux = True
# NOTE: Currently speculate subgraph allows body_r to be
# Tensor or Tuple/List of Tensor.
# Since `grad` expects output with has_aux
# to be (output, aux), only valid output currently is
# (output, some_tensor)
body_r_proxy = body_r.as_proxy()
aux = body_r_proxy[1].node.meta["example_value"]
example_value = (example_value, aux)
fx_proxy = wrap_fx_proxy(tx=tx, proxy=grad_output, example_value=example_value)
# Call contiguous on all the computed grads.
if not has_aux.value:
if isinstance(argnums_v, int):
return fx_proxy.call_method(tx, "contiguous", (), {})
else:
grads = fx_proxy
items = []
for idx in range(len(argnums_v)):
proxy = grads.call_method(
tx, "__getitem__", (ConstantVariable.create(idx),), {}
).call_method(tx, "contiguous", (), {})
items.append(proxy)
return TupleVariable(items)
else: # case: has_aux.value = True
# fx_proxy -> Tuple(grads, aux)
grads = fx_proxy.call_method(
tx, "__getitem__", (ConstantVariable.create(0),), {}
)
aux = fx_proxy.call_method(
tx, "__getitem__", (ConstantVariable.create(1),), {}
)
if isinstance(argnums_v, int):
return TupleVariable([grads.call_method(tx, "contiguous", (), {}), aux])
else:
items = []
for idx in range(len(argnums_v)):
proxy = grads.call_method(
tx, "__getitem__", (ConstantVariable.create(idx),), {}
).call_method(tx, "contiguous", (), {})
items.append(proxy)
return TupleVariable([TupleVariable(items), aux])
class FunctorchVmapHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import ConstantVariable, TensorVariable
from .builder import wrap_fx_proxy