Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][optim_state_dict] Make the new optimizer allgather fusion work with fine-tuning models #110540

Closed
wants to merge 7 commits into from
36 changes: 35 additions & 1 deletion test/distributed/fsdp/test_fsdp_optim_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,12 @@ def param_group1(self) -> List[torch.nn.Parameter]:
# Simple and boring model to test interface and some corner cases that do not
# require complicated wrapping strategy.
class TestDummyModel(torch.nn.Module):
def __init__(self):
def __init__(self, no_grad: bool = False):
super().__init__()
torch.manual_seed(0)
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
self.net1[0].weight.requires_grad = not no_grad
self.net1[0].bias.requires_grad = not no_grad
self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
self.net3 = nn.Linear(32, 64)
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
Expand Down Expand Up @@ -1922,6 +1924,38 @@ def step():

self.run_subtests({"use_orig_params": [False, True]}, _run_test)

@skip_if_lt_x_gpu(2)
def test_no_grad(self):
model = TestDummyModel(no_grad=True).cuda()
fsdp_model = FSDP(deepcopy(model), use_orig_params=True)
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)

for i in range(5):
if i % 2 == 1:
fsdp_model.net1[0].weight.requires_grad = True
fsdp_model.net1[0].bias.requires_grad = True
else:
fsdp_model.net1[0].weight.requires_grad = False
fsdp_model.net1[0].bias.requires_grad = False
batch = fsdp_model.get_input()
loss = fsdp_model(batch).sum()
loss.backward()
fsdp_optim.step()
orig_state_dict = deepcopy(fsdp_optim.state_dict())
optim_state_dict = FSDP.optim_state_dict(fsdp_model, fsdp_optim)
FSDP.optim_state_dict_to_load(
fsdp_model,
fsdp_optim,
FSDP.optim_state_dict(fsdp_model, fsdp_optim),
load_directly=True,
)

self._check_same_state(
fsdp_optim.state_dict(),
orig_state_dict,
check_same_param_keys=True,
)


instantiate_parametrized_tests(TestFSDPOptimState)

Expand Down
104 changes: 79 additions & 25 deletions torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class FSDPParamInfo:
state: _FSDPState
handle: FlatParamHandle
param_indices: Dict[str, int]
param_requires_grad: List[bool]


