diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index a8f8216057a82..5215eb436590f 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -152,8 +152,11 @@ def _get_fqns( Returns: The canonical FQNs based on the model traversal. """ + + # Remove the checkpoint prefix, if it exists. + name = name.replace(_CHECKPOINT_PREFIX, "") if "." not in name: - return {name.replace(_CHECKPOINT_PREFIX, "")} + return {name} obj_names = name.split(".") fqn_obj_names = [] @@ -170,8 +173,6 @@ def _get_fqns( flat_param = getattr(curr_obj, FLAT_PARAM) if prefix: prefix = f"{prefix}." - # FSDP already handles removal of checkpoint prefix, so we can return - # directly return {f"{prefix}{fqn}" for fqn in flat_param._fqns} curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE) if curr_obj_name != FSDP_WRAPPED_MODULE: