-
Notifications
You must be signed in to change notification settings - Fork 21.3k
/
remote_module_test.py
435 lines (375 loc) · 15.7 KB
/
remote_module_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
#!/usr/bin/python3
import enum
from typing import Tuple
import torch
import torch.distributed.rpc as rpc
import torch.testing._internal.dist_utils as dist_utils
from torch import Tensor, nn
from torch._jit_internal import Future
from torch.distributed.nn import RemoteModule
from torch.distributed.nn.api.remote_module import _RemoteModule
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
_PARAM_VAL = torch.nn.Parameter(torch.ones(1))
# RPC handler for querying the device on the destination worker.
def remote_device(module_rref):
for param in module_rref.local_value().parameters():
return param.device
class ModuleCreationMode(enum.Enum):
MODULE_CTOR_WITH_INTERFACE = "module_ctor_with_interface"
MODULE_CTOR = "module_ctor"
@torch.jit.interface
class MyModuleInterface:
def forward(
self, tensor: Tensor, number: int, word: str = "default"
) -> Tuple[str, int, Tensor]:
pass
@torch.jit.interface
class RemoteMyModuleInterface:
def forward(
self, tensor: Tensor, number: int, word: str = "default"
) -> Tuple[str, int, Tensor]:
pass
def forward_async(
self, tensor: Tensor, number: int, word: str = "default"
) -> Future[Tuple[str, int, Tensor]]:
pass
class MyModule(nn.Module):
def __init__(self, first_arg, first_kwarg=-1):
super().__init__()
self.param1 = _PARAM_VAL
def forward(
self, tensor: Tensor, number: int, word: str = "default"
) -> Tuple[str, int, Tensor]:
return word, number, tensor
class BadModule:
def __init__(self, first_arg, first_kwarg=-1):
pass
def create_scripted_module(first_arg, first_kwarg=-1):
module = MyModule(first_arg, first_kwarg=first_kwarg)
scripted_module = torch.jit.script(module)
return scripted_module
class RemoteModuleTest(RpcAgentTestFixture):
@property
def world_size(self): # Override setting in RpcAgentTestFixture
return 2
@staticmethod
def _create_remote_module_iter(dst_worker_name, device="cpu", modes=None):
if modes is None:
modes = ModuleCreationMode.__members__.values()
args = (1,)
kwargs = dict(first_kwarg=2)
if ModuleCreationMode.MODULE_CTOR in modes:
remote_module = RemoteModule(
dst_worker_name, device, MyModule, args, kwargs
)
yield remote_module
if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes:
remote_module = _RemoteModule(
dst_worker_name,
device,
create_scripted_module,
args,
kwargs,
_module_interface_cls=MyModuleInterface,
)
scripted_remote_module = torch.jit.script(remote_module)
yield scripted_remote_module
@dist_utils.dist_init
def test_bad_module(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
args = (1,)
kwargs = dict(first_kwarg=2)
with self.assertRaisesRegex(
ValueError,
r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,",
):
RemoteModule(dst_worker_name, "cpu", BadModule, args, kwargs)
with self.assertRaisesRegex(
ValueError,
r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,",
):
RemoteModule(dst_worker_name, "cpu", BadModule, args, kwargs)
@dist_utils.dist_init
def test_forward_async(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
args = (torch.ones(1), 2, "3")
for remote_module in self._create_remote_module_iter(dst_worker_name):
ret_fut = remote_module.forward_async(*args)
ret = ret_fut.wait()
self.assertEqual(ret, tuple(reversed(args)))
@dist_utils.dist_init
def test_forward_async_script(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
scripted_remote_module = next(
self._create_remote_module_iter(
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
)
)
@torch.jit.script
def run_forward_async(scripted_remote_module: RemoteMyModuleInterface):
ret_fut = scripted_remote_module.forward_async(torch.ones(1), 2, "3")
ret = ret_fut.wait()
return ret
ret = run_forward_async(scripted_remote_module)
self.assertEqual(ret, ("3", 2, torch.ones(1)))
@dist_utils.dist_init
def test_forward_sync(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
args = (torch.ones(1), 2, "3")
for remote_module in self._create_remote_module_iter(dst_worker_name):
ret = remote_module.forward(*args)
self.assertEqual(ret, tuple(reversed(args)))
@dist_utils.dist_init
def test_forward_sync_script(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
scripted_remote_module = next(
self._create_remote_module_iter(
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
)
)
@torch.jit.script
def run_forward(scripted_remote_module: MyModuleInterface):
ret = scripted_remote_module.forward(torch.ones(1), 2, "3")
return ret
ret = run_forward(scripted_remote_module)
self.assertEqual(ret, ("3", 2, torch.ones(1)))
@dist_utils.dist_init
def test_forward_with_kwargs(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
args = (torch.ones(1), 2)
kwargs = dict(word="3")
# Only test Python nn.Module, because script module methods don't support taking kwargs.
for remote_module in self._create_remote_module_iter(
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
):
ret_fut = remote_module.forward_async(*args, **kwargs)
ret = ret_fut.wait()
self.assertEqual(ret, tuple(reversed(args + ("3",))))
ret = remote_module.forward(*args, **kwargs)
self.assertEqual(ret, tuple(reversed(args + ("3",))))
@dist_utils.dist_init
def test_remote_parameters(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
# Only test Python nn.Module, because script module methods don't support ``remote_parameters``.
for remote_module in self._create_remote_module_iter(
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
):
param_rrefs = remote_module.remote_parameters()
self.assertEqual(len(param_rrefs), 1)
self.assertTrue(torch.equal(param_rrefs[0].to_here(), _PARAM_VAL))
@skip_if_lt_x_gpu(1)
@dist_utils.dist_init
def test_valid_device(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
for remote_module in self._create_remote_module_iter(
dst_worker_name, device="cuda:0", modes=[ModuleCreationMode.MODULE_CTOR]
):
device = rpc.rpc_sync(
dst_worker_name, remote_device, (remote_module.module_rref,)
)
self.assertEqual(device.type, "cuda")
self.assertEqual(device.index, 0)
@skip_if_lt_x_gpu(1)
@dist_utils.dist_init
def test_invalid_devices(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
with self.assertRaisesRegex(
RuntimeError,
r"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan"
" device type at start of device string",
):
list(
self._create_remote_module_iter(
dst_worker_name,
device="foo",
modes=[ModuleCreationMode.MODULE_CTOR],
)
)
with self.assertRaisesRegex(
RuntimeError, r"CUDA error: invalid device ordinal"
):
list(
self._create_remote_module_iter(
dst_worker_name,
device="cuda:100",
modes=[ModuleCreationMode.MODULE_CTOR],
)
)
with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"):
list(
self._create_remote_module_iter(
dst_worker_name,
modes=[ModuleCreationMode.MODULE_CTOR],
device="cpu2",
)
)
with self.assertRaisesRegex(
RuntimeError, r"CPU device index must be -1 or zero, got 2"
):
list(
self._create_remote_module_iter(
dst_worker_name,
device="cpu:2",
modes=[ModuleCreationMode.MODULE_CTOR],
)
)
@dist_utils.dist_init
def test_unsupported_methods(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
for remote_module in self._create_remote_module_iter(
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
):
with self.assertRaisesRegex(
ValueError, r"Method ``register_buffer`` not supported for RemoteModule"
):
remote_module.register_buffer("buffer", torch.ones(5))
with self.assertRaisesRegex(
ValueError,
r"Method ``register_parameter`` not supported for RemoteModule",
):
remote_module.register_parameter(
"param", torch.nn.Parameter(torch.ones(1))
)
with self.assertRaisesRegex(
ValueError, r"Method ``add_module`` not supported for RemoteModule"
):
remote_module.add_module("empty", None)
with self.assertRaisesRegex(
ValueError, r"Method ``apply`` not supported for RemoteModule"
):
fn = torch.rand((3, 3), requires_grad=False)
remote_module.apply(fn)
with self.assertRaisesRegex(
ValueError, r"Method ``cuda`` not supported for RemoteModule"
):
remote_module.cuda()
with self.assertRaisesRegex(
ValueError, r"Method ``cpu`` not supported for RemoteModule"
):
remote_module.cpu()
with self.assertRaisesRegex(
ValueError, r"Method ``type`` not supported for RemoteModule"
):
remote_module.type(torch.FloatTensor)
with self.assertRaisesRegex(
ValueError, r"Method ``float`` not supported for RemoteModule"
):
remote_module.float()
with self.assertRaisesRegex(
ValueError, r"Method ``double`` not supported for RemoteModule"
):
remote_module.double()
with self.assertRaisesRegex(
ValueError, r"Method ``bfloat16`` not supported for RemoteModule"
):
remote_module.bfloat16()
with self.assertRaisesRegex(
ValueError, r"Method ``to`` not supported for RemoteModule"
):
remote_module.to("cpu", dtype=torch.int32)
def hook(module, grad_input, grad_output):
pass
with self.assertRaisesRegex(
ValueError,
r"Method ``register_backward_hook`` not supported for RemoteModule",
):
remote_module.register_backward_hook(hook)
with self.assertRaisesRegex(
ValueError,
r"Method ``register_forward_pre_hook`` not supported for RemoteModule",
):
remote_module.register_forward_pre_hook(hook)
with self.assertRaisesRegex(
ValueError,
r"Method ``register_forward_hook`` not supported for RemoteModule",
):
remote_module.register_forward_hook(hook)
with self.assertRaisesRegex(
ValueError, r"Method ``state_dict`` not supported for RemoteModule"
):
remote_module.state_dict()
with self.assertRaisesRegex(
ValueError, r"Method ``load_state_dict`` not supported for RemoteModule"
):
remote_module.load_state_dict({})
with self.assertRaisesRegex(
ValueError,
r"Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead.",
):
remote_module.parameters()
with self.assertRaisesRegex(
ValueError,
r"Method ``named_parameters`` not supported for RemoteModule",
):
remote_module.named_parameters()
with self.assertRaisesRegex(
ValueError, r"Method ``buffers`` not supported for RemoteModule"
):
remote_module.buffers()
with self.assertRaisesRegex(
ValueError, r"Method ``named_buffers`` not supported for RemoteModule"
):
remote_module.named_buffers()
with self.assertRaisesRegex(
ValueError, r"Method ``children`` not supported for RemoteModule"
):
remote_module.children()
with self.assertRaisesRegex(
ValueError, r"Method ``named_children`` not supported for RemoteModule"
):
remote_module.named_children()
with self.assertRaisesRegex(
ValueError, r"Method ``modules`` not supported for RemoteModule"
):
remote_module.modules()
with self.assertRaisesRegex(
ValueError, r"Method ``named_modules`` not supported for RemoteModule"
):
remote_module.named_modules()
with self.assertRaisesRegex(
ValueError, r"Method ``train`` not supported for RemoteModule"
):
remote_module.train()
with self.assertRaisesRegex(
ValueError, r"Method ``eval`` not supported for RemoteModule"
):
remote_module.eval()
with self.assertRaisesRegex(
ValueError, r"Method ``requires_grad_`` not supported for RemoteModule"
):
remote_module.requires_grad_()
with self.assertRaisesRegex(
ValueError, r"Method ``zero_grad`` not supported for RemoteModule"
):
remote_module.zero_grad()
with self.assertRaisesRegex(
ValueError, r"Method ``share_memory`` not supported for RemoteModule"
):
remote_module.share_memory()
with self.assertRaisesRegex(
ValueError, r"Method ``extra_repr`` not supported for RemoteModule"
):
remote_module.extra_repr()