From ca33081570855a55bd01ccf0b7580308d1fc2891 Mon Sep 17 00:00:00 2001 From: Kenny Yu Date: Tue, 11 Nov 2025 23:56:14 +0000 Subject: [PATCH] [tinker-cookbook] supervised: add tracing annotations This adds tracing annotations to make it easier to identify where time is spent in SFT workloads. --- tinker_cookbook/checkpoint_utils.py | 9 ++++++ tinker_cookbook/supervised/train.py | 44 +++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/tinker_cookbook/checkpoint_utils.py b/tinker_cookbook/checkpoint_utils.py index 589d436f..091e8f7b 100644 --- a/tinker_cookbook/checkpoint_utils.py +++ b/tinker_cookbook/checkpoint_utils.py @@ -7,12 +7,14 @@ import tinker from tinker_cookbook.utils.file_utils import read_jsonl +from tinker_cookbook.utils.trace import get_scope_context, scope CHECKPOINTS_BASE_NAME = "checkpoints.jsonl" logger = logging.getLogger(__name__) +@scope def load_checkpoints_file(log_dir: str) -> list[dict[str, Any]]: checkpoint_path = os.path.join(log_dir, CHECKPOINTS_BASE_NAME) if not os.path.exists(checkpoint_path): @@ -20,9 +22,12 @@ def load_checkpoints_file(log_dir: str) -> list[dict[str, Any]]: return [] logger.info(f"Reading checkpoints from {checkpoint_path}") + context = get_scope_context() + context.attributes["checkpoint_path"] = checkpoint_path return read_jsonl(checkpoint_path) +@scope def get_last_checkpoint(log_dir: str, required_key: str = "state_path") -> dict[str, Any] | None: """ Get the last checkpoint from the checkpoints.jsonl file in the specified log directory. @@ -49,6 +54,7 @@ def get_last_checkpoint(log_dir: str, required_key: str = "state_path") -> dict[ return None +@scope async def save_checkpoint_async( training_client: tinker.TrainingClient, name: str, @@ -72,6 +78,8 @@ async def save_checkpoint_async( results = {k: await v.result_async() for k, v in futures.items()} paths = {k + "_path": v.path for k, v in results.items()} + context = get_scope_context() + context.attributes.update(paths) logger.info(f"Saved checkpoints: {paths}") full_dict = {"name": name, **loop_state, **paths} with open(os.path.join(log_path, "checkpoints.jsonl"), "a") as f: @@ -80,6 +88,7 @@ async def save_checkpoint_async( return paths +@scope def save_checkpoint( training_client: tinker.TrainingClient, name: str, diff --git a/tinker_cookbook/supervised/train.py b/tinker_cookbook/supervised/train.py index 761fa86b..f6fb273c 100644 --- a/tinker_cookbook/supervised/train.py +++ b/tinker_cookbook/supervised/train.py @@ -16,6 +16,7 @@ import chz import tinker from tinker.lib.public_interfaces import APIFuture + from tinker_cookbook import checkpoint_utils from tinker_cookbook.display import colorize_example from tinker_cookbook.eval.evaluators import ( @@ -31,6 +32,7 @@ from tinker_cookbook.utils import ml_log from tinker_cookbook.utils.lr_scheduling import compute_schedule_lr_multiplier from tinker_cookbook.utils.misc_utils import timed +from tinker_cookbook.utils.trace import get_scope_context, scope, trace_init logger = logging.getLogger(__name__) @@ -72,6 +74,8 @@ class Config: wandb_project: str | None = None wandb_name: str | None = None + enable_trace: bool = False + @dataclass class SubmittedBatch: @@ -87,6 +91,7 @@ class SubmittedBatch: infrequent_eval_metrics: dict[str, float] | None = None +@scope async def run_evals( evaluators: list[Evaluator], training_client: tinker.TrainingClient, @@ -102,29 +107,42 @@ async def run_evals( checkpoint. Returned metrics are prefixed with ``test/`` so they can be logged next to the same-step training metrics. """ + context = get_scope_context() + context.attributes["step"] = step + metrics = {} sampling_client = None - for evaluator in evaluators: + @scope + async def run_evaluator(evaluator: Evaluator) -> dict[str, float]: + context = get_scope_context() + context.attributes["step"] = step + context.attributes["evaluator_name"] = type(evaluator).__name__ if isinstance(evaluator, TrainingClientEvaluator): - eval_metrics = await evaluator(training_client) + context.attributes["evaluator_type"] = "TrainingClientEvaluator" + return await evaluator(training_client) elif isinstance(evaluator, SamplingClientEvaluator): + context.attributes["evaluator_type"] = "SamplingClientEvaluator" # Create sampling client lazily, only when needed + nonlocal sampling_client if sampling_client is None: # Snapshot the current pre-step weights and create a new sampling client. sampling_client = await training_client.save_weights_and_get_sampling_client_async( f"evals_step_{step}" ) - eval_metrics = await evaluator(sampling_client) + return await evaluator(sampling_client) else: raise ValueError(f"Unknown evaluator type: {type(evaluator)}") + for evaluator in evaluators: + eval_metrics = await run_evaluator(evaluator) # Add test/ prefix to all metrics metrics.update({f"test/{k}": v for k, v in eval_metrics.items()}) return metrics +@scope async def main(config: Config): """Run the standard supervised learning loop used by the supervised recipes. @@ -156,6 +174,18 @@ async def main(config: Config): config=config, do_configure_logging_module=True, ) + if config.enable_trace: + # Get and rename the current (main) task + current_task = asyncio.current_task() + if current_task is not None: + current_task.set_name("main") + trace_events_path = os.path.join(config.log_path, "trace_events.jsonl") + logger.info(f"Tracing is enabled. Trace events will be saved to {trace_events_path}") + logger.info( + f"Run `python tinker_cookbook/utils/trace.py {trace_events_path} trace.json` and visualize in chrome://tracing or https://ui.perfetto.dev/" + ) + trace_init(output_file=os.path.join(config.log_path, "trace_events.jsonl")) + service_client = tinker.ServiceClient(base_url=config.base_url) load_state_path: str | None = ( resume_info["state_path"] if resume_info else config.load_checkpoint_path @@ -192,8 +222,12 @@ async def main(config: Config): f"Training for {n_batches} batches x {config.num_epochs} epochs = {n_batches * config.num_epochs} steps" ) + @scope async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch: step = epoch_idx * n_batches + batch_idx + context = get_scope_context() + context.attributes["step"] = step + batch_start_time = time.time() metrics: dict[str, int | float | str] = {"epoch": epoch_idx} metrics["progress"] = step / progress_denominator @@ -250,7 +284,11 @@ async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch: infrequent_eval_metrics=infrequent_eval_metrics, ) + @scope async def finish_batch(submitted: SubmittedBatch): + context = get_scope_context() + context.attributes["step"] = submitted.step + metrics = submitted.metrics metrics["progress"] = min((submitted.step + 1) / progress_denominator, 1.0)