-
Notifications
You must be signed in to change notification settings - Fork 561
Description
📚 Documentation
[Feature Request / Documentation Improvement] Improve PyTorch/XLA Documentation and Clarify SPMD Usage
Hello PyTorch/XLA team,
During my TPU grant I encountered many undocumented pitfalls and unclear behaviors, which made the setup process very time-consuming and confusing.
I’d like to ask for clarification and improvement on several key points that caused me significant confusion and wasted time.
Perhaps the documentation seems clear to experienced users, but when reading it for the first time, there are many implicit assumptions and missing explanations.
General Request
Please improve the documentation — make it more explicit and practical, especially for multi-host and SPMD setups.
For example, while it’s indeed mentioned in the Running on TPU Pods section that the code must be launched on all hosts, this information is buried too deep and is not referenced in other critical sections like “Troubleshooting Basics.”
It would be much clearer if you placed a visible note near the top of documentation saying something like:
⚠️ For multi-host TPU setups, you must launch the code on all hosts simultaneously.
See Running on TPU Pods (multi-host) for details.
This would help avoid confusion, since right now it’s easy to miss and leads to situations where the code just hangs with no clear reason.
Specific Questions and Issues
- What is recommended to use —
.launchorspmd? - Should SPMD be started on all hosts as well?
- In SPMD, is the batch size global or per-host?
- How is data distributed if each process sees all devices and I have 4 hosts with 4 devices each?
- If the batch size is global, what is the purpose of having multiple hosts? Only for data loading?
- How does XLA decide what data goes to which device — does it shard across all devices globally or only locally per host?
- How to correctly use
scan/scan_layersif the transformer block takes multiple arguments and one of them is of typetorch.bool? assume_pureseems to break if the model containsnn.Parameter. Is it even correct to use it like that?- Can I reuse “params and buffers” between steps, or should I retrieve them every time before a training pass?
syncfree.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0)seems to trigger recompilation around step ~323 (possibly due tobeta2, not sure).- In SPMD, how to correctly get the process ID?
world_sizeandglobal_ordinaldon’t work. Should I useprocess_index?is_master_ordinal(local=False)also doesn’t work. - Please add a note to the docs: when logging, it’s better to use
flush=True, otherwise logs might not appear (which is confusing). Also, wrap training code intry/except, since exceptions sometimes don’t log either. - How can I perform sampling and logging in SPMD mode if I want only one host to handle these tasks (not all hosts)?
- Please provide fully explicit examples — with comments, no abstractions, step-by-step explanations of what each part does and how it can be modified.
- Compilation caching seems broken — when trying to load, it says “not implemented.”
- Can I pass only one
input_sharding=xs.ShardingSpec(mesh, ('fsdp', None))toMpDeviceLoaderif my dataset returns a tuple of 10 tensors with different shapes? xm.rendezvousseems to do nothing in SPMD mode (at least before the training loop).- How to verify that all hosts are actually training one shared model, and not each training separately?
- In the docs,
HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))is shown,
but in practice it only works if you pass named arguments likeici_mesh_shape=ici_mesh_shape, otherwise it errors out. - How to correctly do gradient checkpointing per layer with FSDP?
- How to correctly do gradient clipping?
- If model weights are expected to remain in FP32 when using
autocast, please explicitly state that in the training docs — it would help avoid second-guessing. - What is a reasonable compilation time during training? Mine can take 20–30 minutes.
- What are the actual intended purposes of
torch_xla.step()andtorch_xla.compile()?- Since PyTorch/XLA already compiles and executes lazily, it’s unclear when and why these should be used explicitly.
All of this was tested on v4-32 TPU.
Maybe some of it is covered somewhere in the docs and I just missed it, but I hope you can clarify and improve the documentation.
Thank you for your time and support.