Skip to content

Commit

Permalink
Tweak dynamic=False behavior (#105715)
Browse files Browse the repository at this point in the history
Previously, dynamic=False is a no-op, and dynamic=True preemptively
turns on dynamic shapes everywhere.

Now, dynamic=False *disables* automatic dynamic, and an unset dynamic
defaults to dynamic=None (which uses automatic dynamic.)  This
seems to be more intuitive per
#105634 (comment)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #105715
Approved by: https://github.com/voznesenskym
  • Loading branch information
ezyang authored and pytorchmergebot committed Jul 24, 2023
1 parent 0ab7404 commit 3045e84
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 21 deletions.
24 changes: 21 additions & 3 deletions test/dynamo/test_recompiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ def test_automatic_dynamic_reduce_recompiles(self):
def foo(x, y):
return x * y

def run_foo_6_times_and_count_recompiles():
def run_foo_6_times_and_count_recompiles(dynamic=None):
cnt = torch._dynamo.testing.CompileCounter()

x = torch.randn([2])
y = torch.randn([2])
opt = torch._dynamo.optimize(cnt)(foo)
opt = torch._dynamo.optimize(cnt, dynamic=dynamic)(foo)
opt(x, y)
x = torch.randn([3])
y = torch.randn([3])
Expand Down Expand Up @@ -51,9 +51,21 @@ def run_with_automatic():
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
without = run_foo_6_times_and_count_recompiles(dynamic=False)
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
with_automatic = run_with_automatic()
self.assertEqual(with_automatic.frame_count, 2)
self.assertEqual(with_automatic.op_count, 2)
torch._dynamo.reset()
with_automatic = run_foo_6_times_and_count_recompiles(dynamic=None)
self.assertEqual(with_automatic.frame_count, 2)
self.assertEqual(with_automatic.op_count, 2)
torch._dynamo.reset()
with_dynamic = run_foo_6_times_and_count_recompiles(dynamic=True)
self.assertEqual(with_dynamic.frame_count, 1)
self.assertEqual(with_dynamic.op_count, 1)

@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def test_recompiles_true_false_flop(self):
Expand Down Expand Up @@ -98,7 +110,7 @@ def run_with_automatic():
return run_foo_6_times_and_count_recompiles()

without = run_without_automatic()
self.assertEqual(without.frame_count, 2)
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
with_automatic = run_with_automatic()
Expand Down Expand Up @@ -210,3 +222,9 @@ def foo(a):
self.assertEqual(cmp_result, eager_result)
# Recompile, alias changed
self.assertEqual(cnt.frame_count, 2)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

run_tests()
9 changes: 5 additions & 4 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,7 +1588,7 @@ def __call__(self, model_, inputs_):

def compile(model: Optional[Callable] = None, *,
fullgraph: builtins.bool = False,
dynamic: builtins.bool = False,
dynamic: Optional[builtins.bool] = None,
backend: Union[str, Callable] = "inductor",
mode: Union[str, None] = None,
options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None,
Expand All @@ -1610,12 +1610,13 @@ def compile(model: Optional[Callable] = None, *,
Args:
model (Callable): Module/function to optimize
fullgraph (bool): Whether it is ok to break model into several subgraphs
dynamic (bool): Use dynamic shape tracing. When this is True, we will up-front attempt
dynamic (bool or None): Use dynamic shape tracing. When this is True, we will up-front attempt
to generate a kernel that is as dynamic as possible to avoid recompilations when
sizes change. This may not always work as some operations/optimizations will
force specialization; use TORCH_LOGS=dynamic to debug overspecialization.
In particular, if you use "reduce-overhead", this will force sizes to be static
even with dynamic=True.
When this is False, we will NEVER generate dynamic kernels, we will always specialize.
By default (None), we automatically detect if dynamism has occurred and compile a more
dynamic kernel upon recompile.
backend (str or Callable): backend to be used
- "inductor" is the default backend, which is a good balance between performance and overhead
- Non experimental in-tree backends can be seen with `torch._dynamo.list_backends()`
Expand Down
33 changes: 19 additions & 14 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,18 @@ def innermost_fn(fn):


@contextlib.contextmanager
def enable_dynamic(enable: bool = True, export: bool = False):
if not enable:
yield
return
# dynamic=True used to mean fully dynamic. However, with automatic dynamic, the default flipped to
# deriving dynamism. For back compat, and forward compat for when dynamic=True is default, we take
# dynamic=True here to mean "fully dynamic from the start".
with config.patch(assume_static_by_default=False):
def enable_dynamic(enable: Optional[bool] = None, export: bool = False):
if enable is None:
yield
elif enable:
# Assume everything is dynamic by deafult
with config.patch(assume_static_by_default=False):
yield
else:
with config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
yield


class _TorchDynamoContext:
Expand All @@ -206,7 +209,7 @@ def __init__(
first_ctx=False,
*,
export=False,
dynamic=False,
dynamic=None,
compiler_config=None,
):
super().__init__()
Expand Down Expand Up @@ -379,7 +382,7 @@ def __init__(
first_ctx=False,
*,
export=False,
dynamic=False,
dynamic=None,
compiler_config=None,
):
def on_enter():
Expand Down Expand Up @@ -475,7 +478,7 @@ def _optimize_catch_errors(
hooks: Hooks,
backend_ctx_ctor=null_context,
export=False,
dynamic=False,
dynamic=None,
compiler_config=None,
):
return OptimizeContext(
Expand Down Expand Up @@ -529,7 +532,7 @@ def optimize(
guard_export_fn=None,
guard_fail_fn=None,
disable=False,
dynamic=False,
dynamic=None,
):
"""
The main entrypoint of TorchDynamo. Do graph capture and call
Expand All @@ -547,7 +550,9 @@ def optimize(
nopython: If True, graph breaks will be errors and there will
be a single whole-program graph.
disable: If True, turn this decorator into a no-op
dynamic: If True, turn on dynamic shapes support
dynamic: If True, upfront compile as dynamic a kernel as possible. If False,
disable all dynamic shapes support (always specialize). If None, automatically
detect when sizes vary and generate dynamic kernels upon recompile.
Example Usage::
Expand Down Expand Up @@ -1169,7 +1174,7 @@ def optimize_assert(
hooks=Hooks(None, None),
export=False,
export_constraints=None,
dynamic=False,
dynamic=None,
):
"""
The same as `torch._dynamo.optimize(backend, nopython=True)`
Expand Down

0 comments on commit 3045e84

Please sign in to comment.