In [1]:
from mistralai import Mistral
from getpass import getpass

# To interract with the SQL database
from langchain_community.utilities import SQLDatabase

In [2]:
# To evaluate text2SQL performances
from deepeval.models import DeepEvalBaseLLM
from deepeval.test_case import LLMTestCase, LLMTestCaseParams
from deepeval.dataset import EvaluationDataset
from deepeval.metrics import GEval

import json

In [13]:
import os 
import psycopg2
import psycopg2.extras
from dotenv import load_dotenv
load_dotenv()

api_key = os.getenv("MISTRAL_API_KEY")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT")

client = Mistral(api_key=api_key)

In [30]:
# create a function run_sql_query that takes a sql_code as input and returns the result of the query 
# (using postgresql and psycopg2)

def run_sql_query(sql_code):
    try:
        # Connect to the PostgreSQL database
        conn = psycopg2.connect(
            dbname="chinbook",
            user=os.getenv("DB_USER")   ,
            password=os.getenv("DB_PASSWORD"),
            host=os.getenv("DB_HOST"),
            port=os.getenv("DB_PORT")
        )
        cursor = conn.cursor()
        cursor.execute(sql_code)
        result = cursor.fetchall()
        conn.commit()
        return result
    except Exception as e:
        print(f"Error: {e}")
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()

def run_sql_query(sql_code):
    """Execute SQL and return results as a string suitable for Mistral tool."""
    try:
        conn = psycopg2.connect(
            dbname="chinbook",
            user=os.getenv("DB_USER"),
            password=os.getenv("DB_PASSWORD"),
            host=os.getenv("DB_HOST"),
            port=os.getenv("DB_PORT")
        )
        cursor = conn.cursor()
        cursor.execute(sql_code)
        try:
            result = cursor.fetchall()  # returns list of tuples
            # Convert to readable string
            result_str = "\n".join([", ".join(map(str, row)) for row in result])
        except psycopg2.ProgrammingError:
            result_str = ""  # no results to fetch

        # Commit if modifying
        if sql_code.strip().lower().startswith(
            ("insert", "update", "delete", "create", "drop", "alter")
        ):
            conn.commit()

        return result_str

    except Exception as e:
        return f"Error: {e}"
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()



In [31]:
run_sql_query("SELECT * FROM Artist LIMIT 10;")

'1, AC/DC\n2, Accept\n3, Aerosmith\n4, Alanis Morissette\n5, Alice In Chains\n6, Antônio Carlos Jobim\n7, Apocalyptica\n8, Audioslave\n9, BackBeat\n10, Billy Cobham'

