In [44]:
import os
from sqlalchemy import create_engine
from langchain_community.agent_toolkits import create_sql_agent
from langchain.sql_database import SQLDatabase
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate

from langchain_core.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
    AIMessagePromptTemplate,
    MessagesPlaceholder,
)

# Load environment variables from a .env file if needed
from dotenv import load_dotenv
load_dotenv()

# Ensure that the OPENAI_API_KEY environment variable is set
openai_api_key = os.getenv('OPENAI_API_KEY')

if not openai_api_key:
    raise ValueError("OPENAI_API_KEY is not set in the environment variables")

# Define your connection string
username = 'postgres'
password = '<password>'
host = 'localhost'  # or your actual host
port = '5432'  # or your actual port if different
database = 'chinook'

connection_string = f'postgresql+psycopg2://{username}:{password}@{host}:{port}/{database}'
engine = create_engine(connection_string)

# Initialize the OpenAI LLM
llm = ChatOpenAI(api_key=openai_api_key, temperature=0, model='gpt-4o-mini')
sql_database = SQLDatabase(engine)

prefix = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
If you need to group by any time unit (month, year, day of week, etc.) please use the extract() function and not the date_trunc() function.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.
"""

suffix = """I should look at the tables in the database to see what I can query.  Then I should query the schema of the most relevant tables.
"""

messages = [
                SystemMessagePromptTemplate.from_template(prefix),
                HumanMessagePromptTemplate.from_template("{input}"),
                AIMessagePromptTemplate.from_template(suffix),
                MessagesPlaceholder(variable_name="agent_scratchpad"),
            ]
prompt = ChatPromptTemplate.from_messages(messages)

# Create the SQL agent
sql_agent = create_sql_agent(llm=llm, db=sql_database, verbose=True, top_k=None, agent_type='openai-tools', prompt=prompt)

sql_agent.invoke(
    {
        "input": "what are the total monthly sales?"
    }
)




[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3malbum, artist, customer, employee, genre, invoice, invoice_line, media_type, playlist, playlist_track, track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'invoice'}`


[0m[33;1m[1;3m
CREATE TABLE invoice (
	invoice_id INTEGER NOT NULL, 
	customer_id INTEGER NOT NULL, 
	invoice_date TIMESTAMP WITHOUT TIME ZONE NOT NULL, 
	billing_address VARCHAR(70), 
	billing_city VARCHAR(40), 
	billing_state VARCHAR(40), 
	billing_country VARCHAR(40), 
	billing_postal_code VARCHAR(10), 
	total NUMERIC(10, 2) NOT NULL, 
	CONSTRAINT invoice_pkey PRIMARY KEY (invoice_id), 
	CONSTRAINT invoice_customer_id_fkey FOREIGN KEY(customer_id) REFERENCES customer (customer_id)
)

/*
3 rows from invoice table:
invoice_id	customer_id	invoice_date	billing_address	billing_city	billing_state	billing_country	billing_postal_code	total
1	2	2021-01-01 00:00:00	Theodor-Heu

{'input': 'what are the total monthly sales?',
 'output': 'The total monthly sales are as follows:\n\n- January: $201.12\n- February: $187.20\n- March: $195.10\n- April: $198.14\n- May: $193.10\n- June: $201.10\n- July: $190.10\n- August: $198.10\n- September: $196.20\n- October: $193.10\n- November: $186.24\n- December: $189.10'}

IndexError: list index out of range