Skip to content

Commit 4d21ccc

Browse files
committed
fix log
1 parent 07519c8 commit 4d21ccc

File tree

1 file changed

+89
-89
lines changed

1 file changed

+89
-89
lines changed

paddlenlp/experimental/transformers/inference_model.py

Lines changed: 89 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -82,38 +82,6 @@ def _build_model(self):
8282
# print("gaoziyuan test load from config:", self.model.state_dict())
8383
return self.model
8484

85-
def clear_parameters(self, pid=0) -> None:
86-
"""Clear all model parameters."""
87-
88-
if self.verify_parameters_cleared():
89-
logger.info("Parameters already cleared!")
90-
array = np.zeros([self.nranks],dtype=np.int32)
91-
shm = SharedMemory(create=False, size=array.nbytes, name=f"model_weights_status.{pid}")
92-
value = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf)
93-
value[self.rank] = -2
94-
return
95-
96-
start_time = time.time()
97-
paddle.device.cuda.empty_cache()
98-
self.check_memory_usage("start clear parameters")
99-
for name, param in self.model.state_dict().items():
100-
logger.info(f"Clearing model parameter: {name}")
101-
param._clear_data()
102-
clear_time = time.time() - start_time
103-
logger.info(f"Parameter clearing completed in {clear_time:.2f} seconds")
104-
105-
self.verify_parameters_cleared()
106-
logger.info("Model parameters cleared successfully")
107-
108-
array = np.zeros([self.nranks],dtype=np.int32)
109-
shm = SharedMemory(create=False, size=array.nbytes, name=f"model_weights_status.{pid}")
110-
value = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf)
111-
value[self.rank] = -2
112-
paddle.device.cuda.empty_cache()
113-
self.check_memory_usage("clear parameters end")
114-
logger.info("send clear sigal !")
115-
116-
11785
def get_model(self) -> paddle.nn.Layer:
11886
"""Get the underlying model instance."""
11987
return self.model
@@ -142,103 +110,131 @@ def load_tensor_from_ipc_meta(ipc_state_dict: Dict[str, Any]) -> Dict[str, paddl
142110

143111
return result
144112

145-
def check_memory_usage(self, msg=""):
146-
""" check_memory_usage """
147-
max_memory_allocated_size = paddle.device.cuda.max_memory_allocated()/(1024*1024*1024)
148-
max_memory_reserved_size = paddle.device.cuda.max_memory_reserved()/(1024*1024*1024)
149-
memory_allocated_size = paddle.device.cuda.memory_allocated()/(1024*1024*1024)
150-
memory_reserved_size = paddle.device.cuda.memory_reserved()/(1024*1024*1024)
151-
logger.info(msg)
152-
logger.warning(f"checking gpu memory usage {msg}:\nmax_memory_allocated_size: {max_memory_allocated_size}GB\nmax_memory_reserved_size: {max_memory_reserved_size}GB\nmemory_allocated_size: {memory_allocated_size}GB\nmemory_reserved_size: {memory_reserved_size}GB")
113+
def _log_memory_usage(self, context: str = "") -> None:
114+
"""Log current GPU memory usage."""
115+
max_alloc = paddle.device.cuda.max_memory_allocated() / (1024 ** 3)
116+
max_reserved = paddle.device.cuda.max_memory_reserved() / (1024 ** 3)
117+
curr_alloc = paddle.device.cuda.memory_allocated() / (1024 ** 3)
118+
curr_reserved = paddle.device.cuda.memory_reserved() / (1024 ** 3)
119+
120+
logger.info(f"GPU memory usage {context}:")
121+
logger.warning(
122+
f"max_allocated: {max_alloc:.2f}GB\n"
123+
f"max_reserved: {max_reserved:.2f}GB\n"
124+
f"current_allocated: {curr_alloc:.2f}GB\n"
125+
f"current_reserved: {curr_reserved:.2f}GB"
126+
)
153127

154128
def generate(self, **kwargs):
155129
self.model.generate(**kwargs)
130+
131+
def _update_shared_status(self, pid: int, status: int) -> None:
132+
"""Update shared memory status flag."""
133+
array = np.zeros([self.nranks], dtype=np.int32)
134+
shm = SharedMemory(create=False, size=array.nbytes, name=f"model_weights_status.{pid}")
135+
value = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf)
136+
value[self.rank] = status
156137

