In [1]:
import os
from smolagents import tool, CodeAgent, LiteLLMModel
model = LiteLLMModel(model_id='gpt-4o', 
                     api_key=os.getenv("OPENAI_API_KEY"), 
                     timeout=60, # I'm still tier 1 ㅠ_ㅠ
                    )
import datetime
today = datetime.datetime.today().strftime('%Y%m%d')

import sqlite3
conn = sqlite3.connect("Chinook.db")

In [2]:
def get_table_names(conn):
    """Return a list of table names."""
    table_names = []
    create_tables = []
    tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
    for table in tables.fetchall():
        table_names.append(table[0])
    return table_names

def get_column_metadata(conn, table_name):
    """Return a list of dictionaries containing metadata for each column."""
    column_metadata = []
    
    # Get column info (name, type, primary key)
    columns = conn.execute(f"PRAGMA table_info('{table_name}');").fetchall()
    for col in columns:
        col_name = col[1]
        col_type = col[2]
        is_primary_key = col[5] == 1  # PRAGMA table_info returns 1 for primary key columns

        column_metadata.append({
            "name": col_name,
            "type": col_type,
            "is_primary_key": is_primary_key,
            "is_foreign_key": False, 
            "references_table": None,
            "references_column": None
        })
    
    # Get foreign key info
    foreign_keys = conn.execute(f"PRAGMA foreign_key_list('{table_name}');").fetchall()
    for fk in foreign_keys:
        fk_column = fk[3]  # Column in the current table
        ref_table = fk[2]  # Referenced table
        ref_column = fk[4]  # Referenced column

        for col_meta in column_metadata:
            if col_meta["name"] == fk_column:
                col_meta["is_foreign_key"] = True
                col_meta["references_table"] = ref_table
                col_meta["references_column"] = ref_column
    
    return column_metadata

def get_database_info(conn):
    """Return a list of dicts containing the table name and columns for each table in the database."""
    table_dicts = []
    table_names = get_table_names(conn)
    for table_name in table_names:
        columns_metadata = get_column_metadata(conn, table_name)
        table_dicts.append({"table_name": table_name, "columns": columns_metadata})
    return table_dicts
    

database_schema_dict = get_database_info(conn)
database_schema_string = "\n".join(
    [
        f"Table: {table['table_name']}\n" +
        "\n".join([
            f"  - Column: {col['name']} ({col['type']}), "
            f"Primary Key: {col['is_primary_key']}, "
            f"Foreign Key: {col['is_foreign_key']} "
            f"({'-> ' + col['references_table'] + '.' + col['references_column'] if col['is_foreign_key'] else ''})"
            for col in table["columns"]
        ])
        for table in database_schema_dict
    ]
)

In [6]:
@tool
def run_sql(query: str) -> str:
    """Allows you to perform SQL queries on the table. 
    Returns a string representation of the result.
    
    Args:
        query: A query to perform. This should be correct SQL.    
    """
    try:
        results = str(conn.execute(query).fetchall())
    except Exception as e:
        results = f"query failed with error: {e}"
    return results

agent = CodeAgent(
    model=model,
    tools=[run_sql],
    additional_authorized_imports=["matplotlib"],
    max_steps=20,
)

save_path = f'./output_{today}'
if not os.path.exists(save_path):
    os.makedirs(save_path)
else:
    print(f"path {save_path} already exists")

prepend_sys_prompt = f"""Available SQL Database information including tables and schemas with data type, primary key and foreign key:
{database_schema_string}
\n\n"""

append_sys_prompt = f"""\nPerform all tasks as instructed. Save all codes and outputs under '{save_path}'. When you are finished \
with the task and generated all the necessary response and there is nothing else left for final_answer, \
then simply return 'Task completed'."""

agent.prompt_templates['system_prompt'] = prepend_sys_prompt + agent.prompt_templates['system_prompt'] + append_sys_prompt

In [4]:
task = "Find the total sales amount for each music genre and create a bar chart and save it as 'saves_vs_genre.png'"

In [5]:
result = agent.run(task)


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.

