Skip to content

Commit

Permalink
RRef proxy support for ScriptModule methods (#48339)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #48339

Closes #48294
#48293 added creation and transfer of ScriptModule over RPC in python, but it did not work with ScriptModule.

This PR makes the above work with ScriptModule as per a discussion with mrshenli:
1) We remove the `hasattr()` check and just let Python throw the exception as it would when accessing the py function with `getattr`
2) We  condition on `issubclass(type, ScriptModule)` when checking if it is wrapped with async_function, because `ScriptModule` does not have getattr implemented (this is because ScriptModule forward/function is not a python function, it is a torchscript specific function):
```
torch/jit/_script.py", line 229, in __get__
    return self.__getattr__("forward")  # type: ignore
AttributeError: '_CachedForward' object has no attribute '__getattr__'
```
ghstack-source-id: 117631795

Test Plan: Modified ut

Reviewed By: wanchaol

Differential Revision: D25134423

fbshipit-source-id: 918ca88891c7b0531325f046b61f28947575cff0
  • Loading branch information
rohan-varma authored and facebook-github-bot committed Dec 4, 2020
1 parent fadec77 commit a5fb12d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 21 deletions.
12 changes: 5 additions & 7 deletions torch/distributed/rpc/rref_proxy.py
Expand Up @@ -15,13 +15,11 @@ def _invoke_rpc(rref, rpc_api, func_name, *args, **kwargs):
rref_type = rref._get_type()

_invoke_func = _local_invoke
if rref_type is not torch._C.ScriptModule:
if not hasattr(rref_type, func_name):
raise ValueError(
f"Function {func_name} is not an attribute of type {rref_type} "
f"referenced by RRef {rref}."
)

# Bypass ScriptModules when checking for async function attribute.
bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
rref_type, torch._C.ScriptModule
)
if not bypass_type:
func = getattr(rref_type, func_name)
if hasattr(func, "_wrapped_async_rpc_function"):
_invoke_func = _local_invoke_async_execution
Expand Down
20 changes: 10 additions & 10 deletions torch/testing/_internal/distributed/rpc/jit/rpc_test.py
Expand Up @@ -21,9 +21,6 @@
RpcAgentTestFixture,
)

def run(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)

def rref_isinstance(rref, cls_to_check):
return isinstance(rref.local_value(), cls_to_check)

Expand Down Expand Up @@ -362,6 +359,10 @@ def __init__(self, rank):
def forward(self) -> Tensor:
return self.a

@torch.jit.script_method
def custom_func(self) -> Tensor:
return self.a


def owner_create_rref_my_script_class(a):
return rpc.RRef(MyScriptClass(a))
Expand Down Expand Up @@ -973,20 +974,19 @@ def test_create_script_module_on_remote(self):
)
self.assertTrue(remote_end_is_script)
# Run forward pass remotely.
# TODO: make RRef helper work with ScriptModule.
remote_forward_output = rpc.rpc_sync(
remote_script_module.owner(),
run,
args=(remote_script_module, "forward", (), {}),
)
remote_forward_output = remote_script_module.rpc_sync().forward()
self.assertEqual(remote_forward_output, torch.ones(self.rank))
# Run function defined on ScriptModule remotely.
remote_func_output = remote_script_module.rpc_sync().custom_func()
self.assertEqual(remote_func_output, torch.ones(self.rank))
# Ensure we can transfer ScriptModule RRef to this rank and run
# forward pass.
local_script_module = remote_script_module.to_here()
self.assertTrue(isinstance(local_script_module, torch.jit.ScriptModule))
rank_ones_tensor = local_script_module()
self.assertEqual(rank_ones_tensor, torch.ones(self.rank))

local_script_func_output = local_script_module.custom_func()
self.assertEqual(local_script_func_output, torch.ones(self.rank))

@dist_init
def test_load_script_module_with_pickled_rref(self):
Expand Down
8 changes: 4 additions & 4 deletions torch/testing/_internal/distributed/rpc/rpc_test.py
Expand Up @@ -606,14 +606,14 @@ def test_self_remote_rref_as_remote_arg(self):
def test_rref_proxy_non_exist(self):
dst = worker_name((self.rank + 1) % self.world_size)
rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3))
msg = "non_exist is not an attribute of type"
with self.assertRaisesRegex(ValueError, msg):
msg = "has no attribute \'non_exist\'"
with self.assertRaisesRegex(AttributeError, msg):
rref.rpc_sync().non_exist()

with self.assertRaisesRegex(ValueError, msg):
with self.assertRaisesRegex(AttributeError, msg):
rref.rpc_async().non_exist()

with self.assertRaisesRegex(ValueError, msg):
with self.assertRaisesRegex(AttributeError, msg):
rref.remote().non_exist()

def _test_rref_proxy_tensor(self, dst):
Expand Down

0 comments on commit a5fb12d

Please sign in to comment.