In [35]:
import json
from openai import OpenAI
from tenacity import retry, wait_random_exponential, stop_after_attempt
from termcolor import colored  

GPT_MODEL = "gpt-3.5-turbo-0613"

In [38]:
%set_env OPENAI_API_KEY=

env: OPENAI_API_KEY=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


In [39]:
client = OpenAI()

In [18]:
import sqlite3
conn = sqlite3.connect('Chinook_Sqlite.sqlite')

In [20]:
cur = conn.cursor()

In [23]:
cur.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()

[('Album',),
 ('Artist',),
 ('Customer',),
 ('Employee',),
 ('Genre',),
 ('Invoice',),
 ('InvoiceLine',),
 ('MediaType',),
 ('Playlist',),
 ('PlaylistTrack',),
 ('Track',)]

In [24]:
cur.execute("select * from Album").fetchall()

[(1, 'For Those About To Rock We Salute You', 1),
 (2, 'Balls to the Wall', 2),
 (3, 'Restless and Wild', 2),
 (4, 'Let There Be Rock', 1),
 (5, 'Big Ones', 3),
 (6, 'Jagged Little Pill', 4),
 (7, 'Facelift', 5),
 (8, 'Warner 25 Anos', 6),
 (9, 'Plays Metallica By Four Cellos', 7),
 (10, 'Audioslave', 8),
 (11, 'Out Of Exile', 8),
 (12, 'BackBeat Soundtrack', 9),
 (13, 'The Best Of Billy Cobham', 10),
 (14, 'Alcohol Fueled Brewtality Live! [Disc 1]', 11),
 (15, 'Alcohol Fueled Brewtality Live! [Disc 2]', 11),
 (16, 'Black Sabbath', 12),
 (17, 'Black Sabbath Vol. 4 (Remaster)', 12),
 (18, 'Body Count', 13),
 (19, 'Chemical Wedding', 14),
 (20, 'The Best Of Buddy Guy - The Millenium Collection', 15),
 (21, 'Prenda Minha', 16),
 (22, 'Sozinho Remix Ao Vivo', 16),
 (23, 'Minha Historia', 17),
 (24, 'Afrociberdelia', 18),
 (25, 'Da Lama Ao Caos', 18),
 (26, 'Acústico MTV [Live]', 19),
 (27, 'Cidade Negra - Hits', 19),
 (28, 'Na Pista', 20),
 (29, 'Axé Bahia 2001', 21),
 (30, 'BBC Sessions [

In [25]:
def get_table_names(conn):
    """Return a list of table names."""
    table_names = []
    tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
    for table in tables.fetchall():
        table_names.append(table[0])
    return table_names


def get_column_names(conn, table_name):
    """Return a list of column names."""
    column_names = []
    columns = conn.execute(f"PRAGMA table_info('{table_name}');").fetchall()
    for col in columns:
        column_names.append(col[1])
    return column_names


def get_database_info(conn):
    """Return a list of dicts containing the table name and columns for each table in the database."""
    table_dicts = []
    for table_name in get_table_names(conn):
        columns_names = get_column_names(conn, table_name)
        table_dicts.append({"table_name": table_name, "column_names": columns_names})
    return table_dicts


In [26]:
database_schema_dict = get_database_info(conn)
database_schema_string = "\n".join(
    [
        f"Table: {table['table_name']}\nColumns: {', '.join(table['column_names'])}"
        for table in database_schema_dict
    ]
)

In [28]:
print(database_schema_string)

Table: Album
Columns: AlbumId, Title, ArtistId
Table: Artist
Columns: ArtistId, Name
Table: Customer
Columns: CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Fax, Email, SupportRepId
Table: Employee
Columns: EmployeeId, LastName, FirstName, Title, ReportsTo, BirthDate, HireDate, Address, City, State, Country, PostalCode, Phone, Fax, Email
Table: Genre
Columns: GenreId, Name
Table: Invoice
Columns: InvoiceId, CustomerId, InvoiceDate, BillingAddress, BillingCity, BillingState, BillingCountry, BillingPostalCode, Total
Table: InvoiceLine
Columns: InvoiceLineId, InvoiceId, TrackId, UnitPrice, Quantity
Table: MediaType
Columns: MediaTypeId, Name
Table: Playlist
Columns: PlaylistId, Name
Table: PlaylistTrack
Columns: PlaylistId, TrackId
Table: Track
Columns: TrackId, Name, AlbumId, MediaTypeId, GenreId, Composer, Milliseconds, Bytes, UnitPrice


In [29]:
tools = [
    {
        "type" : "function",
        "function":{
            "name":"ask_database",
            "description": "Use this function to answer user questions about music. Input should be a fully formed SQL query.",
            "parameters":{
                "type":"object",
                "properties":{
                    "query": {
                        "type": "string",
                        "description": f"""
                                SQL query extracting info to answer the user's question.
                                SQL should be written using this database schema:
                                {database_schema_string}
                                The query should be returned in plain text, not in JSON.
                                """,
                    }
                },
                "required":["query"]
            }
        }
    }
]

In [58]:
def ask_database(conn,query):
    """Function to query SQLite database with a provided SQL query."""
    try:
        results = str(cur.execute(query).fetchall())
    except Exception as e:
        results = f"query failed with result {e}"
    return results

In [31]:
def execute_function_call(message):
    if message.tool_calls[0].function.name == "ask_database":
        query = json.loads(message.tool_calls[0].function.arguments)["query"]
        results = ask_database(conn, query)
    else:
        results = f"Error: function {message.tool_calls[0].function.name} does not exist"
    return results

In [32]:
messages = []
messages.append({"role":"system","content": "Answer user questions by generating SQL queries against the Chinook Music Database."})
messages.append({"role": "user", "content": "Hi, who are the top 5 artists by number of tracks?"})


In [40]:
@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request(messages, tools=None, tool_choice=None, model=GPT_MODEL):
    try:
        response = client.chat.completions.create(
            model=model,
            messages=messages,
            tools=tools,
            tool_choice=tool_choice,
        )
        return response
    except Exception as e:
        print("Unable to generate ChatCompletion response")
        print(f"Exception: {e}")
        return e


In [41]:
response = chat_completion_request(messages,tools)

In [42]:
response

ChatCompletion(id='chatcmpl-9I3ixKrQNUuVJ0NUsiTQqS1O941PY', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_fNHhzq0E8t8VJ4Wa2YUw1S8F', function=Function(arguments='{\n  "query": "SELECT Artist.Name, COUNT(Track.TrackId) as TrackCount FROM Artist JOIN Album ON Artist.ArtistId = Album.ArtistId JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Artist.ArtistId ORDER BY TrackCount DESC LIMIT 5"\n}', name='ask_database'), type='function')]))], created=1714091247, model='gpt-3.5-turbo-0613', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=67, prompt_tokens=377, total_tokens=444))

In [54]:
response.choices[0].message.tool_calls[0].function.arguments

'{\n  "query": "SELECT Artist.Name, COUNT(Track.TrackId) as TrackCount FROM Artist JOIN Album ON Artist.ArtistId = Album.ArtistId JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Artist.ArtistId ORDER BY TrackCount DESC LIMIT 5"\n}'

In [56]:
assistant_message = response.choices[0].message
assistant_message.content = str(assistant_message.tool_calls[0].function)
messages.append({"role": assistant_message.role, "content": assistant_message.content})

In [59]:
results = execute_function_call(assistant_message)

In [60]:
results

"[('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Lost', 92)]"

In [67]:
messages.append({"role": "function", "tool_call_id": assistant_message.tool_calls[0].id, 
                 "name": assistant_message.tool_calls[0].function.name, "content": results})

In [68]:
messages

[{'role': 'system',
  'content': 'Answer user questions by generating SQL queries against the Chinook Music Database.'},
 {'role': 'user',
  'content': 'Hi, who are the top 5 artists by number of tracks?'},
 {'role': 'assistant',
  'content': 'Function(arguments=\'{\\n  "query": "SELECT Artist.Name, COUNT(Track.TrackId) as TrackCount FROM Artist JOIN Album ON Artist.ArtistId = Album.ArtistId JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Artist.ArtistId ORDER BY TrackCount DESC LIMIT 5"\\n}\', name=\'ask_database\')'},
 {'role': 'function',
  'tool_call_id': 'call_fNHhzq0E8t8VJ4Wa2YUw1S8F',
  'name': 'ask_database',
  'content': "[('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Lost', 92)]"}]

In [None]:
messages.append({"role": "user", "content": "What is the name of the album with the most tracks?"})
chat_response = chat_completion_request(messages, tools)
assistant_message = chat_response.choices[0].message
assistant_message.content = str(assistant_message.tool_calls[0].function)
messages.append({"role": assistant_message.role, "content": assistant_message.content})
# if assistant_message.tool_calls:
#     results = execute_function_call(assistant_message)
#     messages.append({"role": "function", "tool_call_id": assistant_message.tool_calls[0].id, "name": assistant_message.tool_calls[0].function.name, "content": results})
# pretty_print_conversation(messages)

msg = [{"role":"function","content":f"{assistant_message}"}]

