Skip to content

Commit 9c41196

Browse files
committed
add hot start
1 parent a3b7a01 commit 9c41196

File tree

2 files changed

+120
-39
lines changed

2 files changed

+120
-39
lines changed

paddlenlp/experimental/transformers/inference_model.py

Lines changed: 112 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717

1818
import os
1919
from typing import Any, Dict
20-
20+
import time
2121
import paddle
2222

2323
from paddlenlp.transformers import AutoConfig, AutoInferenceModelForCausalLM
2424
from paddlenlp.utils.log import logger
2525

2626

2727
class InferenceModel:
28-
def __init__(self, predictor_args, model_args, nranks=1, rank=0, load_model_from_ipc=False, cold_start=False):
28+
def __init__(self, predictor_args, model_args, nranks=1, rank=0, load_model_from_ipc=False, hot_start=False):
2929
"""
3030
Initialize the Causal Language Model Loader.
3131
@@ -43,10 +43,11 @@ def __init__(self, predictor_args, model_args, nranks=1, rank=0, load_model_from
4343
self.load_model_from_ipc = load_model_from_ipc
4444
self.model = self._build_model()
4545
self.shared_buffer_to = False
46-
self.local_test = False
46+
self.local_test = True
47+
self.hot_start = hot_start
4748

4849
# (TODO:gaoziyuan)当前启动服务后直接加载参数,后续进行热启动
49-
if load_model_from_ipc and not cold_start:
50+
if load_model_from_ipc and not hot_start:
5051
self.update_parameters()
5152

5253
def _setup_environment(self):
@@ -76,16 +77,26 @@ def _build_model(self):
7677
tensor_parallel_rank=self.rank,
7778
load_model_from_ipc=self.load_model_from_ipc,
7879
)
79-
print("gaoziyuan test load from config:", self.model.state_dict())
80+
# print("gaoziyuan test load from config:", self.model.state_dict())
8081
return self.model
8182

82-
def clear_parameters(self) -> None:
83+
def clear_parameters(self, pid=0) -> None:
8384
"""Clear all model parameters."""
85+
start_time = time.time()
8486
for name, param in self.model.state_dict().items():
8587
logger.info(f"Clearing model parameter: {name}")
8688
param._clear_data()
89+
clear_time = time.time() - start_time
90+
logger.info(f"Parameter clearing completed in {clear_time:.2f} seconds")
8791

92+
self.verify_parameters_cleared()
8893
logger.info("Model parameters cleared successfully")
94+
if self.hot_start:
95+
array = np.zeros([1],dtype=np.int32)
96+
shm = SharedMemory(create=False, size=array.nbytes, name=f"model_weights_status.{pid}")
97+
value = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf)
98+
value[0] = -2
99+
89100

90101
def get_model(self) -> paddle.nn.Layer:
91102
"""Get the underlying model instance."""
@@ -117,6 +128,7 @@ def load_tensor_from_ipc_meta(ipc_state_dict: Dict[str, Any]) -> Dict[str, paddl
117128

118129
def update_parameters(
119130
self,
131+
pid=0,
120132
) -> None:
121133
"""
122134
Update model parameters from IPC state dictionary.
@@ -125,34 +137,98 @@ def update_parameters(
125137
ipc_state_dict: Dictionary containing new parameters in IPC format
126138
"""
127139
if self.local_test:
128-
state_dict = paddle.load("/root/paddlejob/workspace/env_run/output/model_local_qwen")
129-
if not self.shared_buffer_to:
130-
print("通过set_state_dict更新参数")
131-
self.model.set_state_dict(state_dict)
132-
else:
133-
model_path = "/shared_ipc_meta"
134140
current_device_id = int(os.getenv("FLAGS_selected_gpus"))
135-
ipc_state_dict_path = os.path.join(model_path, f"ipc_metas_{current_device_id}")
136-
ipc_state_dict = paddle.load(ipc_state_dict_path)
137-
state_dict = self.load_tensor_from_ipc_meta(ipc_state_dict)
138-
if not self.shared_buffer_to:
139-
print("通过set_state_dict更新参数")
140-
self.model.set_state_dict(state_dict)
141-
return
142-
143-
infer_model_state_dict = self.model.state_dict()
144-
145-
print("通过shared_buffer_to更新参数")
146-
for name, param in state_dict.items():
147-
if name in infer_model_state_dict:
148-
logger.info(f"Updating model parameter: {name}")
149-
update_param = infer_model_state_dict[name]
150-
assert (
151-
update_param.dtype == param.dtype
152-
), f"Type mismatch for {name}: {param.dtype} vs {update_param.dtype}"
153-
assert (
154-
update_param.shape == param.shape
155-
), f"Shape mismatch for {name}: {param.shape} vs {update_param.shape}"
156-
param._share_buffer_to(update_param)
157-
158-
logger.info("Model parameters updated successfully")
141+
model_path = f"/shared_ipc_meta/model_state.tp0{current_device_id}.pdparams"
142+
print("model_apth : ", model_path)
143+
state_dict = paddle.load(model_path)
144+
set_start = time.time()
145+
self.model.set_state_dict(state_dict)
146+
logger.info(f"set_state_dict completed in {time.time() - set_start:.2f} seconds")
147+
self.verify_parameters_updated()
148+
return
149+
150+
start_time = time.time()
151+
logger.info("Starting parameter update process...")
152+
153+
model_path = "/shared_ipc_meta"
154+
current_device_id = int(os.getenv("FLAGS_selected_gpus"))
155+
ipc_state_dict_path = os.path.join(model_path, f"ipc_metas_{current_device_id}")
156+
logger.info(f"ipc_state_dict_path is {ipc_state_dict_path}")
157+
ipc_state_dict = paddle.load(ipc_state_dict_path)
158+
convert_start = time.time()
159+
state_dict = self.load_tensor_from_ipc_meta(ipc_state_dict)
160+
logger.info(f"IPC meta converted to tensors in {time.time() - convert_start:.2f} seconds")
161+
if not self.shared_buffer_to:
162+
logger.info("Updating parameters via set_state_dict...")
163+
set_start = time.time()
164+
self.model.set_state_dict(state_dict)
165+
logger.info(f"set_state_dict completed in {time.time() - set_start:.2f} seconds")
166+
self.verify_parameters_updated()
167+
else:
168+
share_start = time.time()
169+
logger.info("通过shared_buffer_to更新参数")
170+
infer_model_state_dict = self.model.state_dict()
171+
for name, param in state_dict.items():
172+
if name in infer_model_state_dict:
173+
logger.info(f"Updating model parameter: {name}")
174+
update_param = infer_model_state_dict[name]
175+
assert (
176+
update_param.dtype == param.dtype
177+
), f"Type mismatch for {name}: {param.dtype} vs {update_param.dtype}"
178+
assert (
179+
update_param.shape == param.shape
180+
), f"Shape mismatch for {name}: training {param.shape} vs infer {update_param.shape}"
181+
param._share_buffer_to(update_param)
182+
logger.info(f"Parameter sharing completed in {time.time() - share_start:.2f} seconds")
183+
184+
if self.hot_start:
185+
array = np.zeros([1],dtype=np.int32)
186+
shm = SharedMemory(create=False, size=array.nbytes, name=f"model_weights_status.{pid}")
187+
value = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf)
188+
value[0] = 2
189+
190+
191+
def verify_parameters_cleared(self) -> bool:
192+
"""
193+
Verify that all model parameters have been cleared.
194+
195+
Returns:
196+
bool: True if all parameters are cleared, False otherwise
197+
"""
198+
logger.info("Verifying parameters are cleared...")
199+
all_cleared = True
200+
for name, param in self.model.state_dict().items():
201+
if param._is_initialized():
202+
logger.error(f"Parameter {name} was not properly cleared!")
203+
all_cleared = False
204+
205+
if all_cleared:
206+
logger.info("All parameters verified as cleared successfully")
207+
else:
208+
logger.error("Some parameters were not properly cleared!")
209+
210+
return all_cleared
211+
212+
def verify_parameters_updated(self) -> bool:
213+
"""
214+
Verify that model parameters match the source state dictionary.
215+
216+
Args:
217+
source_state_dict: Dictionary containing the expected parameters
218+
219+
Returns:
220+
bool: True if all parameters match, False otherwise
221+
"""
222+
logger.info("Verifying parameters are cleared...")
223+
all_update = True
224+
for name, param in self.model.state_dict().items():
225+
if not param._is_initialized():
226+
logger.error(f"Parameter {name} was not properly cleared!")
227+
all_update = False
228+
229+
if all_update:
230+
logger.info("All parameters verified as updated successfully")
231+
else:
232+
logger.error("Some parameters were not properly updated!")
233+
234+
return all_update

paddlenlp/experimental/transformers/test_inference_model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
# limitations under the License.
1414

1515
# 后面会删掉,仅提供示例
16-
from inference_utils import ModelArgument, PredictorArgument
17-
16+
from paddlenlp.experimental.transformers.inference_utils import ModelArgument, PredictorArgument
1817
from paddlenlp.experimental.transformers.inference_model import InferenceModel
1918

2019
predictor_args = PredictorArgument()
@@ -27,9 +26,15 @@
2726
# 如果需要
2827
# predictor_args.quant_type = "weight_only_int8"
2928

30-
inference_model = InferenceModel(predictor_args, model_args, load_model_from_ipc=True, cold_start=False)
29+
inference_model = InferenceModel(predictor_args, model_args, load_model_from_ipc=True, cold_start=True)
3130

3231
model = inference_model.model
32+
33+
print(inference_model.verify_parameters_cleared())
34+
35+
print(inference_model.verify_parameters_updated())
36+
37+
3338
# print(model.get_name_mappings_to_training())
3439

3540
# 获取inference model 的 key\shape\type

0 commit comments

Comments
 (0)