1
- # mypy: ignore-errors
2
-
3
1
"""
4
2
This module provides common utilities and base classes for TorchDynamo backends.
5
3
21
19
import contextlib
22
20
import functools
23
21
import logging
22
+ from collections .abc import Iterable
23
+ from typing import Any , Callable
24
+ from typing_extensions import ParamSpec , TypeVar
24
25
from unittest .mock import patch
25
26
26
27
import torch
36
37
37
38
log = logging .getLogger (__name__ )
38
39
40
+ P = ParamSpec ("P" )
41
+ R = TypeVar ("R" )
42
+
39
43
40
44
class AotAutograd :
41
- def __init__ (self , ** kwargs ) -> None :
45
+ def __init__ (self , ** kwargs : Any ) -> None :
42
46
self .__name__ = "compiler_fn"
43
47
self .kwargs = kwargs
44
48
45
- def __call__ (self , gm : torch .fx .GraphModule , example_inputs , ** kwargs ):
49
+ def __call__ (
50
+ self , gm : torch .fx .GraphModule , example_inputs : Iterable [Any ], ** kwargs : Any
51
+ ) -> Callable [..., Any ]:
46
52
if kwargs :
47
53
log .warning ("aot_autograd-based backend ignoring extra kwargs %s" , kwargs )
48
54
@@ -66,16 +72,16 @@ def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs):
66
72
counters ["aot_autograd" ]["not_ok" ] += 1
67
73
return gm
68
74
69
- def wrap_bw_compiler (bw_compiler_fn ) :
70
- def _wrapped_bw_compiler (* args , ** kwargs ) :
75
+ def wrap_bw_compiler (bw_compiler_fn : Callable [ P , R ]) -> Callable [..., R ] :
76
+ def _wrapped_bw_compiler (* args : P . args , ** kwargs : P . kwargs ) -> R :
71
77
# Note [Wrapping bw_compiler in disable]
72
78
# The two disables here:
73
79
# - stop TorchDynamo from trying to compile the bw_compiler function itself
74
80
# - stop TorchDynamo from trying to compile our the generated backwards pass bw_compiler produces
75
81
return disable (
76
82
disable (
77
83
bw_compiler_fn , reason = "do not trace backward compiler function"
78
- )(* args , ** kwargs ),
84
+ )(* args , ** kwargs ), # type: ignore[misc]
79
85
reason = "do not trace generated backwards pass" ,
80
86
)
81
87
@@ -99,7 +105,9 @@ def _wrapped_bw_compiler(*args, **kwargs):
99
105
# debug asserts slow down compile time noticeably,
100
106
# So only default them on when the aot_eager backend is used.
101
107
if self .kwargs .get ("fw_compiler" , None ) == nop :
102
- patch_config = patch ("functorch.compile.config.debug_assert" , True )
108
+ patch_config : contextlib .AbstractContextManager [Any ] = patch (
109
+ "functorch.compile.config.debug_assert" , True
110
+ )
103
111
else :
104
112
patch_config = contextlib .nullcontext ()
105
113
@@ -116,11 +124,11 @@ def _wrapped_bw_compiler(*args, **kwargs):
116
124
raise
117
125
118
126
119
- def aot_autograd (** kwargs ) -> AotAutograd :
127
+ def aot_autograd (** kwargs : Any ) -> AotAutograd :
120
128
return AotAutograd (** kwargs )
121
129
122
130
123
- def mem_efficient_fusion_kwargs (use_decomps ) :
131
+ def mem_efficient_fusion_kwargs (use_decomps : bool ) -> dict [ str , Any ] :
124
132
from functorch .compile import (
125
133
default_decompositions ,
126
134
min_cut_rematerialization_partition ,
@@ -140,28 +148,30 @@ def mem_efficient_fusion_kwargs(use_decomps):
140
148
return kwargs
141
149
142
150
143
- def fake_tensor_unsupported (fn ) :
151
+ def fake_tensor_unsupported (fn : Callable [[ Any , list [ Any ], Any ], R ]) -> Any :
144
152
"""
145
153
Decorator for backends that need real inputs. We swap out fake
146
154
tensors for zero tensors.
147
155
"""
148
156
149
157
@functools .wraps (fn )
150
- def wrapper (model , inputs , ** kwargs ) :
158
+ def wrapper (model : Any , inputs : Any , ** kwargs : Any ) -> Any :
151
159
with _disable_current_modes ():
152
160
inputs = list (map (defake , inputs ))
153
- return fn (model , inputs , ** kwargs )
161
+ return fn (model , inputs , ** kwargs ) # type: ignore[call-arg]
154
162
155
163
return wrapper
156
164
157
165
158
- def device_from_inputs (example_inputs ) -> torch .device :
166
+ def device_from_inputs (example_inputs : Iterable [ Any ] ) -> torch .device :
159
167
for x in example_inputs :
160
168
if hasattr (x , "device" ):
161
169
return x .device
170
+ return torch .device ("cpu" ) # Default fallback
162
171
163
172
164
- def dtype_from_inputs (example_inputs ) -> torch .dtype :
173
+ def dtype_from_inputs (example_inputs : Iterable [ Any ] ) -> torch .dtype :
165
174
for x in example_inputs :
166
175
if hasattr (x , "dtype" ):
167
176
return x .dtype
177
+ return torch .float32 # Default fallback
0 commit comments