diff --git a/test/distributed/_composable/test_fully_shard.py b/test/distributed/_composable/test_fully_shard.py index 4c61552fdefd..afb4d83c5030 100644 --- a/test/distributed/_composable/test_fully_shard.py +++ b/test/distributed/_composable/test_fully_shard.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: distributed"] -import unittest import contextlib import copy import functools @@ -743,7 +742,6 @@ def _test_optim_state_save_load(self, model1, optim1, model2, optim2) -> None: for key, value in group1.items(): self.assertEqual(value, group2[key]) - @unittest.skip("The test currently fails on CI.") @skip_if_lt_x_gpu(2) def test_optim_state_dict_save_load(self): orig_model = CompositeParamModel(device=torch.device("cuda")) @@ -755,7 +753,6 @@ def test_optim_state_dict_save_load(self): self._test_optim_state_save_load(orig_model, orig_optim, composable_model, composable_optim) - @unittest.skip("The test currently fails on CI.") @skip_if_lt_x_gpu(2) def test_optim_state_dict_submodule_fully_shard(self): orig_model = CompositeParamModel(device=torch.device("cuda")) diff --git a/test/distributed/fsdp/test_fsdp_optim_state.py b/test/distributed/fsdp/test_fsdp_optim_state.py index ff4f84136791..249f1ff35048 100644 --- a/test/distributed/fsdp/test_fsdp_optim_state.py +++ b/test/distributed/fsdp/test_fsdp_optim_state.py @@ -2,7 +2,6 @@ import bisect import sys -import unittest from enum import auto, Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Type @@ -783,7 +782,6 @@ def test_flatten_sharded_optim_state_dict_transformer(self) -> None: num_iters=3, ) - @unittest.skip("The test currently fails on CI.") @skip_if_lt_x_gpu(2) def test_use_orig_params(self) -> None: """Tests :meth:`optim_state_dict` for an FSDP-root nested model.""" @@ -1442,7 +1440,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: loss.backward() optim.step() - @unittest.skip("The test currently fails on CI.") @skip_if_lt_x_gpu(2) def test_compatible_with_named_optimizer(self): class TestDummyModel(torch.nn.Module): diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index d01cafdc9316..a78d3e180951 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -19,6 +19,7 @@ import torch.distributed as dist import torch.distributed.fsdp._traversal_utils as traversal_utils import torch.nn as nn +import torch.nn.functional as F from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed.fsdp._common_utils import ( _apply_to_modules, @@ -1436,6 +1437,113 @@ def return_fn(fqn_to_param_info): ) +@dataclass +class StateInfo: + tensors: Dict[str, _PosDimTensorInfo] + scalar_tensors: Dict[str, torch.Tensor] + non_tensors: Dict[str, Any] + + +@dataclass +class AllGatherInfo: + tensors: List[torch.Tensor] + numels: List[int] + work: Optional[dist.Work] + + +def _all_gather_optim_state( + fsdp_state: _FSDPState, optim_state: Dict[str, Any], param_numel: int +) -> Dict[str, Any]: + """ + All-gathering state from all the ranks. This API is slow as it uses + ``all_gather_object``. However, optim state_dict is not in the critical path. + We can fuse the communication across differnt state if the performance + becomes a problem. + """ + # Allgather the scalar tensor state, non-tensor states and tensors metadata. + processed_state = StateInfo({}, {}, {}) + for state_name, value in sorted_items(optim_state): + if torch.is_tensor(value): + if value.dim() == 0: + processed_state.scalar_tensors[state_name] = value + else: + processed_state.tensors[state_name] = _PosDimTensorInfo( + value.shape, value.dtype + ) + else: + processed_state.non_tensors = value + object_list: List[StateInfo] = [ + processed_state for _ in range(fsdp_state.world_size) + ] + dist.all_gather_object(object_list, processed_state) + + # Convert the gathered, pre-proccessed state of each rank to the original one. + gathered_state: Dict[str, Any] = {} + + all_tensor_states = sorted( + list(set([n for state in object_list for n in state.tensors.keys()])) + ) + for name in all_tensor_states: + numels = [] + dtype = torch.float + max_numel = 0 + for object_state in object_list: + numels.append(0) + info = object_state.tensors.get(name, None) + if info is not None: + numels[-1] = info.shape.numel() + dtype = info.dtype + max_numel = max(max_numel, numels[-1]) + local_state = ( + optim_state[name] + if name in optim_state + else torch.empty(max_numel, dtype=dtype, device=fsdp_state.compute_device) + ) + if max_numel > local_state.numel(): + local_state = F.pad(local_state, [0, max_numel - local_state.numel()]) + tensors = [ + torch.empty(max_numel, dtype=dtype, device=fsdp_state.compute_device) + if rank != fsdp_state.rank + else local_state + for rank in range(len(object_list)) + ] + work = dist.all_gather( + tensors, local_state, group=fsdp_state.process_group, async_op=True + ) + gathered_state[name] = AllGatherInfo(tensors, numels, work) + + for object_state in object_list: + for name, non_tensor_value in object_state.non_tensors.items(): + curr_non_tensor_value = gathered_state.get(name, None) + assert ( + curr_non_tensor_value is None + or curr_non_tensor_value == non_tensor_value + ), f"Different ranks have different values for {name}." + gathered_state[name] = non_tensor_value + + for name, scalar_tensor_value in object_state.scalar_tensors.items(): + curr_scalar_tensor_value = gathered_state.get(name, None) + assert curr_scalar_tensor_value is None or torch.equal( + scalar_tensor_value, curr_scalar_tensor_value + ), f"Different ranks have different values for {name}." + gathered_state[name] = scalar_tensor_value + + for name, value in list(gathered_state.items()): + if not isinstance(value, AllGatherInfo): + continue + assert value.work is not None + value.work.wait() + gathered_state[name] = torch.cat( + [ + rank_tensor[:rank_numel] + for rank_tensor, rank_numel in zip(value.tensors, value.numels) + if rank_numel > 0 + ] + ) + + return gathered_state + + def _gather_orig_param_state( fsdp_param_info: FSDPParamInfo, fqn: str, @@ -1458,51 +1566,18 @@ def _gather_orig_param_state( ): return optim_state - # Gathering state from all ranks. This step may be slow. However, - # `state_dict()` is not in the critical path. We can fuse the communication - # if the performance becomes a problem. - state_objects = { - state_name: value for state_name, value in sorted_items(optim_state) - } - object_list: List[Dict[str, Any]] = [{} for _ in range(fsdp_state.world_size)] - dist.all_gather_object(object_list, state_objects) - orig_state: Dict[str, Any] = {} - for idx, state in enumerate(object_list): - for state_name, value in state.items(): - curr_value = orig_state.get(state_name, []) - if torch.is_tensor(value): - if value.dim() > 0: - curr_value.append(value.to(fsdp_state.compute_device)) - orig_state[state_name] = curr_value - else: # zero dim tensor, e.g., step. - if torch.is_tensor(curr_value): - assert torch.equal(curr_value, value) - else: - orig_state[state_name] = value - else: - assert curr_value == [] or curr_value == value - orig_state[state_name] = value + gathered_state = _all_gather_optim_state( + fsdp_state, optim_state, flat_param._numels[param_idx] + ) # Unflatten state values. - for state_name in orig_state.keys(): - value = orig_state[state_name] - if not isinstance(value, list) or not torch.is_tensor(value[0]): + for state_name, value in list(gathered_state.items()): + if not torch.is_tensor(value) or value.dim() == 0: continue - try: - value = torch.concat(value)[: flat_param._numels[param_idx]].reshape( - flat_param._shapes[param_idx] - ) - except Exception as e: - raise Exception( - ( - flat_param._numels[param_idx], - flat_param._shapes[param_idx], - len(value), - value[0].shape, - state_name, - fqn, - ) - ) + + value = value[: flat_param._numels[param_idx]].reshape( + flat_param._shapes[param_idx] + ) if shard_state: assert fsdp_state.process_group is not None value = _ext_chunk_tensor( @@ -1513,8 +1588,8 @@ def _gather_orig_param_state( fsdp_state.process_group, ) value = value.cpu() - orig_state[state_name] = value - return orig_state + gathered_state[state_name] = value + return gathered_state def _shard_orig_param_state(