Skip to content

FSDP2 and autocast compatibility issue #148831

@yjxiong

Description

@yjxiong

🐛 Describe the bug

Bug Report: Composable FSDP (fully_shard) Loses Autocast Context During Checkpoint Recomputation

Description

When using the composable FSDP API (fully_shard) with torch.utils.checkpoint.checkpoint, the autocast context is lost during the recomputation phase of the backward pass. This causes inputs to have different dtypes during the recomputation compared to the original forward pass, leading to tensor metadata mismatch errors.

The classic FSDP API (FullyShardedDataParallel class) does not have this issue and correctly preserves the autocast context during checkpoint recomputation.

Steps to Reproduce

The following minimal reproducible sample demonstrates the issue by comparing both FSDP implementations with checkpointing under autocast:

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
from torch.distributed.fsdp import MixedPrecision

class CheckpointedModule(nn.Module):
    def __init__(self, wrap_type: str = "FSDP"):
        super().__init__()
        self.wrap_type = wrap_type
        self.linear1 = nn.Linear(2048, 2048)
    
    def forward(self, x):
        print(f"wrap_type: {self.wrap_type} forward input dtype {x.dtype}, shape {x.shape}")
        x = nn.functional.gelu(x)
        print(f"wrap_type: {self.wrap_type} gelu output dtype {x.dtype}, shape {x.shape}")
        x = self.linear1(x)
        print(f"wrap_type: {self.wrap_type} linear1 output dtype {x.dtype}, shape {x.shape}")
        return x

def init_process_group():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(dist.get_rank())

def reproduce_fsdp_checkpoint_dtype_bug():
    # Initialize process group (single process for simplicity)
    init_process_group()
    
    # FSDP1 mixed precision
    mp_policy_fsdp = MixedPrecision(
        param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16
    )
    
    # FSDP2 mixed precision
    mp_policy_fsdp2 = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16)
    
    # FSDP1
    fsdp_model = FSDP(CheckpointedModule(wrap_type="FSDP"), mixed_precision=mp_policy_fsdp, device_id=dist.get_rank())
    
    # FSDP2
    fsdp2_model = CheckpointedModule(wrap_type="fully_shard")
    fully_shard(fsdp2_model, mp_policy=mp_policy_fsdp2)
    
    # Create input
    x = torch.randn(256, 2048).cuda()
    x.requires_grad = True
    
    # Run with autocast, FSDP
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        # Forward pass
        out = checkpoint(fsdp_model, x, use_reentrant=False)
        
        # Backward pass to trigger recomputation
        loss = out.sum()
        loss.backward()
    
    # Run with autocast, fully_shard
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        # Forward pass
        out = checkpoint(fsdp2_model, x, use_reentrant=False)
        
        # Backward pass
        loss = out.sum()
        loss.backward()
    
    dist.destroy_process_group()

if __name__ == "__main__":
    reproduce_fsdp_checkpoint_dtype_bug()

Expected Behavior

Both FSDP implementations should maintain consistent dtypes between the original forward pass and the recomputation during backward. The autocast context should be preserved during checkpoint recomputation.

Actual Behavior

  1. With classic FSDP (FullyShardedDataParallel), the dtype remains consistent between forward pass and recomputation.
  2. With composable FSDP (fully_shard), the autocast context is lost during recomputation, causing inputs to have different dtypes in recompute than they did in the original forward pass. This causes metadata mismatch error.

During recomputation, you can observe that inputs to the fully_shard model revert to float32 instead of maintaining bfloat16, which can lead to tensor metadata mismatch errors.

Environment

  • PyTorch version: 2.5
  • CUDA version: 12.4
  • GPU type: NVIDIA H100
  • OS: Linux Ubuntu 22

Versions

Environment

  • PyTorch version: 2.5
  • CUDA version: 12.4
  • GPU type: NVIDIA H100
  • OS: Linux Ubuntu 22

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @zhaojuanmao @mrshenli @rohan-varma @chauhang @mori360 @kwen2501 @c-p-i-o

Metadata

Metadata

Assignees

Labels

module: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions