-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Description
🐛 Describe the bug
I'm reading the source code implementation of FSDP2. I noticed a detail that the unsharded param is obtained through split_with_sizes_copy when needed. Since this operation writes to the memory of the param, I expected it to cause an increase in the version counter, which would further lead to failure to pass the verification of the version counter during the backward pass. However, in fact, it works normally on the GPU.
This confuses me a lot. So I tested split_with_sizes_copy separately and found that its handling of the version counter is inconsistent on the CPU and GPU: on the GPU, the version counter of out is not updated correctly.
import torch
x = torch.ones(4, 2)
y = torch.ones(4, 1)
z = torch.ones(4, 1)
torch.split_with_sizes_copy(x, [1, 1], dim=1, out=[y, z])
print(y._version, z._version) # 1 1
x = torch.ones(4, 2).cuda()
y = torch.ones(4, 1).cuda()
z = torch.ones(4, 1).cuda()
torch.split_with_sizes_copy(x, [1, 1], dim=1, out=[y, z])
print(y._version, z._version) # 0 0
Versions
I tested the torch versions including the nightly(2.5.0.dev20240709+cu121) versions since release 2.3.1, and this problem exists in all of them.
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @zhaojuanmao @mrshenli @rohan-varma @chauhang