Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HSDP] add sync_module_state unit test #108392

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions test/distributed/fsdp/test_fsdp_hybrid_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,48 @@ def test_hsdp_save_load_state_dict(self):
FSDP.optim_state_dict_to_load(load_model, load_optim, osd)
load_optim.load_state_dict(osd)

@skip_if_lt_x_gpu(4)
def test_hsdp_sync_module_state(self):
model = MyModel().cuda()
num_node_devices = torch.cuda.device_count()
shard_rank_lists = list(range(0, num_node_devices // 2)), list(
range(num_node_devices // 2, num_node_devices)
)
shard_groups = (
dist.new_group(shard_rank_lists[0]),
dist.new_group(shard_rank_lists[1]),
)
my_shard_group = (
shard_groups[0] if self.rank in shard_rank_lists[0] else shard_groups[1]
)
my_replicate_group = None
my_rank = self.rank
# Create groups like (0, 4), (1, 5), (2, 6) etc and assign appropriately
shard_factor = len(shard_rank_lists[0])
for i in range(num_node_devices // 2):
replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
replicate_group = dist.new_group(replicate_group_ranks)
if my_rank in replicate_group_ranks:
my_replicate_group = replicate_group

nn.init.constant_(model.lin1.weight, self.rank)
nn.init.constant_(model.lin2.weight, self.rank)
nn.init.constant_(model.lin3.weight, self.rank)

fsdp_ctor = partial(
FSDP,
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
use_orig_params=True,
sync_module_states=True,
process_group=(my_shard_group, my_replicate_group),
)
model = fsdp_ctor(model)

with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
assert (model.lin1.weight == 0).all()
assert (model.lin2.weight == 0).all()
assert (model.lin3.weight == 0).all()

@skip_if_lt_x_gpu(2)
def test_invalid_pg_specification_raises(self):
pol = ModuleWrapPolicy({nn.Linear})
Expand Down
Loading