Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 57 additions & 28 deletions tinker_cookbook/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@
logger = logging.getLogger(__name__)


def _get_evaluator_name(evaluator: SamplingClientEvaluator) -> str:
return (
evaluator.name
if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None
else ""
)


@contextmanager
def _get_logtree_scope(
log_path: str | None, num_groups_to_log: int, f_name: str, scope_name: str
Expand Down Expand Up @@ -254,6 +262,47 @@ class Config:
num_groups_to_log: int = 4 # Number of groups to log per iteration (0 = disable logging)


@scope
async def run_evaluations_parallel(
evaluators: list[SamplingClientEvaluator],
sampling_client: tinker.SamplingClient,
cfg: Config,
i_batch: int,
) -> dict[str, Any]:
"""Run all evaluators in parallel and return aggregated metrics."""

async def run_single_evaluation(evaluator, cfg, i_batch, sampling_client):
ev_name = _get_evaluator_name(evaluator)
with _get_logtree_scope(
log_path=cfg.log_path,
num_groups_to_log=cfg.num_groups_to_log,
f_name=f"eval_{ev_name}_iteration_{i_batch:06d}",
scope_name=f"Running evaluation {ev_name} {i_batch}",
):
eval_metrics = await evaluator(sampling_client)
return {f"test/{k}": v for k, v in eval_metrics.items()}

# Create tasks for all evaluators with names for better traceability
tasks = []
for i, evaluator in enumerate(evaluators):
ev_name = _get_evaluator_name(evaluator)
task = asyncio.create_task(
run_single_evaluation(evaluator, cfg, i_batch, sampling_client),
name=f"eval_{ev_name or i}_iteration_{i_batch:06d}",
)
tasks.append(task)

# Wait for all to complete
results = await asyncio.gather(*tasks)

# Merge all metrics
metrics = {}
for result in results:
metrics.update(result)

return metrics


@scope
async def do_sync_training_with_stream_minibatch(
start_batch: int,
Expand Down Expand Up @@ -289,20 +338,10 @@ async def do_sync_training_with_stream_minibatch(
# Run evaluations
if (cfg.eval_every > 0 and i_batch % cfg.eval_every == 0) or i_batch == end_batch - 1:
with timed("run_evals", metrics):
for evaluator in evaluators:
ev_name = (
evaluator.name
if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None
else ""
)
with _get_logtree_scope(
log_path=cfg.log_path,
num_groups_to_log=cfg.num_groups_to_log,
f_name=f"eval_{ev_name}_iteration_{i_batch:06d}",
scope_name=f"Running evaluation {ev_name} {i_batch}",
):
eval_metrics = await evaluator(sampling_client)
metrics.update({f"test/{k}": v for k, v in eval_metrics.items()})
eval_metrics = await run_evaluations_parallel(
evaluators, sampling_client, cfg, i_batch
)
metrics.update(eval_metrics)

with _get_logtree_scope(
cfg.log_path,
Expand Down Expand Up @@ -924,20 +963,10 @@ async def do_sync_training(
# Run evaluations
if cfg.eval_every > 0 and i_batch % cfg.eval_every == 0:
with timed("run_evals", metrics):
for evaluator in evaluators:
ev_name = (
evaluator.name
if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None
else ""
)
with _get_logtree_scope(
log_path=cfg.log_path,
num_groups_to_log=cfg.num_groups_to_log,
f_name=f"eval_{ev_name}_iteration_{i_batch:06d}",
scope_name=f"Running evaluation {ev_name} {i_batch}",
):
eval_metrics = await evaluator(sampling_client)
metrics.update({f"test/{k}": v for k, v in eval_metrics.items()})
eval_metrics = await run_evaluations_parallel(
evaluators, sampling_client, cfg, i_batch
)
metrics.update(eval_metrics)

# Get batch and sample trajectories
env_group_builders_P = dataset.get_batch(i_batch)
Expand Down