In [14]:
import sys
sys.path.append("..")

In [15]:
from langchain_groq import ChatGroq
from dotenv import load_dotenv
import os

load_dotenv()

llm = ChatGroq(
    groq_api_key=os.getenv("GROQ_API_KEY"),
    model_name="openai/gpt-oss-120b",
    temperature=0,
    top_p=1
)

                    top_p was transferred to model_kwargs.
                    Please confirm that top_p is what you intended.
  validated_self = self.__pydantic_validator__.validate_python(data, self_instance=self)


In [16]:
from langchain_community.utilities import SQLDatabase

from sql_alchemy.engine import engine
 
db = SQLDatabase(engine)
db.dialect

Connection successful!


  self._metadata.reflect(


'postgresql'

In [17]:
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

## Way 1 (old, shortcut)

In [6]:
from langchain_community.agent_toolkits.sql.base import create_sql_agent

sql_agent = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True
)

question = "How many tshirts do we have left for nike in extra small size and white color?"

response = sql_agent. invoke( {"input": question})

print("\nFinal Answer:")
print(response["output"])



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m<think>
Okay, let's tackle this question. The user is asking how many t-shirts they have left for Nike in extra small size and white color. 

First, I need to figure out which tables in the database are relevant. Since they mentioned t-shirts, Nike, size, and color, the most likely table is something like a products or inventory table. But to be sure, I should list all the tables in the database.

I'll start by using the sql_db_list_tables tool. That will give me a list of all tables. Let's see what tables exist. Maybe there's a table named 'products', 'inventory', 'tshirts', or similar.

Once I have the list of tables, I should check the schema of the most relevant one. For example, if there's a 'products' table, I'll look at its columns to see if it includes brand, size, color, and quantity. The user needs to filter by brand (Nike), size (extra small), color (white), and sum the quantity.

If the schema shows that the 

# Embeddings

In [1]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
len(model.encode("test"))

  from .autonotebook import tqdm as notebook_tqdm


384

In [2]:
rag_docs = [
    {
        "content": "Revenue is calculated as price multiplied by stock quantity.",
        "category": "definition",
        "source": "manual"
    },
    {
        "content": "Discounts are optional. When calculating revenue, products without discounts should be included at full price using a LEFT JOIN.",
        "category": "business_rule",
        "source": "manual"
    },
    {
        "content": "Post-discount revenue means applying the percentage discount per t_shirt_id before summing total revenue.",
        "category": "business_rule",
        "source": "manual"
    },
    {
        "content": "Stock quantity represents current inventory and should be summed when calculating total stock.",
        "category": "definition",
        "source": "manual"
    }
]


In [3]:
texts = [doc["content"] for doc in rag_docs]

embeddings = model.encode(texts)
print(len(embeddings), len(embeddings[0]))


4 384


In [4]:
import psycopg2
import os

conn = psycopg2.connect(
    host=os.getenv("SUPABASE_DB_HOST"),
    dbname=os.getenv("SUPABASE_DB_NAME"),
    user=os.getenv("SUPABASE_DB_USER"),
    password=os.getenv("SUPABASE_DB_PASSWORD"),
    port=os.getenv("SUPABASE_DB_PORT"),
)

In [5]:
insert_sql = """
INSERT INTO rag_documents (content, category, source, embedding)
VALUES (%s, %s, %s, %s);
"""

with conn.cursor() as cur:
    for doc, emb in zip(rag_docs, embeddings):
        cur.execute(
            insert_sql,
            (
                doc["content"],
                doc["category"],
                doc["source"],
                emb.tolist(),  # IMPORTANT: convert numpy â†’ list
            )
        )
    conn.commit()


In [9]:
question = "If we sell all Leviâ€™s T-shirts today with discounts applied, how much revenue will we generate?"
question_embedding = model.encode(question).tolist()

retrieve_sql = """
SELECT content, category
FROM rag_documents
ORDER BY embedding <-> (%s)::vector
LIMIT %s;
"""

conn.rollback()  # ðŸ‘ˆ IMPORTANT

with conn.cursor() as cur:
    cur.execute(retrieve_sql, (question_embedding, 3))
    retrieved_docs = cur.fetchall()

print(retrieved_docs)


[('Post-discount revenue means applying the percentage discount per t_shirt_id before summing total revenue.', 'business_rule'), ('Revenue is calculated as price multiplied by stock quantity.', 'definition'), ('Discounts are optional. When calculating revenue, products without discounts should be included at full price using a LEFT JOIN.', 'business_rule')]


## RAG formatter to feed into system prompt

In [10]:
def format_rag_context(docs):
    context_lines = []
    for i, (content, category) in enumerate(docs, 1):
        context_lines.append(f"{i}. ({category}) {content}")
    return "\n".join(context_lines)

In [13]:
rag_context = format_rag_context(retrieved_docs)

# Modern Way

In [20]:
from langchain.agents import create_agent

system_message = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

You MUST double check your query before executing it. If you get an error while
executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database.

To start you should ALWAYS look at the tables in the database to see what you
can query. Do NOT skip this step.

Then you should query the schema of the most relevant tables.

ADDITIONAL BUSINESS CONTEXT (USE IF RELEVANT):

{rag_context}

IMPORTANT:
- The above context contains business rules and definitions.
- Use it ONLY if relevant to the user question.
- Do NOT hallucinate rules not present here.

IMPORTANT DATABASE RULES:

- Column values such as brand, color, and size are ENUMs and are CASE-SENSITIVE.
- You MUST inspect the table schema before filtering on ENUM columns.
- You MUST match user input to the closest valid ENUM value from the database.
  Example: "white" â†’ "White", "nike" â†’ "Nike", "extra small" â†’ "XS".
- If a query returns zero rows, you MUST re-check casing and ENUM values before answering.
- You MUST NEVER answer with brands, colors, or sizes that were not mentioned by the user.
- If the database does not contain matching rows after verification, say so explicitly.
- Do NOT guess. Do NOT generalize. Do NOT switch brands or sizes.

OUTPUT FORMAT (STRICT):

You MUST respond in valid JSON only.
Do NOT include explanations outside JSON.
Do NOT include markdown.

The JSON object MUST have exactly these keys:
- "sql": a single SQL query string (read-only)
- "metric": the primary metric being computed
- "assumptions": a list of assumptions made while forming the SQL
Also, if you cannot form a correct SQL query, return:
{{
  "sql": null,
  "metric": null,
  "assumptions": ["insufficient information"]
}}""".format(
    dialect=db.dialect,
    top_k=5,
    rag_context = rag_context
)

agent = create_agent(llm, toolkit.get_tools(), system_prompt=system_message)

final_event = None

events = agent.stream(
    {"messages": [("user", question)]},
    stream_mode="values",
)

for event in events:
    final_event = event

final_message = final_event["messages"][-1]
print(final_message.content)

{
  "sql": "SELECT SUM(t.price * t.stock_quantity * (1 - COALESCE(d.pct_discount,0)/100.0)) AS revenue\nFROM t_shirts t\nLEFT JOIN discounts d USING (t_shirt_id)\nWHERE t.brand = 'Levi';",
  "metric": "total revenue from selling all Levi's T-shirts with discounts applied",
  "assumptions": [
    "Brand enum value for Levi's is 'Levi'",
    "All T-shirts are sold at current stock quantity",
    "Discounts are applied per t_shirt_id before summing revenue",
    "Products without a discount are sold at full price"
  ]
}
