In [162]:
from langchain_core.prompts import ChatPromptTemplate,MessagesPlaceholder
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from dotenv import load_dotenv
import sqlite3
from pprint import pprint
from typing import TypedDict, List, Annotated, Union
from langchain_core.messages import BaseMessage,AIMessage,HumanMessage,SystemMessage
from pydantic import BaseModel, Field
import json
from langgraph.graph import StateGraph, END, add_messages
import operator
load_dotenv()

True

In [36]:
### generating a schema description ##
conn = sqlite3.connect('/home/shivargha/langGraph-agentic-playground/InsightQuery_Agent/Chinook_Sqlite.sqlite')
cursor = conn.cursor()
## Select all tables ##
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [row[0] for row in cursor.fetchall()]
schema_description = ""
for table in tables:
    cursor.execute(f"PRAGMA table_info({table})")
    columns = cursor.fetchall()
    schema_description += f"\nTable: {table}\n"
    for col in columns:
        schema_description += f" - {col[1]} ({col[2]})\n"

In [37]:
pprint(schema_description)

('\n'
 'Table: Album\n'
 ' - AlbumId (INTEGER)\n'
 ' - Title (NVARCHAR(160))\n'
 ' - ArtistId (INTEGER)\n'
 '\n'
 'Table: Artist\n'
 ' - ArtistId (INTEGER)\n'
 ' - Name (NVARCHAR(120))\n'
 '\n'
 'Table: Customer\n'
 ' - CustomerId (INTEGER)\n'
 ' - FirstName (NVARCHAR(40))\n'
 ' - LastName (NVARCHAR(20))\n'
 ' - Company (NVARCHAR(80))\n'
 ' - Address (NVARCHAR(70))\n'
 ' - City (NVARCHAR(40))\n'
 ' - State (NVARCHAR(40))\n'
 ' - Country (NVARCHAR(40))\n'
 ' - PostalCode (NVARCHAR(10))\n'
 ' - Phone (NVARCHAR(24))\n'
 ' - Fax (NVARCHAR(24))\n'
 ' - Email (NVARCHAR(60))\n'
 ' - SupportRepId (INTEGER)\n'
 '\n'
 'Table: Employee\n'
 ' - EmployeeId (INTEGER)\n'
 ' - LastName (NVARCHAR(20))\n'
 ' - FirstName (NVARCHAR(20))\n'
 ' - Title (NVARCHAR(30))\n'
 ' - ReportsTo (INTEGER)\n'
 ' - BirthDate (DATETIME)\n'
 ' - HireDate (DATETIME)\n'
 ' - Address (NVARCHAR(70))\n'
 ' - City (NVARCHAR(40))\n'
 ' - State (NVARCHAR(40))\n'
 ' - Country (NVARCHAR(40))\n'
 ' - PostalCode (NVARCHAR(10))\n'
 ' 

In [74]:
llm_model = ChatGoogleGenerativeAI(model = 'gemini-2.0-flash',max_retries=2)
template = '''
You are an expert SQL data analyst, you convert natural language questions into correct and optimised SQL queries.
You are working with the following database schema:
{schema_database}

Description about the database:
The Chinook database is a sample SQL database that simulates a digital music store. \
It contains tables for artists, albums, tracks, customers, invoices, and employees 

Here is the user question:
{user_query}

Your task outline is:
1. Understand the user's query and intent
2. Identify the relevant tables and columns
3. Join tables correctly if needed
4. Filter and aggregate results appropriately
5. Return the SQL query only
'''
sql_prompt_template = ChatPromptTemplate.from_template(template)
# sql_generation = prompt_template | llm_model

In [209]:
class AgentState(TypedDict):
    messages: Annotated[List[BaseMessage],add_messages]
    question: HumanMessage
    on_topic_classifier: str ##on topic or off topic classifier
    next_tool_selection: Union[str,None]
    sql_query: Union[str,None] ##stores the SQL syntax###
    sql_query_columns: List[str]
    sql_result: Annotated[List[str],operator.concat]

In [210]:
class ClassifyQuestion(BaseModel):
    on_topic_label:str = Field(description = 'Is the question based on the schema described and can be\
                                            converted into a SQL Query?If yes -> "Yes" if not -> "No"')
    
def on_topic_classifier(state:AgentState):
    print("Inside On Topic Classifier, at present the state is:",state)
    recent_question = state['question'].content
    sys_message = SystemMessage(content= """ You are classifier that determine's if the user's question is about the following database:
            The Chinook database is a sample SQL database that simulates a digital music store. \
            It contains tables for artists, albums, tracks, customers, invoices, and employees.
            The following is the schema description of the database:
            {schema_database}
            Use the database description and schema description to understand,\
            if the question is relevant and is in the bounds of the above database, respond with a 'Yes'.Otherwise respond with a 'No'
                                    """.format(schema_database=schema_description))
    human_message = HumanMessage(content=f"User Question: {state['question'].content}")
    classfier_prompt_template = ChatPromptTemplate.from_messages([sys_message,human_message])
    structure_llm = llm_model.with_structured_output(ClassifyQuestion)
    classifier_chain = classfier_prompt_template | structure_llm
    on_topic_res = classifier_chain.invoke({})
    state['on_topic_classifier'] = str(on_topic_res.on_topic_label.strip())
    return state

In [211]:
##test the on topic node ##
on_topic_classifier({'question':HumanMessage(content="What are the table names in chinook database?")})

Inside On Topic Classifier, at present the state is: {'question': HumanMessage(content='What are the table names in chinook database?', additional_kwargs={}, response_metadata={})}


{'question': HumanMessage(content='What are the table names in chinook database?', additional_kwargs={}, response_metadata={}),
 'on_topic_classifier': 'Yes'}

In [212]:
class SQLOutput(BaseModel):
    sql_query:str = Field(...,description="SQL Query generated")
    sql_column_names:List[str] = Field(...,description = "List of column names after data extraction")

def SQLGenerator(state:AgentState):
    print("inside the SQLGenerator agent; state at the moment:",state)
    recent_question = state['question'].content
    structure_llm = llm_model.with_structured_output(SQLOutput)
    sql_generation = sql_prompt_template | structure_llm
    state['messages'].append(HumanMessage(content=recent_question))
    sql_query_out = sql_generation.invoke({"schema_database":schema_description,"user_query":recent_question})
    print(sql_query_out)
    state['sql_query'] = sql_query_out.sql_query
    state['sql_query_columns'] = sql_query_out.sql_column_names
    state['next_tool_selection'] = 'sqlexecutor'
    state['messages'].append(AIMessage(content=sql_query_out.sql_query))
    return state

In [213]:
state_after_sqlgen = SQLGenerator({"messages":[],'sql_result':[],\
                                   "question":HumanMessage(content="Revenue contributed by each customer.")})
state_after_sqlgen 

inside the SQLGenerator agent; state at the moment: {'messages': [], 'sql_result': [], 'question': HumanMessage(content='Revenue contributed by each customer.', additional_kwargs={}, response_metadata={})}
sql_query='SELECT c.FirstName || " " || c.LastName AS customer_name, SUM(i.Total) AS total_revenue FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY total_revenue DESC' sql_column_names=['customer_name', 'total_revenue']


{'messages': [HumanMessage(content='Revenue contributed by each customer.', additional_kwargs={}, response_metadata={}),
  AIMessage(content='SELECT c.FirstName || " " || c.LastName AS customer_name, SUM(i.Total) AS total_revenue FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY total_revenue DESC', additional_kwargs={}, response_metadata={})],
 'sql_result': [],
 'question': HumanMessage(content='Revenue contributed by each customer.', additional_kwargs={}, response_metadata={}),
 'sql_query': 'SELECT c.FirstName || " " || c.LastName AS customer_name, SUM(i.Total) AS total_revenue FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY total_revenue DESC',
 'sql_query_columns': ['customer_name', 'total_revenue'],
 'next_tool_selection': 'sqlexecutor'}

In [216]:
def SQLExecutor(state:AgentState):
    if state['next_tool_selection'] == 'sqlexecutor':
        sql_query_code = state['sql_query']
        try:
            conn = sqlite3.connect("/home/shivargha/langGraph-agentic-playground/InsightQuery_Agent/Chinook_Sqlite.sqlite")
            result = conn.execute(sql_query_code).fetchall()
            state['sql_result'].append(json.dumps(result))
            return state
        except:
            state['sql_result'].append(json.dumps(result))
            return state

In [217]:
state_after_sqlexec = SQLExecutor(state_after_sqlgen)
state_after_sqlexec

{'messages': [HumanMessage(content='Revenue contributed by each customer.', additional_kwargs={}, response_metadata={}),
  AIMessage(content='SELECT c.FirstName || " " || c.LastName AS customer_name, SUM(i.Total) AS total_revenue FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY total_revenue DESC', additional_kwargs={}, response_metadata={})],
 'sql_result': ['[["Helena Hol\\u00fd", 49.62], ["Richard Cunningham", 47.62], ["Luis Rojas", 46.62], ["Ladislav Kov\\u00e1cs", 45.62], ["Hugh O\'Reilly", 45.62], ["Frank Ralston", 43.62], ["Julia Barnett", 43.62], ["Fynn Zimmermann", 43.62], ["Astrid Gruber", 42.62], ["Victor Stevens", 42.62], ["Terhi H\\u00e4m\\u00e4l\\u00e4inen", 41.62], ["Franti\\u0161ek Wichterlov\\u00e1", 40.62], ["Isabelle Mercier", 40.62], ["Johannes Van der Berg", 40.62], ["Lu\\u00eds Gon\\u00e7alves", 39.62], ["Fran\\u00e7ois Tremblay", 39.62], ["Bj\\u00f8rn Hansen", 39.62], ["Jack Smith", 39.62], ["Dan Miller", 39.62], ["Heat

In [218]:
state_after_sqlexec['sql_result'][-1]
json.loads(state_after_sqlexec['sql_result'][-1])

[['Helena Holý', 49.62],
 ['Richard Cunningham', 47.62],
 ['Luis Rojas', 46.62],
 ['Ladislav Kovács', 45.62],
 ["Hugh O'Reilly", 45.62],
 ['Frank Ralston', 43.62],
 ['Julia Barnett', 43.62],
 ['Fynn Zimmermann', 43.62],
 ['Astrid Gruber', 42.62],
 ['Victor Stevens', 42.62],
 ['Terhi Hämäläinen', 41.62],
 ['František Wichterlová', 40.62],
 ['Isabelle Mercier', 40.62],
 ['Johannes Van der Berg', 40.62],
 ['Luís Gonçalves', 39.62],
 ['François Tremblay', 39.62],
 ['Bjørn Hansen', 39.62],
 ['Jack Smith', 39.62],
 ['Dan Miller', 39.62],
 ['Heather Leacock', 39.62],
 ['João Fernandes', 39.62],
 ['Wyatt Girard', 39.62],
 ['Jennifer Peterson', 38.62],
 ['Tim Goyer', 38.62],
 ['Camille Bernard', 38.62],
 ['Dominique Lefebvre', 38.62],
 ['Joakim Johansson', 38.62],
 ['Manoj Pareek', 38.62],
 ['Leonie Köhler', 37.62],
 ['Daan Peeters', 37.62],
 ['Kara Nielsen', 37.62],
 ['Eduardo Martins', 37.62],
 ['Alexandre Rocha', 37.62],
 ['Roberto Almeida', 37.62],
 ['Fernanda Ramos', 37.62],
 ['Mark Philip

In [151]:
conn = sqlite3.connect("/home/shivargha/langGraph-agentic-playground/InsightQuery_Agent/Chinook_Sqlite.sqlite")
cursor = conn.cursor()
z = cursor.execute("SELECT t.Name FROM Track t JOIN InvoiceLine il ON t.TrackId = il.TrackId GROUP BY t.Name ORDER BY SUM(il.Quantity) DESC LIMIT 5")

In [159]:
sql = "SELECT t.Name FROM Track t JOIN InvoiceLine il ON t.TrackId = il.TrackId GROUP BY t.Name ORDER BY SUM(il.Quantity) DESC LIMIT 5"
conn.execute(sql).fetchall()

[('The Trooper',),
 ('Untitled',),
 ('The Number Of The Beast',),
 ('Sure Know Something',),
 ('Hallowed Be Thy Name',)]