Skip to content

Commit

Permalink
[FSDP][optim_state_dict] Make the new optimizer allgather fusion work…
Browse files Browse the repository at this point in the history
… with fine-tuning models

Pull Request resolved: #110540

With use_orig_params=True, it is possible that some parameters with the same FlatParameter are in the optimizer while others parameters are frozen. This PR makes the allgather fusion logic support the case.
ghstack-source-id: 202944435
@exported-using-ghexport

Differential Revision: [D49922028](https://our.internmc.facebook.com/intern/diff/D49922028/)
  • Loading branch information
fegin committed Oct 4, 2023
1 parent 4069d1d commit c8836ee
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 27 deletions.
40 changes: 38 additions & 2 deletions test/distributed/fsdp/test_fsdp_optim_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,12 @@ def param_group1(self) -> List[torch.nn.Parameter]:
# Simple and boring model to test interface and some corner cases that do not
# require complicated wrapping strategy.
class TestDummyModel(torch.nn.Module):
def __init__(self):
def __init__(self, no_grad: bool = False):
super().__init__()
torch.manual_seed(0)
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
self.net1[0].weight.requires_grad = not no_grad
self.net1[0].bias.requires_grad = not no_grad
self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
self.net3 = nn.Linear(32, 64)
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
Expand Down Expand Up @@ -1525,7 +1527,9 @@ def _run_on_all_optim_state_apis(

@skip_if_lt_x_gpu(2)
@parametrize("state_dict_type", STATE_DICT_TYPES)
def test_save_load_without_0th_param_state(self, state_dict_type: StateDictType):
def test_save_load_without_0th_param_state(
self, state_dict_type: StateDictType
):
"""
Tests saving and loading an optim state dict for Adam optimizer (i.e.
any optimizer with a "step" key in its state) when the first parameter
Expand Down Expand Up @@ -1922,6 +1926,38 @@ def step():

self.run_subtests({"use_orig_params": [False, True]}, _run_test)

@skip_if_lt_x_gpu(2)
def test_no_grad(self):
model = TestDummyModel(no_grad=True).cuda()
fsdp_model = FSDP(deepcopy(model), use_orig_params=True)
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)

for i in range(5):
if i % 2 == 1:
fsdp_model.net1[0].weight.requires_grad = True
fsdp_model.net1[0].bias.requires_grad = True
else:
fsdp_model.net1[0].weight.requires_grad = False
fsdp_model.net1[0].bias.requires_grad = False
batch = fsdp_model.get_input()
loss = fsdp_model(batch).sum()
loss.backward()
fsdp_optim.step()
orig_state_dict = deepcopy(fsdp_optim.state_dict())
optim_state_dict = FSDP.optim_state_dict(fsdp_model, fsdp_optim)
FSDP.optim_state_dict_to_load(
fsdp_model,
fsdp_optim,
FSDP.optim_state_dict(fsdp_model, fsdp_optim),
load_directly=True,
)

self._check_same_state(
fsdp_optim.state_dict(),
orig_state_dict,
check_same_param_keys=True,
)


instantiate_parametrized_tests(TestFSDPOptimState)

Expand Down
104 changes: 79 additions & 25 deletions torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class FSDPParamInfo:
state: _FSDPState
handle: FlatParamHandle
param_indices: Dict[str, int]
param_requires_grad: List[bool]


def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
Expand Down Expand Up @@ -1346,8 +1347,8 @@ def _convert_all_state_info(
{n for state in state_info for n in state.tensors.keys()}
)
empty_ranks: Set[int] = set()
# First check all the non-scalar states and get the exist status of the
# the states on each rank.
# First check all the non-scalar states and get the information of
# states on each rank.
for state_name in all_tensor_states:
numels = []
dtype: Optional[torch.dtype] = None
Expand All @@ -1367,13 +1368,15 @@ def _convert_all_state_info(
assert not empty_ranks or empty_ranks == _empty_ranks
empty_ranks = _empty_ranks
if state_name not in state_buffers:
state_buffers[state_name] = [None for _ in input_states]
state_buffers[state_name] = [
None for _ in fsdp_param_info.param_indices
]
local_state = input_states[fqn].get(state_name, None)
state_buffers[state_name][fsdp_param_info.param_indices[fqn]] = local_state

# Restoring the scalar and non-tensor states. If the corresponding non-scalar
# states do not exist on the rank, we also skip the scalar and non-tensor
# states on that rank
# Restoring the scalar and non-tensor states. If the corresponding
# non-scalar states do not exist on the rank, we also skip the scalar
# non-tensor states on that rank.
for rank, object_state in enumerate(state_info):
if rank in empty_ranks:
continue
Expand Down Expand Up @@ -1414,13 +1417,11 @@ def _unflatten_orig_param_states(
return
flat_param = fsdp_param_info.handle.flat_param
fsdp_state = fsdp_param_info.state
numel = 0
for fqn, gathered_state in output_states.items():
value = gathered_state[state_name]

param_idx = fsdp_param_info.param_indices[fqn]
value = value.reshape(flat_param._shapes[param_idx])
numel += value.numel()
if shard_state:
osd_config = fsdp_state._optim_state_dict_config
if getattr(osd_config, "_use_dtensor", False):
Expand Down Expand Up @@ -1466,6 +1467,11 @@ def _allgather_orig_param_states(
fsdp_param_info, gathered_state_info, input_states, output_states
)

has_state_params: List[bool] = [
True if fqn in output_states else False
for fqn, idx in fsdp_param_info.param_indices.items()
]

# Loop through the ``state_buffers`` and construct the flattened, concatenated,
# sharded states. The size of the constructed state will be the same size as
# flat_param (also sharded).
Expand All @@ -1486,13 +1492,20 @@ def _allgather_orig_param_states(
begin = fsdp_state.rank * flat_param._sharded_size.numel()
# End is inclusive.
end = begin + flat_param._sharded_size.numel() - 1
# buffer_idx corresponds to the parameter index in the FlatParameter.
mem_offset, buffer_idx = 0, 0
# param_idx corresponds to the parameter index in the FlatParameter.
mem_offset, param_idx = 0, 0
for numel, is_padding in zip(
flat_param._numels_with_padding, flat_param._is_padding_mask
):
if is_padding:
# This memory range is a padding.
frozen_and_no_state = not is_padding and (
not fsdp_param_info.param_requires_grad[param_idx]
and not has_state_params[param_idx]
)

if is_padding or frozen_and_no_state:
# This memory range is a padding or the param is frozen and does
# not require gradient. For the later case, we treat it as a
# padding and add empty values to the local_buffers.

padding_begin, padding_end = mem_offset, mem_offset + numel - 1
if padding_begin <= begin <= padding_end:
Expand All @@ -1519,15 +1532,24 @@ def _allgather_orig_param_states(
padding_len = 0
if padding_len:
local_buffers.append(empty_func(padding_len))
else:
# This memory range is a parameter in FlatParameter. As for the
# optimizer state_dict, this memory range is a state.

# We need to check if this rank owns the buffer. If this is None,
# the rank does not have any part of the corresponding parameter.
if buffers[buffer_idx] is not None:
local_buffers.append(cast(torch.Tensor, buffers[buffer_idx]))
buffer_idx += 1

if not is_padding:
# This memory range is a parameter in FlatParameter. So there
# should be an corresponding state in the optimizer unless the
# parameter is frozen, which we treat it as a padding above.

# We need to check if this rank owns the buffer. If this is None:
# 1.) the rank does not own any part of the original parameter.
# As a result, there is no corresponding optimizer state on
# the rank as well.
# 2.) the parameter is frozen AND no optimizer state for the
# parameter. If a parameter is frozen, there can still be
# optimizer state if the parameter is not frozen in the
# previous steps.
if buffers[param_idx] is not None:
local_buffers.append(cast(torch.Tensor, buffers[param_idx]))
param_idx += 1

mem_offset += numel

shard_numel_padded = flat_param._sharded_size.numel() - (
Expand Down Expand Up @@ -1569,7 +1591,8 @@ def _allgather_orig_param_states(
"logic."
)
for fqn, idx in fsdp_param_info.param_indices.items():
output_states[fqn][state_name] = orig_states[idx]
if fsdp_param_info.param_requires_grad[idx] or fqn in output_states:
output_states[fqn][state_name] = orig_states[idx]

_unflatten_orig_param_states(
fsdp_param_info,
Expand Down Expand Up @@ -1609,7 +1632,19 @@ def _gather_all_orig_param_state(
fsdp_param_info, gathered_state_info, input_states, shard_state, to_save
)
if to_save:
assert set(output_states.keys()) == set(fsdp_param_info.param_indices.keys())
for key, idx in fsdp_param_info.param_indices.items():
if key in output_states:
continue
if not fsdp_param_info.param_requires_grad[idx]:
continue

raise RuntimeError(
f"{key} is not in the output state. "
"The FSDPParamInfo has the param keys "
f"{sorted(fsdp_param_info.param_indices.keys())} while "
"the output_states has the param keys "
f"{sorted(output_states.keys())}."
)
return output_states
else:
return {}
Expand Down Expand Up @@ -1666,7 +1701,22 @@ def _convert_state_with_orig_params(
for _all_states in all_states.values():
fqn = next(iter(_all_states.keys()))
fsdp_param_info = fqn_to_fsdp_param_info[fqn]
assert set(fsdp_param_info.param_indices.keys()) == set(_all_states.keys())
assert len(fsdp_param_info.param_requires_grad) > 0, (
"With use_orig_params, FSDPParamInfo should have requires_grad "
"information. However, the length is zero."
)
for key, idx in fsdp_param_info.param_indices.items():
if key in _all_states:
continue
if not fsdp_param_info.param_requires_grad[idx]:
continue
raise RuntimeError(
f"{key} is not in the optimizer state. "
"The FSDPParamInfo has the param keys "
f"{sorted(fsdp_param_info.param_indices.keys())} while "
"the optimizer has the param keys "
f"{sorted(_all_states.keys())}."
)
fsdp_osd_state.update(
_gather_all_orig_param_state(
fsdp_param_info,
Expand Down Expand Up @@ -1904,7 +1954,7 @@ def module_fn(module, prefix, tree_level, fqn_to_param_info):
if not handle:
return
flat_param = handle.flat_param
fsdp_param_info = FSDPParamInfo(fsdp_state, handle, {})
fsdp_param_info = FSDPParamInfo(fsdp_state, handle, {}, [])
# NOTE: `idx` indexes into the data structures *without* padding
# elements
for idx, local_fqn in enumerate(flat_param._fqns):
Expand All @@ -1913,6 +1963,10 @@ def module_fn(module, prefix, tree_level, fqn_to_param_info):
assert fqn_to_param_info[fqn].handle.flat_param is flat_param, fqn
fqn_to_param_info[fqn] = fsdp_param_info
fsdp_param_info.param_indices[fqn] = idx
if flat_param._params is not None:
fsdp_param_info.param_requires_grad.append(
flat_param._params[idx].requires_grad
)

def return_fn(fqn_to_param_info):
return fqn_to_param_info
Expand Down

0 comments on commit c8836ee

Please sign in to comment.