Skip to content

Commit

Permalink
remove act checkpoint tag
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Apr 23, 2024
1 parent 7706cd7 commit f7dc14e
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions torch/distributed/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,11 @@ def _get_fqns(
Returns:
The canonical FQNs based on the model traversal.
"""

# Remove the checkpoint prefix, if it exists.
name = name.replace(_CHECKPOINT_PREFIX, "")
if "." not in name:
return {name.replace(_CHECKPOINT_PREFIX, "")}
return {name}

obj_names = name.split(".")
fqn_obj_names = []
Expand All @@ -170,8 +173,6 @@ def _get_fqns(
flat_param = getattr(curr_obj, FLAT_PARAM)
if prefix:
prefix = f"{prefix}."
# FSDP already handles removal of checkpoint prefix, so we can return
# directly
return {f"{prefix}{fqn}" for fqn in flat_param._fqns}
curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)
if curr_obj_name != FSDP_WRAPPED_MODULE:
Expand Down

0 comments on commit f7dc14e

Please sign in to comment.