-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
rref_proxy.py
40 lines (30 loc) · 1.17 KB
/
rref_proxy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from functools import partial
from . import functions
import torch
def _local_invoke(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
@functions.async_execution
def _local_invoke_async_execution(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
def _invoke_rpc(rref, rpc_api, func_name, *args, **kwargs):
rref_type = rref._get_type()
_invoke_func = _local_invoke
# Bypass ScriptModules when checking for async function attribute.
bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
rref_type, torch._C.ScriptModule
)
if not bypass_type:
func = getattr(rref_type, func_name)
if hasattr(func, "_wrapped_async_rpc_function"):
_invoke_func = _local_invoke_async_execution
return rpc_api(
rref.owner(),
_invoke_func,
args=(rref, func_name, args, kwargs)
)
class RRefProxy:
def __init__(self, rref, rpc_api):
self.rref = rref
self.rpc_api = rpc_api
def __getattr__(self, func_name):
return partial(_invoke_rpc, self.rref, self.rpc_api, func_name)