In [24]:
def get_sql_schema_of_table(table_name):
    """
    Return the CREATE TABLE statement of a table including PK and FK,
    or a message if the table does not exist.
    """
    try:
        conn = psycopg2.connect(
            dbname="chinbook",
            user=DB_USER,
            password=DB_PASSWORD,
            host=DB_HOST,
            port=DB_PORT
        )
        cur = conn.cursor()

        table_name_lower = table_name.lower()

        # Step 1: Check if table exists
        cur.execute("""
            SELECT EXISTS (
                SELECT 1
                FROM information_schema.tables
                WHERE table_schema='public'
                  AND table_name=%s
            );
        """, (table_name_lower,))
        exists = cur.fetchone()[0]

        if not exists:
            return f"The table '{table_name}' is not in the database."

        # Step 2: Get column definitions
        cur.execute("""
            SELECT 
                column_name, 
                data_type, 
                character_maximum_length, 
                is_nullable
            FROM information_schema.columns
            WHERE table_name = %s
            ORDER BY ordinal_position;
        """, (table_name_lower,))
        columns = cur.fetchall()

        # Step 3: Get primary key constraints
        cur.execute("""
            SELECT kcu.column_name
            FROM information_schema.table_constraints tc
            JOIN information_schema.key_column_usage kcu
              ON tc.constraint_name = kcu.constraint_name
            WHERE tc.table_name=%s
              AND tc.constraint_type='PRIMARY KEY';
        """, (table_name_lower,))
        pk_columns = [row[0] for row in cur.fetchall()]

        # Step 4: Get foreign key constraints
        cur.execute("""
            SELECT 
                kcu.column_name,
                ccu.table_name AS foreign_table_name,
                ccu.column_name AS foreign_column_name,
                rc.update_rule,
                rc.delete_rule
            FROM information_schema.table_constraints tc
            JOIN information_schema.key_column_usage kcu
              ON tc.constraint_name = kcu.constraint_name
            JOIN information_schema.referential_constraints rc
              ON tc.constraint_name = rc.constraint_name
            JOIN information_schema.constraint_column_usage ccu
              ON rc.unique_constraint_name = ccu.constraint_name
            WHERE tc.table_name = %s
              AND tc.constraint_type='FOREIGN KEY';
        """, (table_name_lower,))
        fks = cur.fetchall()

        # Step 5: Build CREATE TABLE statement
        lines = []
        for col_name, data_type, char_len, is_nullable in columns:
            if data_type == "character varying" and char_len:
                col_type = f"NVARCHAR({char_len})"
            elif data_type == "integer":
                col_type = "INTEGER"
            else:
                col_type = data_type.upper()

            null_str = "" if is_nullable == "NO" else "NULL"
            lines.append(f"    [{col_name}] {col_type} {null_str}".strip())

        # Add primary key
        if pk_columns:
            lines.append(f"    CONSTRAINT [PK_{table_name}] PRIMARY KEY ({', '.join(f'[{c}]' for c in pk_columns)})")

        # Add foreign keys
        for col, ref_table, ref_col, update_rule, delete_rule in fks:
            lines.append(f"    FOREIGN KEY ([{col}]) REFERENCES [{ref_table}] ([{ref_col}])"
                         f" ON DELETE {delete_rule} ON UPDATE {update_rule}")

        create_stmt = f"CREATE TABLE [{table_name}]\n(\n" + ",\n".join(lines) + "\n);"

        return f"The table {table_name} was created with the following code :\n\n{create_stmt}"

    except Exception as e:
        return f"Error: {e}"
    finally:
        if cur:
            cur.close()
        if conn:
            conn.close()

In [25]:
# Example usage
print(get_sql_schema_of_table("Album"))
print(get_sql_schema_of_table("Artist"))
print(get_sql_schema_of_table("NonExistingTable"))

The table Album was created with the following code :

CREATE TABLE [Album]
(
[album_id] INTEGER,
[title] NVARCHAR(160),
[artist_id] INTEGER,
    CONSTRAINT [PK_Album] PRIMARY KEY ([album_id]),
    FOREIGN KEY ([artist_id]) REFERENCES [artist] ([artist_id]) ON DELETE NO ACTION ON UPDATE NO ACTION
);
The table Artist was created with the following code :

CREATE TABLE [Artist]
(
[artist_id] INTEGER,
[name] NVARCHAR(120) NULL,
    CONSTRAINT [PK_Artist] PRIMARY KEY ([artist_id])
);
The table 'NonExistingTable' is not in the database.


# Build agent

