Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] make fx.wrap idempotent #104838

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/package/test_package_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
# Support the case where we run this file directly.
from common import PackageTestCase

torch.fx.wrap("len")
# Do it twice to make sure it doesn't affect anything
torch.fx.wrap("len")

class TestPackageFX(PackageTestCase):
"""Tests for compatibility with FX."""
Expand Down Expand Up @@ -162,6 +165,27 @@ def __init__(self, root, graph, info):
input_x = torch.randn(3)
self.assertEqual(loaded_gm(input_x), gm(input_x))

def test_package_fx_wrap(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a):
return len(a)

traced = torch.fx.symbolic_trace(TestModule())

f = BytesIO()
with torch.package.PackageExporter(f) as pe:
pe.save_pickle("model", "model.pkl", traced)
f.seek(0)

pi = PackageImporter(f)
loaded_traced = pi.load_pickle("model", "model.pkl")
input = torch.rand(2, 3)
self.assertEqual(loaded_traced(input), traced(input))



if __name__ == "__main__":
run_tests()
12 changes: 7 additions & 5 deletions torch/fx/_symbolic_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,9 +839,11 @@ def __deepcopy__(self, memo):
return new_tracer


# List of pairs of (global dict, function name) functions
# to patch for the purposes of the wrap() API.
_wrapped_fns_to_patch: List[Tuple[dict, str]] = []
# Dictionary of (id(globals dict), function name) => globals_dict to patch for
# the purposes of the wrap() API.
# We key by the globals dict id and function name to ensure we're wrapping a given
# function only once.
_wrapped_fns_to_patch: Dict[Tuple[int, str], dict] = {}

# List of methods on classes to wrap (class type, function name)
# this currently only works for Tensor.* methods that aren't traced properly
Expand Down Expand Up @@ -1002,7 +1004,7 @@ def _patch_wrapped_functions(patcher: _Patcher):
Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap
the listed global functions in the `_create_wrapped_func` wrapper.
"""
for frame_dict, name in _wrapped_fns_to_patch:
for (_, name), frame_dict in _wrapped_fns_to_patch.items():
if name not in frame_dict and hasattr(builtins, name):
orig_fn = getattr(builtins, name)
else:
Expand Down Expand Up @@ -1088,7 +1090,7 @@ def my_custom_function(x, y):

# consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search
# semantics would be slightly different, but would add support `from x import wrapped_function`
_wrapped_fns_to_patch.append((f.f_globals, fn_name))
_wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals
return fn_or_name


Expand Down