Skip to content

Commit

Permalink
Add agent evaluation sketch
Browse files Browse the repository at this point in the history
  • Loading branch information
pantonante committed Feb 20, 2024
1 parent 7eebb2c commit 2661074
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 37 deletions.
2 changes: 1 addition & 1 deletion continuous_eval/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from continuous_eval.eval.modules import Tool, AgentModule, Module, Metric, Test
from continuous_eval.eval.pipeline import Pipeline, ModuleOutput
from continuous_eval.eval.pipeline import Pipeline, ModuleOutput, CalledTools
from continuous_eval.eval.dataset import Dataset
3 changes: 2 additions & 1 deletion continuous_eval/eval/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import yaml

UUID = str
from continuous_eval.eval.types import UUID, ToolCall

_SAFE_DICT = {k: v for k, v in typing.__dict__.items() if not k.startswith("__")}
_SAFE_DICT["UUID"] = UUID
_SAFE_DICT["ToolCall"] = ToolCall


@dataclass(frozen=True)
Expand Down
63 changes: 40 additions & 23 deletions continuous_eval/eval/manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import json
import warnings
from pathlib import Path
from typing import Dict, List, Optional, get_origin
from enum import Enum
from typing import Any, List, Optional, get_origin

from loguru import logger

from continuous_eval.eval.dataset import Dataset, DatasetField
from continuous_eval.eval.pipeline import ModuleOutput, Pipeline
from continuous_eval.eval.result_types import EvaluationResults, MetricsResults, TestResults
from continuous_eval.eval.pipeline import CalledTools, ModuleOutput, Pipeline
from continuous_eval.eval.result_types import TOOL_PREFIX, EvaluationResults, MetricsResults, TestResults


class LogMode(Enum):
APPEND = 0
REPLACE = 1


class EvaluationManager:
Expand Down Expand Up @@ -87,22 +91,40 @@ def next_sample(self):

# Logging results

def log(self, key, value):
def log(
self,
module: str,
value: Any,
mode: LogMode = LogMode.REPLACE,
**kwargs,
):
# Make sure everything looks good
if self._pipeline is None:
raise ValueError("Pipeline not set")
assert type(value) == get_origin(self._pipeline.module_by_name(key).output) or isinstance(
value, self._pipeline.module_by_name(key).output
), f"Value {value} does not match expected type in the pipeline"
# assert type(value) == get_origin(
# self._pipeline.module_by_name(module).output
# ) or isinstance(
# value, self._pipeline.module_by_name(module).output
# ), f"Value {value} does not match expected type in the pipeline"
if not self._is_running:
raise ValueError("Cannot log when not running")
if key not in self._eval_results.results[self._idx]:
raise ValueError(f"Key {key} not found, review your pipeline")
if module not in self._eval_results.results[self._idx]:
raise ValueError(f"module {module} not found, review your pipeline")

if isinstance(self._eval_results.results[self._idx][key], dict):
self._eval_results.results[self._idx][key].update(value)
if kwargs and "tool_args" in kwargs:
key = f"{TOOL_PREFIX}{module}"
self._eval_results.results[self._idx][key].append({"name": value, "kwargs": kwargs["tool_args"]})
else:
self._eval_results.results[self._idx][key] = value
if mode == LogMode.REPLACE:
self._eval_results.results[self._idx][module] = value
elif mode == LogMode.APPEND:
if not isinstance(self._eval_results.results[self._idx][module], list):
if isinstance(value, list):
self._eval_results.results[self._idx][module].extend(value)
else:
self._eval_results.results[self._idx][module].append(value)
else:
self._eval_results.results[self._idx][module].add(value)

# Evaluate

Expand All @@ -115,6 +137,10 @@ def _prepare(self, module, metric):
elif isinstance(val, ModuleOutput):
module_name = module.name if val.module is None else val.module.name
kwargs[key] = [val(x[module_name]) for x in self._eval_results.results]
elif isinstance(val, CalledTools):
module_name = module.name if val.module is None else val.module.name
val_key = f"{TOOL_PREFIX}{module_name}"
kwargs[key] = [val(x[val_key]) for x in self._eval_results.results]
else:
raise ValueError(f"Invalid promised parameter {key}={val}")
return kwargs
Expand All @@ -138,15 +164,6 @@ def aggregate_eval_results(self):
assert all(
[len(module_res) == len(self.dataset.data) for module_res in self._metrics_results.results.values()]
), "Evaluation is not complete"
# aggregated_results = dict()
# for module in self._pipeline.modules:
# if module.eval is not None:
# for metric in module.eval:
# aggregated_results
# if metric.aggregate is not None:
# self._eval_results[module.name][metric.name] = metric.aggregate(
# self._eval_results[module.name][metric.name]
# )
return self._metrics_results.results

