From a5fb12d168125fd228c4b36ef08a3ad4904ac457 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Fri, 4 Dec 2020 11:31:36 -0800 Subject: [PATCH] RRef proxy support for ScriptModule methods (#48339) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48339 Closes https://github.com/pytorch/pytorch/issues/48294 https://github.com/pytorch/pytorch/pull/48293 added creation and transfer of ScriptModule over RPC in python, but it did not work with ScriptModule. This PR makes the above work with ScriptModule as per a discussion with mrshenli: 1) We remove the `hasattr()` check and just let Python throw the exception as it would when accessing the py function with `getattr` 2) We condition on `issubclass(type, ScriptModule)` when checking if it is wrapped with async_function, because `ScriptModule` does not have getattr implemented (this is because ScriptModule forward/function is not a python function, it is a torchscript specific function): ``` torch/jit/_script.py", line 229, in __get__ return self.__getattr__("forward") # type: ignore AttributeError: '_CachedForward' object has no attribute '__getattr__' ``` ghstack-source-id: 117631795 Test Plan: Modified ut Reviewed By: wanchaol Differential Revision: D25134423 fbshipit-source-id: 918ca88891c7b0531325f046b61f28947575cff0 --- torch/distributed/rpc/rref_proxy.py | 12 +++++------ .../_internal/distributed/rpc/jit/rpc_test.py | 20 +++++++++---------- .../_internal/distributed/rpc/rpc_test.py | 8 ++++---- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/torch/distributed/rpc/rref_proxy.py b/torch/distributed/rpc/rref_proxy.py index 17ce9da643b9..f087514d92a8 100644 --- a/torch/distributed/rpc/rref_proxy.py +++ b/torch/distributed/rpc/rref_proxy.py @@ -15,13 +15,11 @@ def _invoke_rpc(rref, rpc_api, func_name, *args, **kwargs): rref_type = rref._get_type() _invoke_func = _local_invoke - if rref_type is not torch._C.ScriptModule: - if not hasattr(rref_type, func_name): - raise ValueError( - f"Function {func_name} is not an attribute of type {rref_type} " - f"referenced by RRef {rref}." - ) - + # Bypass ScriptModules when checking for async function attribute. + bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass( + rref_type, torch._C.ScriptModule + ) + if not bypass_type: func = getattr(rref_type, func_name) if hasattr(func, "_wrapped_async_rpc_function"): _invoke_func = _local_invoke_async_execution diff --git a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py index 95d2ca860afd..2a0b114f2b8a 100644 --- a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py @@ -21,9 +21,6 @@ RpcAgentTestFixture, ) -def run(rref, func_name, args, kwargs): - return getattr(rref.local_value(), func_name)(*args, **kwargs) - def rref_isinstance(rref, cls_to_check): return isinstance(rref.local_value(), cls_to_check) @@ -362,6 +359,10 @@ def __init__(self, rank): def forward(self) -> Tensor: return self.a + @torch.jit.script_method + def custom_func(self) -> Tensor: + return self.a + def owner_create_rref_my_script_class(a): return rpc.RRef(MyScriptClass(a)) @@ -973,20 +974,19 @@ def test_create_script_module_on_remote(self): ) self.assertTrue(remote_end_is_script) # Run forward pass remotely. - # TODO: make RRef helper work with ScriptModule. - remote_forward_output = rpc.rpc_sync( - remote_script_module.owner(), - run, - args=(remote_script_module, "forward", (), {}), - ) + remote_forward_output = remote_script_module.rpc_sync().forward() self.assertEqual(remote_forward_output, torch.ones(self.rank)) + # Run function defined on ScriptModule remotely. + remote_func_output = remote_script_module.rpc_sync().custom_func() + self.assertEqual(remote_func_output, torch.ones(self.rank)) # Ensure we can transfer ScriptModule RRef to this rank and run # forward pass. local_script_module = remote_script_module.to_here() self.assertTrue(isinstance(local_script_module, torch.jit.ScriptModule)) rank_ones_tensor = local_script_module() self.assertEqual(rank_ones_tensor, torch.ones(self.rank)) - + local_script_func_output = local_script_module.custom_func() + self.assertEqual(local_script_func_output, torch.ones(self.rank)) @dist_init def test_load_script_module_with_pickled_rref(self): diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index ba7f0d650b22..46dbacc3c2eb 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -606,14 +606,14 @@ def test_self_remote_rref_as_remote_arg(self): def test_rref_proxy_non_exist(self): dst = worker_name((self.rank + 1) % self.world_size) rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3)) - msg = "non_exist is not an attribute of type" - with self.assertRaisesRegex(ValueError, msg): + msg = "has no attribute \'non_exist\'" + with self.assertRaisesRegex(AttributeError, msg): rref.rpc_sync().non_exist() - with self.assertRaisesRegex(ValueError, msg): + with self.assertRaisesRegex(AttributeError, msg): rref.rpc_async().non_exist() - with self.assertRaisesRegex(ValueError, msg): + with self.assertRaisesRegex(AttributeError, msg): rref.remote().non_exist() def _test_rref_proxy_tensor(self, dst):