Skip to content

Commit b9a32d3

Browse files
authored
[DDP] Add a test case to test a larger model (#4085)
Summary: This commit adds a test case to test a larger model that can trigger multiple all_reduces instead of one. It also fixes a xmp issue while the launcher being used to run consecutive mp experiments. Test Plan: XRT: MASTER_ADDR=localhost MASTER_PORT=6000 python test/test_ddp.py TestXrtDistributedDataParallel.test_ddp_correctness_large_net PJRT: PJRT_DEVICE=TPU python test/pjrt/test_ddp.py TestPjRtDistributedDataParallel.test_ddp_correctness_large_net
1 parent 4a267b4 commit b9a32d3

File tree

5 files changed

+124
-20
lines changed

5 files changed

+124
-20
lines changed

test/args_parse.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def parse_common_options(datadir=None,
3333
parser.add_argument('--tidy', action='store_true')
3434
parser.add_argument('--metrics_debug', action='store_true')
3535
parser.add_argument('--async_closures', action='store_true')
36+
parser.add_argument('--debug', action='store_true')
3637
if opts:
3738
for name, aopts in opts:
3839
parser.add_argument(name, **aopts)

test/distributed_util.py

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from absl import logging
21
import copy
32
import torch
43
import torch.distributed as dist
@@ -7,7 +6,61 @@
76
from torch.nn.parallel import DistributedDataParallel as DDP
87
import torch_xla.core.xla_model as xm
98
import torch_xla.distributed.xla_backend
10-
from torch_xla.experimental import pjrt
9+
10+
11+
# The followings are helpers useful for debugging purpose.
12+
def comp_hook(state: object,
13+
bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
14+
"""
15+
Debug utils. Please refer to DistributedDataParallel.register_comm_hook to learn
16+
how to use it.
17+
"""
18+
print("comp_hook called.")
19+
fut = torch.futures.Future()
20+
fut.set_result(bucket.buffer())
21+
return fut
22+
23+
24+
def calculate_model_size(model):
25+
"""
26+
Debug utils. Calculate the given model's size in mb.
27+
"""
28+
param_size = 0
29+
for param in model.parameters():
30+
param_size += param.nelement() * param.element_size()
31+
buffer_size = 0
32+
for buffer in model.buffers():
33+
buffer_size += buffer.nelement() * buffer.element_size()
34+
35+
size_all_mb = (param_size + buffer_size) / 1024**2
36+
print('model size: {:.3f}MB'.format(size_all_mb))
37+
38+
39+
class LargeNet(nn.Module):
40+
41+
def __init__(self):
42+
super(LargeNet, self).__init__()
43+
self.net1 = nn.Linear(10, 1000)
44+
self.net2 = nn.Linear(1000, 1000)
45+
self.net3 = nn.Linear(1000, 1000)
46+
self.relu = nn.ReLU()
47+
self.net4 = nn.Linear(1000, 10)
48+
49+
def forward(self, x):
50+
output1 = self.relu(self.net1(x))
51+
output2 = self.relu(self.net2(output1))
52+
output3 = self.relu(self.net3(output2))
53+
return self.net4(output3)
54+
55+
56+
class SmallNet(nn.Module):
57+
58+
def __init__(self):
59+
super(SmallNet, self).__init__()
60+
self.net = nn.Linear(10, 10)
61+
62+
def forward(self, x):
63+
return self.net(x)
1164

1265

1366
def init_xla_backend(init_file: str):
@@ -40,17 +93,32 @@ def train_step(model, inputs, labels, optimizer, loss_fn):
4093
return loss
4194

4295

43-
def ddp_correctness(init_file: str):
96+
def ddp_correctness(init_file: str,
97+
*,
98+
use_large_net: bool = False,
99+
debug: bool = False):
44100
rank, world_size = init_xla_backend(init_file)
45101

46102
device = xm.xla_device()
47103

48104
# To make nn.Linear init same parameters across devices.
49105
torch.manual_seed(2022)
50-
cpu_model = nn.Linear(10, 10)
106+
# Lower range probably makes sense too. Anyway, stick to 100 as the original PoC.
107+
steps = 100
108+
cpu_model = SmallNet()
109+
if use_large_net:
110+
steps = 5 # To save test time.
111+
cpu_model = LargeNet()
112+
51113
# TODO(@alanwaketan): Investigate whether we can omit the gradient_as_bucket_view option.
114+
# bucket_cap_mb is set to 1 mb such that we can still have multiple all_reduces while avoiding
115+
# using models that are too larger (25 mb).
116+
# To be noted, DDP currently uses one bucket for the first iteration. See pytorch#73732.
52117
ddp_model = DDP(
53-
copy.deepcopy(cpu_model).to(device), gradient_as_bucket_view=True)
118+
copy.deepcopy(cpu_model).to(device),
119+
gradient_as_bucket_view=True,
120+
bucket_cap_mb=1)
121+
# ddp_model.register_comm_hook(state=None, hook=comp_hook)
54122

55123
cpu_optimizer = optim.SGD(cpu_model.parameters(), lr=1e-100)
56124
ddp_optimizer = optim.SGD(ddp_model.parameters(), lr=1e-100)
@@ -59,8 +127,7 @@ def ddp_correctness(init_file: str):
59127
local_batch_size = 2
60128
global_batch_size = local_batch_size * world_size
61129
offset = rank * local_batch_size
62-
# Lower range probably makes sense too. Anyway, stick to 100 as the original PoC.
63-
for step in range(100):
130+
for step in range(steps):
64131
# To make torch.randn produce same results across devices.
65132
torch.manual_seed(2022 + step)
66133

@@ -82,7 +149,9 @@ def ddp_correctness(init_file: str):
82149
# TODO(@alanwaketan): Investigate why the atol here is this low.
83150
assert torch.allclose(cpu_loss, ddp_loss, atol=1e-02)
84151
assert_all_close(cpu_model.parameters(), ddp_model.parameters())
85-
# To display the below messages, set '--verbosity=1'.
86-
logging.debug(
87-
"iteration %d: cpu_loss = %f, ddp_loss = %f, cpu_model.parameters() ~= ddp_model.parameters()",
88-
step, cpu_loss, ddp_loss)
152+
# To display the below messages, set '--debug'.
153+
# Here we don't use FLAGS.debug because this function is often ran in different processes than the launcher.
154+
if debug:
155+
print(
156+
f"iteration {step}: cpu_loss = {cpu_loss}, ddp_loss = {ddp_loss}, cpu_model.parameters() ~= ddp_model.parameters()"
157+
)

test/pjrt/test_ddp.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
xla_test_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
1111
sys.path.append(xla_test_folder)
1212

13+
import args_parse
1314
import distributed_util as util
1415

16+
FLAGS = args_parse.parse_common_options()
17+
1518

1619
class TestPjRtDistributedDataParallel(parameterized.TestCase):
1720

@@ -27,8 +30,17 @@ def test_ddp_init(self):
2730
pjrt._run_multiprocess(self._ddp_init, self.create_tempfile().full_path)
2831

2932
def test_ddp_correctness(self):
30-
pjrt._run_multiprocess(util.ddp_correctness,
31-
self.create_tempfile().full_path)
33+
pjrt._run_multiprocess(
34+
util.ddp_correctness,
35+
self.create_tempfile().full_path,
36+
debug=FLAGS.debug)
37+
38+
def test_ddp_correctness_large_net(self):
39+
pjrt._run_multiprocess(
40+
util.ddp_correctness,
41+
self.create_tempfile().full_path,
42+
use_large_net=True,
43+
debug=FLAGS.debug)
3244

3345

3446
if __name__ == "__main__":

test/test_ddp.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
import torch_xla.core.xla_model as xm
44
import torch_xla.distributed.xla_multiprocessing as xmp
55

6+
import args_parse
67
import distributed_util as util
78

9+
FLAGS = args_parse.parse_common_options()
10+
811

912
class TestXrtDistributedDataParallel(parameterized.TestCase):
1013

1114
@staticmethod
12-
def _ddp_correctness(rank):
15+
def _ddp_correctness(rank, use_large_net: bool, debug: bool):
1316
# We cannot run this guard before XMP,
1417
# see API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing.
1518
device = xm.xla_device()
@@ -19,10 +22,13 @@ def _ddp_correctness(rank):
1922
'Default device {} is not a TPU device'.format(device),
2023
file=sys.stderr)
2124
return
22-
util.ddp_correctness(None)
25+
util.ddp_correctness(None, use_large_net=use_large_net, debug=debug)
2326

2427
def test_ddp_correctness(self):
25-
xmp.spawn(self._ddp_correctness, args=())
28+
xmp.spawn(self._ddp_correctness, args=(False, FLAGS.debug))
29+
30+
def test_ddp_correctness_large_net(self):
31+
xmp.spawn(self._ddp_correctness, args=(True, FLAGS.debug))
2632

2733

2834
if __name__ == "__main__":

torch_xla/distributed/xla_multiprocessing.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ def _is_xla_config():
3232
return False
3333

3434

35+
# TODO: Some usages of this function are to caculate the number of hosts (a TPU concept),
36+
# and some are to caculate the number of processes within a world (which can span multiple hosts).
37+
# The latter should really be what this function is supposed to do. It's so confusing. We
38+
# should improve it.
3539
def _get_world_size():
3640
# We cannot use the xla_model.py API here, as the features used in that module
3741
# needs the setup provided by this one.
@@ -145,10 +149,12 @@ def _get_mp_device_ordinal(index, gindex):
145149
return index if xenv.HOST_ORDINAL in os.environ else gindex
146150

147151

148-
def _setup_workers(num_devices):
152+
# TODO: Consolidate this with _setup_gpu_worker.
153+
def _setup_gpu_workers(num_devices):
149154
world_size = _get_world_size()
150155
workers_env = os.environ.get(xenv.WORKERS, None)
151156
workers = []
157+
# TODO: Is this path actually being used? This seems to support multi-host GPUs (is this a thing at all?).
152158
if workers_env is not None:
153159
wcfg = _parse_workers_config(workers_env)
154160
assert world_size == len(
@@ -209,7 +215,7 @@ def _pre_fork_setup(num_devices):
209215
socket.getfqdn(),
210216
xu.get_free_tcp_ports()[0])
211217
if dev_kind == 'GPU':
212-
_setup_workers(num_devices)
218+
_setup_gpu_workers(num_devices)
213219
_create_gpu_devices(num_devices)
214220
elif dev_kind == 'CPU':
215221
_pre_fork_cpu_setup(num_devices)
@@ -377,7 +383,7 @@ def spawn(fn,
377383
Returns:
378384
The same object returned by the `torch.multiprocessing.spawn` API. If
379385
`nprocs` is 1 the `fn` function will be called directly, and the API will
380-
not return.
386+
return None.
381387
"""
382388
if pjrt.using_pjrt():
383389
return pjrt.spawn(fn, args)
@@ -387,17 +393,27 @@ def spawn(fn,
387393
return _run_direct(fn, args, nprocs, join, daemon, start_method)
388394

389395
pf_cfg = _pre_fork_setup(nprocs)
396+
result = None
390397
if pf_cfg.num_devices == 1:
391398
_start_fn(0, pf_cfg, fn, args)
392399
else:
393-
return torch.multiprocessing.start_processes(
400+
result = torch.multiprocessing.start_processes(
394401
_mp_start_fn,
395402
args=(pf_cfg, fn, args),
396403
nprocs=pf_cfg.num_devices,
397404
join=join,
398405
daemon=daemon,
399406
start_method=start_method)
400407

408+
# For GPU, xenv.WORKERS are set in the launcher and then get carried to the children.
409+
# However, if the launcher is reused to do another multi-process experiment, _setup_gpu_workers
410+
# would mistake the xenv.WORKERS as configured to enable multi-host experiments. Each worker then
411+
# represents a host. Therefore, reset it after launching all children.
412+
if pf_cfg.dev_kind == 'GPU':
413+
os.environ.pop(xenv.WORKERS)
414+
415+
return result
416+
401417

402418
class MpModelWrapper(object):
403419
"""Wraps a model to minimize host memory usage when `fork` method is used.

0 commit comments

Comments
 (0)