-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
remote_module.py
387 lines (312 loc) · 14.9 KB
/
remote_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
#!/usr/bin/python3
import types
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)
import torch
import torch.distributed.rpc as rpc
from torch import Tensor, device, dtype, nn
from torch.distributed.nn.jit import instantiator
from torch.nn.parameter import Parameter
from torch.utils.hooks import RemovableHandle
_grad_t = Union[Tuple[Tensor, ...], Tensor]
# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use
# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be
# the type of the subclass, not the looser type of `Module`.
T = TypeVar("T", bound="Module")
_NON_SCRIPTABLE_REMOTE_MODULE_MODULE = (
instantiator.instantiate_non_scriptable_remote_module_template()
)
# RPC handler.
def _instantiate_template(module_interface_cls):
instantiator.instantiate_scriptable_remote_module_template(module_interface_cls)
def _create_module(module_cls, args, kwargs, device="cpu", module_interface_cls=None):
module = module_cls(*args, **kwargs)
if not isinstance(module, nn.Module):
raise ValueError(
"Expect `module_cls(*args, **kwargs)` returns an instance of <class nn.Module>, "
f"but it returns an instance of {type(module)}."
)
if module_interface_cls is not None:
module = torch.jit.script(module)
module.to(device)
return rpc.RRef(module, module_interface_cls)
def _param_rrefs(module_rref, recurse):
ret = []
for param in module_rref.local_value().parameters(recurse):
ret.append(rpc.RRef(param))
return ret
def _raise_not_supported(name):
raise ValueError("Method ``{}`` not supported for RemoteModule".format(name))
class _RemoteModule(nn.Module):
def __init__(
self,
remote_device: str,
module_cls: nn.Module,
args: Tuple = None,
kwargs: Dict[str, Any] = None,
_module_interface_cls: Any = None,
):
"""
A RemoteModule instance can only be created after RPC initialization.
It creates a user-specified module on a specified remote node.
It behaves like a regular ``nn.Module`` except that the ``forward`` method is
executed on the remote node.
It takes care of autograd recording to ensure the backward pass propogates
gradients back to the corresponding remote module.
The arguments of ``forward_async`` and ``forward`` are the same as
the ``forward`` method of the module returned by the ``module_cls``.
Apart from ``forward_async`` and ``forward``, no other methods are supported from nn.Module for now.
Particularly, to create a hybrid model, typically the local modules should be
created outside of remote modules, rather than as submodules of any remote module (by calling ``add_module``).
Hybrid Example:
>>> class HybridModel(nn.Module):
>>> def __init__(self):
>>> nn.Module.__init__(self)
>>> self.remote_embedding = RemoteModule(...)
>>> self.local_linear = nn.Linear(...)
For example, if ``module_cls`` returns an instance of ``nn.Linear``,
that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``,
the generated ``RemoteModule`` will have 2 methods in signature of
``def forward(input: Tensor) -> Tensor:`` and
``def forward_async(input: Tensor) -> Future[Tensor]:``.
Arguments:
remote_device (str): Device on the destination worker where we‘d like to place this module.
The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
E.g., "trainer0/cpu", "ps0/cuda:0".
module_cls (nn.Module): For example,
>>> class MyModule(nn.Module):
>>> def forward(input):
>>> return input + 1
>>>
>>> module_cls = MyModule
args (Sequence, optional): args to be passed to ``module_cls``.
kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
_module_interface_cls (type, optional): The TorchScript interface type for the module
to be created. The type object should be decorated by @torch.jit.interface.
If not provided, the generated RemoteModule is not torchscript-able.
Warning, this is an experimental API and susceptible to frequent changes.
Returns:
A remote module instance which wraps the :class:`~nn.Module` created by the
user-provided ``module_cls``, it has a blocking ``forward`` method and an
asynchronous ``forward_async`` method that returns a future of the ``forward`` call
on the user-provided module on the remote side.
Example::
Run the following code in two different processes:
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> from torch import nn, Tensor
>>> from torch.distributed.nn.api.remote_module import RemoteModule
>>>
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> remote_linear_module = RemoteModule(
>>> "worker1/cpu", nn.Linear, args=(20, 30),
>>> )
>>> input = torch.randn(128, 20)
>>> ret_fut = remote_linear_module.forward_async(input)
>>> ret = ret_fut.wait()
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>>
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
"""
super().__init__()
# Sanity checks.
assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC."
# Default arguments preperation.
args = args if args is not None else ()
kwargs = kwargs if kwargs is not None else {}
[self.on, self.device] = remote_device.split("/")
if _module_interface_cls is not None:
# Users reply on this field to know if this generated RemoteModule is TorchScript-able.
self.is_scriptable = True
# Instantiate template on remote side.
fut = rpc.rpc_async(
self.on, _instantiate_template, (_module_interface_cls,)
)
# Instantiate template on local side.
generated_module = (
instantiator.instantiate_scriptable_remote_module_template(
_module_interface_cls
)
)
generated_methods = generated_module._generated_methods
# Create the module on the remote side.
fut.wait() # Ensure remote_module_cls is available on remote side.
else:
self.is_scriptable = False
generated_methods = _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
# Create the module on the remote side.
self.module_rref = rpc.rpc_sync(
self.on,
_create_module,
(module_cls, args, kwargs, self.device, _module_interface_cls),
)
# Install generated methods.
for method in generated_methods:
method_name = method.__name__
method = torch.jit.export(method)
setattr(self, method_name, types.MethodType(method, self))
def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]:
r"""Returns a list of RRefs of remote module parameters.
This is typically passed to a distributed optimizer.
Args:
recurse (bool): if True, then returns parameters of the remote module
and all submodules of the remote module.
Otherwise, returns only parameters that are direct members of the remote module.
Returns:
A list of RRefs to remote module parameters.
"""
return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse))
def register_buffer(
self, name: str, tensor: Optional[Tensor], persistent: bool = True
) -> None:
_raise_not_supported(self.register_buffer.__name__)
def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
_raise_not_supported(self.register_parameter.__name__)
def add_module(self, name: str, module: Optional["Module"]) -> None:
_raise_not_supported(self.add_module.__name__)
def apply(self: T, fn: Callable[["Module"], None]) -> T:
_raise_not_supported(self.apply.__name__)
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
_raise_not_supported(self.cuda.__name__)
def cpu(self: T) -> T:
_raise_not_supported(self.cpu.__name__)
def type(self: T, dst_type: Union[dtype, str]) -> T:
_raise_not_supported(self.type.__name__)
def float(self: T) -> T:
_raise_not_supported(self.float.__name__)
def double(self: T) -> T:
_raise_not_supported(self.double.__name__)
def half(self: T) -> T:
_raise_not_supported(self.half.__name__)
def bfloat16(self: T) -> T:
_raise_not_supported(self.bfloat16.__name__)
def to(self, *args, **kwargs):
_raise_not_supported(self.to.__name__)
def register_backward_hook(
self, hook: Callable[["Module", _grad_t, _grad_t], Union[None, Tensor]]
) -> RemovableHandle:
_raise_not_supported(self.register_backward_hook.__name__)
def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
_raise_not_supported(self.register_forward_pre_hook.__name__)
def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle:
_raise_not_supported(self.register_forward_hook.__name__)
def state_dict(self, destination=None, prefix="", keep_vars=False):
_raise_not_supported(self.state_dict.__name__)
def load_state_dict(
self,
state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]],
strict: bool = True,
):
_raise_not_supported(self.load_state_dict.__name__)
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
raise ValueError(
"Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead."
)
def named_parameters(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, Tensor]]:
_raise_not_supported(self.named_parameters.__name__)
def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
_raise_not_supported(self.buffers.__name__)
def named_buffers(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, Tensor]]:
_raise_not_supported(self.named_buffers.__name__)
def children(self) -> Iterator["Module"]:
_raise_not_supported(self.children.__name__)
def named_children(self) -> Iterator[Tuple[str, "Module"]]:
_raise_not_supported(self.named_children.__name__)
def modules(self) -> Iterator["Module"]:
_raise_not_supported(self.modules.__name__)
def named_modules(self, memo: Optional[Set["Module"]] = None, prefix: str = ""):
_raise_not_supported(self.named_modules.__name__)
def train(self: T, mode: bool = True) -> T:
_raise_not_supported(self.train.__name__)
def eval(self: T) -> T:
_raise_not_supported(self.eval.__name__)
def requires_grad_(self: T, requires_grad: bool = True) -> T:
_raise_not_supported(self.requires_grad_.__name__)
def zero_grad(self) -> None:
_raise_not_supported(self.zero_grad.__name__)
def share_memory(self: T) -> T:
_raise_not_supported(self.share_memory.__name__)
def extra_repr(self) -> str:
_raise_not_supported(self.extra_repr.__name__)
class RemoteModule(_RemoteModule):
"""
A RemoteModule instance can only be created after RPC initialization.
It creates a user-specified module on a specified remote node.
It behaves like a regular ``nn.Module`` except that the ``forward`` method is
executed on the remote node.
It takes care of autograd recording to ensure the backward pass propogates
gradients back to the corresponding remote module.
The arguments of ``forward_async`` and ``forward`` are the same as
the ``forward`` method of the module returned by the ``module_cls``.
For example, if ``module_cls`` returns an instance of ``nn.Linear``,
that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``,
the generated ``RemoteModule`` will have 2 methods in signature of
``def forward(input: Tensor) -> Tensor:`` and
``def forward_async(input: Tensor) -> Future[Tensor]:``.
Arguments:
remote_device (str): Device on the destination worker where we‘d like to place this module.
The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
E.g., "trainer0/cpu", "ps0/cuda:0".
module_cls (nn.Module): For example,
>>> class MyModule(nn.Module):
>>> def forward(input):
>>> return input + 1
>>>
>>> module_cls = MyModule
args (Sequence, optional): args to be passed to ``module_cls``.
kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
Returns:
A remote module instance which wraps the :class:`~nn.Module` created by the
user-provided ``module_cls``, it has a blocking ``forward`` method and an
asynchronous ``forward_async`` method that returns a future of the ``forward`` call
on the user-provided module on the remote side.
Example::
Run the following code in two different processes:
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> from torch import nn, Tensor
>>> from torch.distributed.nn.api.remote_module import RemoteModule
>>>
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> remote_linear_module = RemoteModule(
>>> "worker1/cpu", nn.Linear, args=(20, 30),
>>> )
>>> input = torch.randn(128, 20)
>>> ret_fut = remote_linear_module.forward_async(input)
>>> ret = ret_fut.wait()
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>>
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
"""
def __init__(
self,
remote_device: str,
module_cls: nn.Module,
args: Tuple = None,
kwargs: Dict[str, Any] = None,
):
super().__init__(remote_device, module_cls, args, kwargs)