From 679c8f52aedb9b4cdf5d456a50d1d932d1cb9aac Mon Sep 17 00:00:00 2001 From: Simon Edwardsson Date: Wed, 19 Jul 2023 21:52:44 +0100 Subject: [PATCH] Support for mocking functions in test YMLs (#4) --- benchllm/cli/listener.py | 40 ++++++++++++++++- benchllm/data_types.py | 24 ++++++++++ benchllm/evaluator/evaluator.py | 1 + benchllm/tester.py | 56 ++++++++++++++++++++++-- benchllm/utils.py | 39 ++++++++++++++++- examples/weather_functions/default.yml | 4 ++ examples/weather_functions/eval.py | 8 ++++ examples/weather_functions/forecast.py | 58 +++++++++++++++++++++++++ examples/weather_functions/rainy.yml | 10 +++++ examples/weather_functions/sunny.yml | 10 +++++ examples/weather_functions/tomorrow.yml | 9 ++++ 11 files changed, 254 insertions(+), 5 deletions(-) create mode 100644 examples/weather_functions/default.yml create mode 100644 examples/weather_functions/eval.py create mode 100644 examples/weather_functions/forecast.py create mode 100644 examples/weather_functions/rainy.yml create mode 100644 examples/weather_functions/sunny.yml create mode 100644 examples/weather_functions/tomorrow.yml diff --git a/benchllm/cli/listener.py b/benchllm/cli/listener.py index f1f6135..026c9d9 100644 --- a/benchllm/cli/listener.py +++ b/benchllm/cli/listener.py @@ -10,9 +10,17 @@ from rich.table import Table from benchllm.cache import MemoryCache -from benchllm.data_types import Evaluation, FunctionID, Prediction, Test, TestFunction +from benchllm.data_types import ( + CallErrorType, + Evaluation, + FunctionID, + Prediction, + Test, + TestFunction, +) from benchllm.evaluator import Evaluator from benchllm.listener import EvaluatorListener, TesterListener +from benchllm.utils import collect_call_errors class ReportListener(TesterListener, EvaluatorListener): @@ -107,7 +115,37 @@ def evaluate_prediction_ended(self, evaluation: Evaluation) -> None: else: typer.secho("F", fg=typer.colors.RED, bold=True, nl=False) + def handle_call_error(self, evaluations) -> None: + predictions_with_calls = [ + evaluation.prediction for evaluation in evaluations if evaluation.prediction.test.calls + ] + if not predictions_with_calls: + return + + print_centered(" Call Warnings ") + + for prediction in predictions_with_calls: + errors = collect_call_errors(prediction) + if not errors: + continue + relative_path = prediction.function_id.relative_str(self.root_dir) + print_centered(f" [yellow]{relative_path}[/yellow] :: [yellow]{prediction.test.file_path}[/yellow] ", "-") + + for error in errors: + if error.error_type == CallErrorType.MISSING_ARGUMENT: + print( + f'[blue][bold]{error.function_name}[/bold][/blue] was never called with [blue][bold]"{error.argument_name}"[/bold][/blue]' + ) + elif error.error_type == CallErrorType.MISSING_FUNCTION: + print(f"[blue][bold]{error.function_name}[/bold][/blue] was never declared") + elif error.error_type == CallErrorType.VALUE_MISMATCH: + print( + f'[blue][bold]{error.function_name}[/bold][/blue] was called with "{error.argument_name}=[red][bold]{error.actual_value}[/bold][/red]", expected "[green][bold]{error.expected_value}[/bold][/green]"' + ) + def evaluate_ended(self, evaluations: list[Evaluation]) -> None: + self.handle_call_error(evaluations) + failed = [evaluation for evaluation in evaluations if not evaluation.passed] total_test_time = ( 0.0 if self._eval_only else sum(evaluation.prediction.time_elapsed for evaluation in evaluations) or 0.0 diff --git a/benchllm/data_types.py b/benchllm/data_types.py index 68aaa78..4f6fbb9 100644 --- a/benchllm/data_types.py +++ b/benchllm/data_types.py @@ -1,3 +1,4 @@ +from enum import Enum from pathlib import Path from typing import Any, Callable, Generic, Optional, TypeVar from uuid import uuid4 @@ -5,12 +6,20 @@ from pydantic import BaseModel, Field +class TestCall(BaseModel): + __test__ = False + name: str + arguments: dict[str, Any] + returns: Any + + class Test(BaseModel): __test__ = False id: str = Field(default_factory=lambda: str(uuid4())) input: Any expected: list[str] file_path: Optional[Path] = None + calls: Optional[list[TestCall]] = None class FunctionID(BaseModel): @@ -45,6 +54,21 @@ class Prediction(BaseModel): output: str time_elapsed: float function_id: FunctionID + calls: dict[str, list[dict[str, Any]]] = {} + + +class CallErrorType(str, Enum): + MISSING_FUNCTION = "Missing function" + MISSING_ARGUMENT = "Missing argument" + VALUE_MISMATCH = "Value mismatch" + + +class CallError(BaseModel): + function_name: str + argument_name: Optional[str] = None + expected_value: Optional[Any] = None + actual_value: Optional[Any] = None + error_type: CallErrorType class Evaluation(BaseModel): diff --git a/benchllm/evaluator/evaluator.py b/benchllm/evaluator/evaluator.py index 87affb6..bfc214c 100644 --- a/benchllm/evaluator/evaluator.py +++ b/benchllm/evaluator/evaluator.py @@ -60,6 +60,7 @@ def _run_evaluation(self, prediction: Prediction) -> Evaluation: start = timer() match = self.evaluate_prediction(prediction) end = timer() + evaluation = Evaluation( prediction=prediction, passed=isinstance(match, Evaluator.Match), eval_time_elapsed=end - start ) diff --git a/benchllm/tester.py b/benchllm/tester.py index d3e48d3..5242f35 100644 --- a/benchllm/tester.py +++ b/benchllm/tester.py @@ -1,11 +1,13 @@ import importlib.util import inspect import json +import sys import uuid +from contextlib import contextmanager from pathlib import Path from timeit import default_timer as timer from types import ModuleType -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Iterator, Optional, Union import yaml from pydantic import ValidationError, parse_obj_as @@ -111,10 +113,19 @@ def run(self) -> list[Prediction]: self._broadcast_test_started(test) start = timer() - output = test_function.function(input) + + # set up mock functions for the test calls + calls_made: dict[str, Any] = {} + with setup_mocks(test, calls_made): + output = test_function.function(input) + end = timer() prediction = Prediction( - test=test, output=output, time_elapsed=end - start, function_id=test_function.function_id + test=test, + output=output, + time_elapsed=end - start, + function_id=test_function.function_id, + calls=calls_made, ) self._predictions.append(prediction) self._broadcast_test_ended(prediction) @@ -226,8 +237,47 @@ def import_module_from_file(file_path: Path) -> ModuleType: if not module: raise Exception(f"Failed to load module from {file_path}") + # Temporarly add the directory of the file to the system path so that the module can import other modules. + file_module = file_path.resolve().parent + old_sys_path = sys.path.copy() + sys.path.append(str(file_module)) + # Execute the module. spec.loader.exec_module(module) + # Restore the system path + sys.path = old_sys_path + # Return the module. return module + + +@contextmanager +def setup_mocks(test: Test, calls_made: dict[str, Any]) -> Iterator[None]: + """Sets up mock functions for the test calls""" + old_functions = [] + for call in test.calls or []: + mock_name = call.name + module_name, function_name = mock_name.rsplit(".", 1) + # we need to import the module before we can mock the function + module = importlib.import_module(module_name) + old_functions.append((module, function_name, getattr(module, function_name))) + + def mock_function(*args: tuple, **kwargs: dict[str, Any]) -> Any: + assert not args, "Positional arguments are not supported" + if mock_name not in calls_made: + calls_made[mock_name] = [] + calls_made[mock_name].append(kwargs) + return call.returns + + try: + setattr(module, function_name, mock_function) + except AttributeError: + print(f"Function {function_name} doesn't exist in module {module_name}") + + try: + yield + finally: + # restore the old function + for old_function in old_functions: + setattr(old_function[0], old_function[1], old_function[2]) diff --git a/benchllm/utils.py b/benchllm/utils.py index b4f6f74..9e556d8 100644 --- a/benchllm/utils.py +++ b/benchllm/utils.py @@ -4,7 +4,7 @@ import yaml -from benchllm.data_types import Prediction +from benchllm.data_types import CallError, CallErrorType, Prediction class DecoratorFinder(ast.NodeVisitor): @@ -82,3 +82,40 @@ def load_prediction_files(paths: list[Path]) -> list[Prediction]: data = yaml.safe_load(file) predictions.append(Prediction(**data)) return predictions + + +def collect_call_errors(prediction: Prediction) -> list[CallError]: + """Assert that the calls in the prediction match the expected calls.""" + if prediction.test.calls is None: + return [] + errors = [] + lookup = {call.name: call for call in prediction.test.calls} + + for function_name, invocations in prediction.calls.items(): + call = lookup[function_name] + if not call: + errors.append(CallError(function_name=function_name, error_type=CallErrorType.MISSING_FUNCTION)) + continue + + for arguments in invocations: + for argument_name, argument_value in call.arguments.items(): + if argument_name not in arguments: + errors.append( + CallError( + function_name=function_name, + argument_name=argument_name, + error_type=CallErrorType.MISSING_ARGUMENT, + ) + ) + for argument_name, argument_value in arguments.items(): + if argument_name in call.arguments and argument_value != call.arguments[argument_name]: + errors.append( + CallError( + function_name=function_name, + argument_name=argument_name, + expected_value=call.arguments[argument_name], + actual_value=argument_value, + error_type=CallErrorType.VALUE_MISMATCH, + ) + ) + return errors diff --git a/examples/weather_functions/default.yml b/examples/weather_functions/default.yml new file mode 100644 index 0000000..8359aca --- /dev/null +++ b/examples/weather_functions/default.yml @@ -0,0 +1,4 @@ +expected: + - "Yes" +id: d80b8d57-9c50-44fe-b492-c3cab549d9f3 +input: Will it rain tomorrow? diff --git a/examples/weather_functions/eval.py b/examples/weather_functions/eval.py new file mode 100644 index 0000000..a59c1d2 --- /dev/null +++ b/examples/weather_functions/eval.py @@ -0,0 +1,8 @@ +from forecast import run + +import benchllm + + +@benchllm.test() +def eval(input: str): + return run(input) diff --git a/examples/weather_functions/forecast.py b/examples/weather_functions/forecast.py new file mode 100644 index 0000000..c2ad63a --- /dev/null +++ b/examples/weather_functions/forecast.py @@ -0,0 +1,58 @@ +import json + +import openai + + +def get_n_day_weather_forecast(location: str, num_days: int): + return f"The weather in {location} will be rainy for the next {num_days} days." + + +def chain(prompt: list[dict], functions): + response = openai.ChatCompletion.create( + model="gpt-3.5-turbo-0613", messages=prompt, temperature=0.0, functions=functions + ) + + choice = response["choices"][0] + if choice.get("finish_reason") == "function_call": + function_call = choice["message"]["function_call"] + function_name = function_call["name"] + function_args = json.loads(function_call["arguments"]) + fn = globals()[function_name] + output = fn(**function_args) + prompt.append({"role": "function", "name": function_name, "content": output}) + return chain(prompt, functions) + else: + return response.choices[0].message.content.strip() + + +def run(question: str): + messages = [ + { + "role": "user", + "content": "Only answer questions with 'yes', 'no' or 'unknown', you must not reply with anything else", + }, + {"role": "system", "content": "Use the get_n_day_weather_forecast function for weather questions"}, + {"role": "user", "content": question}, + ] + + functions = [ + { + "name": "get_n_day_weather_forecast", + "description": "Get an N-day weather forecast", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "num_days": { + "type": "integer", + "description": "The number of days to forecast. E.g. 1 for today, 2 for tomorrow etc", + }, + }, + "required": ["location", "format", "num_days"], + }, + }, + ] + return chain(messages, functions) diff --git a/examples/weather_functions/rainy.yml b/examples/weather_functions/rainy.yml new file mode 100644 index 0000000..b0bbdd3 --- /dev/null +++ b/examples/weather_functions/rainy.yml @@ -0,0 +1,10 @@ +expected: + - "yes" +id: 4b717b39-da96-4ca3-8aaf-070c7f11449c +input: I'm going to San Francisco today, do I need an umbrella? +calls: + - name: forecast.get_n_day_weather_forecast + returns: It's rainy in San Francisco today. + arguments: + location: San Francisco + num_days: 1 diff --git a/examples/weather_functions/sunny.yml b/examples/weather_functions/sunny.yml new file mode 100644 index 0000000..c7f0534 --- /dev/null +++ b/examples/weather_functions/sunny.yml @@ -0,0 +1,10 @@ +expected: + - "no" +id: 4b717b39-da96-4ca3-8aaf-070c7f11449c +input: I live in London, can I expect rain today? +calls: + - name: forecast.get_n_day_weather_forecast + returns: It's sunny today in London today. + arguments: + location: London + num_days: 1 diff --git a/examples/weather_functions/tomorrow.yml b/examples/weather_functions/tomorrow.yml new file mode 100644 index 0000000..e236ae9 --- /dev/null +++ b/examples/weather_functions/tomorrow.yml @@ -0,0 +1,9 @@ +expected: + - "yes" +id: 4b717b39-da96-4ca3-8aaf-070c7f11449c +input: Do I need a sun hat tomorrow? +calls: + - name: forecast.get_n_day_weather_forecast + returns: It's sunny in your location tomorrow. + arguments: + num_days: 2