-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Description
🐛 Describe the bug
Bug Report: Composable FSDP (fully_shard
) Loses Autocast Context During Checkpoint Recomputation
Description
When using the composable FSDP API (fully_shard
) with torch.utils.checkpoint.checkpoint
, the autocast context is lost during the recomputation phase of the backward pass. This causes inputs to have different dtypes during the recomputation compared to the original forward pass, leading to tensor metadata mismatch errors.
The classic FSDP API (FullyShardedDataParallel
class) does not have this issue and correctly preserves the autocast context during checkpoint recomputation.
Steps to Reproduce
The following minimal reproducible sample demonstrates the issue by comparing both FSDP implementations with checkpointing under autocast:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
from torch.distributed.fsdp import MixedPrecision
class CheckpointedModule(nn.Module):
def __init__(self, wrap_type: str = "FSDP"):
super().__init__()
self.wrap_type = wrap_type
self.linear1 = nn.Linear(2048, 2048)
def forward(self, x):
print(f"wrap_type: {self.wrap_type} forward input dtype {x.dtype}, shape {x.shape}")
x = nn.functional.gelu(x)
print(f"wrap_type: {self.wrap_type} gelu output dtype {x.dtype}, shape {x.shape}")
x = self.linear1(x)
print(f"wrap_type: {self.wrap_type} linear1 output dtype {x.dtype}, shape {x.shape}")
return x
def init_process_group():
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
def reproduce_fsdp_checkpoint_dtype_bug():
# Initialize process group (single process for simplicity)
init_process_group()
# FSDP1 mixed precision
mp_policy_fsdp = MixedPrecision(
param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16
)
# FSDP2 mixed precision
mp_policy_fsdp2 = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16)
# FSDP1
fsdp_model = FSDP(CheckpointedModule(wrap_type="FSDP"), mixed_precision=mp_policy_fsdp, device_id=dist.get_rank())
# FSDP2
fsdp2_model = CheckpointedModule(wrap_type="fully_shard")
fully_shard(fsdp2_model, mp_policy=mp_policy_fsdp2)
# Create input
x = torch.randn(256, 2048).cuda()
x.requires_grad = True
# Run with autocast, FSDP
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
# Forward pass
out = checkpoint(fsdp_model, x, use_reentrant=False)
# Backward pass to trigger recomputation
loss = out.sum()
loss.backward()
# Run with autocast, fully_shard
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
# Forward pass
out = checkpoint(fsdp2_model, x, use_reentrant=False)
# Backward pass
loss = out.sum()
loss.backward()
dist.destroy_process_group()
if __name__ == "__main__":
reproduce_fsdp_checkpoint_dtype_bug()
Expected Behavior
Both FSDP implementations should maintain consistent dtypes between the original forward pass and the recomputation during backward. The autocast context should be preserved during checkpoint recomputation.
Actual Behavior
- With classic FSDP (
FullyShardedDataParallel
), the dtype remains consistent between forward pass and recomputation. - With composable FSDP (
fully_shard
), the autocast context is lost during recomputation, causing inputs to have different dtypes in recompute than they did in the original forward pass. This causes metadata mismatch error.
During recomputation, you can observe that inputs to the fully_shard
model revert to float32 instead of maintaining bfloat16, which can lead to tensor metadata mismatch errors.
Environment
- PyTorch version: 2.5
- CUDA version: 12.4
- GPU type: NVIDIA H100
- OS: Linux Ubuntu 22
Versions
Environment
- PyTorch version: 2.5
- CUDA version: 12.4
- GPU type: NVIDIA H100
- OS: Linux Ubuntu 22
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @zhaojuanmao @mrshenli @rohan-varma @chauhang @mori360 @kwen2501 @c-p-i-o