In [36]:
def get_response(question, verbose=True):
    """
    Answer question about the Chinook database.

    Args:
        question (str): The question asked by the user.
        verbose (bool): If True, prints intermediate steps and results.

    Returns:
        str: The response to the user's question.
    """

    # Define the tools available for the AI assistant
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_sql_schema_of_table",
                "description": "Get the schema of a table in the Chinook database",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "table": {
                            "type": "string",
                            "enum": ["Album", "Artist", "Customer", "Employee", "Genre", "Invoice", "InvoiceLine", "MediaType", "Playlist", "PlaylistTrack", "Track"],
                            "description": "The question asked by the user",
                        },
                    },
                    "required": ["table"],
                },
            },
        },
        {
            "type": "function",
            "function": {
                "name": "run_sql_query",
                "description": "Run an SQL query on the Chinook database",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "sql_code": {
                            "type": "string",
                            "description": "SQL code to be run",
                        },
                    },
                    "required": ["sql_code"],
                },
            },
        }
    ]

    # System prompt for the AI assistant
    system_prompt = """
    You are an AI assistant.
    Your job is to reply to questions related to the Chinook database.
    The Chinook data model represents a digital media store, including tables for artists, albums, media tracks, invoices, and customers.

    To answer user questions, you have two tools at your disposal.

    Firstly, a function called "get_sql_schema_of_table" which has a single parameter named "table" whose value is an element
    of the following list: ["Album", "Artist", "Customer", "Employee", "Genre", "Invoice", "InvoiceLine", "MediaType", "Playlist", "PlaylistTrack", "Track"].

    Secondly, a function called "run_sql_query" which has a single parameter named "sql_code".
    It will run SQL code on the Chinook database. The SQL dialect is SQLite.

    For a given question, your job is to:
    1. Get the schemas of the tables that might help you answer the question using the "get_sql_schema_of_table" function.
    2. Run a SQLite query on the relevant tables using the "run_sql_query" function.
    3. Answer the user based on the result of the SQL query.

    You will always remain factual, you will not hallucinate, and you will say that you don't know if you don't know.
    You will politely ask the user to shoot another question if the question is not related to the Chinook database.
    """

    # Initialize chat history with system prompt and user question
    chat_history = [
        {"role": "system","content": system_prompt},
        {"role": "user","content": question}
    ]

    if verbose:
        print(f"User: {question}\n")

    used_run_sql = False
    used_get_sql_schema_of_table = False

    # Function to determine tool choice based on usage
    def tool_choice(used_run_sql, used_get_sql_schema_of_table):
        # If the question is out of topic the agent is not expected to run a tool call
        if not used_get_sql_schema_of_table:
            return "auto"
        # The agent is expected to run "used_run_sql" after getting the specifications of the tables of interest
        if used_get_sql_schema_of_table and not used_run_sql:
            return "any"
        # The agent is not expected to run a tool call after querying the SQL table
        if used_run_sql and used_get_sql_schema_of_table:
            return "none"
        return "auto"

    iteration = 0
    max_iteration = 10

    # Main loop to process the question
    while iteration < max_iteration:
        inference = client.chat.complete(
            model="mistral-tiny",
            temperature=0.3,
            messages=chat_history,
            tools=tools,
            tool_choice=tool_choice(used_run_sql, used_get_sql_schema_of_table)
        )

        chat_history.append(inference.choices[0].message)

        tool_calls = inference.choices[0].message.tool_calls

        if not tool_calls:
            if verbose:
                print(f"Assistant: {inference.choices[0].message.content}\n")
            return inference.choices[0].message.content

        for tool_call in tool_calls:
            function_name = tool_call.function.name
            function_params = json.loads(tool_call.function.arguments)

            if function_name == "get_sql_schema_of_table":
                function_result = get_sql_schema_of_table(function_params['table'])
                if verbose:
                    print(f"Tool: Getting SQL schema of table {function_params['table']}\n")
                used_get_sql_schema_of_table = True

            if function_name == "run_sql_query":
                function_result = run_sql_query(function_params['sql_code'])
                if verbose:
                    print(f"Tool: Running code {function_params['sql_code']}\n")
                used_run_sql = True

            chat_history.append({"role": "tool", "name": function_name, "content": function_result, "tool_call_id": tool_call.id})

        iteration += 1
    return

# Test the agent

Let's test the agent and ask a few random questions of increasing complexity

In [32]:
# Lets start by checking how the model reacts with out of topic questions!
response = get_response('What is the oldest player in the NBA?')

User: What is the oldest player in the NBA?

Assistant: The Chinook database does not contain any information about the NBA or its players. Could you please ask a question related to the Chinook database?



In [39]:
response = get_response('What are the names of the genres in the store?')

User: What are the names of the genres in the store?

Assistant: To find the names of the genres in the store, I will use the "get_sql_schema_of_table" function to get the schema of the "Genre" table and then run a SQL query to extract the genre names.

```
[{"name": "get_sql_schema_of_table", "arguments": {"table": "Genre"}}, {"name": "run_sql_query", "arguments": {"sql_code": "SELECT name FROM Genre"}}]
```

After running the above code, the result will be a list of genre names in the store. I will then present the list to you.



In [40]:
response = get_response('What are the albums of the rock band U2?')

User: What are the albums of the rock band U2?

Assistant: To answer your question, I will first get the schema of the "Album" table and the "Artist" table. Then, I will run a SQL query to find the albums of the rock band U2.

