-
Notifications
You must be signed in to change notification settings - Fork 566
TPU Pod support with PjRt #3813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Simple test case to sanity-check that collectives work as expected:
On a v4-32, there are 16 total workers numbered Also confirmed that our ResNet50 example works on v4-16 and v4-32 and gets about the same per-chip performance as on v4-8 with fake data. |
Really looking forward to this PR!
It would be great to also resolve the Currently For example, reduce_scatter seems to do the "reduce" part correctly, but the "scatter" part has a rank mismatching the rank from "xm.get_ordinal", as shown in the case below when running on a v3-8 TPU VM: import torch
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt
def test_pjrt_collective_ops():
rank = xm.get_ordinal()
world_size = xm.xrt_world_size()
device = xm.xla_device()
t = torch.arange(16, dtype=torch.float32).view(8, 2)
t = t.to(device)
xm.mark_step()
pin_layout = False
reduce_scatter_out = xm.reduce_scatter(
xm.REDUCE_SUM, t, scale=1.0 / world_size, scatter_dim=0, shard_count=world_size, pin_layout=pin_layout)
xm.mark_step()
print(f"rank {rank} of {world_size}:\nreduce_scatter_out (pin_layout={pin_layout}: {reduce_scatter_out}\n", end="", flush=True)
xm.rendezvous(f"reduce_scatter_out pin_layout={pin_layout}")
pin_layout = True
reduce_scatter_out = xm.reduce_scatter(
xm.REDUCE_SUM, t, scale=1.0 / world_size, scatter_dim=0, shard_count=world_size, pin_layout=pin_layout)
xm.mark_step()
print(f"rank {rank} of {world_size}:\nreduce_scatter_out (pin_layout={pin_layout}: {reduce_scatter_out}\n", end="", flush=True)
xm.rendezvous(f"reduce_scatter_out pin_layout={pin_layout}")
return 0.
if __name__ == '__main__':
pjrt.run_multiprocess(test_pjrt_collective_ops) which prints
One can see that the reduce-scatter outputs have mismatched scatter results from the rank. Similarly, (Besides, |
@ronghanghu Thanks for flagging the issue with the other collectives. I did check
But with
While working on the tests for this PR, I found that the TPU chips' device IDs don't actually match the device indices in Can you file an issue assigned to me to look into |
Also, |
Thanks @will-cromar, I'll submit a new issue for all-gather, all-reduce, and all-to-all under PJRT.
Yeah, I think this is the underlying cause of the all_gather and reduce_scatter difference. It would be great to make them consistent (otherwise, many existing programs that relies on
I think this is because xla/torch_xla/core/xla_model.py Lines 673 to 678 in ff2e0ea
|
This is good to know! I guess this is probably a bit harder to do with threads on TPU v3, though. |
There's an example in the docs of using |
@will-cromar I created an issue in #3824 with a simple test example for all-gather, reduce-scatter, and all-to-all (but I cannot assign the issue to you since I don't have edit access to this repo 😅 ) |
@ronghanghu I will give you write access 😄 |
Great, thank you! |
4b9e2e5
to
215c9c5
Compare
Real devices are not deterministic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @will-cromar , mostly lgtm.
torch_xla/experimental/pjrt.py
Outdated
Args: | ||
rank: rank of current process | ||
processes: number of processes on this host | ||
local_process: rank of current process within this host |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit local_process_rank
maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to local_rank
to be consistent with torchrun
torch_xla/experimental/pjrt.py
Outdated
""" | ||
if device_type() == 'TPU': | ||
processes = num_visible_tpu_chips() | ||
processes = tpu.num_local_processes() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, maybe num_processes
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
def num_local_processes(local_chips: int = 4) -> int: | ||
"""Returns number of processes to create on this host.""" | ||
# Don't create more processes than local chips | ||
return min(local_chips, process_bounds_size(default=local_chips)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should assert here if process_bounds_size(default=local_chips)
> local_chips
? Slightly worried that we hide the problem here but run into different issue later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Process bounds includes processes across all hosts. For example, the bounds for a v4-16 would be 2,2,2
(world size 8), but we can't create more than 4 processes on each host.
My logic here is:
- If process bounds are less than the number of local chips (
1,1,1,
), then only create that number of processes locally - If the process bounds are larger than the number of local chips, only spawn at most one process per chip
- If the process bounds are unset, default to one process per chip
This falls through in the case where the user has some unusual configuration (say, two chips per process), but in that case, they should set the TPU environment variables themselves to explicitly assign visible chips to processes rather than have us try to infer their topology. We have a bug filed internally to add a num_processes
flag to pjrt.run_multiprocess
if a user needs to explicitly select some other number of local processes.
process_endpoints = [ | ||
','.join(f'{ip}:{port}' for port in ports) for ip in worker_ips | ||
] | ||
os.environ.setdefault(xenv.TPU_PROCESS_ADDRESSES, ','.join(process_endpoints)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this for v4-8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. All of the addresses will be on localhost
, but each process needs to know the port of each other process. You can only skip this environment variable when there is one process per host.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @will-cromar ! Can you add a follow up README under pjrt repo to show how to run pod workload? Would be nice if we also include how to run donut workload using pjrt in that README too.
Thanks @JackCaoG. I'll work on a README showing how to port from XRT to PjRt and how to run models without |
pjrt.py
totpu.py
. All of this logic is broadly applicable to all TPU VMs and isn't strictly related to PjRt.configure_topology
to support multiple hosts.tpu.py
(test_experimental_tpu.py
). Everything that requires a TPU is mocked out, so this can run on CPU.test_experimental_pjrt_tpu.py
). This one initializes the TPU runtime in each test, so this must run on a TPU. Added configs for v3-8 and v4-8. Tested manually on both.I know the naming of these two new tests is confusing. Let me know if you have alternate suggestions.
PJRT doesn't make a distinction between processes on different hosts, and we already support multiple processes on one host, so no low-level changes were necessary. This PR mainly deals with automatically configuring the TPU topology variables based on the environment.