Skip to content

Conversation

will-cromar
Copy link
Collaborator

@will-cromar will-cromar commented Aug 1, 2022

  • Move TPU-specific logic from pjrt.py to tpu.py. All of this logic is broadly applicable to all TPU VMs and isn't strictly related to PjRt.
  • Update configure_topology to support multiple hosts.
  • Create unit tests for tpu.py (test_experimental_tpu.py). Everything that requires a TPU is mocked out, so this can run on CPU.
  • Create short integration test for TPU (test_experimental_pjrt_tpu.py). This one initializes the TPU runtime in each test, so this must run on a TPU. Added configs for v3-8 and v4-8. Tested manually on both.

I know the naming of these two new tests is confusing. Let me know if you have alternate suggestions.

PJRT doesn't make a distinction between processes on different hosts, and we already support multiple processes on one host, so no low-level changes were necessary. This PR mainly deals with automatically configuring the TPU topology variables based on the environment.

@will-cromar
Copy link
Collaborator Author

Simple test case to sanity-check that collectives work as expected:

$ gcloud compute tpus tpu-vm ssh --project=tpu-pytorch --zone=us-central2-b wcromar-v4-32 --internal-ip --worker=all --command 'PJRT_DEVICE=TPU python3 -c "
import torch_xla.core.xla_model as xm
import torch_xla.experimental.pjrt as pjrt
import torch
def f():
  ix = torch.ones([5], device=xm.xla_device()) * xm.get_ordinal()
  return xm.all_reduce(xm.REDUCE_SUM, ix).cpu().numpy()

print(pjrt.run_multiprocess(f))
"'
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
{3: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 0: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 2: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 1: {0: array([120., 120., 120., 120., 120.], dtype=float32)}}
{0: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 3: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 1: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 2: {0: array([120., 120., 120., 120., 120.], dtype=float32)}}
{1: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 2: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 3: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 0: {0: array([120., 120., 120., 120., 120.], dtype=float32)}}
{2: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 0: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 3: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 1: {0: array([120., 120., 120., 120., 120.], dtype=float32)}}

On a v4-32, there are 16 total workers numbered [0, 16), so the expected result is sum(range(16)) = 120.

Also confirmed that our ResNet50 example works on v4-16 and v4-32 and gets about the same per-chip performance as on v4-8 with fake data.

@will-cromar will-cromar marked this pull request as ready for review August 2, 2022 17:50
@ronghanghu
Copy link
Collaborator

ronghanghu commented Aug 3, 2022

Really looking forward to this PR!

Simple test case to sanity-check that collectives work as expected:

It would be great to also resolve the reduce_scatter, all_gather, and all_to_all collective ops in PJRT 😃

Currently all_reduce works well on v3-8 after #3704, but reduce_scatter still doesn't work well, and the all_gather under "pin_layout=False" doesn't work well (all_gather under "pin_layout=True" is actually using all_reduce).


For example, reduce_scatter seems to do the "reduce" part correctly, but the "scatter" part has a rank mismatching the rank from "xm.get_ordinal", as shown in the case below when running on a v3-8 TPU VM:

import torch
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt


def test_pjrt_collective_ops():
  rank = xm.get_ordinal()
  world_size = xm.xrt_world_size()
  device = xm.xla_device()

  t = torch.arange(16, dtype=torch.float32).view(8, 2)
  t = t.to(device)
  xm.mark_step()

  pin_layout = False
  reduce_scatter_out = xm.reduce_scatter(
    xm.REDUCE_SUM, t, scale=1.0 / world_size, scatter_dim=0, shard_count=world_size, pin_layout=pin_layout)
  xm.mark_step()
  print(f"rank {rank} of {world_size}:\nreduce_scatter_out (pin_layout={pin_layout}: {reduce_scatter_out}\n", end="", flush=True)
  xm.rendezvous(f"reduce_scatter_out pin_layout={pin_layout}")

  pin_layout = True
  reduce_scatter_out = xm.reduce_scatter(
    xm.REDUCE_SUM, t, scale=1.0 / world_size, scatter_dim=0, shard_count=world_size, pin_layout=pin_layout)
  xm.mark_step()
  print(f"rank {rank} of {world_size}:\nreduce_scatter_out (pin_layout={pin_layout}: {reduce_scatter_out}\n", end="", flush=True)
  xm.rendezvous(f"reduce_scatter_out pin_layout={pin_layout}")

  return 0.


if __name__ == '__main__':
  pjrt.run_multiprocess(test_pjrt_collective_ops)