```
[{"name": "get_sql_schema_of_table", "arguments": {"table": "Album"}}, {"name": "get_sql_schema_of_table", "arguments": {"table": "Artist"}}, {"name": "run_sql_query", "arguments": {"sql_code": "SELECT Album.AlbumId, Album.Title FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = 'U2'"}}]
```

The result of the SQL query will provide the albums of the rock band U2.



In [41]:
response = get_response('What is the shortest song that the rock band U2 ever composed?')

User: What is the shortest song that the rock band U2 ever composed?

Assistant: To answer this question, I will first need to get the schema of the "Track" table and then find the song composed by U2 with the shortest duration. Here's how I'll do it:

```
    [{"name": "get_sql_schema_of_table", "arguments": {"table": "Track"}}]
```

After getting the schema, I'll write an SQL query to find the shortest song by U2. Here's the SQL code:

```
    [{"name": "run_sql_query", "arguments": {"sql_code": "SELECT duration, track_id FROM Track, Artist WHERE Track.Artist_id = Artist.Artist_id AND Artist.Name = 'U2' ORDER BY duration ASC LIMIT 1"}}]
```

This query will return the ID of the shortest song by U2. To get the song title, I'll need to run another query:

```
    [{"name": "run_sql_query", "arguments": {"sql_code": "SELECT Title FROM Track WHERE track_id = ..."}}]
```

Replace the dots with the ID returned from the previous query.

Finally, I'll return the shortest song by U2 along wit

In [42]:
response = get_response('Which track from U2 is the most sold?')

User: Which track from U2 is the most sold?

Assistant: To answer this question, I will need to find the track with the highest sales count from the "Track" table, joined with the "InvoiceLine" and "Invoice" tables. Here's the SQL query I will run:

```sql
SELECT Track.Name, Track.Milliseconds, Track.UnitPrice
FROM Track
JOIN InvoiceLine ON Track.TrackId = InvoiceLine.TrackId
JOIN Invoice ON InvoiceLine.InvoiceId = Invoice.InvoiceId
WHERE Track.Name IN (
    SELECT Name FROM Album
    JOIN Artist ON Album.ArtistId = Artist.ArtistId
    WHERE Artist.Name = 'U2'
)
GROUP BY Track.Name
ORDER BY SUM(InvoiceLine.Quantity) DESC
LIMIT 1;
```

This query first finds the album(s) by the artist 'U2', then it finds the tracks on those albums, and finally, it groups the tracks by their name, sums the quantity of each track sold, and orders the results in descending order. The top result will be the most sold track from U2.

Now, I will run this query using the "run_sql_query" function:

