Skip to content

Commit b9c19b4

Browse files
authored
add test for dynamo + traceable collectives (#7745)
1 parent 0b99085 commit b9c19b4

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
import torch_xla
3+
from torch_xla import runtime as xr
4+
import torch_xla.core.xla_model as xm
5+
import torch_xla.debug.metrics as met
6+
7+
8+
def dummy_collective_fn(input):
9+
res_tensor = xm.all_reduce(xm.REDUCE_SUM, input)
10+
res_tensor += 3.0
11+
res_tensor = xm.all_gather(res_tensor, dim=0)
12+
return res_tensor
13+
14+
15+
def _mp_fn(index):
16+
device = xm.xla_device()
17+
world_size = xr.world_size()
18+
if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'):
19+
print(f'skip this test for hw {xm.xla_device_hw(device)}')
20+
return
21+
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
22+
for dynamic in [True, False]:
23+
met.clear_all()
24+
compiled_collective = torch.compile(
25+
dummy_collective_fn, backend="openxla", dynamic=dynamic)
26+
res_tensor = compiled_collective(ordinal_tensor)
27+
expected_tensor = torch.tensor(
28+
[world_size * world_size / 2] * world_size, dtype=torch.float) + 3.0
29+
torch_xla.sync()
30+
torch.allclose(res_tensor.cpu(), expected_tensor)
31+
assert met.metric_data("ExecuteTime")[0] == 1
32+
33+
34+
if __name__ == '__main__':
35+
torch_xla.launch(_mp_fn, args=())

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ function run_mp_op_tests {
296296
run_test "$CDIR/test_mp_save.py"
297297
run_test "$CDIR/test_mp_mesh_reduce.py"
298298
run_test "$CDIR/test_mp_sync_batch_norm.py"
299+
run_test "$CDIR/dynamo/test_traceable_collectives.py"
299300
run_test "$CDIR/test_fsdp_auto_wrap.py"
300301
# run_torchrun "$CDIR/test_mp_early_exit.py"
301302
run_pt_xla_debug "$CDIR/debug_tool/test_mp_pt_xla_debug.py"

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ python3 test/test_autocast.py
1818
python3 test/test_grad_checkpoint.py
1919
python3 test/dynamo/test_dynamo.py
2020
python3 test/dynamo/test_dynamo_dynamic_shape.py
21+
python3 test/dynamo/test_traceable_collectives.py
2122
python3 test/spmd/test_spmd_debugging.py
2223
XLA_PARAMETER_WRAPPING_THREADSHOLD=1 python test/spmd/test_spmd_parameter_wrapping.py
2324
python3 test/pjrt/test_dtypes.py

0 commit comments

Comments
 (0)