Skip to content

Commit

Permalink
Not flatten states when use_orig_param is True and shaarding is NO_SHARD
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
zhaojuanmao committed Apr 27, 2023
1 parent 2eab5ab commit 90a560f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
27 changes: 27 additions & 0 deletions test/distributed/fsdp/test_fsdp_optim_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._shard_utils import _gather_state_dict
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
FullStateDictConfig,
Expand Down Expand Up @@ -1882,6 +1883,32 @@ def step():

self.run_subtests({"use_orig_params": [False, True]}, _run_test)

@skip_if_lt_x_gpu(2)
def test_use_orig_param_with_no_shard(self):
model = FSDP(
TestDummyModel().cuda(),
sharding_strategy=ShardingStrategy.NO_SHARD,
use_orig_params=True,
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)

def step():
loss = model(model.get_input())
loss.backward(loss)
optim.step()

step()

original_osd = deepcopy(optim.state_dict())

osd = FSDP.optim_state_dict(model, optim)
osd_to_load = FSDP.optim_state_dict_to_load(model, optim, osd)
optim.load_state_dict(osd_to_load)

new_osd = optim.state_dict()

self.assertEqual(original_osd, new_osd)


instantiate_parametrized_tests(TestFSDPOptimState)

Expand Down
6 changes: 5 additions & 1 deletion torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,11 @@ def _shard_orig_param_state(
intra_param_start_idx = shard_param_info.intra_param_start_idx
intra_param_end_idx = shard_param_info.intra_param_end_idx
for state_name, value in optim_state.items():
if torch.is_tensor(value) and value.dim() > 0:
if (
torch.is_tensor(value)
and value.dim() > 0
and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
):
value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1]
new_optim_state[state_name] = value
return new_optim_state
Expand Down

0 comments on commit 90a560f

Please sign in to comment.