In [1]:
import yaml
from Agent import *
from load import Flight, load_flights_dataset
import os

class EvaluationResult:
    def __init__(self, accuracy, conversation):
        self.accuracy = accuracy
        self.conversation = conversation
        
flights = load_flights_dataset()

In [2]:
def eval_agent(benchmark_file: str, flights: List[Flight]) -> float:
    """
    Evaluate the agent on the given benchmark YAML file.
    """
    agent = Agent(flights)
    with open(benchmark_file, "r") as file:
        steps = yaml.safe_load(file)
    for n, step in enumerate(steps):
        response = agent.say(step["prompt"])
        match step["expected_type"]:
            case "text":
                if not isinstance(response, TextResponse):
                    return EvaluationResult(n / len(steps), agent.conversation)
            case "find-flights":
                if not isinstance(response, FindFlightsResponse):
                    return EvaluationResult(n / len(steps), agent.conversation)
                if set(response.available_flights) != set(step["expected_result"]):
                    return EvaluationResult(n / len(steps), agent.conversation)
            case "book-flight":
                if not isinstance(response, BookFlightResponse):
                    return EvaluationResult(n / len(steps), agent.conversation)
                if response.booked_flight != step["expected_result"]:
                    return EvaluationResult(n / len(steps), agent.conversation)
    return EvaluationResult(1.0, agent.conversation)       


In [3]:
for file in os.listdir('benchmarks'):
    print(f"FILE: {file}")
    result = eval_agent(benchmark_file=f'benchmarks/{file}', flights=flights)
    print(f"ACCURACY: {result.accuracy}")
    print()

FILE: benchmark8.yml
ACCURACY: 0.4

FILE: benchmark9.yml
ACCURACY: 1.0

FILE: benchmark7.yml
ACCURACY: 1.0

FILE: benchmark6.yml
ACCURACY: 1.0

FILE: benchmark4.yml
ACCURACY: 1.0

FILE: benchmark5.yml
ACCURACY: 1.0

FILE: benchmark1.yml
ACCURACY: 1.0

FILE: benchmark2.yml
ACCURACY: 0.0

FILE: benchmark3.yml
ACCURACY: 1.0

FILE: benchmark10.yml
ACCURACY: 1.0

