Skip to content

Commit

Permalink
add hsdp sync_module_state unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
lxg2015 committed Sep 1, 2023
1 parent d96446b commit e928cc4
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions test/distributed/fsdp/test_fsdp_hybrid_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,48 @@ def test_hsdp_save_load_state_dict(self):
FSDP.optim_state_dict_to_load(load_model, load_optim, osd)
load_optim.load_state_dict(osd)

@skip_if_lt_x_gpu(4)
def test_hsdp_sync_module_state(self):
model = MyModel().cuda()
num_node_devices = torch.cuda.device_count()
shard_rank_lists = list(range(0, num_node_devices // 2)), list(
range(num_node_devices // 2, num_node_devices)
)
shard_groups = (
dist.new_group(shard_rank_lists[0]),
dist.new_group(shard_rank_lists[1]),
)
my_shard_group = (
shard_groups[0] if self.rank in shard_rank_lists[0] else shard_groups[1]
)
my_replicate_group = None
my_rank = self.rank
# Create groups like (0, 4), (1, 5), (2, 6) etc and assign appropriately
shard_factor = len(shard_rank_lists[0])
for i in range(num_node_devices // 2):
replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
replicate_group = dist.new_group(replicate_group_ranks)
if my_rank in replicate_group_ranks:
my_replicate_group = replicate_group

nn.init.constant_(model.lin1.weight, self.rank)
nn.init.constant_(model.lin2.weight, self.rank)
nn.init.constant_(model.lin3.weight, self.rank)

fsdp_ctor = partial(
FSDP,
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
use_orig_params=True,
sync_module_states=True,
process_group=(my_shard_group, my_replicate_group),
)
model = fsdp_ctor(model)

with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
assert (model.lin1.weight == 0).all()
assert (model.lin2.weight == 0).all()
assert (model.lin3.weight == 0).all()

@skip_if_lt_x_gpu(2)
def test_invalid_pg_specification_raises(self):
pol = ModuleWrapPolicy({nn.Linear})
Expand Down

0 comments on commit e928cc4

Please sign in to comment.