In [2]:
import os
import pandas as pd

from getpass import getpass
from langchain.utilities import SQLDatabase
from langchain.agents import create_sql_agent, AgentType
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.chat_models import ChatOpenAI
from langchain.evaluation.qa import QAEvalChain
from langchain.callbacks import get_openai_callback

from utils.evaluation_prompts import GRADING_PROMPT
from utils.utils import CustomDatabase

if "OPENAI_API_KEY" not in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass("OpenAI API Key: ")

## Load evaluation dataset

In [4]:
databases = [
    {
        'name': 'architecture',
        'eval_set_path': 'datasets/pampa_dataset/architecture_eval_dataset.json',
        'db_path': 'datasets/pampa_dataset/architecture.sqlite'
    },
    {
        'name': 'products_for_hire',
        'eval_set_path': 'datasets/pampa_dataset/products_for_hire_eval_dataset.json',
        'db_path': 'datasets/pampa_dataset/products_for_hire.sqlite'
    },
    {
        'name': 'restaurant_1',
        'eval_set_path': 'datasets/pampa_dataset/restaurant_1_eval_dataset.json',
        'db_path': 'datasets/pampa_dataset/restaurant_1.sqlite'
    },
    {
        'name': 'riding_club',
        'eval_set_path': 'datasets/pampa_dataset/riding_club_eval_dataset.json',
        'db_path': 'datasets/pampa_dataset/riding_club.sqlite'
    },
]

eval_databases = []
for database in databases:
    eval_databases.append(CustomDatabase(name=database['name'],
                                         database=SQLDatabase.from_uri(f"sqlite:///{database['db_path']}"),
                                         evaluation_dataset=pd.read_json(database['eval_set_path'])
                                         )
                                        )

## Create your custom Agent
We need to create one Agent for each database

In [5]:
llm = ChatOpenAI(temperature=0)

for eval_database in eval_databases:
    print(f"Creating Agent for {eval_database.name}...")
    eval_database.agent = create_sql_agent(
        llm=llm,
        verbose=False,
        agent_type=AgentType.OPENAI_FUNCTIONS,
        toolkit=SQLDatabaseToolkit(llm=llm, db=eval_database.database)
    )

Creating Agent for architecture...
Creating Agent for products_for_hire...
Creating Agent for restaurant_1...
Creating Agent for riding_club...


## Evaluate the Agent

First we **run** the agent over all databases:

In [19]:
for eval_database in eval_databases[:1]: # remove slice on list to run on whole list
    with get_openai_callback() as agent_cb:
        eval_database.run_agent()
        agent_tokens = agent_cb.total_tokens
    print(f"Finished running agent for {eval_database.name}")
    print(f"Agent run token cost: {agent_tokens}")

Finished running agent for database: architecture
Agent run token cost: 50824


Now we evaluate the agent results:

In [14]:
evaluate_chain = QAEvalChain.from_llm(llm=llm,prompt=GRADING_PROMPT)

In [22]:
for eval_database in eval_databases[:1]: # remove slice on list to run on whole list
    
    df = eval_database.evaluation_dataset
    questions = list(df['question'])
    targets = list(df['nl_result'])
    agent_results = list(df['agent_results'])
    
    predictions = []
    target_results = []
    
    for i, question in enumerate(questions):
        predictions.append({'question':question, 'result': agent_results[i]})
        target_results.append({'question':question, 'answer': targets[i]})

    with get_openai_callback() as eval_cb:
        res = evaluate_chain.evaluate(target_results, predictions, question_key="question", prediction_key="result")
        accuracy = len([r for r in res if r['results'] == 'Correct'])/len(res)
        eval_database.evaluation_dataset['Evaluation'] = [r['results'] for r in res]
    
    print("---------------------------------------------")
    print(f"Finished evaluation for database: {eval_database.name}")
    print(f"Accuracy: {accuracy}")
    print(f"Evaluation token cost: {eval_cb.total_tokens}")
    print("---------------------------------------------\n\n")

---------------------------------------------
Finished evaluation for database: architecture
Accuracy: 0.7058823529411765
Evaluation token cost: 3263
---------------------------------------------




We can explore results for a specific dataset:

In [24]:
eval_databases[0].evaluation_dataset

Unnamed: 0,db_id,question,query,query_result,nl_result,agent_results,Evaluation
6945,architecture,How many architects are female?,SELECT count(*) FROM architect WHERE gender =...,"[(1,)]",There is 1 female architect.,There is 1 female architect in the database.,Correct
6946,architecture,"List the name, nationality and id of all male ...","SELECT name , nationality , id FROM architec...","[('Frank Gehry', 'Canadian', '2'), ('Frank Llo...",The list of male architects ordered by their n...,"The names, nationalities, and IDs of all male ...",Correct
6947,architecture,What is the maximum length in meters for the b...,"SELECT max(T1.length_meters) , T2.name FROM b...","[(121.0, 'Frank Lloyd Wright')]",The maximum length for the bridges is 121 mete...,The maximum length of the bridges in meters is...,Correct
6948,architecture,What is the average length in feet of the brid...,SELECT avg(length_feet) FROM bridge,"[(244.64,)]",The average length of the bridges is approxima...,The average length of the bridges in feet is a...,Correct
6949,architecture,What are the names and year of construction fo...,"SELECT name , built_year FROM mill WHERE TYPE...","[('Le Vieux Molen', 1840), ('Moulin Bertrand',...",The mills of 'Grondzeiler' type are named 'Le ...,The names and year of construction for the mil...,Correct
6950,architecture,What are the distinct names and nationalities ...,"SELECT DISTINCT T1.name , T1.nationality FROM...","[('Frank Lloyd Wright', 'American'), ('Frank G...",The distinct names and nationalities of the ar...,The distinct names and nationalities of the ar...,Correct
6951,architecture,What are the names of the mills which are not ...,SELECT name FROM mill WHERE LOCATION != 'Donceel',"[('Le Vieux Molen',), ('Moulin de Fexhe',), ('...",The names of the mills which are not located i...,"Based on the schema of the `mill` table, I can...",Correct
6952,architecture,What are the distinct types of mills that are ...,SELECT DISTINCT T1.type FROM mill AS T1 JOIN a...,"[('Grondzeiler',)]",The distinct type of mill that is built by Ame...,The distinct type of mills that are built by A...,Incorrect
6953,architecture,What are the ids and names of the architects w...,"SELECT T1.id , T1.name FROM architect AS T1 J...","[('1', 'Frank Lloyd Wright'), ('2', 'Frank Geh...",The architects who built at least 3 bridges ar...,"Based on the schema of the tables, I can see t...",Correct
6954,architecture,"What is the id, name and nationality of the ar...","SELECT T1.id , T1.name , T1.nationality FROM...","[('1', 'Frank Lloyd Wright', 'American')]",The architect who built the most mills is Fran...,The architect who built the most mills is Fran...,Incorrect