157-
def update_parameters(
158-
self,
159-
pid=0,
160-
) -> None:
161-
"""
162-
Update model parameters from IPC state dictionary.
163-
164-
Args:
165-
ipc_state_dict: Dictionary containing new parameters in IPC format
166-
"""
167-
if self.verify_parameters_updated() and not self.first_load:
138+
def update_parameters(self, pid: int = 0) -> None:
139+
"""Update model parameters from IPC state dictionary."""
140+
if self.verify_parameters_updated(False) and not self.first_load:
168141
logger.info("Parameters already updated.")
169-
array = np.zeros([self.nranks],dtype=np.int32)
170-
shm = SharedMemory(create=False, size=array.nbytes, name=f"model_weights_status.{pid}")
171-
value = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf)
172-
value[self.rank] = 2
142+
self._update_shared_status(pid, 2)
173143
return
174144

175145
paddle.device.cuda.empty_cache()
176-
self.check_memory_usage("start update parameters")
146+
self._log_memory_usage("start update parameters")
147+
177148
if self.local_test:
178149
current_device_id = int(os.getenv("FLAGS_selected_gpus"))
179150
model_path = f"/shared_ipc_meta/model_state.tp0{current_device_id}.pdparams"
180-
print("model_apth : ", model_path)
181-
state_dict = paddle.load(model_path)
151+
logger.info(f"Loading model from: {model_path}")
152+
182153
set_start = time.time()
183-
self.model.set_state_dict(state_dict)
154+
self.model.set_state_dict(paddle.load(model_path))
184155
logger.info(f"set_state_dict completed in {time.time() - set_start:.2f} seconds")
156+
185157
self.verify_parameters_updated()
186-
self.check_memory_usage("update parameters end")
158+
self._log_memory_usage("update parameters end")
159+
187160
if not self.first_load:
188161
logger.info("send update signal")
189-
array = np.zeros([self.nranks],dtype=np.int32)
190-
shm = SharedMemory(create=False, size=array.nbytes, name=f"model_weights_status.{pid}")
191-
value = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf)
192-
value[self.rank] = 2
162+
self._update_shared_status(pid, 2)
193163
self.first_load = False
194164
return
195165

196166
start_time = time.time()
197167
logger.info("Starting parameter update process...")
198-
model_path = "/shared_ipc_meta"
168+
199169
current_device_id = int(os.getenv("FLAGS_selected_gpus"))
200-
ipc_state_dict_path = os.path.join(model_path, f"ipc_metas_{current_device_id}")
201-
logger.info(f"ipc_state_dict_path is {ipc_state_dict_path}")
202-
ipc_state_dict = paddle.load(ipc_state_dict_path)
170+
ipc_state_dict_path = f"/shared_ipc_meta/ipc_metas_{current_device_id}"
171+
logger.info(f"Loading IPC state dict from: {ipc_state_dict_path}")
172+
203173
convert_start = time.time()
204-
state_dict = self.load_tensor_from_ipc_meta(ipc_state_dict)
174+
state_dict = self.load_tensor_from_ipc_meta(paddle.load(ipc_state_dict_path))
205175
logger.info(f"IPC meta converted to tensors in {time.time() - convert_start:.2f} seconds")
176+
206177
if not self.shared_buffer_to:
207178
logger.info("Updating parameters via set_state_dict...")
208179
set_start = time.time()
209180
self.model.set_state_dict(state_dict)
210181
logger.info(f"set_state_dict completed in {time.time() - set_start:.2f} seconds")
211-
self.verify_parameters_updated()
212182
else:
183+
logger.info("Updating parameters via shared_buffer_to...")
213184
share_start = time.time()
214-
logger.info("通过shared_buffer_to更新参数")
215185
infer_model_state_dict = self.model.state_dict()
186+
216187
for name, param in state_dict.items():
217188
if name in infer_model_state_dict:
218189
logger.info(f"Updating model parameter: {name}")
219190
update_param = infer_model_state_dict[name]
220-
assert (
221-
update_param.dtype == param.dtype
222-
), f"Type mismatch for {name}: {param.dtype} vs {update_param.dtype}"
223-
assert (
224-
update_param.shape == param.shape
225-
), f"Shape mismatch for {name}: training {param.shape} vs infer {update_param.shape}"
191+
192+
if update_param.dtype != param.dtype:
193+
raise TypeError(f"Type mismatch for {name}: {param.dtype} vs {update_param.dtype}")
194+
if update_param.shape != param.shape:
195+
raise ValueError(f"Shape mismatch for {name}: {param.shape} vs {update_param.shape}")
196+
226197
param._share_buffer_to(update_param)
227-
logger.info(f"Parameter sharing completed in {time.time() - share_start:.2f} seconds")
198+
199+
logger.info(f"Parameter sharing completed in {time.time() - share_start:.2f} seconds")
228200

