In [1]:
from langchain_community.vectorstores.pgvector import PGVector
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.document_loaders.text import TextLoader
from langchain_core.output_parsers import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
from dotenv import load_dotenv
import os

app_dir = os.path.join(os.getcwd(), "app")
load_dotenv(os.path.join(app_dir, ".env"))


DATABASE_URL = "postgresql+psycopg://admin:admin@localhost:5432/vectordb"

embeddings = OpenAIEmbeddings()

store = PGVector(
    collection_name="vectordb",
    connection_string=DATABASE_URL,
    embedding_function=embeddings,
)
loader1 = TextLoader("./data/restaurant.txt")
loader2 = TextLoader("./data/founder.txt")

docs2 = loader1.load()
docs1 = loader2.load()
docs = docs1 + docs2

splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=20)
chunks = splitter.split_documents(docs)
store.add_documents(chunks)
retriever = store.as_retriever()

  store = PGVector(
  store = PGVector(


In [2]:
from operator import itemgetter

template = """Answer the question based only on the following context:
{context}

Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
model = ChatOpenAI(model="gpt-4o-mini")

rag_chain = (
    {
        "context": itemgetter("question") | retriever,
        "question": itemgetter("question"),
    }
    | prompt
    | model
    | StrOutputParser()
)

In [3]:
rag_chain.invoke({"question": "Who is the owner of the restaurant?"})

'The owner of the restaurant is Chef Amico.'

In [43]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.utilities.sql_database import SQLDatabase

template = """Based on the table schema below, write a SQL query that would answer the user's question:

Rules (must follow ALL):
- Output ONLY the SQL query text.
- No Markdown, no triple backticks, no labels, no explanations.
- Single line, no line breaks.
- Must start with SELECT or WITH and end with a semicolon.

Schema: {schema}
Question: {question}
SQL Query:

Return ONLY the SQL query."""
prompt = ChatPromptTemplate.from_template(template)


CONNECTION_STRING = (
    "postgresql+psycopg://readonlyuser:readonlypassword@localhost:5432/vectordb"
)

db = SQLDatabase.from_uri(CONNECTION_STRING)


def get_schema(_):
    schema = db.get_table_info()
    return schema


def run_query(query):
    return db.run(query)

In [35]:
print(get_schema("_"))


CREATE TABLE docstore (
	key VARCHAR NOT NULL, 
	value JSONB, 
	CONSTRAINT docstore_pkey PRIMARY KEY (key)
)

/*
3 rows from docstore table:
key	value

*/


CREATE TABLE langchain_pg_collection (
	name VARCHAR, 
	cmetadata JSON, 
	uuid UUID NOT NULL, 
	CONSTRAINT langchain_pg_collection_pkey PRIMARY KEY (uuid)
)

/*
3 rows from langchain_pg_collection table:
name	cmetadata	uuid
vectordb	None	3d45f74b-bc32-4d37-8c11-42ceb699fc0f
*/


CREATE TABLE langchain_pg_embedding (
	collection_id UUID, 
	embedding VECTOR, 
	document VARCHAR, 
	cmetadata JSON, 
	custom_id VARCHAR, 
	uuid UUID NOT NULL, 
	CONSTRAINT langchain_pg_embedding_pkey PRIMARY KEY (uuid), 
	CONSTRAINT langchain_pg_embedding_collection_id_fkey FOREIGN KEY(collection_id) REFERENCES langchain_pg_collection (uuid) ON DELETE CASCADE
)

/*
3 rows from langchain_pg_embedding table:
collection_id	embedding	document	cmetadata	custom_id	uuid
3d45f74b-bc32-4d37-8c11-42ceb699fc0f	[-0.00013965 -0.0140749   0.01334101 ... -0.00639858 -0.

In [36]:
from sqlalchemy import create_engine, inspect
from tabulate import tabulate


def get_schema(_):
    engine = create_engine(CONNECTION_STRING)

    inspector = inspect(engine)
    columns = inspector.get_columns("products")

    column_data = [
        {
            "Column Name": col["name"],
            "Data Type": str(col["type"]),
            "Nullable": "Yes" if col["nullable"] else "No",
            "Default": col["default"] if col["default"] else "None",
            "Autoincrement": "Yes" if col["autoincrement"] else "No",
        }
        for col in columns
    ]
    schema_output = tabulate(column_data, headers="keys", tablefmt="grid")
    formatted_schema = f"Schema for 'PRODUCTS' table:\n{schema_output}"

    return formatted_schema

In [37]:
print(get_schema("_"))

Schema for 'PRODUCTS' table:
+---------------+----------------+------------+--------------------------------------+-----------------+
| Column Name   | Data Type      | Nullable   | Default                              | Autoincrement   |
| id            | INTEGER        | No         | nextval('products_id_seq'::regclass) | Yes             |
+---------------+----------------+------------+--------------------------------------+-----------------+
| name          | VARCHAR(100)   | Yes        | None                                 | No              |
+---------------+----------------+------------+--------------------------------------+-----------------+
| price         | NUMERIC(10, 2) | Yes        | None                                 | No              |
+---------------+----------------+------------+--------------------------------------+-----------------+
| description   | TEXT           | Yes        | None                                 | No              |
+---------------+---------

In [38]:
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

model = ChatOpenAI(model="gpt-4o-mini")

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | model.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

sql_response.invoke({"question": "Whats the most expensive dessert you offer?"})

"SELECT name, price FROM PRODUCTS WHERE category = 'dessert' ORDER BY price DESC LIMIT 1;"

In [40]:
from langchain_core.runnables import RunnableLambda

template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)


def debug(input):
    print("SQL Output: ", input["query"])
    return input


sql_chain = (
    RunnablePassthrough.assign(query=sql_response).assign(
        schema=get_schema,
        response=lambda x: run_query(x["query"]),
    )
    | RunnableLambda(debug)
    | prompt_response
    | model
    | StrOutputParser()
)

In [41]:
sql_chain.invoke({"question": "Whats the most expensive dessert you offer?"})

SQL Output:  SELECT name, price FROM PRODUCTS WHERE category = 'dessert' ORDER BY price DESC LIMIT 1;


'The most expensive dessert we offer is the panettone, which is priced at $15.00.'

### What can go wrong? Users could run potential malicious queries

In [44]:
sql_chain.invoke({"question": "Drop all products from the products table"})

ProgrammingError: (psycopg.errors.InsufficientPrivilege) permission denied for table products
[SQL: DELETE FROM PRODUCTS;]
(Background on this error at: https://sqlalche.me/e/20/f405)

In [49]:
from sqlalchemy.exc import DBAPIError
from psycopg.errors import InsufficientPrivilege

try:
    result = sql_chain.invoke({"question": "Drop all products from the products table"})
except DBAPIError as err:
    if isinstance(err.orig, InsufficientPrivilege):
        result = "Haha nice try! Got ya!"
    else:
        result = "An unexpected error occurred"
except Exception as e:
    result = "An unexpected error occurred"

print(result)

Haha nice try! Got ya!


### Routing

In [50]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate


classification_template = PromptTemplate.from_template(
    """You are good at classifying a question.
    Given the user question below, classify it as either being about `Database`, `Chat` or 'Offtopic'.

    <If the question is about products of the restaurant OR ordering food classify the question as 'Database'>
    <If the question is about restaurant related topics like opening hours and similar topics, classify it as 'Chat'>
    <If the question is about whether, football or anything not related to the restaurant or
    products, classify the question as 'offtopic'>

    <question>
    {question}
    </question>

    Classification:"""
)

classification_chain = classification_template | ChatOpenAI(model="gpt-4o-mini") | StrOutputParser()

In [51]:
classification_chain.invoke({"question": "How is the wheather?"})

'Offtopic'

In [52]:
def route(info):
    if "database" in info["topic"].lower():
        return sql_chain
    elif "chat" in info["topic"].lower():
        return rag_chain
    else:
        return "I am sorry, I am not allowed to answer about this topic."

In [53]:
from langchain_core.runnables import RunnableLambda, RunnableParallel

full_chain = RunnableParallel(
    {
        "topic": classification_chain,
        "question": lambda x: x["question"],
    }
) | RunnableLambda(route)

In [54]:
full_chain.invoke({"question": "Whats the most expensive dessert you offer?"})

SQL Output:  SELECT name, price FROM PRODUCTS WHERE category = 'dessert' ORDER BY price DESC LIMIT 1;


'The most expensive dessert we offer is [Dessert Name], priced at [Price].'

In [55]:
full_chain.invoke({"question": "How will the wheater be tomorrow?"})

'I am sorry, I am not allowed to answer about this topic.'

In [56]:
full_chain.invoke({"question": "Who is the owner of the restaurant?"})

'The owner of the restaurant is Chef Amico.'