-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Description
🐛 Describe the bug
When using data parallel (DP, using fully_shard
here) and tensor parallel (TP), if TP is applied to only a subset of layers such that some have only DP applied, the DP only parameters for the same DP rank but different TP rank slowly diverge from each other in the presence of small amounts of non-determinism, such as the one introduced by non-deterministic implementation of certain operations.
To give a practical example of where this is an issue, in my company we are working on VLM models such that the visual encoder is small and doesn't necessarily benefit from TP, or implementing TP may be non-trivial, whereas we use TP on the language model. In that situation, we found with larger models that the difference in weights for the visual encoder for the same DP rank but different TP rank become quickly significant, breaking training entirely. Training appears to run properly (loss descending etc.), but when saving checkpoints with DCP for instance, loading them leads to very poor results.
Here are 3 workarounds I found, although they have significant drawbacks:
- Ensuring deterministic implementations are used throughout the model
- Although Pytorch itself makes this fairly easy and usually not too expensive, third party operations may or may not provide deterministic implementations, and ensuring your own operations are deterministic is non trivial.
- I also think that Pytorch should ideally support mixing DP and TP even in the presence of small amounts of non determinism.
- Manually synchronizing the gradients at every iteration across the TP process group
- I share an implementation in the below reproducing code.
- With this implementation, this workaround can be quite expensive. For our internal use cases, it reduces overall training throughput by about 15%.
- Applying TP to all layers of the model
- This may not make sense from a performance standpoint for small sub models.
- This may be non-trivial to implement correctly and maintain.
I would be happy if I could be pointed towards better workarounds.
I am not sure whether this is properly a Pytorch bug, or whether Pytorch does not aim to support such use cases (TP + DP with DP-only parameters in the presence of non-deterministic operations). In the latter case, I think this should be documented.
Here is a toy example reproducing this behavior, based on the Pytorch tutorial, simulating a very small amount of non-determinism:
from argparse import ArgumentParser
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets
from torchvision.transforms import ToTensor
class NeuralNetwork(nn.Module):
def __init__(self, simulate_non_determinism: bool):
super().__init__()
self.flatten = nn.Flatten()
self.dp1 = nn.Linear(28 * 28, 512)
self.dp2 = nn.Linear(512, 512)
self.tp_colwise = nn.Linear(512, 512)
self.tp_rowwise = nn.Linear(512, 10)
self.simulate_non_determinism = simulate_non_determinism
def forward(self, x):
eps = 1e-5
x = self.flatten(x)
x = F.relu(self.dp1(x))
if self.simulate_non_determinism:
x = x + torch.rand_like(x) * eps
x = F.relu(self.dp2(x))
x = F.relu(self.tp_colwise(x))
logits = self.tp_rowwise(x)
return logits
def train_loop(
dataloader, model: NeuralNetwork, loss_fn, optimizer, sync_grad: bool, mesh: DeviceMesh
):
size = len(dataloader)
# Set the model to training mode - important for batch normalization and dropout layers
# Unnecessary in this situation but added for best practices
model.train()
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X.to("cuda"))
loss = loss_fn(pred, y.to("cuda"))
# Backpropagation
loss.backward()
if sync_grad:
# Workaround: synchronize gradients across the TP group.
for param in (p for p in model.parameters() if p.requires_grad):
if isinstance(param.grad, DTensor) and param.grad.device_mesh == mesh["dp"]:
dist.all_reduce(
param.grad._local_tensor,
op=dist.ReduceOp.AVG,
group=mesh["tp"].get_group(),
)
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0 and dist.get_rank() == 0:
loss, current = loss.detach().item(), batch + len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test_loop(
dataloader,
model,
loss_fn,
mesh: DeviceMesh,
):
# Set the model to evaluation mode - important for batch normalization and dropout layers
# Unnecessary in this situation but added for best practices
model.eval()
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
# Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
# also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
with torch.no_grad():
for X, y in dataloader:
y = y.to("cuda")
pred = model(X.to("cuda"))
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
correct_tensor = torch.tensor(correct, device="cuda", requires_grad=False)
dist.all_reduce(correct_tensor, group=mesh["dp"].get_group())
test_loss /= num_batches
correct_tensor /= size
if dist.get_rank() == 0:
print(
f"Test Error: \n Accuracy: {(100 * correct_tensor.item()):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
def main() -> None:
parser = ArgumentParser()
parser.add_argument("--deterministic", action="store_true")
parser.add_argument("--sync-grad", action="store_true")
args = parser.parse_args()
deterministic: bool = args.deterministic
sync_grad: bool = args.sync_grad
torch.manual_seed(42)
mesh = init_device_mesh(device_type="cuda", mesh_shape=(2, 2), mesh_dim_names=("dp", "tp"))
training_data = datasets.FashionMNIST(
root="data", train=True, download=True, transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data", train=False, download=True, transform=ToTensor()
)
total_batch_size = 64
dp_size = mesh["dp"].size()
assert total_batch_size % dp_size == 0
batch_size = total_batch_size // dp_size
train_dataloader = DataLoader(
training_data,
batch_size=batch_size,
sampler=DistributedSampler(
dataset=training_data,
num_replicas=dp_size,
rank=mesh["dp"].get_local_rank(),
shuffle=False,
),
)
test_dataloader = DataLoader(
test_data,
batch_size=batch_size,
sampler=DistributedSampler(
dataset=test_data,
num_replicas=dp_size,
rank=mesh["dp"].get_local_rank(),
shuffle=False,
),
)
model = NeuralNetwork(not deterministic)
model = parallelize_module(
model,
parallelize_plan={"tp_colwise": ColwiseParallel(), "tp_rowwise": RowwiseParallel()},
device_mesh=mesh["tp"],
)
model: NeuralNetwork = fully_shard(model, mesh=mesh["dp"])
model.to("cuda")
learning_rate = 1e-3
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# Purposefully using different seeds to simulate non-determinism.
torch.manual_seed(mesh.get_rank())
for epoch in range(10):
train_loop(train_dataloader, model, loss_fn, optimizer, sync_grad, mesh)
for name, param in model.named_parameters():
assert isinstance(param.data, DTensor)
if param.data.device_mesh != mesh["dp"]:
continue
param_data = param.data.to_local()
gather_list = [torch.empty_like(param_data) for _ in range(mesh["tp"].size())]
dist.all_gather(gather_list, param_data, group=mesh["tp"].get_group())
if mesh["tp"].get_local_rank() == 0:
expected = gather_list[0]
for actual in gather_list[1:]:
torch.testing.assert_close(
actual, expected, msg=lambda text: f"{mesh.get_coordinate()}: {name}\n{text}"
)
test_loop(test_dataloader, model, loss_fn, mesh)
if mesh.get_rank() == 0:
print("Done!")
if __name__ == "__main__":
main()
On a machine with 4 GPUs, running using torchrun --standalone --nnodes=1 --nproc-per-node=4
, this fails with the below error. When specifying --deterministic
or --sync-grad
, the error will not occur.
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/user02107/vlm-train/tp_with_dp_only.py", line 194, in <module>
[rank0]: main()
[rank0]: File "/home/user02107/vlm-train/tp_with_dp_only.py", line 183, in main
[rank0]: torch.testing.assert_close(
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/testing/_comparison.py", line 1519, in assert_close
[rank0]: raise error_metas[0].to_error(msg)
[rank0]: AssertionError: [0, 0]: dp2.weight
[rank0]: Tensor-likes are not close!
[rank0]: Mismatched elements: 15 / 131072 (0.0%)
[rank0]: Greatest absolute difference: 1.3850629329681396e-05 at index (201, 475) (up to 1e-05 allowed)
[rank0]: Greatest relative difference: 0.012554704211652279 at index (189, 475) (up to 1.3e-06 allowed)
[rank2]: Traceback (most recent call last):
[rank2]: File "/home/user02107/vlm-train/tp_with_dp_only.py", line 194, in <module>
[rank2]: main()
[rank2]: File "/home/user02107/vlm-train/tp_with_dp_only.py", line 183, in main
[rank2]: torch.testing.assert_close(
[rank2]: File "/usr/local/lib/python3.12/dist-packages/torch/testing/_comparison.py", line 1519, in assert_close
[rank2]: raise error_metas[0].to_error(msg)
[rank2]: AssertionError: [1, 0]: dp1.weight
[rank2]: Tensor-likes are not close!
[rank2]: Mismatched elements: 138 / 200704 (0.1%)
[rank2]: Greatest absolute difference: 1.9358471035957336e-05 at index (219, 602) (up to 1e-05 allowed)
[rank2]: Greatest relative difference: 0.07610764354467392 at index (219, 523) (up to 1.3e-06 allowed)
[rank0]:[W808 14:24:16.420236322 ProcessGroupNCCL.cpp:1505] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank2]:[W808 14:24:16.636097906 ProcessGroupNCCL.cpp:1505] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
W0808 14:24:18.108000 226334 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 226342 closing signal SIGTERM
W0808 14:24:18.109000 226334 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 226344 closing signal SIGTERM
E0808 14:24:18.727000 226334 torch/distributed/elastic/multiprocessing/api.py:874] failed (exitcode: 1) local_rank: 0 (pid: 226341) of binary: /usr/bin/python
Traceback (most recent call last):
File "/usr/local/bin/torchrun", line 8, in <module>
sys.exit(main())
^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 892, in main
run(args)
File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 883, in run
elastic_launch(
File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 139, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 270, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
tp_with_dp_only.py FAILED
------------------------------------------------------------
Failures:
[1]:
time : 2025-08-08_14:24:18
host : srdgx00344.cm.cluster
rank : 2 (local_rank: 2)
exitcode : 1 (pid: 226343)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2025-08-08_14:24:18
host : srdgx00344.cm.cluster
rank : 0 (local_rank: 0)
exitcode : 1 (pid: 226341)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
I think this issue may also be caused by this though I cannot be sure.
Versions
$ python collect_env.py
Collecting environment information...
PyTorch version: 2.8.0a0+5228986c39.nv25.06
Is debug build: False
CUDA used to build PyTorch: 12.9
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.2 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 3.31.6
Libc version: glibc-2.39
Python version: 3.12.3 (main, Feb 4 2025, 14:48:35) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1053-nvidia-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.9.86
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
Nvidia driver version: 535.161.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.10.2
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 224
On-line CPU(s) list: 0-223
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8480C
CPU family: 6
Model: 143
Thread(s) per core: 2
Core(s) per socket: 56
Socket(s): 2
Stepping: 8
CPU(s) scaling MHz: 96%
CPU max MHz: 3800.0000
CPU min MHz: 800.0000
BogoMIPS: 4000.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 5.3 MiB (112 instances)
L1i cache: 3.5 MiB (112 instances)
L2 cache: 224 MiB (112 instances)
L3 cache: 210 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-55,112-167
NUMA node1 CPU(s): 56-111,168-223
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] intel-openmp==2021.4.0
[pip3] mkl==2021.1.1
[pip3] mkl-devel==2021.1.1
[pip3] mkl-include==2021.1.1
[pip3] mypy_extensions==1.1.0
[pip3] numpy==1.26.4
[pip3] nvidia-cudnn-frontend==1.12.0
[pip3] nvtx==0.2.11
[pip3] onnx==1.17.0
[pip3] open-clip-torch==2.24.0
[pip3] optree==0.16.0
[pip3] pynvjitlink==0.3.0
[pip3] pytorch-triton==3.3.0+git96316ce52.nvinternal
[pip3] tbb==2021.13.1
[pip3] torch==2.8.0a0+5228986c39.nv25.6
[pip3] torch_tensorrt==2.8.0a0
[pip3] torchao==0.11.0+git
[pip3] torchinfo==1.8.0
[pip3] torchprofile==0.0.4
[pip3] torchvision==0.22.0a0+95f10a4e
[conda] Could not collect
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @zhaojuanmao @mrshenli @rohan-varma @chauhang @mori360