diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index 1f67fa7..addc636 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -46,6 +46,14 @@ logger = logging.getLogger(__name__) +def _get_evaluator_name(evaluator: SamplingClientEvaluator) -> str: + return ( + evaluator.name + if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None + else "" + ) + + @contextmanager def _get_logtree_scope( log_path: str | None, num_groups_to_log: int, f_name: str, scope_name: str @@ -254,6 +262,47 @@ class Config: num_groups_to_log: int = 4 # Number of groups to log per iteration (0 = disable logging) +@scope +async def run_evaluations_parallel( + evaluators: list[SamplingClientEvaluator], + sampling_client: tinker.SamplingClient, + cfg: Config, + i_batch: int, +) -> dict[str, Any]: + """Run all evaluators in parallel and return aggregated metrics.""" + + async def run_single_evaluation(evaluator, cfg, i_batch, sampling_client): + ev_name = _get_evaluator_name(evaluator) + with _get_logtree_scope( + log_path=cfg.log_path, + num_groups_to_log=cfg.num_groups_to_log, + f_name=f"eval_{ev_name}_iteration_{i_batch:06d}", + scope_name=f"Running evaluation {ev_name} {i_batch}", + ): + eval_metrics = await evaluator(sampling_client) + return {f"test/{k}": v for k, v in eval_metrics.items()} + + # Create tasks for all evaluators with names for better traceability + tasks = [] + for i, evaluator in enumerate(evaluators): + ev_name = _get_evaluator_name(evaluator) + task = asyncio.create_task( + run_single_evaluation(evaluator, cfg, i_batch, sampling_client), + name=f"eval_{ev_name or i}_iteration_{i_batch:06d}", + ) + tasks.append(task) + + # Wait for all to complete + results = await asyncio.gather(*tasks) + + # Merge all metrics + metrics = {} + for result in results: + metrics.update(result) + + return metrics + + @scope async def do_sync_training_with_stream_minibatch( start_batch: int, @@ -289,20 +338,10 @@ async def do_sync_training_with_stream_minibatch( # Run evaluations if (cfg.eval_every > 0 and i_batch % cfg.eval_every == 0) or i_batch == end_batch - 1: with timed("run_evals", metrics): - for evaluator in evaluators: - ev_name = ( - evaluator.name - if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None - else "" - ) - with _get_logtree_scope( - log_path=cfg.log_path, - num_groups_to_log=cfg.num_groups_to_log, - f_name=f"eval_{ev_name}_iteration_{i_batch:06d}", - scope_name=f"Running evaluation {ev_name} {i_batch}", - ): - eval_metrics = await evaluator(sampling_client) - metrics.update({f"test/{k}": v for k, v in eval_metrics.items()}) + eval_metrics = await run_evaluations_parallel( + evaluators, sampling_client, cfg, i_batch + ) + metrics.update(eval_metrics) with _get_logtree_scope( cfg.log_path, @@ -924,20 +963,10 @@ async def do_sync_training( # Run evaluations if cfg.eval_every > 0 and i_batch % cfg.eval_every == 0: with timed("run_evals", metrics): - for evaluator in evaluators: - ev_name = ( - evaluator.name - if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None - else "" - ) - with _get_logtree_scope( - log_path=cfg.log_path, - num_groups_to_log=cfg.num_groups_to_log, - f_name=f"eval_{ev_name}_iteration_{i_batch:06d}", - scope_name=f"Running evaluation {ev_name} {i_batch}", - ): - eval_metrics = await evaluator(sampling_client) - metrics.update({f"test/{k}": v for k, v in eval_metrics.items()}) + eval_metrics = await run_evaluations_parallel( + evaluators, sampling_client, cfg, i_batch + ) + metrics.update(eval_metrics) # Get batch and sample trajectories env_group_builders_P = dataset.get_batch(i_batch)