Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2] allow meta tensors during loading state dict and cpu offloading #126267

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions test/distributed/_composable/fsdp/test_fully_shard_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import torch
import torch.nn as nn
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._tensor import DTensor
from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
from torch.distributed._tensor import distribute_tensor, DTensor
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
Expand Down Expand Up @@ -68,6 +68,42 @@ def _test_1d_state_dict_save_load(self, mlp_dim: int):
for key, value in ref_sharded_sd.items():
self.assertEqual(value, sharded_sd[key])

@skip_if_lt_x_gpu(2)
def test_1d_state_dict_cpu_offload(self):
mlp_dim = 4
offload_policy = CPUOffloadPolicy(pin_memory=True)
torch.manual_seed(42)
with torch.device("meta"):
model = nn.Sequential(
nn.Linear(mlp_dim, mlp_dim, bias=False),
nn.Linear(mlp_dim, mlp_dim, bias=False),
)
for module in model:
fully_shard(module, offload_policy=offload_policy)
fully_shard(model, offload_policy=offload_policy)

# split full sd into multiple pieces
# to test loading with `strict=False`
state_dicts = []
for name, dtensor in model.named_parameters():
full_tensor = torch.randn(dtensor.size())
sharded_tensor = distribute_tensor(
full_tensor, dtensor.device_mesh, dtensor.placements
)
state_dicts.append({name: sharded_tensor})

# check that we can load with some parameters still on meta device
for sd in state_dicts:
model.load_state_dict(sd, assign=True, strict=False)

# lazy init without error
inp = torch.rand((mlp_dim, mlp_dim), device="cuda")
model(inp)

state_dict = model.state_dict()
for name, dtensor in state_dict.items():
self.assertEqual(dtensor.device.type, "cpu")

@skip_if_lt_x_gpu(2)
def test_2d_state_dict_save_load(self):
dp_size = 2
Expand Down
4 changes: 3 additions & 1 deletion torch/distributed/_composable/fsdp/_fsdp_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _init_sharded_param(self, param: nn.Parameter, device: torch.device):
self.padded_sharded_param_size = padded_sharded_param.size()
if sharded_param.numel() > 0:
padded_sharded_param[: sharded_param.size(0)].copy_(sharded_param)
if self.offload_to_cpu:
if self.offload_to_cpu and not torch.empty(0).is_meta:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.offload_to_cpu and not torch.empty(0).is_meta:
if self.offload_to_cpu and not padded_sharded_param.is_meta:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolving in #126305

padded_sharded_param = padded_sharded_param.cpu()
if self.pin_memory:
padded_sharded_param = padded_sharded_param.pin_memory()
Expand Down Expand Up @@ -584,6 +584,8 @@ def reset_sharded_param(self):
)
self.sharded_param = new_param
local_tensor = new_param._local_tensor
if local_tensor.is_meta:
return
padded_sharded_size = self.padded_sharded_param_size
if local_tensor.size() != padded_sharded_size:
padded_local_tensor = local_tensor.new_zeros(padded_sharded_size)
Expand Down
Loading