From 46fd4503be080f0fa22dcb6519878ff35274ef36 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sat, 27 Apr 2024 11:38:32 -0700 Subject: [PATCH] [dynamo][eval_frame] Create a dynamic wrapper fn to avoid cache collisions ghstack-source-id: bd9fe3a96748bba33cc066909b414b433087bd80 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124881 --- test/dynamo/test_functions.py | 14 ++++++++++++++ torch/_dynamo/external_utils.py | 23 ++++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 3b20c7294eeca..e3be5a1c2137c 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -2761,6 +2761,20 @@ def fn(x, ys, zs): with self.assertRaisesRegex(ValueError, "zip()"): opt_fn(x, ys[:1], zs) + def test_external_utils_wrapper(self): + def fn1(x): + return torch.sin(x) + + def fn2(x): + return torch.cos(x) + + opt_fn1 = torch.compile(torch._dynamo.external_utils.wrap_inline(fn1)) + opt_fn2 = torch.compile(torch._dynamo.external_utils.wrap_inline(fn2)) + + opt_fn1(torch.randn(4)) + with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): + opt_fn2(torch.randn(4)) + instantiate_parametrized_tests(FunctionTests) diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index 87e1e44fbe1aa..6620013881662 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -1,6 +1,7 @@ # This module contains functions that *will be allowed* by dynamo import functools +import types import torch import torch.utils._pytree as pytree @@ -26,6 +27,24 @@ def is_compiling() -> bool: return torch.compiler.is_compiling() +def create_new_fn(fn): + from .bytecode_transformation import transform_code_object + + def nothing(*args): + pass + + new_code = transform_code_object(fn.__code__, nothing) + new_fn = types.FunctionType( + new_code, + fn.__globals__, + fn.__name__, + fn.__defaults__, + fn.__closure__, + ) + new_fn.__kwdefaults__ = fn.__kwdefaults__ + return new_fn + + def wrap_inline(fn): """ Create an extra frame around fn that is not in skipfiles @@ -35,7 +54,9 @@ def wrap_inline(fn): def inner(*args, **kwargs): return fn(*args, **kwargs) - return inner + # Create a new function dynamically to avoid Dynamo cache collisions on the + # same fn.__code__ object. + return create_new_fn(inner) def call_hook(hook, *args):