Skip to content

Commit

Permalink
[dynamo][eval_frame] Create a dynamic wrapper fn to avoid cache colli…
Browse files Browse the repository at this point in the history
…sions

ghstack-source-id: bd9fe3a96748bba33cc066909b414b433087bd80
Pull Request resolved: #124881
  • Loading branch information
anijain2305 committed Apr 27, 2024
1 parent 230983a commit 46fd450
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
14 changes: 14 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2761,6 +2761,20 @@ def fn(x, ys, zs):
with self.assertRaisesRegex(ValueError, "zip()"):
opt_fn(x, ys[:1], zs)

def test_external_utils_wrapper(self):
def fn1(x):
return torch.sin(x)

def fn2(x):
return torch.cos(x)

opt_fn1 = torch.compile(torch._dynamo.external_utils.wrap_inline(fn1))
opt_fn2 = torch.compile(torch._dynamo.external_utils.wrap_inline(fn2))

opt_fn1(torch.randn(4))
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
opt_fn2(torch.randn(4))


instantiate_parametrized_tests(FunctionTests)

Expand Down
23 changes: 22 additions & 1 deletion torch/_dynamo/external_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This module contains functions that *will be allowed* by dynamo

import functools
import types

import torch
import torch.utils._pytree as pytree
Expand All @@ -26,6 +27,24 @@ def is_compiling() -> bool:
return torch.compiler.is_compiling()


def create_new_fn(fn):
from .bytecode_transformation import transform_code_object

def nothing(*args):
pass

new_code = transform_code_object(fn.__code__, nothing)
new_fn = types.FunctionType(
new_code,
fn.__globals__,
fn.__name__,
fn.__defaults__,
fn.__closure__,
)
new_fn.__kwdefaults__ = fn.__kwdefaults__
return new_fn


def wrap_inline(fn):
"""
Create an extra frame around fn that is not in skipfiles
Expand All @@ -35,7 +54,9 @@ def wrap_inline(fn):
def inner(*args, **kwargs):
return fn(*args, **kwargs)

return inner
# Create a new function dynamically to avoid Dynamo cache collisions on the
# same fn.__code__ object.
return create_new_fn(inner)


def call_hook(hook, *args):
Expand Down

0 comments on commit 46fd450

Please sign in to comment.