@@ -82,38 +82,6 @@ def _build_model(self):
82
82
# print("gaoziyuan test load from config:", self.model.state_dict())
83
83
return self .model
84
84
85
- def clear_parameters (self , pid = 0 ) -> None :
86
- """Clear all model parameters."""
87
-
88
- if self .verify_parameters_cleared ():
89
- logger .info ("Parameters already cleared!" )
90
- array = np .zeros ([self .nranks ],dtype = np .int32 )
91
- shm = SharedMemory (create = False , size = array .nbytes , name = f"model_weights_status.{ pid } " )
92
- value = np .ndarray (array .shape , dtype = array .dtype , buffer = shm .buf )
93
- value [self .rank ] = - 2
94
- return
95
-
96
- start_time = time .time ()
97
- paddle .device .cuda .empty_cache ()
98
- self .check_memory_usage ("start clear parameters" )
99
- for name , param in self .model .state_dict ().items ():
100
- logger .info (f"Clearing model parameter: { name } " )
101
- param ._clear_data ()
102
- clear_time = time .time () - start_time
103
- logger .info (f"Parameter clearing completed in { clear_time :.2f} seconds" )
104
-
105
- self .verify_parameters_cleared ()
106
- logger .info ("Model parameters cleared successfully" )
107
-
108
- array = np .zeros ([self .nranks ],dtype = np .int32 )
109
- shm = SharedMemory (create = False , size = array .nbytes , name = f"model_weights_status.{ pid } " )
110
- value = np .ndarray (array .shape , dtype = array .dtype , buffer = shm .buf )
111
- value [self .rank ] = - 2
112
- paddle .device .cuda .empty_cache ()
113
- self .check_memory_usage ("clear parameters end" )
114
- logger .info ("send clear sigal !" )
115
-
116
-
117
85
def get_model (self ) -> paddle .nn .Layer :
118
86
"""Get the underlying model instance."""
119
87
return self .model
@@ -142,103 +110,131 @@ def load_tensor_from_ipc_meta(ipc_state_dict: Dict[str, Any]) -> Dict[str, paddl
142
110
143
111
return result
144
112
145
- def check_memory_usage (self , msg = "" ):
146
- """ check_memory_usage """
147
- max_memory_allocated_size = paddle .device .cuda .max_memory_allocated ()/ (1024 * 1024 * 1024 )
148
- max_memory_reserved_size = paddle .device .cuda .max_memory_reserved ()/ (1024 * 1024 * 1024 )
149
- memory_allocated_size = paddle .device .cuda .memory_allocated ()/ (1024 * 1024 * 1024 )
150
- memory_reserved_size = paddle .device .cuda .memory_reserved ()/ (1024 * 1024 * 1024 )
151
- logger .info (msg )
152
- logger .warning (f"checking gpu memory usage { msg } :\n max_memory_allocated_size: { max_memory_allocated_size } GB\n max_memory_reserved_size: { max_memory_reserved_size } GB\n memory_allocated_size: { memory_allocated_size } GB\n memory_reserved_size: { memory_reserved_size } GB" )
113
+ def _log_memory_usage (self , context : str = "" ) -> None :
114
+ """Log current GPU memory usage."""
115
+ max_alloc = paddle .device .cuda .max_memory_allocated () / (1024 ** 3 )
116
+ max_reserved = paddle .device .cuda .max_memory_reserved () / (1024 ** 3 )
117
+ curr_alloc = paddle .device .cuda .memory_allocated () / (1024 ** 3 )
118
+ curr_reserved = paddle .device .cuda .memory_reserved () / (1024 ** 3 )
119
+
120
+ logger .info (f"GPU memory usage { context } :" )
121
+ logger .warning (
122
+ f"max_allocated: { max_alloc :.2f} GB\n "
123
+ f"max_reserved: { max_reserved :.2f} GB\n "
124
+ f"current_allocated: { curr_alloc :.2f} GB\n "
125
+ f"current_reserved: { curr_reserved :.2f} GB"
126
+ )
153
127
154
128
def generate (self , ** kwargs ):
155
129
self .model .generate (** kwargs )
130
+
131
+ def _update_shared_status (self , pid : int , status : int ) -> None :
132
+ """Update shared memory status flag."""
133
+ array = np .zeros ([self .nranks ], dtype = np .int32 )
134
+ shm = SharedMemory (create = False , size = array .nbytes , name = f"model_weights_status.{ pid } " )
135
+ value = np .ndarray (array .shape , dtype = array .dtype , buffer = shm .buf )
136
+ value [self .rank ] = status
156
137
157
- def update_parameters (
158
- self ,
159
- pid = 0 ,
160
- ) -> None :
161
- """
162
- Update model parameters from IPC state dictionary.
163
-
164
- Args:
165
- ipc_state_dict: Dictionary containing new parameters in IPC format
166
- """
167
- if self .verify_parameters_updated () and not self .first_load :
138
+ def update_parameters (self , pid : int = 0 ) -> None :
139
+ """Update model parameters from IPC state dictionary."""
140
+ if self .verify_parameters_updated (False ) and not self .first_load :
168
141
logger .info ("Parameters already updated." )
169
- array = np .zeros ([self .nranks ],dtype = np .int32 )
170
- shm = SharedMemory (create = False , size = array .nbytes , name = f"model_weights_status.{ pid } " )
171
- value = np .ndarray (array .shape , dtype = array .dtype , buffer = shm .buf )
172
- value [self .rank ] = 2
142
+ self ._update_shared_status (pid , 2 )
173
143
return
174
144
175
145
paddle .device .cuda .empty_cache ()
176
- self .check_memory_usage ("start update parameters" )
146
+ self ._log_memory_usage ("start update parameters" )
147
+
177
148
if self .local_test :
178
149
current_device_id = int (os .getenv ("FLAGS_selected_gpus" ))
179
150
model_path = f"/shared_ipc_meta/model_state.tp0{ current_device_id } .pdparams"
180
- print ( "model_apth : " , model_path )
181
- state_dict = paddle . load ( model_path )
151
+ logger . info ( f"Loading model from: { model_path } " )
152
+
182
153
set_start = time .time ()
183
- self .model .set_state_dict (state_dict )
154
+ self .model .set_state_dict (paddle . load ( model_path ) )
184
155
logger .info (f"set_state_dict completed in { time .time () - set_start :.2f} seconds" )
156
+
185
157
self .verify_parameters_updated ()
186
- self .check_memory_usage ("update parameters end" )
158
+ self ._log_memory_usage ("update parameters end" )
159
+
187
160
if not self .first_load :
188
161
logger .info ("send update signal" )
189
- array = np .zeros ([self .nranks ],dtype = np .int32 )
190
- shm = SharedMemory (create = False , size = array .nbytes , name = f"model_weights_status.{ pid } " )
191
- value = np .ndarray (array .shape , dtype = array .dtype , buffer = shm .buf )
192
- value [self .rank ] = 2
162
+ self ._update_shared_status (pid , 2 )
193
163
self .first_load = False
194
164
return
195
165
196
166
start_time = time .time ()
197
167
logger .info ("Starting parameter update process..." )
198
- model_path = "/shared_ipc_meta"
168
+
199
169
current_device_id = int (os .getenv ("FLAGS_selected_gpus" ))
200
- ipc_state_dict_path = os . path . join ( model_path , f" ipc_metas_{ current_device_id } ")
201
- logger .info (f"ipc_state_dict_path is { ipc_state_dict_path } " )
202
- ipc_state_dict = paddle . load ( ipc_state_dict_path )
170
+ ipc_state_dict_path = f"/shared_ipc_meta/ ipc_metas_{ current_device_id } "
171
+ logger .info (f"Loading IPC state dict from: { ipc_state_dict_path } " )
172
+
203
173
convert_start = time .time ()
204
- state_dict = self .load_tensor_from_ipc_meta (ipc_state_dict )
174
+ state_dict = self .load_tensor_from_ipc_meta (paddle . load ( ipc_state_dict_path ) )
205
175
logger .info (f"IPC meta converted to tensors in { time .time () - convert_start :.2f} seconds" )
176
+
206
177
if not self .shared_buffer_to :
207
178
logger .info ("Updating parameters via set_state_dict..." )
208
179
set_start = time .time ()
209
180
self .model .set_state_dict (state_dict )
210
181
logger .info (f"set_state_dict completed in { time .time () - set_start :.2f} seconds" )
211
- self .verify_parameters_updated ()
212
182
else :
183
+ logger .info ("Updating parameters via shared_buffer_to..." )
213
184
share_start = time .time ()
214
- logger .info ("通过shared_buffer_to更新参数" )
215
185
infer_model_state_dict = self .model .state_dict ()
186
+
216
187
for name , param in state_dict .items ():
217
188
if name in infer_model_state_dict :
218
189
logger .info (f"Updating model parameter: { name } " )
219
190
update_param = infer_model_state_dict [name ]
220
- assert (
221
- update_param .dtype == param .dtype
222
- ), f"Type mismatch for { name } : { param .dtype } vs { update_param .dtype } "
223
- assert (
224
- update_param .shape == param .shape
225
- ), f"Shape mismatch for { name } : training { param . shape } vs infer { update_param . shape } "
191
+
192
+ if update_param .dtype != param .dtype :
193
+ raise TypeError ( f"Type mismatch for { name } : { param .dtype } vs { update_param .dtype } " )
194
+ if update_param . shape != param . shape :
195
+ raise ValueError ( f"Shape mismatch for { name } : { param .shape } vs { update_param .shape } " )
196
+
226
197
param ._share_buffer_to (update_param )
227
- logger .info (f"Parameter sharing completed in { time .time () - share_start :.2f} seconds" )
198
+
199
+ logger .info (f"Parameter sharing completed in { time .time () - share_start :.2f} seconds" )
228
200
229
201
if not self .first_load :
230
202
logger .info ("send update signal" )
231
- array = np .zeros ([self .nranks ],dtype = np .int32 )
232
- shm = SharedMemory (create = False , size = array .nbytes , name = f"model_weights_status.{ pid } " )
233
- value = np .ndarray (array .shape , dtype = array .dtype , buffer = shm .buf )
234
- value [self .rank ] = 2
203
+ self ._update_shared_status (pid , 2 )
204
+
235
205
self .first_load = False
206
+ self .verify_parameters_updated ()
207
+ paddle .device .cuda .empty_cache ()
208
+ self ._log_memory_usage ("update parameters end" )
209
+
210
+ def clear_parameters (self , pid : int = 0 ) -> None :
211
+ """Clear all model parameters."""
212
+ if self .verify_parameters_cleared (False ):
213
+ logger .info ("Parameters already cleared!" )
214
+ self ._update_shared_status (pid , - 2 )
215
+ return
216
+
217
+ start_time = time .time ()
218
+ paddle .device .cuda .empty_cache ()
219
+ self ._log_memory_usage ("start clear parameters" )
220
+
221
+ for name , param in self .model .state_dict ().items ():
222
+ logger .info (f"Clearing model parameter: { name } " )
223
+ param ._clear_data ()
224
+
225
+ clear_time = time .time () - start_time
226
+ logger .info (f"Parameter clearing completed in { clear_time :.2f} seconds" )
227
+
228
+ self .verify_parameters_cleared ()
229
+ logger .info ("Model parameters cleared successfully" )
236
230
231
+ self ._update_shared_status (pid , - 2 )
237
232
paddle .device .cuda .empty_cache ()
238
- self .check_memory_usage ("update parameters end" )
233
+ self ._log_memory_usage ("clear parameters end" )
234
+ logger .info ("send clear signal!" )
239
235
240
236
241
- def verify_parameters_cleared (self ) -> bool :
237
+ def verify_parameters_cleared (self , erro_log : bool = True ) -> bool :
242
238
"""
243
239
Verify that all model parameters have been cleared.
244
240
@@ -249,17 +245,19 @@ def verify_parameters_cleared(self) -> bool:
249
245
all_cleared = True
250
246
for name , param in self .model .state_dict ().items ():
251
247
if param ._is_initialized ():
252
- logger .error (f"Parameter { name } was not properly cleared!" )
248
+ if erro_log :
249
+ logger .error (f"Parameter { name } was not properly cleared!" )
253
250
all_cleared = False
254
251
255
252
if all_cleared :
256
253
logger .info ("All parameters verified as cleared successfully" )
257
254
else :
258
- logger .error ("Some parameters were not properly cleared!" )
255
+ if erro_log :
256
+ logger .error ("Some parameters were not properly cleared!" )
259
257
260
258
return all_cleared
261
259
262
- def verify_parameters_updated (self ) -> bool :
260
+ def verify_parameters_updated (self , erro_log : bool = True ) -> bool :
263
261
"""
264
262
Verify that model parameters match the source state dictionary.
265
263
@@ -273,12 +271,14 @@ def verify_parameters_updated(self) -> bool:
273
271
all_update = True
274
272
for name , param in self .model .state_dict ().items ():
275
273
if not param ._is_initialized ():
276
- logger .error (f"Parameter { name } was not properly cleared!" )
274
+ if erro_log :
275
+ logger .error (f"Parameter { name } was not properly cleared!" )
277
276
all_update = False
278
277
279
278
if all_update :
280
279
logger .info ("All parameters verified as updated successfully" )
281
280
else :
282
- logger .error ("Some parameters were not properly updated!" )
281
+ if erro_log
282
+ logger .error ("Some parameters were not properly updated!" )
283
283
284
284
return all_update
0 commit comments