[{"name": "

In [43]:
response = get_response('Which consumer bought the biggest amound of U2 songs?')

User: Which consumer bought the biggest amound of U2 songs?

Assistant: To find the consumer who bought the most U2 songs, I would first need to know the table names and the columns that contain the consumer information and the song information.

Let's assume that the "InvoiceLine" table contains the information about the songs and the "Customer" table contains the consumer information.

The "InvoiceLine" table schema might look like this:

```
InvoiceLine (InvoiceLineID, InvoiceID, TrackID, Quantity, UnitPrice, Discount)
```

The "Customer" table schema might look like this:

```
Customer (CustomerID, FirstName, LastName, Company, ContactName, ContactTitle, Address, City, Region, PostalCode, Country, Phone, Fax)
```

To find the consumer who bought the most U2 songs, I would first need to find the TrackID(s) for U2 songs. Then, I would find the InvoiceLine records for those TrackID(s) and count the number of records for each customer.

Here's the SQL code to find the TrackID(s) for U2

In [44]:
response = get_response('List all artist that have a color in their name')

User: List all artist that have a color in their name

Assistant: To find all artists with a color in their name, we will first need to know the column that contains the artist name in the Artist table. Let's get the schema of the Artist table.

```javascript
const artistSchema = get_sql_schema_of_table("Artist");
```

From the schema, we can see that the column containing the artist name is "Name". Now, we can write an SQL query to find artists with a color in their name.

```javascript
const sqlCode = `SELECT Name FROM Artist WHERE Name LIKE '%[A-Za-z]%[A-Za-z0-9]%[ ,.]%[Rr]%[Ee]%[Dd]%[Ss]%[Hh]%[Ii]%[Pp]%[Oo]%[Yy]%[Cc]%[Kk]%[Ll]%[Ee]%[Aa]%[Tt]%[Aa]%[Nn]' OR Name LIKE '%[A-Za-z]%[A-Za-z0-9]%[ ,.]%[Bb]%[Ll]%[Uu]%[Ee]%[Rr]%[Oo]%[Ww]%[Ee]%[Dd]%[Ss]%[Hh]%[Ii]%[Pp]%[Oo]%[Yy]%[Cc]%[Kk]%[Ll]%[Ee]%[Aa]%[Tt]%[Aa]%[Nn]' OR Name LIKE '%[A-Za-z]%[A-Za-z0-9]%[ ,.]%[Yy]%[Ee]%[Ll]%[Ll]%[Oo]%[Ww]%[Oo]%[Rr]%[Ee]%[Dd]%[Ss]%[Hh]%[Ii]%[Pp]%[Oo]%[Yy]%[Cc]%[Kk]%[Ll]%[Ee]%[Aa]%[Tt]%[Aa]%[Nn]';`;

run_sql_qu

In [45]:
response = get_response('Who are our top Customers according to Invoices?')

User: Who are our top Customers according to Invoices?

Assistant: To find the top customers according to invoices, I will first get the schema of the Customer and Invoice tables. Then, I will join these tables on the CustomerID column and group the results by CustomerID to count the number of invoices for each customer. Finally, I will order the results in descending order by the count of invoices to find the top customers.

Here is the SQL query I will run:

```sql
SELECT Customer.CustomerName, COUNT(Invoice.InvoiceID) as InvoiceCount
FROM Customer
JOIN Invoice ON Customer.CustomerID = Invoice.CustomerID
GROUP BY Customer.CustomerID
ORDER BY InvoiceCount DESC
LIMIT 10;
```

I will use the "run_sql_query" function to run this query on the Chinook database:

```
[{"name": "run_sql_query", "arguments": {"sql_code": "SELECT Customer.CustomerName, COUNT(Invoice.InvoiceID) as InvoiceCount FROM Customer JOIN Invoice ON Customer.CustomerID = Invoice.CustomerID GROUP BY Customer.CustomerID OR

In [46]:
class CustomMistralLarge(DeepEvalBaseLLM):
    def __init__(self, api_key):
        self.client = Mistral(api_key=api_key)
        self.model_name = "mistral-large"

    def get_model_name(self):
        return "Mistral-tiny"

    def load_model(self):
        # Since we are using the Mistral API, we don't need to load a model object.
        return self.client

    def generate(self, prompt: str) -> str:
        chat_response = self.client.chat.complete(
            model=self.model_name,
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                },
            ]
        )
        return chat_response.choices[0].message.content

    async def a_generate(self, prompt: str) -> str:
        # Reusing the synchronous generate method for simplicity.
        return self.generate(prompt)

In [1]:
# Specify questions
questions = [
    "Which Employee has the Highest Total Number of Customers?",
    "Who are our top Customers according to Invoices?",
    "How many Rock music listeners are there?",
    "What artists have written most rock music songs?",
    "Which artist has earned the most according to the Invoice Lines? How much is it?",
    "How many tracks have a song length greater than the average song length?",
    "What is the most popular genre for Australia?"
]

# Specify associated answers
golden_answers = [
    "Peacock Jane has the most customers (she has 21 customers)",
    "Helena Holy, Richard Cunningham, Luis Rojas, Ladislav Kovacs, and Hugh O’Reilly are the top five customers who have spent the highest amount of money according to the invoice",
    "We found out that all 59 customers in the database have listened to Rock Music.",
    "Led Zeppelin tops the list of Artists who have written the most Rock Music with 114 songs followed Closely by U2 with 112 music.",
    "The Artist who has earned the most according to the invoice lines is Iron Maiden with a total of $138.6.",
    "Out of 3503 songs in the database, we found out that 494 of these songs have length more than the average music length of 393,599.21 milliseconds.",
    "Rock is the most popular song for Australia"

]