Skip to content

Commit

Permalink
Support for mocking functions in test YMLs (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
simedw committed Jul 19, 2023
1 parent 3933993 commit 679c8f5
Show file tree
Hide file tree
Showing 11 changed files with 254 additions and 5 deletions.
40 changes: 39 additions & 1 deletion benchllm/cli/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@
from rich.table import Table

from benchllm.cache import MemoryCache
from benchllm.data_types import Evaluation, FunctionID, Prediction, Test, TestFunction
from benchllm.data_types import (
CallErrorType,
Evaluation,
FunctionID,
Prediction,
Test,
TestFunction,
)
from benchllm.evaluator import Evaluator
from benchllm.listener import EvaluatorListener, TesterListener
from benchllm.utils import collect_call_errors


class ReportListener(TesterListener, EvaluatorListener):
Expand Down Expand Up @@ -107,7 +115,37 @@ def evaluate_prediction_ended(self, evaluation: Evaluation) -> None:
else:
typer.secho("F", fg=typer.colors.RED, bold=True, nl=False)

def handle_call_error(self, evaluations) -> None:
predictions_with_calls = [
evaluation.prediction for evaluation in evaluations if evaluation.prediction.test.calls
]
if not predictions_with_calls:
return

print_centered(" Call Warnings ")

for prediction in predictions_with_calls:
errors = collect_call_errors(prediction)
if not errors:
continue
relative_path = prediction.function_id.relative_str(self.root_dir)
print_centered(f" [yellow]{relative_path}[/yellow] :: [yellow]{prediction.test.file_path}[/yellow] ", "-")

for error in errors:
if error.error_type == CallErrorType.MISSING_ARGUMENT:
print(
f'[blue][bold]{error.function_name}[/bold][/blue] was never called with [blue][bold]"{error.argument_name}"[/bold][/blue]'
)
elif error.error_type == CallErrorType.MISSING_FUNCTION:
print(f"[blue][bold]{error.function_name}[/bold][/blue] was never declared")
elif error.error_type == CallErrorType.VALUE_MISMATCH:
print(
f'[blue][bold]{error.function_name}[/bold][/blue] was called with "{error.argument_name}=[red][bold]{error.actual_value}[/bold][/red]", expected "[green][bold]{error.expected_value}[/bold][/green]"'
)

def evaluate_ended(self, evaluations: list[Evaluation]) -> None:
self.handle_call_error(evaluations)

failed = [evaluation for evaluation in evaluations if not evaluation.passed]
total_test_time = (
0.0 if self._eval_only else sum(evaluation.prediction.time_elapsed for evaluation in evaluations) or 0.0
Expand Down
24 changes: 24 additions & 0 deletions benchllm/data_types.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Generic, Optional, TypeVar
from uuid import uuid4

from pydantic import BaseModel, Field


class TestCall(BaseModel):
__test__ = False
name: str
arguments: dict[str, Any]
returns: Any


class Test(BaseModel):
__test__ = False
id: str = Field(default_factory=lambda: str(uuid4()))
input: Any
expected: list[str]
file_path: Optional[Path] = None
calls: Optional[list[TestCall]] = None


class FunctionID(BaseModel):
Expand Down Expand Up @@ -45,6 +54,21 @@ class Prediction(BaseModel):
output: str
time_elapsed: float
function_id: FunctionID
calls: dict[str, list[dict[str, Any]]] = {}


class CallErrorType(str, Enum):
MISSING_FUNCTION = "Missing function"
MISSING_ARGUMENT = "Missing argument"
VALUE_MISMATCH = "Value mismatch"


class CallError(BaseModel):
function_name: str
argument_name: Optional[str] = None
expected_value: Optional[Any] = None
actual_value: Optional[Any] = None
error_type: CallErrorType


class Evaluation(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions benchllm/evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _run_evaluation(self, prediction: Prediction) -> Evaluation:
start = timer()
match = self.evaluate_prediction(prediction)
end = timer()

evaluation = Evaluation(
prediction=prediction, passed=isinstance(match, Evaluator.Match), eval_time_elapsed=end - start
)
Expand Down
56 changes: 53 additions & 3 deletions benchllm/tester.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import importlib.util
import inspect
import json
import sys
import uuid
from contextlib import contextmanager
from pathlib import Path
from timeit import default_timer as timer
from types import ModuleType
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Iterator, Optional, Union

import yaml
from pydantic import ValidationError, parse_obj_as
Expand Down Expand Up @@ -111,10 +113,19 @@ def run(self) -> list[Prediction]:

self._broadcast_test_started(test)
start = timer()
output = test_function.function(input)

# set up mock functions for the test calls
calls_made: dict[str, Any] = {}
with setup_mocks(test, calls_made):
output = test_function.function(input)

end = timer()
prediction = Prediction(
test=test, output=output, time_elapsed=end - start, function_id=test_function.function_id
test=test,
output=output,
time_elapsed=end - start,
function_id=test_function.function_id,
calls=calls_made,
)
self._predictions.append(prediction)
self._broadcast_test_ended(prediction)
Expand Down Expand Up @@ -226,8 +237,47 @@ def import_module_from_file(file_path: Path) -> ModuleType:
if not module:
raise Exception(f"Failed to load module from {file_path}")

# Temporarly add the directory of the file to the system path so that the module can import other modules.
file_module = file_path.resolve().parent
old_sys_path = sys.path.copy()
sys.path.append(str(file_module))

# Execute the module.
spec.loader.exec_module(module)

# Restore the system path
sys.path = old_sys_path

# Return the module.
return module


@contextmanager
def setup_mocks(test: Test, calls_made: dict[str, Any]) -> Iterator[None]:
"""Sets up mock functions for the test calls"""
old_functions = []
for call in test.calls or []:
mock_name = call.name
module_name, function_name = mock_name.rsplit(".", 1)
# we need to import the module before we can mock the function
module = importlib.import_module(module_name)
old_functions.append((module, function_name, getattr(module, function_name)))

def mock_function(*args: tuple, **kwargs: dict[str, Any]) -> Any:
assert not args, "Positional arguments are not supported"
if mock_name not in calls_made:
calls_made[mock_name] = []
calls_made[mock_name].append(kwargs)
return call.returns

try:
setattr(module, function_name, mock_function)
except AttributeError:
print(f"Function {function_name} doesn't exist in module {module_name}")

try:
yield
finally:
# restore the old function
for old_function in old_functions:
setattr(old_function[0], old_function[1], old_function[2])
39 changes: 38 additions & 1 deletion benchllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import yaml

from benchllm.data_types import Prediction
from benchllm.data_types import CallError, CallErrorType, Prediction


class DecoratorFinder(ast.NodeVisitor):
Expand Down Expand Up @@ -82,3 +82,40 @@ def load_prediction_files(paths: list[Path]) -> list[Prediction]:
data = yaml.safe_load(file)
predictions.append(Prediction(**data))
return predictions


def collect_call_errors(prediction: Prediction) -> list[CallError]:
"""Assert that the calls in the prediction match the expected calls."""
if prediction.test.calls is None:
return []
errors = []
lookup = {call.name: call for call in prediction.test.calls}

for function_name, invocations in prediction.calls.items():
call = lookup[function_name]
if not call:
errors.append(CallError(function_name=function_name, error_type=CallErrorType.MISSING_FUNCTION))
continue

for arguments in invocations:
for argument_name, argument_value in call.arguments.items():
if argument_name not in arguments:
errors.append(
CallError(
function_name=function_name,
argument_name=argument_name,
error_type=CallErrorType.MISSING_ARGUMENT,
)
)
for argument_name, argument_value in arguments.items():
if argument_name in call.arguments and argument_value != call.arguments[argument_name]:
errors.append(
CallError(
function_name=function_name,
argument_name=argument_name,
expected_value=call.arguments[argument_name],
actual_value=argument_value,
error_type=CallErrorType.VALUE_MISMATCH,
)
)
return errors
4 changes: 4 additions & 0 deletions examples/weather_functions/default.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
expected:
- "Yes"
id: d80b8d57-9c50-44fe-b492-c3cab549d9f3
input: Will it rain tomorrow?
8 changes: 8 additions & 0 deletions examples/weather_functions/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from forecast import run

import benchllm


@benchllm.test()
def eval(input: str):
return run(input)
58 changes: 58 additions & 0 deletions examples/weather_functions/forecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import json

import openai


def get_n_day_weather_forecast(location: str, num_days: int):
return f"The weather in {location} will be rainy for the next {num_days} days."


def chain(prompt: list[dict], functions):
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0613", messages=prompt, temperature=0.0, functions=functions
)

choice = response["choices"][0]
if choice.get("finish_reason") == "function_call":
function_call = choice["message"]["function_call"]
function_name = function_call["name"]
function_args = json.loads(function_call["arguments"])
fn = globals()[function_name]
output = fn(**function_args)
prompt.append({"role": "function", "name": function_name, "content": output})
return chain(prompt, functions)
else:
return response.choices[0].message.content.strip()


def run(question: str):
messages = [
{
"role": "user",
"content": "Only answer questions with 'yes', 'no' or 'unknown', you must not reply with anything else",
},
{"role": "system", "content": "Use the get_n_day_weather_forecast function for weather questions"},
{"role": "user", "content": question},
]

functions = [
{
"name": "get_n_day_weather_forecast",
"description": "Get an N-day weather forecast",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"num_days": {
"type": "integer",
"description": "The number of days to forecast. E.g. 1 for today, 2 for tomorrow etc",
},
},
"required": ["location", "format", "num_days"],
},
},
]
return chain(messages, functions)
10 changes: 10 additions & 0 deletions examples/weather_functions/rainy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
expected:
- "yes"
id: 4b717b39-da96-4ca3-8aaf-070c7f11449c
input: I'm going to San Francisco today, do I need an umbrella?
calls:
- name: forecast.get_n_day_weather_forecast
returns: It's rainy in San Francisco today.
arguments:
location: San Francisco
num_days: 1
10 changes: 10 additions & 0 deletions examples/weather_functions/sunny.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
expected:
- "no"
id: 4b717b39-da96-4ca3-8aaf-070c7f11449c
input: I live in London, can I expect rain today?
calls:
- name: forecast.get_n_day_weather_forecast
returns: It's sunny today in London today.
arguments:
location: London
num_days: 1
9 changes: 9 additions & 0 deletions examples/weather_functions/tomorrow.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
expected:
- "yes"
id: 4b717b39-da96-4ca3-8aaf-070c7f11449c
input: Do I need a sun hat tomorrow?
calls:
- name: forecast.get_n_day_weather_forecast
returns: It's sunny in your location tomorrow.
arguments:
num_days: 2

0 comments on commit 679c8f5

Please sign in to comment.