Skip to content

[FSDP2 Related]torch.split_with_sizes_copy of the GPU does not update the version counter of out correctly. #132014

@medivh-xp

Description

@medivh-xp

🐛 Describe the bug

I'm reading the source code implementation of FSDP2. I noticed a detail that the unsharded param is obtained through split_with_sizes_copy when needed. Since this operation writes to the memory of the param, I expected it to cause an increase in the version counter, which would further lead to failure to pass the verification of the version counter during the backward pass. However, in fact, it works normally on the GPU.

This confuses me a lot. So I tested split_with_sizes_copy separately and found that its handling of the version counter is inconsistent on the CPU and GPU: on the GPU, the version counter of out is not updated correctly.

import torch

x = torch.ones(4, 2)
y = torch.ones(4, 1)
z = torch.ones(4, 1)
torch.split_with_sizes_copy(x, [1, 1], dim=1, out=[y, z])
print(y._version, z._version)  # 1 1

x = torch.ones(4, 2).cuda()
y = torch.ones(4, 1).cuda()
z = torch.ones(4, 1).cuda()
torch.split_with_sizes_copy(x, [1, 1], dim=1, out=[y, z])
print(y._version, z._version)  # 0 0

Versions

I tested the torch versions including the nightly(2.5.0.dev20240709+cu121) versions since release 2.3.1, and this problem exists in all of them.

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @zhaojuanmao @mrshenli @rohan-varma @chauhang

Metadata

Metadata

Assignees

Labels

module: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions