From 7872e17c6565fb2bae33527ad3b73f26351dcf81 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Mon, 13 Jan 2025 15:43:51 -0800 Subject: [PATCH] Train without profiling --- examples/text_to_image/train_text_to_image_xla.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_xla.py b/examples/text_to_image/train_text_to_image_xla.py index 7ef74cf46..37883fe93 100644 --- a/examples/text_to_image/train_text_to_image_xla.py +++ b/examples/text_to_image/train_text_to_image_xla.py @@ -77,10 +77,11 @@ def start_training(self): dataloader_exception = True print(e) break - if step == measure_start_step and PROFILE_DIR is not None: - xm.wait_device_ops() - xp.trace_detached('localhost:9012', PROFILE_DIR, duration_ms=args.profile_duration) - last_time = time.time() + if step == measure_start_step: + last_time = time.time() + if PROFILE_DIR is not None: + xm.wait_device_ops() + xp.trace_detached('localhost:9012', PROFILE_DIR, duration_ms=args.profile_duration) loss = self.step_fn(batch["pixel_values"], batch["input_ids"]) self.global_step += 1 xm.mark_step()