From f7dc14e4832851d982d0d4a22a37f266903c7632 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Mon, 22 Apr 2024 20:15:19 -0700 Subject: [PATCH] remove act checkpoint tag --- torch/distributed/checkpoint/state_dict.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index a8f8216057a82..5215eb436590f 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -152,8 +152,11 @@ def _get_fqns( Returns: The canonical FQNs based on the model traversal. """ + + # Remove the checkpoint prefix, if it exists. + name = name.replace(_CHECKPOINT_PREFIX, "") if "." not in name: - return {name.replace(_CHECKPOINT_PREFIX, "")} + return {name} obj_names = name.split(".") fqn_obj_names = [] @@ -170,8 +173,6 @@ def _get_fqns( flat_param = getattr(curr_obj, FLAT_PARAM) if prefix: prefix = f"{prefix}." - # FSDP already handles removal of checkpoint prefix, so we can return - # directly return {f"{prefix}{fqn}" for fqn in flat_param._fqns} curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE) if curr_obj_name != FSDP_WRAPPED_MODULE: