This notebook aims at at using langchain AI agent with a postgreSQL database

In [1]:
from sqlalchemy import create_engine
from langchain_community.utilities.sql_database import SQLDatabase
from protected import pwd_db, grocq_api_key
import psycopg2

connection_params = {
    'dbname': 'pagila',
    'user': 'postgres',
    'password': pwd_db,
    'host': 'localhost',
    'port': 5432
    }
# Connect to the database
conn = psycopg2.connect(**connection_params)
cursor = conn.cursor()

conn = psycopg2.connect(**connection_params)
print("Direct psycopg2 connection successful.")

# Wrap psycopg2 connection with SQLAlchemy engine
engine = create_engine('postgresql+psycopg2://', creator=lambda: conn)
print("SQLAlchemy engine created successfully.")

# Use the engine in SQLDatabase
sql_database = SQLDatabase(engine)
print("SQLDatabase created successfully.")

# Example usage
tables = sql_database.get_table_names()
print("Tables in the database:", tables)

Direct psycopg2 connection successful.
SQLAlchemy engine created successfully.
SQLDatabase created successfully.
Tables in the database: ['actor', 'address', 'category', 'city', 'country', 'customer', 'film', 'film_actor', 'film_category', 'generated_views', 'inventory', 'language', 'payment', 'payment_p2007_01', 'payment_p2007_02', 'payment_p2007_03', 'payment_p2007_04', 'payment_p2007_05', 'payment_p2007_06', 'queries_log', 'rental', 'staff', 'store']


  tables = sql_database.get_table_names()


In [2]:
def get_logging_context(
):
    """
    Connect to PostgreSQL and retrieve table structures and some rows
    from queries_log and generated_views for an LLM 'logging context'.
    """
    connection_params = {
    'dbname': 'pagila',
    'user': 'postgres',
    'password': pwd_db,
    'host': 'localhost',
    'port': 5432
    }
    # Connect to the database
    conn = psycopg2.connect(**connection_params)
    cursor = conn.cursor()

    # Helper: fetch columns from information_schema
    def fetch_table_structure(table_name):
        cursor.execute("""
            SELECT column_name, data_type, is_nullable
            FROM information_schema.columns
            WHERE table_name = %s
            ORDER BY ordinal_position
        """, (table_name,))
        columns = cursor.fetchall()
        # Build a small text block describing each column
        lines = []
        for col_name, data_type, is_nullable in columns:
            lines.append(f"  - {col_name} {data_type} {'(nullable)' if is_nullable == 'YES' else '(not null)'}")
        structure_str = "Columns:\n" + "\n".join(lines)
        return structure_str

    # Helper: fetch sample rows from each table
    def fetch_sample_rows(table_name):
        cursor.execute(f"SELECT * FROM {table_name} ORDER BY created_at")
        rows = cursor.fetchall()
        # If you want column names, re-fetch from cursor.description:
        col_names = [desc.name for desc in cursor.description]

        lines = []
        for row in rows:
            # Zip column names with the row values
            row_data = ", ".join(f"{col}: {val}" for col, val in zip(col_names, row))
            lines.append("    " + row_data)
        if not lines:
            lines = ["    No rows found."]
        return "\n".join(lines)

    # Build the logging context
    logging_context_lines = ["-- LOGGING SCHEMA METADATA --"]

    for table_name in ["queries_log", "generated_views"]:
        logging_context_lines.append(f"Table: {table_name}")
        # structure
        table_structure = fetch_table_structure(table_name)
        logging_context_lines.append(table_structure)
        # sample rows
        sample_rows = fetch_sample_rows(table_name)
        logging_context_lines.append("Sample rows (up to 5):\n" + sample_rows)
        logging_context_lines.append("")  # blank line

    # Combine into a single text block
    logging_context = "\n".join(logging_context_lines)

    # Clean up
    cursor.close()
    conn.close()

    return logging_context

In [11]:
def prepare_agent_prompt(input_text, logging_context):
    agent_prompt = f"""
    Query the database using PostgreSQL syntax.
    
    1) Generate a PostgreSQL query using the input: {input_text}.     
    2) a SQL statements to log this query into the tables queries_log and generated_views based on the logging schema:
           - An INSERT INTO queries_log, with:
               * question: the exact user question\n"
               * description: a short description of what the query does (NOT NULL)
               * sql_text: the exact SQL query from (1)
               * created_at: use NOW() or CURRENT_TIMESTAMP
             (Do not set them to NULL. Provide real values.)
           - an INSERT INTO generated_views referencing that new view. Be aware of the columns and dependencies of those 2 tables.

    Do not run any query.
    Answer needs to be in the format of a JSON object.
    There must be as many keys as there are sql queries. The value of each key must be a runnable PostgreSQL query.

    Here is the logging context for the log queries :
    {logging_context}
    """

    return agent_prompt

In [12]:
import os
from langchain_community.agent_toolkits import create_sql_agent
from utils.groq_class import GroqLLM, GroqOutputParser
from langchain_ollama import OllamaLLM