which prints


rank 0 of 8:
reduce_scatter_out (pin_layout=False: tensor([[0., 1.]], device='xla:0')
rank 1 of 8:
reduce_scatter_out (pin_layout=False: tensor([[2., 3.]], device='xla:1')
rank 2 of 8:
reduce_scatter_out (pin_layout=False: tensor([[12., 13.]], device='xla:0')
rank 3 of 8:
reduce_scatter_out (pin_layout=False: tensor([[14., 15.]], device='xla:1')
rank 4 of 8:
reduce_scatter_out (pin_layout=False: tensor([[4., 5.]], device='xla:0')
rank 5 of 8:
reduce_scatter_out (pin_layout=False: tensor([[6., 7.]], device='xla:1')
rank 6 of 8:
reduce_scatter_out (pin_layout=False: tensor([[8., 9.]], device='xla:0')
rank 7 of 8:
reduce_scatter_out (pin_layout=False: tensor([[10., 11.]], device='xla:1')

rank 0 of 8:
reduce_scatter_out (pin_layout=True: tensor([[0., 1.]], device='xla:0')
rank 1 of 8:
reduce_scatter_out (pin_layout=True: tensor([[2., 3.]], device='xla:1')
rank 2 of 8:
reduce_scatter_out (pin_layout=True: tensor([[12., 13.]], device='xla:0')
rank 3 of 8:
reduce_scatter_out (pin_layout=True: tensor([[14., 15.]], device='xla:1')
rank 4 of 8:
reduce_scatter_out (pin_layout=True: tensor([[4., 5.]], device='xla:0')
rank 5 of 8:
reduce_scatter_out (pin_layout=True: tensor([[6., 7.]], device='xla:1')
rank 6 of 8:
reduce_scatter_out (pin_layout=True: tensor([[8., 9.]], device='xla:0')
rank 7 of 8:
reduce_scatter_out (pin_layout=True: tensor([[10., 11.]], device='xla:1')

One can see that the reduce-scatter outputs have mismatched scatter results from the rank. Similarly, all_gather under pin_layout=False has the same problem.

(Besides, xm.rendezvous is not working under PJRT yet -- it doesn't actually introduce a barrier across all ranks.)

@will-cromar
Copy link
Collaborator Author

@ronghanghu Thanks for flagging the issue with the other collectives. I did check all_gather as well, but I didn't think to try with pin_layout=False. This snippet gives the expected results:

def _mp_fn():
  device = xm.xla_device()

  ones = torch.ones((3), device=device) * xm.get_ordinal()

  res = xm.all_gather(ones, pin_layout=True)
  xm.mark_step()

  print(xm.get_ordinal(), res)


pjrt.run_multiprocess(_mp_fn)
$ PJRT_DEVICE=TPU python all_gather.py
0 tensor([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3.], device='xla:0')
2 tensor([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3.], device='xla:0')
3 tensor([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3.], device='xla:0')
1 tensor([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3.], device='xla:0')

But with pin_layout=False, I get this:

$ PJRT_DEVICE=TPU python all_gather.py
3 tensor([0., 0., 0., 3., 3., 3., 2., 2., 2., 1., 1., 1.], device='xla:0')
1 tensor([0., 0., 0., 3., 3., 3., 2., 2., 2., 1., 1., 1.], device='xla:0')
0 tensor([0., 0., 0., 3., 3., 3., 2., 2., 2., 1., 1., 1.], device='xla:0')
2 tensor([0., 0., 0., 3., 3., 3., 2., 2., 2., 1., 1., 1.], device='xla:0')

While working on the tests for this PR, I found that the TPU chips' device IDs don't actually match the device indices in TPU_VISIBLE_DEVICES (which do correspond to xm.get_ordinal). For example, on a v4-8, I found that the PjRt device IDs are ordered ['TPU:0', 'TPU:2', 'TPU:3', 'TPU:1']. I bet that is related. Although, I'm not entirely sure why pin_layout would affect the results like this.

Can you file an issue assigned to me to look into reduce_scatter, all_gather, and all_to_all?

@will-cromar
Copy link
Collaborator Author

Also,xm.rendezvous doesn't work yet, but we had another early tester tell us that they were able to work around it by creating a gloo process group and using dist.barrier

@ronghanghu
Copy link
Collaborator

ronghanghu commented Aug 3, 2022

Thanks @will-cromar, I'll submit a new issue for all-gather, all-reduce, and all-to-all under PJRT.

I found that the PjRt device IDs are ordered ['TPU:0', 'TPU:2', 'TPU:3', 'TPU:1']. I bet that is related.

Yeah, I think this is the underlying cause of the all_gather and reduce_scatter difference. It would be great to make them consistent (otherwise, many existing programs that relies on xm.reduce_scatter and other collective ops cannot work as expected).

Although, I'm not entirely sure why pin_layout would affect the results like this.

I think this is because xm.all_gather only performs an actual all-gather under pin_layout=False. Under pin_layout=True, the API xm.all_gather doesn't actually perform all-gather, but instead, it first pads the inputs with zeros and then performs all_reduce as in

if pin_layout and xla_device_hw(
value.device) in ('TPU', 'GPU') and output == None:
# There is not an easy way to pin the all_gather layout on TPU and GPU, use
# all_reduce based all_gather for this purpose.
return _all_gather_using_all_reduce(
value, dim=dim, groups=groups, pin_layout=True)
, as introduced in #3568 and motivated by previous issues such as #3511

@ronghanghu
Copy link
Collaborator

ronghanghu commented Aug 3, 2022

Also,xm.rendezvous doesn't work yet, but we had another early tester tell us that they were able to work around it by creating a gloo process group and using dist.barrier

This is good to know! I guess this is probably a bit harder to do with threads on TPU v3, though.

@will-cromar
Copy link
Collaborator Author

will-cromar commented Aug 3, 2022

barrier will almost certainly not work with threads if you use the global default process group (i.e. use init_process_group), because each thread will use the same PG. It might work if you call dist.new_group in each thread and pass that group into barrier directly.

There's an example in the docs of using nccl as the default process group and manually initializing a separate gloo process group to use with barrier: https://pytorch.org/docs/stable/distributed.html#monitored-barrier

@ronghanghu
Copy link
Collaborator

ronghanghu commented Aug 3, 2022

@will-cromar I created an issue in #3824 with a simple test example for all-gather, reduce-scatter, and all-to-all (but I cannot assign the issue to you since I don't have edit access to this repo 😅 )

@JackCaoG
Copy link
Collaborator

JackCaoG commented Aug 5, 2022

@ronghanghu I will give you write access 😄

@ronghanghu
Copy link
Collaborator

@ronghanghu I will give you write access 😄

Great, thank you!

Real devices are not deterministic.
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @will-cromar , mostly lgtm.

Args:
rank: rank of current process
processes: number of processes on this host
local_process: rank of current process within this host
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit local_process_rank maybe?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to local_rank to be consistent with torchrun

"""
if device_type() == 'TPU':
processes = num_visible_tpu_chips()
processes = tpu.num_local_processes()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, maybe num_processes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def num_local_processes(local_chips: int = 4) -> int:
"""Returns number of processes to create on this host."""
# Don't create more processes than local chips
return min(local_chips, process_bounds_size(default=local_chips))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should assert here if process_bounds_size(default=local_chips) > local_chips? Slightly worried that we hide the problem here but run into different issue later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Process bounds includes processes across all hosts. For example, the bounds for a v4-16 would be 2,2,2 (world size 8), but we can't create more than 4 processes on each host.

My logic here is:

  • If process bounds are less than the number of local chips (1,1,1,), then only create that number of processes locally
  • If the process bounds are larger than the number of local chips, only spawn at most one process per chip
  • If the process bounds are unset, default to one process per chip

This falls through in the case where the user has some unusual configuration (say, two chips per process), but in that case, they should set the TPU environment variables themselves to explicitly assign visible chips to processes rather than have us try to infer their topology. We have a bug filed internally to add a num_processes flag to pjrt.run_multiprocess if a user needs to explicitly select some other number of local processes.

process_endpoints = [
','.join(f'{ip}:{port}' for port in ports) for ip in worker_ips
]
os.environ.setdefault(xenv.TPU_PROCESS_ADDRESSES, ','.join(process_endpoints))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this for v4-8?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. All of the addresses will be on localhost, but each process needs to know the port of each other process. You can only skip this environment variable when there is one process per host.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @will-cromar ! Can you add a follow up README under pjrt repo to show how to run pod workload? Would be nice if we also include how to run donut workload using pjrt in that README too.

@will-cromar
Copy link
Collaborator Author

Thanks @JackCaoG. I'll work on a README showing how to port from XRT to PjRt and how to run models without xla_dist.

@will-cromar will-cromar merged commit 63c67df into master Aug 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants