-
Notifications
You must be signed in to change notification settings - Fork 339
/
acc_tracer.py
720 lines (600 loc) · 27.4 KB
/
acc_tracer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
import ast
import builtins
import copy
import inspect
import logging
import operator
import textwrap
import warnings
from types import FunctionType
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)
import torch
import torch.jit as jit
import torch.nn as nn
from torch._sources import normalize_source_lines
from torch.fx import Graph, Tracer
from torch.fx.experimental.normalize import NormalizeArgs
from torch.fx.node import Argument, Node, Target
from torch.fx.passes import shape_prop
from . import acc_normalizer, acc_ops, acc_shape_prop, acc_utils # noqa: F401
_LOGGER = logging.getLogger(__name__)
def _get_exception_wrapper_attr_name(exc_type: Type[Exception]) -> str:
return f"_conditional_exception_wrapper_{exc_type.__name__}"
class Acc_Rewriter(ast.NodeTransformer):
"""
Take a FunctionType object representing a `forward` method, then
perform an AST rewrite to swap out nodes that are not symbolically
traceable with a callsite to the FX alternative.
To support swapping out an AST node, define a new `visit` method on
that node. For more details, see:
https://docs.python.org/3/library/ast.html#ast.NodeTransformer
"""
def __init__(self):
super().__init__()
self.exceptions_rewritten: Set[Type[Exception]] = set()
self.exceptions_bool_rewritten: Set[Type[Exception]] = set()
def rewrite(
self, fn: FunctionType
) -> Tuple[FunctionType, Set[Type[Exception]], Set[Type[Exception]]]:
# Normalize the source lines
sourcelines, _ = inspect.getsourcelines(fn)
sourcelines = normalize_source_lines(sourcelines)
source = "".join(sourcelines)
normalized_str = textwrap.dedent(source)
# Rewrite the original AST
source_ast = ast.parse(normalized_str)
dest_ast = ast.fix_missing_locations(self.visit(source_ast))
# Pull out the compiled function from the newly-created Module
code = compile(dest_ast, "", "exec")
globals_dict = copy.copy(fn.__globals__)
keys_before = set(globals_dict.keys())
exec(code, globals_dict) # noqa P204
new_keys = list(set(globals_dict.keys()) - keys_before)
assert len(new_keys) <= 1
fn_compiled = globals_dict[fn.__name__]
# Return the correct FunctionType object and the Exceptions that were
# rewritten during visit_If.
return fn_compiled, self.exceptions_rewritten, self.exceptions_bool_rewritten
def visit_Assert(self, node: ast.Assert):
"""
Swap out the Assert node (Python's `assert`) with a callsite to the
symbolically-traceable torch._assert function
"""
# Create the Call node
n = ast.parse("torch._assert()", mode="eval")
assert isinstance(n, ast.Expression)
call_node = n.body
assert isinstance(call_node, ast.Call)
msg = node.msg if node.msg else ast.Constant(value="", kind=None)
call_node.args = [node.test, msg]
# Ensure that the new node conforms to the Python AST grammar
expr_wrapper = ast.Expr(value=call_node)
# Return the new Call node to signify that we want to use it as
# a replacement for the original _assert node
return ast.copy_location(expr_wrapper, node)
def visit_If(self, if_node: ast.If):
"""
Swap out the pattern `If(x): Raise(y)` with a ConditionalExceptionWrapper
specialized for the specific exception y. The specialized
ConditionalExceptionWrapper module will be added in the RewrittenModule.
Only works with builtin Exceptions, as we assume the signature of the
init for the Exception is a string.
"""
raise_node = if_node.body[0]
if not isinstance(raise_node, ast.Raise):
return if_node
# Don't handle orelse for now.
# TODO: Move orelse to the body after calling ConditionalExceptionWrapper.
if len(if_node.orelse) != 0:
return if_node
def _reuse_loc(node):
return ast.copy_location(node, if_node)
# If the exception has a message then we expect the raise's exc to be a
# Call w/ a msg. Else if it's a exc Name then there's no msg to use.
node_for_exc = raise_node.exc
if isinstance(node_for_exc, ast.Name):
# E.g. `raise AssertionError`, i.e. without an exc_msg.
name_node_of_exc = node_for_exc
exc_msg = _reuse_loc(ast.Constant(None))
elif isinstance(node_for_exc, ast.Call):
# E.g. `raise AssertionError("error message")`
name_node_of_exc = node_for_exc.func # type: ignore[assignment]
if not isinstance(name_node_of_exc, ast.Name):
return if_node
# Most assertions just take a single string arg, but some may not; skip
# handling such assertions for now.
if len(node_for_exc.args) != 1:
return if_node
exc_msg = node_for_exc.args[0]
else:
return if_node
# Convert what we expect is the name of the exception into its
# associated python class.
name_of_exc = name_node_of_exc.id
try:
exc_type = eval(name_of_exc) # noqa P204
except Exception:
return if_node
# Check that we actually have a builtin exception.
if (
not issubclass(exc_type, Exception)
or getattr(getattr(exc_type, "__class__", None), "__module__", None)
!= "builtins"
):
return if_node
# We need a ConditionalExceptionWrapper specialized for every kind of
# exception, so add it to exceptions_rewritten to remember for later to
# add a specialized attr with it.
self.exceptions_rewritten.add(exc_type)
# From here we definitely should be able to do the replacement. Create a
# Call node to the ConditionalExceptionWrapper module we're replacing
# the If with, with args set as the If's condition and the string of the
# exception. The call to the self._conditional_exception_wrapper_*Error
# module is safe because the RewrittenModule will add it as an attr
# based on the returned exceptions_rewritten, and we assume we are
# currently modifying the AST of a method from a RewrittenModule.
exc_wrapper_node = ast.parse(
f"self.{_get_exception_wrapper_attr_name(exc_type)}()", mode="eval"
)
assert isinstance(exc_wrapper_node, ast.Expression)
exc_wrapper_call_node = exc_wrapper_node.body
assert isinstance(exc_wrapper_call_node, ast.Call)
if isinstance(if_node.test, ast.BoolOp) and isinstance(
if_node.test.op, ast.And
):
self.exceptions_bool_rewritten.add(exc_type)
bool_wrapper_node = ast.parse(
f"self.{_get_exception_wrapper_attr_name(exc_type)}_bool()", mode="eval"
)
assert isinstance(exc_wrapper_node, ast.Expression)
bool_wrapper_call_node = bool_wrapper_node.body
assert isinstance(exc_wrapper_call_node, ast.Call)
bool_wrapper_call_node.args = if_node.test.values
exc_wrapper_call_node.args = [
_reuse_loc(bool_wrapper_call_node),
exc_msg,
]
else:
exc_wrapper_call_node.args = [if_node.test, exc_msg]
# Ensure that the new node conforms to the Python AST grammar
expr_wrapper = _reuse_loc(ast.Expr(_reuse_loc(exc_wrapper_call_node)))
# Return the new node to signify that we want to use it as a replacement
# for the original `If x: Raise y` pattern.
return expr_wrapper
class ConditionalExceptionWrapper(nn.Module):
"""
This wrapper class is used to wrap conditional raising of exceptions during
rewriting. For example:
.. code-block:: python
if self.name != "x":
raise AssertionError(f"Name was not x: {self.name}")
Is rewritten into
.. code-block:: python
self._conditional_exception_wrapper_AssertionError(
self.name != "x", f"Name was not x: {self.name}"
)
Note that __init__ takes the Exception class that it is wrapping, while
forward takes the condition to check and the message for the exception.
"""
# Mark as impure so that calls to it will not be removed during DCE.
_is_impure = True
def __init__(self, exc: Type[Exception]):
super().__init__()
self.exc = exc
def forward(self, cond: bool, msg: str):
if cond:
raise self.exc if msg is None else self.exc(msg)
class ConditionalExceptionBoolCondWrapper(nn.Module):
"""
This is a wrapper class to for boolean ops used inside conditionals
raising exceptions.
This currently only handles binary input cases for the `and` operator
at one level of depth
For example:
.. code-block:: python
if self.name != "x" and self.name != "y":
raise AssertionError(f"Name was not x: {self.name}")
rewrites the `self.name != "x" and self.name != "y"` with
a `_conditional_exception_wrapper_AssertionError_bool` as follows:
.. code-block:: python
self._conditional_exception_wrapper_AssertionError(
self._conditional_exception_wrapper_AssertionError_bool(self.name != "x" and self.name != "y"), f"Name was not x: {self.name}"
)
"""
# Mark as impure so that calls to it will not be removed during DCE.
_is_impure = True
def __init__(self, op):
super().__init__()
def forward(self, *conds: Iterable):
return all(conds)
# Custom tracer that traces to the functional level and rewrites asserts and
# exceptions.
class AccRewritingTracer(Tracer):
# Add an explicit check for mutable operations, which break symbolic tracing.
check_mutable_operations = True
# Disble proxying buffers, which currently breaks some quantization code
proxy_buffer_attributes = False
# Note: Treat ConditionalExceptionWrapper as a leaf so that we don't
# trace into it, because it contains control flow and raises an exception.
DEFAULT_LEAF_MODULE_LIST = {
ConditionalExceptionBoolCondWrapper,
ConditionalExceptionWrapper,
torch.nn.quantized.Linear,
torch.nn.quantized.Conv2d,
torch.nn.intrinsic.quantized.ConvReLU2d,
jit.ScriptModule,
jit.RecursiveScriptModule,
torch.nn.modules.activation.MultiheadAttention,
}
def is_leaf_module(self, m: nn.Module, mod_qual_name: str) -> bool:
return getattr(m, "_base_class_origin", type(m)) in self.leaf_module_list
def trace(
self,
root: nn.Module,
concrete_args: Optional[Dict[str, Any]] = None,
ast_rewriter_allow_list: Optional[Set] = None,
leaf_module_list: Optional[Set] = None,
) -> Tuple[Graph, nn.Module]:
self.leaf_module_list = self.DEFAULT_LEAF_MODULE_LIST
if leaf_module_list:
self.leaf_module_list.update(leaf_module_list)
rewritten = _rewrite(root, ast_rewriter_allow_list, self.leaf_module_list)
return super().trace(rewritten, concrete_args), rewritten
# override TraceBase's method
def create_node(
self,
kind: str,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
) -> Node:
"""
Inserts a graph node given target, args, kwargs, and name.
This method can be overridden to do extra checking, validation, or
modification of values used in node creation. For example, one might
want to disallow in-place operations from being recorded.
"""
## Hacky way to decide inplace ops
if type(target) != str:
name_target = target.__name__
else:
name_target = target
allow_list = ["and_", "or_"] # python operator.and_, operator.or_
if (
name_target[-1] == "_"
and name_target[0] != "_"
and not (name_target in allow_list)
and kind != "placeholder"
):
raise RuntimeError(
f"Tried to trace mutable operation {name_target}. FX only supports functional code"
)
return self.graph.create_node(kind, target, args, kwargs, name, type_expr)
# List of modules that need rewriting to be supported for tracing.
DEFAULT_REWRITE_ALLOW_LIST = {
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
}
def _rewrite(
mod_to_rewrite: nn.Module,
allow_list: Optional[Set] = None,
leaf_module_list: Optional[Set] = None,
) -> nn.Module:
if allow_list is None:
allow_list = DEFAULT_REWRITE_ALLOW_LIST
else:
allow_list = allow_list.union(DEFAULT_REWRITE_ALLOW_LIST)
if not leaf_module_list:
leaf_module_list = set()
# Rewrite this module's functions as well as all recursive modules'
# functions that are attrs of this moodule. Return the new, rewritten module
# hierarchy.
def rewrite_module(m: nn.Module):
if isinstance(m, jit.ScriptModule):
# ScriptModule cannot be rewritten, so bypass it. The issue is it
# requires explicitly calling its `__init__()`, calling
# `nn.Module.__init__()` in the derived `RewrittenModule` is not
# enough. And even if we init it we can't do much with it.
return m
# If m is an already-rewritten RewrittenModule, then use the original base class.
base_class: Type[nn.Module] = getattr(m, "_base_class_origin", type(m))
# Keep track of all the ConditionalExceptionWrappers that the
# Acc_Rewriter calls into in this module so we can add them in init
# below.
all_added_wrappers: Set[Type[Exception]] = set()
all_added_bool_wrappers: Set[Type[Exception]] = set()
# Note: Make this a subclass of our base class.
class RewrittenModule(base_class): # type: ignore[valid-type, misc]
# Keep track of the base_class so that symbolic tracing can
# determine what kind of module this originally was later on.
_base_class_origin = base_class
# Add suffix to qualname so it's easier to debug the origin of this module.
__qualname__ = f"{base_class.__qualname__}__AccRewrittenModule"
# Write all of the non-dunder or special methods from base_class
# into RewrittenModule.
for method_name in dir(base_class):
method = getattr(base_class, method_name, None)
if method is None and method_name not in {"__doc__"}:
_LOGGER.warning(
f"{__qualname__} does not have attribute {method_name}"
)
if builtins.type(method) is not FunctionType:
continue
# Always skip rewriting dunder methods, as they haven't (yet) been
# problematic, and modifying them has caused issues previously.
if method_name.startswith("__") and method_name.endswith("__"):
continue
# Only rewrite those Modules explicitly in the allow_list.
assert allow_list is not None
if base_class not in allow_list:
vars()[method_name] = method
else:
(
vars()[method_name],
added_wrappers,
added_bool_wrappers,
) = Acc_Rewriter().rewrite(method)
all_added_wrappers.update(added_wrappers)
all_added_bool_wrappers.update(added_bool_wrappers)
def __init__(self, orig):
nn.Module.__init__(self)
# Iterate over all added exception wrappers and add
# ConditionalExceptionWrapper attrs for each.
for exc_type in all_added_wrappers:
wrapper_name = _get_exception_wrapper_attr_name(exc_type)
assert not hasattr(self, wrapper_name)
setattr(
self,
wrapper_name,
ConditionalExceptionWrapper(exc_type),
)
for exc_type in all_added_bool_wrappers:
wrapper_name = f"{_get_exception_wrapper_attr_name(exc_type)}_bool"
assert not hasattr(self, wrapper_name)
setattr(
self,
wrapper_name,
ConditionalExceptionBoolCondWrapper(exc_type),
)
# Recursively rewrite and copy all module attrs of this module.
for k, v in orig.__dict__.items():
if k == "_modules":
for mod_k, mod_v in v.items():
if getattr(mod_v, "_base_class_origin", type(mod_v)) in leaf_module_list: # type: ignore[operator]
_LOGGER.info(
f"Skip rewriting leaf module {type(mod_v)}"
)
self._modules[mod_k] = mod_v
else:
self._modules[mod_k] = rewrite_module(mod_v)
else:
self.__dict__[k] = v
# Add suffix to name so it's easier to debug the origin of this module.
RewrittenModule.__name__ = f"{base_class.__name__}__AccRewrittenModule"
return RewrittenModule(m)
return rewrite_module(mod_to_rewrite)
def _remove_assertions(gm: torch.fx.GraphModule) -> bool:
"""
Unconditionally removes all assertions found in GraphModule gm.
Returns whether the graph is modified.
"""
changed = False
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch._assert:
gm.graph.erase_node(node)
changed = True
return changed
def _remove_exceptions(gm: torch.fx.GraphModule) -> bool:
"""
Unconditionally removes all call_modules to ConditionalExceptionWrappers
found in GraphModule gm. Returns whether the graph is modified.
"""
changed = False
for node in reversed(gm.graph.nodes):
if node.op == "call_module" and (
isinstance(gm.get_submodule(node.target), ConditionalExceptionWrapper)
or isinstance(
gm.get_submodule(node.target), ConditionalExceptionBoolCondWrapper
)
):
gm.graph.erase_node(node)
changed = True
return changed
def _replace_tensor_meta_with_rank(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
if node.op != "output" and "tensor_meta" in node.meta:
node.meta["tensor_rank"] = acc_utils.map_tensor_metadata(
node.meta["tensor_meta"], lambda x: len(x.shape)
)
del node.meta["tensor_meta"]
def _replace_transpose_last_dims_impl(
transpose_node: torch.fx.Node,
) -> int:
transpose_input_node = transpose_node.args[0]
dim0 = cast(int, transpose_node.args[1])
dim1 = cast(int, transpose_node.args[2])
changed = False
def _calculate_dim(
transpose_dim: Union[torch.fx.Node, int]
) -> Union[torch.fx.Node, int]:
nonlocal transpose_input_node
nonlocal changed
if isinstance(transpose_dim, torch.fx.Node):
# Transpose dim is sub node
if not (
transpose_dim.op == "call_function"
and transpose_dim.target == operator.sub
and len(transpose_dim.args) == 2
):
return transpose_dim
# Validity of length/subtracted int
len_node = transpose_dim.args[0]
sub_value = transpose_dim.args[1]
if not (
isinstance(len_node, torch.fx.Node)
and len_node.target == len
and isinstance(sub_value, int)
):
return transpose_dim
getattr_node = len_node.args[0]
# Check nodes for input.shape
if not (
isinstance(getattr_node, torch.fx.Node)
and getattr_node.target == getattr
and len(getattr_node.args) == 2
and getattr_node.args[0] == transpose_input_node
and getattr_node.args[1] == "shape"
):
return transpose_dim
changed = True
rank = transpose_input_node.meta["tensor_rank"]
return rank - sub_value
return transpose_dim
dim0 = _calculate_dim(dim0)
dim1 = _calculate_dim(dim1)
if changed:
with transpose_node.graph.inserting_before(transpose_node):
new_transpose_node = transpose_node.graph.call_method(
"transpose", (transpose_input_node, dim0, dim1)
)
new_transpose_node.meta = transpose_node.meta.copy()
transpose_node.replace_all_uses_with(new_transpose_node)
return changed
# Allows mapping for transpose in the case where inputs are of the form x.transpose(a, b),
# where a and b are len(x.shape()) - n, where n is an int. In this case the inputs to transpose
# would be nodes rather than ints, so this replaces those nodes with their integral values
def _replace_transpose_last_dims(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
if node.op == "call_method" and node.target == "transpose":
if len(node.args) != 3:
continue
changed = _replace_transpose_last_dims_impl(node)
if changed:
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
def rewriter_base_trace(
mod,
ast_rewriter_allow_list,
leaf_module_list,
concrete_args: Optional[Dict[str, Any]] = None,
):
rewritten_graph, rewritten_mod = AccRewritingTracer().trace(
mod,
concrete_args,
ast_rewriter_allow_list=ast_rewriter_allow_list,
leaf_module_list=leaf_module_list,
)
assert isinstance(rewritten_mod, nn.Module)
# Note: use the rewritten_mod here as the root. This is necessary because
# RewrittenModule includes a new module for the ConditionalExceptionWrapper.
return torch.fx.GraphModule(rewritten_mod, rewritten_graph)
def trace(
mod: nn.Module,
sample_inputs: Sequence[Any],
remove_assertions: bool = True,
remove_exceptions: bool = True,
use_acc_normalization: bool = True,
ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None,
leaf_module_list: Optional[Set[Type[nn.Module]]] = None,
acc_normalization_block_list: Optional[
Set[Tuple[str, Union[str, Callable]]]
] = None,
dont_retrace_gm: bool = False,
concrete_args: Optional[Dict[str, Any]] = None,
) -> torch.fx.GraphModule:
"""
Performs tracing and arg normalization specialized for accelerator lowering.
It first rewrites the AST of the module's methods (and all attr methods
recursively) to transform un-tracable parts of the module to make them
traceable.
It then traces to the functional level so that optimizations and backend
accelerator importers have the ability to see and/or change inputs to each
op.
It then removes assertions and exception wrappers found during symbolic
tracing if requested based on remove_assertions and remove_exceptions
Dead code is then eliminated, which will e.g. remove any nodes that were
only used by assertions or exceptions if they were removed.
It then performs normalization on args/kwargs, aligning any arg that can be
moved to kwarg to be so, and then making default values explicit.
Args:
mod (Module): The module to transform and trace.
sample_inputs (Tuple[Union[torch.Tensor, List[torch.Tensor]]]):
Sample inputs with which to run shape prop.
remove_assertions (bool): Whether to remove assertion nodes from
the graph after symbolic tracing.
remove_exceptions (bool): Whether to remove exception wrapper nodes
from the graph after symbolic tracing.
use_acc_normalization (bool): Whether to use acc-specific
normalization to all acc_ops.
ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of
modules that need AST rewriting.
leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where
modules will not be traced into.
acc_normalization_block_list (Optional[Set[Tuple[str, Union[str, Callable]]]]):
Optional set of (op, target) pairs to not apply acc
normalization to. Just like the register_acc_op decarators,
the target can either be a string (e.g. for op == "call_method")
or a callable (e.g. for op == "call_function").
dont_retrace_gm (bool): Optional bool for whether to re-trace the provided
module if it's a graph module already.
"""
if mod.training:
warnings.warn(
"acc_tracer does not support currently support models for training."
" Calling eval on model before tracing."
)
mod.eval()
assert isinstance(sample_inputs, (list, tuple))
# Rewrite the module to make it symbolic traceable, and then trace it.
if dont_retrace_gm and isinstance(mod, torch.fx.GraphModule):
traced = mod
else:
traced = rewriter_base_trace(
mod, ast_rewriter_allow_list, leaf_module_list, concrete_args
)
# Now remove all assertions and exceptions if requested.
if remove_assertions:
_remove_assertions(traced)
if remove_exceptions:
_remove_exceptions(traced)
# Cleanup any dead code from the original module as well as resulting dead
# nodes after removing assertions and exceptions.
traced.graph.eliminate_dead_code()
traced.recompile()
# Run shape prop to add node.meta["type"] to nodes, needed for NormalizeArgs.
acc_shape_prop.AccShapeProp(traced).propagate(*sample_inputs)
# Swap out tensor_meta for tensor_rank, because we don't actually want to rely on
# tensor_meta yet for normalization/lowering, though rank shouldn't change.
_replace_tensor_meta_with_rank(traced)
# Replace occurrences of x.transpose(len(x.shape) - a, len(x.shape) - b), where
# a and b are integers with their directly calculated dimensions
_replace_transpose_last_dims(traced)
# Now normalize args/kwargs to make default values visible. Leave args/kwargs as
# they were, since all-kwarg normalization is broken, and we don't need it anyway.
traced = NormalizeArgs(traced, normalize_to_only_use_kwargs=False).transform()
# Normalize to acc-specialized wrappers for consistency across op naming and
# ensuring all kwarg usage.
if use_acc_normalization:
acc_normalizer.normalize(
traced, acc_normalization_block_list=acc_normalization_block_list
)
traced.recompile()
# Run shape prop to again to populate tensor_meta after normalize.
acc_shape_prop.AccShapeProp(traced).propagate(*sample_inputs)
return traced