| 
 | 1 | +import torch  | 
 | 2 | +import torch_xla  | 
 | 3 | +from torch_xla import runtime as xr  | 
 | 4 | +import torch_xla.core.xla_model as xm  | 
 | 5 | +import torch_xla.debug.metrics as met  | 
 | 6 | + | 
 | 7 | + | 
 | 8 | +def dummy_collective_fn(input):  | 
 | 9 | +  res_tensor = xm.all_reduce(xm.REDUCE_SUM, input)  | 
 | 10 | +  res_tensor += 3.0  | 
 | 11 | +  res_tensor = xm.all_gather(res_tensor, dim=0)  | 
 | 12 | +  return res_tensor  | 
 | 13 | + | 
 | 14 | + | 
 | 15 | +def _mp_fn(index):  | 
 | 16 | +  device = xm.xla_device()  | 
 | 17 | +  world_size = xr.world_size()  | 
 | 18 | +  if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'):  | 
 | 19 | +    print(f'skip this test for hw {xm.xla_device_hw(device)}')  | 
 | 20 | +    return  | 
 | 21 | +  ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)  | 
 | 22 | +  for dynamic in [True, False]:  | 
 | 23 | +    met.clear_all()  | 
 | 24 | +    compiled_collective = torch.compile(  | 
 | 25 | +        dummy_collective_fn, backend="openxla", dynamic=dynamic)  | 
 | 26 | +    res_tensor = compiled_collective(ordinal_tensor)  | 
 | 27 | +    expected_tensor = torch.tensor(  | 
 | 28 | +        [world_size * world_size / 2] * world_size, dtype=torch.float) + 3.0  | 
 | 29 | +    torch_xla.sync()  | 
 | 30 | +    torch.allclose(res_tensor.cpu(), expected_tensor)  | 
 | 31 | +    assert met.metric_data("ExecuteTime")[0] == 1  | 
 | 32 | + | 
 | 33 | + | 
 | 34 | +if __name__ == '__main__':  | 
 | 35 | +  torch_xla.launch(_mp_fn, args=())  | 
0 commit comments