17
17
18
18
import os
19
19
from typing import Any , Dict
20
-
20
+ import time
21
21
import paddle
22
22
23
23
from paddlenlp .transformers import AutoConfig , AutoInferenceModelForCausalLM
24
24
from paddlenlp .utils .log import logger
25
25
26
26
27
27
class InferenceModel :
28
- def __init__ (self , predictor_args , model_args , nranks = 1 , rank = 0 , load_model_from_ipc = False , cold_start = False ):
28
+ def __init__ (self , predictor_args , model_args , nranks = 1 , rank = 0 , load_model_from_ipc = False , hot_start = False ):
29
29
"""
30
30
Initialize the Causal Language Model Loader.
31
31
@@ -43,10 +43,11 @@ def __init__(self, predictor_args, model_args, nranks=1, rank=0, load_model_from
43
43
self .load_model_from_ipc = load_model_from_ipc
44
44
self .model = self ._build_model ()
45
45
self .shared_buffer_to = False
46
- self .local_test = False
46
+ self .local_test = True
47
+ self .hot_start = hot_start
47
48
48
49
# (TODO:gaoziyuan)当前启动服务后直接加载参数,后续进行热启动
49
- if load_model_from_ipc and not cold_start :
50
+ if load_model_from_ipc and not hot_start :
50
51
self .update_parameters ()
51
52
52
53
def _setup_environment (self ):
@@ -76,16 +77,26 @@ def _build_model(self):
76
77
tensor_parallel_rank = self .rank ,
77
78
load_model_from_ipc = self .load_model_from_ipc ,
78
79
)
79
- print ("gaoziyuan test load from config:" , self .model .state_dict ())
80
+ # print("gaoziyuan test load from config:", self.model.state_dict())
80
81
return self .model
81
82
82
- def clear_parameters (self ) -> None :
83
+ def clear_parameters (self , pid = 0 ) -> None :
83
84
"""Clear all model parameters."""
85
+ start_time = time .time ()
84
86
for name , param in self .model .state_dict ().items ():
85
87
logger .info (f"Clearing model parameter: { name } " )
86
88
param ._clear_data ()
89
+ clear_time = time .time () - start_time
90
+ logger .info (f"Parameter clearing completed in { clear_time :.2f} seconds" )
87
91
92
+ self .verify_parameters_cleared ()
88
93
logger .info ("Model parameters cleared successfully" )
94
+ if self .hot_start :
95
+ array = np .zeros ([1 ],dtype = np .int32 )
96
+ shm = SharedMemory (create = False , size = array .nbytes , name = f"model_weights_status.{ pid } " )
97
+ value = np .ndarray (array .shape , dtype = array .dtype , buffer = shm .buf )
98
+ value [0 ] = - 2
99
+
89
100
90
101
def get_model (self ) -> paddle .nn .Layer :
91
102
"""Get the underlying model instance."""
@@ -117,6 +128,7 @@ def load_tensor_from_ipc_meta(ipc_state_dict: Dict[str, Any]) -> Dict[str, paddl
117
128
118
129
def update_parameters (
119
130
self ,
131
+ pid = 0 ,
120
132
) -> None :
121
133
"""
122
134
Update model parameters from IPC state dictionary.
@@ -125,34 +137,98 @@ def update_parameters(
125
137
ipc_state_dict: Dictionary containing new parameters in IPC format
126
138
"""
127
139
if self .local_test :
128
- state_dict = paddle .load ("/root/paddlejob/workspace/env_run/output/model_local_qwen" )
129
- if not self .shared_buffer_to :
130
- print ("通过set_state_dict更新参数" )
131
- self .model .set_state_dict (state_dict )
132
- else :
133
- model_path = "/shared_ipc_meta"
134
140
current_device_id = int (os .getenv ("FLAGS_selected_gpus" ))
135
- ipc_state_dict_path = os .path .join (model_path , f"ipc_metas_{ current_device_id } " )
136
- ipc_state_dict = paddle .load (ipc_state_dict_path )
137
- state_dict = self .load_tensor_from_ipc_meta (ipc_state_dict )
138
- if not self .shared_buffer_to :
139
- print ("通过set_state_dict更新参数" )
140
- self .model .set_state_dict (state_dict )
141
- return
142
-
143
- infer_model_state_dict = self .model .state_dict ()
144
-
145
- print ("通过shared_buffer_to更新参数" )
146
- for name , param in state_dict .items ():
147
- if name in infer_model_state_dict :
148
- logger .info (f"Updating model parameter: { name } " )
149
- update_param = infer_model_state_dict [name ]
150
- assert (
151
- update_param .dtype == param .dtype
152
- ), f"Type mismatch for { name } : { param .dtype } vs { update_param .dtype } "
153
- assert (
154
- update_param .shape == param .shape
155
- ), f"Shape mismatch for { name } : { param .shape } vs { update_param .shape } "
156
- param ._share_buffer_to (update_param )
157
-
158
- logger .info ("Model parameters updated successfully" )
141
+ model_path = f"/shared_ipc_meta/model_state.tp0{ current_device_id } .pdparams"
142
+ print ("model_apth : " , model_path )
143
+ state_dict = paddle .load (model_path )
144
+ set_start = time .time ()
145
+ self .model .set_state_dict (state_dict )
146
+ logger .info (f"set_state_dict completed in { time .time () - set_start :.2f} seconds" )
147
+ self .verify_parameters_updated ()
148
+ return
149
+
150
+ start_time = time .time ()
151
+ logger .info ("Starting parameter update process..." )
152
+
153
+ model_path = "/shared_ipc_meta"
154
+ current_device_id = int (os .getenv ("FLAGS_selected_gpus" ))
155
+ ipc_state_dict_path = os .path .join (model_path , f"ipc_metas_{ current_device_id } " )
156
+ logger .info (f"ipc_state_dict_path is { ipc_state_dict_path } " )
157
+ ipc_state_dict = paddle .load (ipc_state_dict_path )
158
+ convert_start = time .time ()
159
+ state_dict = self .load_tensor_from_ipc_meta (ipc_state_dict )
160
+ logger .info (f"IPC meta converted to tensors in { time .time () - convert_start :.2f} seconds" )
161
+ if not self .shared_buffer_to :
162
+ logger .info ("Updating parameters via set_state_dict..." )
163
+ set_start = time .time ()
164
+ self .model .set_state_dict (state_dict )
165
+ logger .info (f"set_state_dict completed in { time .time () - set_start :.2f} seconds" )
166
+ self .verify_parameters_updated ()
167
+ else :
168
+ share_start = time .time ()
169
+ logger .info ("通过shared_buffer_to更新参数" )
170
+ infer_model_state_dict = self .model .state_dict ()
171
+ for name , param in state_dict .items ():
172
+ if name in infer_model_state_dict :
173
+ logger .info (f"Updating model parameter: { name } " )
174
+ update_param = infer_model_state_dict [name ]
175
+ assert (
176
+ update_param .dtype == param .dtype
177
+ ), f"Type mismatch for { name } : { param .dtype } vs { update_param .dtype } "
178
+ assert (
179
+ update_param .shape == param .shape
180
+ ), f"Shape mismatch for { name } : training { param .shape } vs infer { update_param .shape } "
181
+ param ._share_buffer_to (update_param )
182
+ logger .info (f"Parameter sharing completed in { time .time () - share_start :.2f} seconds" )
183
+
184
+ if self .hot_start :
185
+ array = np .zeros ([1 ],dtype = np .int32 )
186
+ shm = SharedMemory (create = False , size = array .nbytes , name = f"model_weights_status.{ pid } " )
187
+ value = np .ndarray (array .shape , dtype = array .dtype , buffer = shm .buf )
188
+ value [0 ] = 2
189
+
190
+
191
+ def verify_parameters_cleared (self ) -> bool :
192
+ """
193
+ Verify that all model parameters have been cleared.
194
+
195
+ Returns:
196
+ bool: True if all parameters are cleared, False otherwise
197
+ """
198
+ logger .info ("Verifying parameters are cleared..." )
199
+ all_cleared = True
200
+ for name , param in self .model .state_dict ().items ():
201
+ if param ._is_initialized ():
202
+ logger .error (f"Parameter { name } was not properly cleared!" )
203
+ all_cleared = False
204
+
205
+ if all_cleared :
206
+ logger .info ("All parameters verified as cleared successfully" )
207
+ else :
208
+ logger .error ("Some parameters were not properly cleared!" )
209
+
210
+ return all_cleared
211
+
212
+ def verify_parameters_updated (self ) -> bool :
213
+ """
214
+ Verify that model parameters match the source state dictionary.
215
+
216
+ Args:
217
+ source_state_dict: Dictionary containing the expected parameters
218
+
219
+ Returns:
220
+ bool: True if all parameters match, False otherwise
221
+ """
222
+ logger .info ("Verifying parameters are cleared..." )
223
+ all_update = True
224
+ for name , param in self .model .state_dict ().items ():
225
+ if not param ._is_initialized ():
226
+ logger .error (f"Parameter { name } was not properly cleared!" )
227
+ all_update = False
228
+
229
+ if all_update :
230
+ logger .info ("All parameters verified as updated successfully" )
231
+ else :
232
+ logger .error ("Some parameters were not properly updated!" )
233
+
234
+ return all_update
0 commit comments