From f4a754951efce713f82cfd4a5469db64e42fcfe3 Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Fri, 31 Oct 2025 04:03:27 +0000 Subject: [PATCH 1/5] parallelize eval --- tinker_cookbook/rl/train.py | 72 ++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index 1f67fa7..acd5ffd 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -254,6 +254,46 @@ class Config: num_groups_to_log: int = 4 # Number of groups to log per iteration (0 = disable logging) +@scope +async def run_evaluations_parallel( + evaluators: list[SamplingClientEvaluator], + sampling_client: tinker.SamplingClient, + cfg: Config, + i_batch: int, +) -> dict[str, Any]: + """Run all evaluators in parallel and return aggregated metrics.""" + async def run_single_evaluation(evaluator, cfg, i_batch, sampling_client): + ev_name = ( + evaluator.name + if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None + else "" + ) + with _get_logtree_scope( + log_path=cfg.log_path, + num_groups_to_log=cfg.num_groups_to_log, + f_name=f"eval_{ev_name}_iteration_{i_batch:06d}", + scope_name=f"Running evaluation {ev_name} {i_batch}", + ): + eval_metrics = await evaluator(sampling_client) + return {f"test/{k}": v for k, v in eval_metrics.items()} + + # Create tasks for all evaluators + tasks = [ + run_single_evaluation(evaluator, cfg, i_batch, sampling_client) + for evaluator in evaluators + ] + + # Wait for all to complete + results = await asyncio.gather(*tasks) + + # Merge all metrics + metrics = {} + for result in results: + metrics.update(result) + + return metrics + + @scope async def do_sync_training_with_stream_minibatch( start_batch: int, @@ -289,20 +329,8 @@ async def do_sync_training_with_stream_minibatch( # Run evaluations if (cfg.eval_every > 0 and i_batch % cfg.eval_every == 0) or i_batch == end_batch - 1: with timed("run_evals", metrics): - for evaluator in evaluators: - ev_name = ( - evaluator.name - if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None - else "" - ) - with _get_logtree_scope( - log_path=cfg.log_path, - num_groups_to_log=cfg.num_groups_to_log, - f_name=f"eval_{ev_name}_iteration_{i_batch:06d}", - scope_name=f"Running evaluation {ev_name} {i_batch}", - ): - eval_metrics = await evaluator(sampling_client) - metrics.update({f"test/{k}": v for k, v in eval_metrics.items()}) + eval_metrics = await run_evaluations_parallel(evaluators, sampling_client, cfg, i_batch) + metrics.update(eval_metrics) with _get_logtree_scope( cfg.log_path, @@ -924,20 +952,8 @@ async def do_sync_training( # Run evaluations if cfg.eval_every > 0 and i_batch % cfg.eval_every == 0: with timed("run_evals", metrics): - for evaluator in evaluators: - ev_name = ( - evaluator.name - if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None - else "" - ) - with _get_logtree_scope( - log_path=cfg.log_path, - num_groups_to_log=cfg.num_groups_to_log, - f_name=f"eval_{ev_name}_iteration_{i_batch:06d}", - scope_name=f"Running evaluation {ev_name} {i_batch}", - ): - eval_metrics = await evaluator(sampling_client) - metrics.update({f"test/{k}": v for k, v in eval_metrics.items()}) + eval_metrics = await run_evaluations_parallel(evaluators, sampling_client, cfg, i_batch) + metrics.update(eval_metrics) # Get batch and sample trajectories env_group_builders_P = dataset.get_batch(i_batch) From 667f41023bc0d52c233eec6605461f69aa2fe7de Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Fri, 31 Oct 2025 14:25:37 +0000 Subject: [PATCH 2/5] b --- tinker_cookbook/rl/train.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index acd5ffd..f40090f 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -262,6 +262,7 @@ async def run_evaluations_parallel( i_batch: int, ) -> dict[str, Any]: """Run all evaluators in parallel and return aggregated metrics.""" + async def run_single_evaluation(evaluator, cfg, i_batch, sampling_client): ev_name = ( evaluator.name @@ -279,8 +280,7 @@ async def run_single_evaluation(evaluator, cfg, i_batch, sampling_client): # Create tasks for all evaluators tasks = [ - run_single_evaluation(evaluator, cfg, i_batch, sampling_client) - for evaluator in evaluators + run_single_evaluation(evaluator, cfg, i_batch, sampling_client) for evaluator in evaluators ] # Wait for all to complete @@ -329,7 +329,9 @@ async def do_sync_training_with_stream_minibatch( # Run evaluations if (cfg.eval_every > 0 and i_batch % cfg.eval_every == 0) or i_batch == end_batch - 1: with timed("run_evals", metrics): - eval_metrics = await run_evaluations_parallel(evaluators, sampling_client, cfg, i_batch) + eval_metrics = await run_evaluations_parallel( + evaluators, sampling_client, cfg, i_batch + ) metrics.update(eval_metrics) with _get_logtree_scope( @@ -952,7 +954,9 @@ async def do_sync_training( # Run evaluations if cfg.eval_every > 0 and i_batch % cfg.eval_every == 0: with timed("run_evals", metrics): - eval_metrics = await run_evaluations_parallel(evaluators, sampling_client, cfg, i_batch) + eval_metrics = await run_evaluations_parallel( + evaluators, sampling_client, cfg, i_batch + ) metrics.update(eval_metrics) # Get batch and sample trajectories From 56eb20126555aa941414667e584b00f801d6f634 Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Fri, 31 Oct 2025 18:00:12 +0000 Subject: [PATCH 3/5] b --- tinker_cookbook/rl/train.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index f40090f..3d5880f 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -46,6 +46,15 @@ logger = logging.getLogger(__name__) + +def _get_evaluator_name(evaluator: SamplingClientEvaluator) -> str: + return ( + evaluator.name + if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None + else "" + ) + + @contextmanager def _get_logtree_scope( log_path: str | None, num_groups_to_log: int, f_name: str, scope_name: str @@ -264,11 +273,7 @@ async def run_evaluations_parallel( """Run all evaluators in parallel and return aggregated metrics.""" async def run_single_evaluation(evaluator, cfg, i_batch, sampling_client): - ev_name = ( - evaluator.name - if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None - else "" - ) + ev_name = _get_evaluator_name(evaluator) with _get_logtree_scope( log_path=cfg.log_path, num_groups_to_log=cfg.num_groups_to_log, @@ -278,10 +283,15 @@ async def run_single_evaluation(evaluator, cfg, i_batch, sampling_client): eval_metrics = await evaluator(sampling_client) return {f"test/{k}": v for k, v in eval_metrics.items()} - # Create tasks for all evaluators - tasks = [ - run_single_evaluation(evaluator, cfg, i_batch, sampling_client) for evaluator in evaluators - ] + # Create tasks for all evaluators with names for better traceability + tasks = [] + for i, evaluator in enumerate(evaluators): + ev_name = _get_evaluator_name(evaluator) + task = asyncio.create_task( + run_single_evaluation(evaluator, cfg, i_batch, sampling_client), + name=f"eval_task_{ev_name or i}_{i_batch}", + ) + tasks.append(task) # Wait for all to complete results = await asyncio.gather(*tasks) From a865cf7473cd98fc73ad4b9e5e6bb56847d19abf Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Fri, 31 Oct 2025 18:24:53 +0000 Subject: [PATCH 4/5] b b --- tinker_cookbook/rl/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index 3d5880f..6b5a066 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -46,7 +46,6 @@ logger = logging.getLogger(__name__) - def _get_evaluator_name(evaluator: SamplingClientEvaluator) -> str: return ( evaluator.name From beedf782183f8ad94cc22db285a6252a7cd3d954 Mon Sep 17 00:00:00 2001 From: Ruiqi Zhong Date: Fri, 31 Oct 2025 18:28:22 +0000 Subject: [PATCH 5/5] b --- tinker_cookbook/rl/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index 6b5a066..addc636 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -288,7 +288,7 @@ async def run_single_evaluation(evaluator, cfg, i_batch, sampling_client): ev_name = _get_evaluator_name(evaluator) task = asyncio.create_task( run_single_evaluation(evaluator, cfg, i_batch, sampling_client), - name=f"eval_task_{ev_name or i}_{i_batch}", + name=f"eval_{ev_name or i}_iteration_{i_batch:06d}", ) tasks.append(task)