Skip to content

[FSDP2] allow different dtypes for the model params with gradients #156784

@twoflypig

Description

@twoflypig

🚀 The feature, motivation and pitch

Summary

Currently, torch's FSDP2 (Fully Sharded Data Parallel 2) does not support having multiple different data types (dtypes) for parameters within the same module. This limitation restricts the flexibility and efficiency of model training, especially in scenarios where mixed-precision training with different dtypes is highly beneficial. This issue serves as a feature request to address this shortcoming and enhance the capabilities of FSDP2.​

Current Problem​

In real-world deep learning applications, many models require the use of multiple dtypes within a single module. For example, in mixed-precision training, certain layers might benefit from using lower precision dtypes like torch.float16 or torch.bfloat16 to reduce memory usage and accelerate computations, while other layers, especially those involving critical calculations or handling small gradients, might need to retain higher precision such as torch.float32 to ensure numerical stability. However, with the current implementation of FSDP2, attempting to use a module with parameters of different dtypes results in errors or unexpected behavior. This lack of support forces developers to either compromise on model performance by using a single dtype throughout the module or implement complex and error-prone workarounds to split the module into smaller, dtype-homogeneous sub-modules, which can disrupt the overall model architecture and training workflow.

Here is the minimual code to reproduce the error

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.optim.lr_scheduler import StepLR

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 20)
        self.layer2 = nn.Linear(20, 30)
        self.layer3 = nn.Linear(30, 10, dtype=torch.bfloat16)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.layer1(x))
        x = self.relu(self.layer2(x))
        x = self.layer3(x.to(torch.bfloat16)).float()
        return x
    
def fsdp_model(model, fsdp_mesh):
    from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
    fsdp_kwargs = {}
    mixed_precision = False #True
    if mixed_precision:
        fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.float32,
        )
    fsdp_kwargs['mesh'] = fsdp_mesh

    fully_shard(model, **fsdp_kwargs)


def run_test():
    torch.distributed.init_process_group(backend="gloo")
    device_mesh = torch.distributed.device_mesh.DeviceMesh('cpu', torch.arange(dist.get_world_size()))
    torch.manual_seed(42)
    model = SimpleModel()
    
    fsdp_model(model, device_mesh)

    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    

    model.train()
    for epoch in range(2):

        inputs = torch.randn(32, 10)
        targets = torch.randn(32, 10)
        
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
def main():
    run_test()

if __name__ == "__main__":
    main()    

# torchrun --nproc-per-node 2 test_fsdp.py

expect to see

[rank0]:   File "/home/anaconda3/envs/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1806, in inner
[rank0]:     args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
[rank0]:                          ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/anaconda3/envs/env/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 62, in fsdp_hook_wrapper
[rank0]:     return torch._dynamo.disable(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/anaconda3/envs/env/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 896, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/anaconda3/envs/env/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 237, in _pre_forward
[rank0]:     args, kwargs = self._root_pre_forward(module, args, kwargs)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/anaconda3/envs/env/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 123, in _root_pre_forward
[rank0]:     self._lazy_init()
[rank0]:   File "/home/anaconda3/envs/env/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 193, in _lazy_init
[rank0]:     state._fsdp_param_group.lazy_init()
[rank0]:   File "/home/anaconda3/envs/env/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 259, in lazy_init
[rank0]:     self._init_mp_dtypes()
[rank0]:   File "/home/anaconda3/envs/env/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 228, in _init_mp_dtypes
[rank0]:     raise AssertionError(
[rank0]: AssertionError: FSDP expects uniform original parameter dtype but got {torch.float32, torch.bfloat16}

Implements

I have tried a method recently, which splits the allgather in a module into multiple communications corresponding to different dtypes. The current test has been successfully run. The specific implementation is to replace FSDPParamGroup with a writing style that supports multiple dtypes, and it can be achieved by adding only 100 lines of code.

Alternatives

No response

Additional context

No response

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @zhaojuanmao @mrshenli @rohan-varma @chauhang @mori360

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions