diff --git a/tinker_cookbook/rl/metric_util.py b/tinker_cookbook/rl/metric_util.py index e301c585..b71d0714 100644 --- a/tinker_cookbook/rl/metric_util.py +++ b/tinker_cookbook/rl/metric_util.py @@ -11,6 +11,7 @@ from tinker_cookbook.rl.types import EnvGroupBuilder, RLDataset, TrajectoryGroup from tinker_cookbook.utils.misc_utils import all_same, dict_mean from tinker_cookbook.utils import logtree +from tinker_cookbook.completers import TokenCompleter def _compute_by_group_metrics(trajectory_groups_P: List[TrajectoryGroup], good_thresh: float = 0.5): @@ -107,7 +108,7 @@ def __init__( self, dataset: RLDataset, max_tokens: int, - name: str | None = None, + name: str = "test", num_groups_to_log: int = 4, ): self.env_group_builders_P = dataset_to_env_group_builders(dataset) @@ -115,9 +116,7 @@ def __init__( self.name = name self.num_groups_to_log = num_groups_to_log - async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]: - policy = TinkerTokenCompleter(sampling_client, max_tokens=self.max_tokens) - + async def eval_token_completer(self, policy: TokenCompleter) -> dict[str, float]: async def run_group_rollout(builder, i): enable_logging = i < self.num_groups_to_log with logtree.optional_enable_logging(enable=enable_logging): @@ -129,6 +128,9 @@ async def run_group_rollout(builder, i): taglist_P = [builder.logging_tags() for builder in self.env_group_builders_P] metrics = compute_trajectory_metrics(trajectory_groups_P, taglist_P) - if self.name is not None: - metrics = {f"{self.name}/{k}": v for k, v in metrics.items()} + metrics = {f"{self.name}/{k}": v for k, v in metrics.items()} return metrics + + async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]: + policy = TinkerTokenCompleter(sampling_client, max_tokens=self.max_tokens) + return await self.eval_token_completer(policy) diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index 33c9f8ed..4e01fddc 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -273,7 +273,7 @@ async def run_single_evaluation(evaluator, cfg, i_batch, sampling_client): scope_name=f"Running evaluation {ev_name} {i_batch}", ): eval_metrics = await evaluator(sampling_client) - return {f"test/{k}": v for k, v in eval_metrics.items()} + return eval_metrics @scope diff --git a/tinker_cookbook/supervised/nll_evaluator.py b/tinker_cookbook/supervised/nll_evaluator.py index e2dd7d08..71d705f6 100644 --- a/tinker_cookbook/supervised/nll_evaluator.py +++ b/tinker_cookbook/supervised/nll_evaluator.py @@ -7,7 +7,8 @@ class NLLEvaluator(TrainingClientEvaluator): - def __init__(self, data: list[tinker.Datum]): + def __init__(self, data: list[tinker.Datum], name: str = "test"): + self.name = name self.data = data async def __call__(self, training_client: tinker.TrainingClient) -> dict[str, float]: @@ -16,9 +17,10 @@ async def __call__(self, training_client: tinker.TrainingClient) -> dict[str, fl logprobs = [x["logprobs"] for x in result.loss_fn_outputs] weights = [datum.loss_fn_inputs["weights"] for datum in self.data] nll = compute_mean_nll(logprobs, weights) - return {"nll": nll} + key = f"{self.name}/nll" + return {key: nll} @classmethod - def from_dataset(cls, dataset: SupervisedDataset) -> "NLLEvaluator": + def from_dataset(cls, dataset: SupervisedDataset, name: str = "test") -> "NLLEvaluator": all_data = list(itertools.chain(*[dataset.get_batch(i) for i in range(len(dataset))])) - return cls(all_data) + return cls(all_data, name=name) diff --git a/tinker_cookbook/supervised/train.py b/tinker_cookbook/supervised/train.py index 761fa86b..6bcaee76 100644 --- a/tinker_cookbook/supervised/train.py +++ b/tinker_cookbook/supervised/train.py @@ -120,7 +120,7 @@ async def run_evals( raise ValueError(f"Unknown evaluator type: {type(evaluator)}") # Add test/ prefix to all metrics - metrics.update({f"test/{k}": v for k, v in eval_metrics.items()}) + metrics.update(eval_metrics) return metrics