Skip to content

Commit f82c7ee

Browse files
Lucaskabelapytorchmergebot
authored andcommitted
Typing for common.py (#160362)
Pull Request resolved: #160362 Approved by: https://github.com/Skylion007
1 parent 25ccc47 commit f82c7ee

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

torch/_dynamo/backends/common.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# mypy: ignore-errors
2-
31
"""
42
This module provides common utilities and base classes for TorchDynamo backends.
53
@@ -21,6 +19,9 @@
2119
import contextlib
2220
import functools
2321
import logging
22+
from collections.abc import Iterable
23+
from typing import Any, Callable
24+
from typing_extensions import ParamSpec, TypeVar
2425
from unittest.mock import patch
2526

2627
import torch
@@ -36,13 +37,18 @@
3637

3738
log = logging.getLogger(__name__)
3839

40+
P = ParamSpec("P")
41+
R = TypeVar("R")
42+
3943

4044
class AotAutograd:
41-
def __init__(self, **kwargs) -> None:
45+
def __init__(self, **kwargs: Any) -> None:
4246
self.__name__ = "compiler_fn"
4347
self.kwargs = kwargs
4448

45-
def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs):
49+
def __call__(
50+
self, gm: torch.fx.GraphModule, example_inputs: Iterable[Any], **kwargs: Any
51+
) -> Callable[..., Any]:
4652
if kwargs:
4753
log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs)
4854

@@ -66,16 +72,16 @@ def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs):
6672
counters["aot_autograd"]["not_ok"] += 1
6773
return gm
6874

69-
def wrap_bw_compiler(bw_compiler_fn):
70-
def _wrapped_bw_compiler(*args, **kwargs):
75+
def wrap_bw_compiler(bw_compiler_fn: Callable[P, R]) -> Callable[..., R]:
76+
def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R:
7177
# Note [Wrapping bw_compiler in disable]
7278
# The two disables here:
7379
# - stop TorchDynamo from trying to compile the bw_compiler function itself
7480
# - stop TorchDynamo from trying to compile our the generated backwards pass bw_compiler produces
7581
return disable(
7682
disable(
7783
bw_compiler_fn, reason="do not trace backward compiler function"
78-
)(*args, **kwargs),
84+
)(*args, **kwargs), # type: ignore[misc]
7985
reason="do not trace generated backwards pass",
8086
)
8187

@@ -99,7 +105,9 @@ def _wrapped_bw_compiler(*args, **kwargs):
99105
# debug asserts slow down compile time noticeably,
100106
# So only default them on when the aot_eager backend is used.
101107
if self.kwargs.get("fw_compiler", None) == nop:
102-
patch_config = patch("functorch.compile.config.debug_assert", True)
108+
patch_config: contextlib.AbstractContextManager[Any] = patch(
109+
"functorch.compile.config.debug_assert", True
110+
)
103111
else:
104112
patch_config = contextlib.nullcontext()
105113

@@ -116,11 +124,11 @@ def _wrapped_bw_compiler(*args, **kwargs):
116124
raise
117125

118126

119-
def aot_autograd(**kwargs) -> AotAutograd:
127+
def aot_autograd(**kwargs: Any) -> AotAutograd:
120128
return AotAutograd(**kwargs)
121129

122130

123-
def mem_efficient_fusion_kwargs(use_decomps):
131+
def mem_efficient_fusion_kwargs(use_decomps: bool) -> dict[str, Any]:
124132
from functorch.compile import (
125133
default_decompositions,
126134
min_cut_rematerialization_partition,
@@ -140,28 +148,30 @@ def mem_efficient_fusion_kwargs(use_decomps):
140148
return kwargs
141149

142150

143-
def fake_tensor_unsupported(fn):
151+
def fake_tensor_unsupported(fn: Callable[[Any, list[Any], Any], R]) -> Any:
144152
"""
145153
Decorator for backends that need real inputs. We swap out fake
146154
tensors for zero tensors.
147155
"""
148156

149157
@functools.wraps(fn)
150-
def wrapper(model, inputs, **kwargs):
158+
def wrapper(model: Any, inputs: Any, **kwargs: Any) -> Any:
151159
with _disable_current_modes():
152160
inputs = list(map(defake, inputs))
153-
return fn(model, inputs, **kwargs)
161+
return fn(model, inputs, **kwargs) # type: ignore[call-arg]
154162

155163
return wrapper
156164

157165

158-
def device_from_inputs(example_inputs) -> torch.device:
166+
def device_from_inputs(example_inputs: Iterable[Any]) -> torch.device:
159167
for x in example_inputs:
160168
if hasattr(x, "device"):
161169
return x.device
170+
return torch.device("cpu") # Default fallback
162171

163172

164-
def dtype_from_inputs(example_inputs) -> torch.dtype:
173+
def dtype_from_inputs(example_inputs: Iterable[Any]) -> torch.dtype:
165174
for x in example_inputs:
166175
if hasattr(x, "dtype"):
167176
return x.dtype
177+
return torch.float32 # Default fallback

0 commit comments

Comments
 (0)