# Tests
Expand Down
2 changes: 0 additions & 2 deletions continuous_eval/eval/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,3 @@ def __post_init__(self):
@dataclass(frozen=True, eq=True)
class AgentModule(Module):
tools: Optional[List[Tool]] = field(default=None)
reference_tool_calls: Optional[DatasetField] = field(default=None)
is_recursive: bool = field(default=False)
25 changes: 15 additions & 10 deletions continuous_eval/eval/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
from dataclasses import dataclass, field
from typing import Any, List, Optional, Set, Tuple, get_origin
from typing import Any, Callable, List, Optional, Set, Tuple

from continuous_eval.eval.dataset import Dataset, DatasetField
from continuous_eval.eval.modules import Module
from continuous_eval.eval.utils import type_hint_to_str


@dataclass
class ModuleOutput:
selector: callable = field(default=lambda x: x)
selector: Callable = field(default=lambda x: x)
module: Optional[Module] = None

def __call__(self, *args: Any) -> Any:
return self.selector(*args)


@dataclass
class CalledTools:
selector: Callable = field(default=lambda x: x)
module: Optional[Module] = None

def __call__(self, *args: Any) -> Any:
Expand Down Expand Up @@ -43,6 +53,8 @@ def module_by_name(self, name: str) -> Module:

def get_metric(self, module_name: str, metric_name: str):
module = self.module_by_name(module_name)
if module.eval is None:
raise ValueError(f"Module {module_name} has no metrics")
try:
metric = [m for m in module.eval if m.name == metric_name][0]
except IndexError:
Expand All @@ -54,17 +66,10 @@ def _validate_modules(self):
for module in self._modules:
if module.name in names:
raise ValueError(f"Module {module.name} already exists")
if module.reference is not None:
assert (
module.reference is not None and module.reference in self._dataset.fields
), f"Field {module.reference.name} not found"
names.add(module.name)
if self._dataset is not None and module.expected_output is not None:
if get_origin(module.output) != get_origin(self._dataset.getattr(self, module.expected_output).type):
raise ValueError(f"Field {module.output} does not match expected type in the dataset.")

def _build_graph(self):
nodes = {m.name: m for m in self._modules}
nodes = {m.name for m in self._modules}
edges = set()
dataset_edges = set()
for module in self._modules:
Expand Down
5 changes: 5 additions & 0 deletions continuous_eval/eval/result_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
from pathlib import Path
from typing import Dict, List, Optional

from continuous_eval.eval.modules import AgentModule
from continuous_eval.eval.pipeline import Pipeline
from continuous_eval.eval.utils import instantiate_type

TOOL_PREFIX = "_tool__"


class EvaluationResults:
def __init__(self, pipeline: Optional[Pipeline] = None) -> None:
Expand All @@ -28,6 +31,8 @@ def _build_empty_samples(self, pipeline: Pipeline):
empty_samples = dict()
for module in pipeline.modules:
empty_samples[module.name] = instantiate_type(module.output)
if isinstance(module, AgentModule):
empty_samples[f"{TOOL_PREFIX}{module.name}"] = list()
return empty_samples

def save(self, filepath: Path):
Expand Down
8 changes: 8 additions & 0 deletions continuous_eval/eval/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Any, Dict, TypedDict

UUID = str


class ToolCall(TypedDict):
name: str
kwargs: Dict[str, Any]
35 changes: 35 additions & 0 deletions continuous_eval/metrics/tools/match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import List

from continuous_eval.eval.types import ToolCall
from continuous_eval.metrics.base import Metric


class ToolDeterministicMatch(Metric):
def __init__(self, order_sensitive: bool = False) -> None:
super().__init__()
self._order_sensitive = order_sensitive

def __call__(self, tools: List[ToolCall], ground_truths: List[ToolCall], **kwargs):
if self._order_sensitive:
# When order matters, compare tool executions directly in sequence.
num_correct = sum(
1
for i, tool in enumerate(tools)
if i < len(ground_truths)
and tool["name"] == ground_truths[i]["name"]
and tool["kwargs"] == ground_truths[i]["kwargs"]
)
else:
# Convert ground_truth to a format that's easy to check for "contains"
ground_truth_set = {
frozenset(tool.items()) for tool in [{"name": tool["name"], **tool["kwargs"]} for tool in ground_truths]
}
# Score
num_correct, matched_executions = 0, set()
for tool in tools:
tool_set = frozenset({"name": tool["name"], **tool["kwargs"]}.items())
if tool_set in ground_truth_set and tool_set not in matched_executions:
num_correct += 1
matched_executions.add(tool_set)

return {"num_correct": num_correct, "score": num_correct / len(ground_truths)}

0 comments on commit 2661074

Please sign in to comment.