Skip to content

Commit

Permalink
[FSDP] Fix for optim state dict
Browse files Browse the repository at this point in the history
Pull Request resolved: #102901

Fix for HSDP + use_orig_params where we need to pass in the PG that
might not be the default.
ghstack-source-id: 191337799

Differential Revision: [D46417327](https://our.internmc.facebook.com/intern/diff/D46417327/)
  • Loading branch information
rohan-varma committed Jun 6, 2023
1 parent 6b8e68c commit e3143d5
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
59 changes: 58 additions & 1 deletion test/distributed/fsdp/test_fsdp_hybrid_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
import torch.nn as nn

from torch.distributed.distributed_c10d import _rank_not_in_group
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
Expand Down Expand Up @@ -76,6 +80,9 @@ def __init__(self):
self.lin2 = nn.Linear(10, 10)
self.lin3 = nn.Linear(10, 10)

def forward(self, x):
return self.lin3(self.lin2(self.lin1(x)))


class ShardingStrategyMode(Enum):
ALL_HYBRID_SHARD = auto()
Expand Down Expand Up @@ -144,6 +151,56 @@ def test_hybrid_shard_pg_mismatch_raises(self):
):
model(inp)

@skip_if_lt_x_gpu(4)
def test_hsdp_save_load_state_dict(self):
model = MyModel().cuda()
num_node_devices = torch.cuda.device_count()
shard_rank_lists = list(range(0, num_node_devices // 2)), list(
range(num_node_devices // 2, num_node_devices)
)
shard_groups = (
dist.new_group(shard_rank_lists[0]),
dist.new_group(shard_rank_lists[1]),
)
my_shard_group = (
shard_groups[0] if self.rank in shard_rank_lists[0] else shard_groups[1]
)
my_replicate_group = None
my_rank = self.rank
# Create groups like (0, 4), (1, 5), (2, 6) etc and assign appropriately
shard_factor = len(shard_rank_lists[0])
for i in range(num_node_devices // 2):
replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
replicate_group = dist.new_group(replicate_group_ranks)
if my_rank in replicate_group_ranks:
my_replicate_group = replicate_group

fsdp_ctor = partial(
FSDP,
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
use_orig_params=True,
process_group=(my_shard_group, my_replicate_group),
)
model = fsdp_ctor(model)
optim = torch.optim.AdamW(model.parameters())
# Initialize optimizer states
model(torch.randn(2, 10)).sum().backward()
optim.step()
shard_g = model.process_group
replicate_g = model._inter_node_state.process_group
assert shard_g == my_shard_group
assert replicate_g == my_replicate_group
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
msd = model.state_dict()
osd = FSDP.optim_state_dict(model, optim)

load_model = fsdp_ctor(MyModel().cuda())
load_optim = torch.optim.AdamW(load_model.parameters())
with FSDP.state_dict_type(load_model, StateDictType.SHARDED_STATE_DICT):
load_model.load_state_dict(msd)
FSDP.optim_state_dict_to_load(load_model, load_optim, osd)
load_optim.load_state_dict(osd)

@skip_if_lt_x_gpu(2)
def test_invalid_pg_specification_raises(self):
pol = ModuleWrapPolicy({nn.Linear})
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,7 +1497,7 @@ def _all_gather_optim_state(
object_list: List[StateInfo] = [
processed_state for _ in range(fsdp_state.world_size)
]
dist.all_gather_object(object_list, processed_state)
dist.all_gather_object(object_list, processed_state, group=fsdp_state.process_group)

# Convert the gathered, pre-processed state of each rank to the original one.
gathered_state: Dict[str, Any] = {}
Expand Down

0 comments on commit e3143d5

Please sign in to comment.