Skip to content

Commit

Permalink
[FSDP][state_dict] Add a summary log when finishing state_dict
Browse files Browse the repository at this point in the history
Pull Request resolved: #103784

Add a summary log when finishing state_dict
ghstack-source-id: 192732157

Differential Revision: [D46807103](https://our.internmc.facebook.com/intern/diff/D46807103/)
  • Loading branch information
fegin committed Jun 21, 2023
1 parent d531c86 commit 362e377
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions torch/distributed/fsdp/_state_dict_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import logging
import math
import warnings
from typing import (
Expand All @@ -24,6 +25,7 @@
Shard,
ShardedTensor,
)
from torch.distributed._tensor import DTensor

from torch.distributed.distributed_c10d import _get_pg_default_device
from torch.distributed.fsdp._common_utils import (
Expand Down Expand Up @@ -53,6 +55,9 @@
from ._unshard_param_utils import _unshard_fsdp_state_params, FLAT_PARAM


logger = logging.getLogger(__name__)


def _convert_to_wrapped_module_name(module_name: str) -> str:
module_name = module_name.replace(f"{FSDP_PREFIX}", "")
module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "")
Expand Down Expand Up @@ -265,6 +270,9 @@ def _common_unshard_post_state_dict_hook(
)
for buffer, clean_fqn in zip(buffers, buffer_clean_fqns):
fqn = f"{prefix}{clean_fqn}"
logger.info(
"FSDP is casting the dtype of %s to %s", fqn, buffer.dtype
)
state_dict[fqn] = buffer.clone()
return state_dict

Expand Down Expand Up @@ -676,6 +684,29 @@ def _post_state_dict_hook(
processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
module, fsdp_state, state_dict, prefix
)

if fsdp_state._is_root:
logger.info("FSDP finished processing state_dict(), prefix=%s", prefix)
for key, tensor in sorted(processed_state_dict.items()):
if key.startswith(prefix) and isinstance(tensor, torch.Tensor):
local_shape = tensor.shape
if isinstance(tensor, ShardedTensor):
local_shape = None
shards = tensor.local_shards()
if shards:
local_shape = shards[0].tensor.shape
elif isinstance(tensor, DTensor):
local_shape = tensor.to_local().shape
logger.info(
"FQN=%s: type=%s, shape=%s, local_shape=%s, dtype=%s, device=%s",
key,
type(tensor),
tensor.shape,
local_shape,
tensor.dtype,
tensor.device,
)

return processed_state_dict


Expand Down

0 comments on commit 362e377

Please sign in to comment.