# os.environ["GROQ_API_KEY"] = grocq_api_key
# # Set up your Groq API key and model
# api_key = os.environ.get("GROQ_API_KEY")
# model_name = "llama3-groq-8b-8192-tool-use-preview"

# # Initialize the custom LLM and parser
# llm = GroqLLM(api_key=api_key, model=model_name)
# parser = GroqOutputParser()

llm = OllamaLLM(model="llama3")

format_prompt = """
                Answer needs to be in the format of a JSON object.
                There must be as many keys as there are sql queries. The value of each key must be a runnable PostgreSQL query
                """



# Create the SQL agent
agent_executor = create_sql_agent(llm, db=sql_database, agent_type="zero-shot-react-description", verbose=True
                                #   , format_instructions  = format_prompt
                                  , agent_executor_kwargs={"handle_parsing_errors": True}
                                  )

question = "Get me all films by category"

logging_context = get_logging_context()
agent_prompt = prepare_agent_prompt(question, logging_context)

# Invoke the agent
response = agent_executor.invoke(agent_prompt)
print(response)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mI'll start by getting a list of tables and then looking at the schema for relevant tables.

Action: sql_db_list_tables
Action Input: empty string[0m[38;5;200m[1;3mactor, address, category, city, country, customer, film, film_actor, film_category, generated_views, inventory, language, payment, payment_p2007_01, payment_p2007_02, payment_p2007_03, payment_p2007_04, payment_p2007_05, payment_p2007_06, queries_log, rental, staff, store[0m[32;1m[1;3mThought: Now that I have a list of tables, I can focus on the schema for relevant tables. Since the question is about getting all films by category, I'll look at the schema for the film and category tables.

Action: sql_db_schema
Action Input: film, category[0m[33;1m[1;3m
CREATE TABLE category (
	category_id SERIAL NOT NULL, 
	name VARCHAR(25) NOT NULL, 
	last_update TIMESTAMP WITHOUT TIME ZONE DEFAULT now() NOT NULL, 
	CONSTRAINT category_pkey PRIMARY KEY (category_id)
)

  metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables]
  for tbl in self._metadata.sorted_tables


[32;1m[1;3mBased on the input question "Create a query to get all films by category", I will create a SQL query that joins the film and category tables to retrieve the desired information.

Here is the SQL query:
```sql
SELECT f.title, c.name 
FROM film f 
JOIN film_category fc ON f.film_id = fc.film_id 
JOIN category c ON fc.category_id = c.category_id;
```
This query joins the film table with the film_category table on the film_id column, and then joins the film_category table with the category table on the category_id column. This will return a list of films along with their corresponding categories.

To log this interaction in the queries_log table:
```sql
INSERT INTO queries_log (query_text) VALUES ('SELECT f.title, c.name FROM film f JOIN film_category fc ON f.film_id = fc.film_id JOIN category c ON fc.category_id = c.category_id;');
```
And to create a view for this query in the generated_views table:
```sql
INSERT INTO generated_views (view_name, query_text) VALUES ('films_by

In [None]:
# response

NameError: name 'response' is not defined

### SQLDatabaseChain

In [9]:
from langchain_experimental.sql import SQLDatabaseChain

# Initialize the LLM
llm = OllamaLLM(model="llama3")

# Create the SQL Database Chain
db_chain = SQLDatabaseChain.from_llm(
    llm=llm,
    db=sql_database,
    verbose=True,
    use_query_checker=True,
    return_intermediate_steps=True
)

In [10]:
question = "Get me all films by category"

logging_context = get_logging_context()
agent_prompt = prepare_agent_prompt(question, logging_context)

result = db_chain.invoke(agent_prompt)



[1m> Entering new SQLDatabaseChain chain...[0m

    Query the database using PostgreSQL syntax.
    
    1) Generate a PostgreSQL query using the input: Get me all films by category.     
    2) a SQL statements to log this query into the tables queries_log and generated_views based on the logging schema: -- LOGGING SCHEMA METADATA --
Table: queries_log
Columns:
  - query_id integer (not null)
  - question text (not null)
  - description text (nullable)
  - sql_text text (not null)
  - created_at timestamp without time zone (not null)
Sample rows (up to 5):
    query_id: 3, question: Create a list of all the actorsÃ¢â‚¬â„¢ first name and last name. Display the first and last name of each actor in a single column in upper case letters. Name the column Actor Name., description: Retrieves the names of all actors, formatting them as "Actor Name"., sql_text: SELECT UPPER(CONCAT(staff.first_name, ' , staff.last_name)) AS 'Actor Name' FROM staff;, created_at: 2024-12-24 00:45:03.197281
  

  metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables]
  for tbl in self._metadata.sorted_tables


[32;1m[1;3m```
SELECT UPPER(CONCAT(staff.first_name, ' , staff.last_name)) AS "Actor Name" FROM staff;
```[0m

ProgrammingError: (psycopg2.errors.SyntaxError) ERREUR:  erreur de syntaxe sur ou près de « ``` »
LINE 1: ```
        ^

[SQL: ```
SELECT UPPER(CONCAT(staff.first_name, ' , staff.last_name)) AS "Actor Name" FROM staff;
```]
(Background on this error at: https://sqlalche.me/e/20/f405)