229201
if not self.first_load:
230202
logger.info("send update signal")
231-
array = np.zeros([self.nranks],dtype=np.int32)
232-
shm = SharedMemory(create=False, size=array.nbytes, name=f"model_weights_status.{pid}")
233-
value = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf)
234-
value[self.rank] = 2
203+
self._update_shared_status(pid, 2)
204+
235205
self.first_load = False
206+
self.verify_parameters_updated()
207+
paddle.device.cuda.empty_cache()
208+
self._log_memory_usage("update parameters end")
209+
210+
def clear_parameters(self, pid: int = 0) -> None:
211+
"""Clear all model parameters."""
212+
if self.verify_parameters_cleared(False):
213+
logger.info("Parameters already cleared!")
214+
self._update_shared_status(pid, -2)
215+
return
216+
217+
start_time = time.time()
218+
paddle.device.cuda.empty_cache()
219+
self._log_memory_usage("start clear parameters")
220+
221+
for name, param in self.model.state_dict().items():
222+
logger.info(f"Clearing model parameter: {name}")
223+
param._clear_data()
224+
225+
clear_time = time.time() - start_time
226+
logger.info(f"Parameter clearing completed in {clear_time:.2f} seconds")
227+
228+
self.verify_parameters_cleared()
229+
logger.info("Model parameters cleared successfully")
236230

231+
self._update_shared_status(pid, -2)
237232
paddle.device.cuda.empty_cache()
238-
self.check_memory_usage("update parameters end")
233+
self._log_memory_usage("clear parameters end")
234+
logger.info("send clear signal!")
239235

240236

241-
def verify_parameters_cleared(self) -> bool:
237+
def verify_parameters_cleared(self, erro_log:bool = True) -> bool:
242238
"""
243239
Verify that all model parameters have been cleared.
244240
@@ -249,17 +245,19 @@ def verify_parameters_cleared(self) -> bool:
249245
all_cleared = True
250246
for name, param in self.model.state_dict().items():
251247
if param._is_initialized():
252-
logger.error(f"Parameter {name} was not properly cleared!")
248+
if erro_log:
249+
logger.error(f"Parameter {name} was not properly cleared!")
253250
all_cleared = False
254251

255252
if all_cleared:
256253
logger.info("All parameters verified as cleared successfully")
257254
else:
258-
logger.error("Some parameters were not properly cleared!")
255+
if erro_log:
256+
logger.error("Some parameters were not properly cleared!")
259257

260258
return all_cleared
261259

262-
def verify_parameters_updated(self) -> bool:
260+
def verify_parameters_updated(self, erro_log:bool = True) -> bool:
263261
"""
264262
Verify that model parameters match the source state dictionary.
265263
@@ -273,12 +271,14 @@ def verify_parameters_updated(self) -> bool:
273271
all_update = True
274272
for name, param in self.model.state_dict().items():
275273
if not param._is_initialized():
276-
logger.error(f"Parameter {name} was not properly cleared!")
274+
if erro_log:
275+
logger.error(f"Parameter {name} was not properly cleared!")
277276
all_update = False
278277

279278
if all_update:
280279
logger.info("All parameters verified as updated successfully")
281280
else:
282-
logger.error("Some parameters were not properly updated!")
281+
if erro_log
282+
logger.error("Some parameters were not properly updated!")
283283

284284
return all_update

0 commit comments

Comments
 (0)