def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
Expand Down Expand Up @@ -1346,8 +1347,8 @@ def _convert_all_state_info(
{n for state in state_info for n in state.tensors.keys()}
)
empty_ranks: Set[int] = set()
# First check all the non-scalar states and get the exist status of the
# the states on each rank.
# First check all the non-scalar states and get the information of
# states on each rank.
for state_name in all_tensor_states:
numels = []
dtype: Optional[torch.dtype] = None
Expand All @@ -1367,13 +1368,15 @@ def _convert_all_state_info(
assert not empty_ranks or empty_ranks == _empty_ranks
empty_ranks = _empty_ranks
if state_name not in state_buffers:
state_buffers[state_name] = [None for _ in input_states]
state_buffers[state_name] = [
None for _ in fsdp_param_info.param_indices
]
local_state = input_states[fqn].get(state_name, None)
state_buffers[state_name][fsdp_param_info.param_indices[fqn]] = local_state

# Restoring the scalar and non-tensor states. If the corresponding non-scalar
# states do not exist on the rank, we also skip the scalar and non-tensor
# states on that rank
# Restoring the scalar and non-tensor states. If the corresponding
# non-scalar states do not exist on the rank, we also skip the scalar
# non-tensor states on that rank.
for rank, object_state in enumerate(state_info):
if rank in empty_ranks:
continue
Expand Down Expand Up @@ -1414,13 +1417,11 @@ def _unflatten_orig_param_states(
return
flat_param = fsdp_param_info.handle.flat_param
fsdp_state = fsdp_param_info.state
numel = 0
for fqn, gathered_state in output_states.items():
value = gathered_state[state_name]

param_idx = fsdp_param_info.param_indices[fqn]
value = value.reshape(flat_param._shapes[param_idx])
numel += value.numel()
if shard_state:
osd_config = fsdp_state._optim_state_dict_config
if getattr(osd_config, "_use_dtensor", False):
Expand Down Expand Up @@ -1466,6 +1467,11 @@ def _allgather_orig_param_states(
fsdp_param_info, gathered_state_info, input_states, output_states
)

has_state_params: List[bool] = [
True if fqn in output_states else False
for fqn, idx in fsdp_param_info.param_indices.items()
]

# Loop through the ``state_buffers`` and construct the flattened, concatenated,
# sharded states. The size of the constructed state will be the same size as
# flat_param (also sharded).
Expand All @@ -1486,13 +1492,20 @@ def _allgather_orig_param_states(
begin = fsdp_state.rank * flat_param._sharded_size.numel()
# End is inclusive.
end = begin + flat_param._sharded_size.numel() - 1
# buffer_idx corresponds to the parameter index in the FlatParameter.
mem_offset, buffer_idx = 0, 0
# param_idx corresponds to the parameter index in the FlatParameter.
mem_offset, param_idx = 0, 0
for numel, is_padding in zip(
flat_param._numels_with_padding, flat_param._is_padding_mask
):
if is_padding:
# This memory range is a padding.
frozen_and_no_state = not is_padding and (
not fsdp_param_info.param_requires_grad[param_idx]
and not has_state_params[param_idx]
)

if is_padding or frozen_and_no_state:
# This memory range is a padding or the param is frozen and does
# not require gradient. For the later case, we treat it as a
# padding and add empty values to the local_buffers.

padding_begin, padding_end = mem_offset, mem_offset + numel - 1
if padding_begin <= begin <= padding_end:
Expand All @@ -1519,15 +1532,24 @@ def _allgather_orig_param_states(
padding_len = 0
if padding_len:
local_buffers.append(empty_func(padding_len))
else:
# This memory range is a parameter in FlatParameter. As for the
# optimizer state_dict, this memory range is a state.

# We need to check if this rank owns the buffer. If this is None,
# the rank does not have any part of the corresponding parameter.
if buffers[buffer_idx] is not None:
local_buffers.append(cast(torch.Tensor, buffers[buffer_idx]))
buffer_idx += 1

if not is_padding:
# This memory range is a parameter in FlatParameter. So there
# should be an corresponding state in the optimizer unless the
# parameter is frozen, which we treat it as a padding above.

# We need to check if this rank owns the buffer. If this is None:
# 1.) the rank does not own any part of the original parameter.
# As a result, there is no corresponding optimizer state on
# the rank as well.
# 2.) the parameter is frozen AND no optimizer state for the
# parameter. If a parameter is frozen, there can still be
# optimizer state if the parameter is not frozen in the
# previous steps.
if buffers[param_idx] is not None:
local_buffers.append(cast(torch.Tensor, buffers[param_idx]))
param_idx += 1

mem_offset += numel

shard_numel_padded = flat_param._sharded_size.numel() - (
Expand Down Expand Up @@ -1569,7 +1591,8 @@ def _allgather_orig_param_states(
"logic."
)
for fqn, idx in fsdp_param_info.param_indices.items():
output_states[fqn][state_name] = orig_states[idx]
if fsdp_param_info.param_requires_grad[idx] or fqn in output_states:
output_states[fqn][state_name] = orig_states[idx]

_unflatten_orig_param_states(
fsdp_param_info,
Expand Down Expand Up @@ -1609,7 +1632,19 @@ def _gather_all_orig_param_state(
fsdp_param_info, gathered_state_info, input_states, shard_state, to_save
)
if to_save:
assert set(output_states.keys()) == set(fsdp_param_info.param_indices.keys())
for key, idx in fsdp_param_info.param_indices.items():
if key in output_states:
continue
if not fsdp_param_info.param_requires_grad[idx]:
continue

raise RuntimeError(
f"{key} is not in the output state. "
"The FSDPParamInfo has the param keys "
f"{sorted(fsdp_param_info.param_indices.keys())} while "
"the output_states has the param keys "
f"{sorted(output_states.keys())}."
)
return output_states
else:
return {}
Expand Down Expand Up @@ -1666,7 +1701,22 @@ def _convert_state_with_orig_params(
for _all_states in all_states.values():
fqn = next(iter(_all_states.keys()))
fsdp_param_info = fqn_to_fsdp_param_info[fqn]
assert set(fsdp_param_info.param_indices.keys()) == set(_all_states.keys())
assert len(fsdp_param_info.param_requires_grad) > 0, (
"With use_orig_params, FSDPParamInfo should have requires_grad "
"information. However, the length is zero."
)
for key, idx in fsdp_param_info.param_indices.items():
if key in _all_states:
continue
if not fsdp_param_info.param_requires_grad[idx]:
continue
raise RuntimeError(
f"{key} is not in the optimizer state. "
"The FSDPParamInfo has the param keys "
f"{sorted(fsdp_param_info.param_indices.keys())} while "
"the optimizer has the param keys "
f"{sorted(_all_states.keys())}."
)
fsdp_osd_state.update(
_gather_all_orig_param_state(
fsdp_param_info,
Expand Down Expand Up @@ -1904,7 +1954,7 @@ def module_fn(module, prefix, tree_level, fqn_to_param_info):
if not handle:
return
flat_param = handle.flat_param
fsdp_param_info = FSDPParamInfo(fsdp_state, handle, {})
fsdp_param_info = FSDPParamInfo(fsdp_state, handle, {}, [])
# NOTE: `idx` indexes into the data structures *without* padding
# elements
for idx, local_fqn in enumerate(flat_param._fqns):
Expand All @@ -1913,6 +1963,10 @@ def module_fn(module, prefix, tree_level, fqn_to_param_info):
assert fqn_to_param_info[fqn].handle.flat_param is flat_param, fqn
fqn_to_param_info[fqn] = fsdp_param_info
fsdp_param_info.param_indices[fqn] = idx
if flat_param._params is not None:
fsdp_param_info.param_requires_grad.append(
flat_param._params[idx].requires_grad
)

def return_fn(fqn_to_param_info):
return fqn_to_param_info
Expand Down