## Local LLM SQL Agent for AdventureWorksLT 2022 Database on SQL Server Express

### Import Dependencies

In [None]:
from typing_extensions import TypedDict
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from sqlalchemy import text, inspect
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.engine import URL
from langgraph.graph import StateGraph, END
from langchain_community.llms import ollama
from langchain_core.runnables import RunnableConfig
from langchain_ollama import ChatOllama
import gradio as gr


  from .autonotebook import tqdm as notebook_tqdm


### Database Connection

In [2]:
connection_string = (
    r"Driver=ODBC Driver 18 for SQL Server;"
    r"Server={DATABASE SERVER};"
    r"Database=AdventureWorksLT2022;"
    r"Trusted_Connection=yes;"
    r"TrustServerCertificate=yes;"
)

DATABASE_URL = URL.create(
    "mssql+pyodbc", 
    query={"odbc_connect": connection_string}
)

engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

### Getting Database Schema

In [74]:
class AgentState(TypedDict):
    question: str
    sql_query: str
    query_result: str
    query_rows: list
    current_user: str
    attempts: int
    relevance: str
    sql_error: bool



In [80]:

def get_database_schema(engine):    
    inspector2 = inspect(engine)
    DBschema = ""

    tables = inspector2.get_table_names()

    for table in tables:
        DBschema += f"\nTable: {table}\n"

        columns = inspector2.get_columns(table)
        DBschema += "Columns"
        for col in columns:
            DBschema += f"- {col['name']}: {col['type']}, Nullable: {col['nullable']}\n"
        
        p_k = inspector2.get_pk_constraint(table)
        primary_keys = p_k.get('constrained_columns', [])
        if primary_keys:
            DBschema += f"Primary Keys: {', '.join(primary_keys)}\n"

        
        foreign_keys = inspector2.get_foreign_keys(table)

        DBschema += "Foreign Keys\n"
        for fk in foreign_keys:
            DBschema += f"- {fk['constrained_columns']}  ->  {fk['referred_table']}.{fk['referred_columns'][0]}\n"


    return DBschema





### Test get_database_schema() and the DBSchema returned

In [82]:
print(get_database_schema(engine))


Table: Address
Columns- AddressID: INTEGER, Nullable: False
- AddressLine1: NVARCHAR(60) COLLATE "SQL_Latin1_General_CP1_CI_AS", Nullable: False
- AddressLine2: NVARCHAR(60) COLLATE "SQL_Latin1_General_CP1_CI_AS", Nullable: True
- City: NVARCHAR(30) COLLATE "SQL_Latin1_General_CP1_CI_AS", Nullable: False
- StateProvince: NVARCHAR(50) COLLATE "SQL_Latin1_General_CP1_CI_AS", Nullable: False
- CountryRegion: NVARCHAR(50) COLLATE "SQL_Latin1_General_CP1_CI_AS", Nullable: False
- PostalCode: NVARCHAR(15) COLLATE "SQL_Latin1_General_CP1_CI_AS", Nullable: False
- rowguid: UNIQUEIDENTIFIER, Nullable: False
- ModifiedDate: DATETIME, Nullable: False
Foreign Keys

Table: BuildVersion
Columns- SystemInformationID: TINYINT, Nullable: False
- Database Version: NVARCHAR(25) COLLATE "SQL_Latin1_General_CP1_CI_AS", Nullable: False
- VersionDate: DATETIME, Nullable: False
- ModifiedDate: DATETIME, Nullable: False
Foreign Keys

Table: Customer
Columns- CustomerID: INTEGER, Nullable: False
- NameStyle: B

In [None]:
inspector2 = inspect(engine)
DBschema = ""

tables = inspector2.get_table_names(schema="dbo")

for table in tables:
    DBschema += f"\nTable: {table}\n"

    columns = inspector2.get_columns(table, schema="dbo")
    DBschema += "Columns"
    for col in columns:
        DBschema += f"- {col['name']}: {col['type']}, Nullable: {col['nullable']}\n"
    
    p_k = inspector2.get_pk_constraint(table, schema="dbo")
    primary_keys = p_k.get('constrained_columns', [])
    if primary_keys:
        DBschema += f"Primary Keys: {', '.join(primary_keys)}\n"

    
    foreign_keys = inspector2.get_foreign_keys(table, schema="dbo")

    DBschema += "Foreign Keys\n"
    for fk in foreign_keys:
        DBschema += f"- {fk['constrained_columns']}  ->  {fk['referred_table']}.{fk['referred_columns'][0]}\n"


print(DBschema)





Table: Address
Columns- AddressID: INTEGER, Nullable: False
- AddressLine1: NVARCHAR(60) COLLATE "SQL_Latin1_General_CP1_CI_AS", Nullable: False
- AddressLine2: NVARCHAR(60) COLLATE "SQL_Latin1_General_CP1_CI_AS", Nullable: True
- City: NVARCHAR(30) COLLATE "SQL_Latin1_General_CP1_CI_AS", Nullable: False
- StateProvince: NVARCHAR(50) COLLATE "SQL_Latin1_General_CP1_CI_AS", Nullable: False
- CountryRegion: NVARCHAR(50) COLLATE "SQL_Latin1_General_CP1_CI_AS", Nullable: False
- PostalCode: NVARCHAR(15) COLLATE "SQL_Latin1_General_CP1_CI_AS", Nullable: False
- rowguid: UNIQUEIDENTIFIER, Nullable: False
- ModifiedDate: DATETIME, Nullable: False
Foreign Keys

Table: BuildVersion
Columns- SystemInformationID: TINYINT, Nullable: False
- Database Version: NVARCHAR(25) COLLATE "SQL_Latin1_General_CP1_CI_AS", Nullable: False
- VersionDate: DATETIME, Nullable: False
- ModifiedDate: DATETIME, Nullable: False
Foreign Keys

Table: Customer
Columns- CustomerID: INTEGER, Nullable: False
- NameStyle: B

### Get current user

In [7]:
class GetCurrentUser(BaseModel):
    current_user: str = Field(
        description="The name of the current user based on the provided user ID."
    )

def get_current_user(state: AgentState, config: RunnableConfig):
    print("Retrieving the current user based on user ID.")
    user_id = config["configurable"].get("current_user_id", None)
    if not user_id:
        state["current_user"] = "User not found"
        print("No user ID provided in the configuration.")
        return state

    session = SessionLocal()
    try:
        user = session.query(User).filter(User.id == int(user_id)).first()
        if user:
            state["current_user"] = user.name
            print(f"Current user set to: {state['current_user']}")
        else:
            state["current_user"] = "User not found"
            print("User not found in the database.")
    except Exception as e:
        state["current_user"] = "Error retrieving user"
        print(f"Error retrieving user: {str(e)}")
    finally:
        session.close()
    return state


### Check question relevance

In [57]:
class CheckRelevance(BaseModel):
    relevance: str = Field(
        description="Indicates whether the question is related to the database schema. 'relevant' or 'not_relevant'."
    )

def check_relevance(state: AgentState, config: RunnableConfig):
    question = state["question"]
    schema = get_database_schema(engine)
    print(f"Checking relevance of the question: {question}")
    system = """You are an assistant that determines whether a given question is related to the following database schema.

Schema:
{schema}

Respond with only "relevant" or "not_relevant".

For example, questions that mention sales, customers, products or orders are relevant.
""".format(schema=schema)
    human = f"Question: {question}"
    check_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", human),
        ]
    )
    llm = ChatOllama(model="qwen2.5-coder:3b", temperature=0)
    structured_llm = llm.with_structured_output(CheckRelevance)
    relevance_checker = check_prompt | structured_llm
    relevance = relevance_checker.invoke({})
    state["relevance"] = relevance.relevance
    print(f"Relevance determined: {state['relevance']}")
    return state


#### Test for CheckRelevance class

In [54]:
def test_check_relevance():
    # Sample question to test the relevance checker
    question = "What is the city and state of the address with Address ID of 297?"

    # Create an AgentState instance
    state = AgentState()
    state["question"] = question

    # Run the check_relevance method
    config = RunnableConfig()  # You might need to define or configure a RunnableConfig if needed
    new_state = check_relevance(state, config)

    # Check the result
    relevance = new_state["relevance"]
    print(f"Relevance of question: {question} is {relevance}")

    # Expected output: Relevance of question: What is the average order amount per customer? is relevant

test_check_relevance()


Checking relevance of the question: What is the city and state of the address with Address ID of 297?
Relevance determined: relevant
Relevance of question: What is the city and state of the address with Address ID of 297? is relevant


### Convert Question to SQL

In [72]:
class ConvertToSQL(BaseModel):
    sql_query: str = Field(
        description="The SQL query corresponding to the user's natural language question."
    )

def convert_nl_to_sql(state: AgentState, config: RunnableConfig):
    question = state["question"]
    #current_user = state["current_user"]
    schema = get_database_schema(engine)
    print(f"Converting question to SQL for user : {question}")
    system = """You are an assistant that converts natural language questions into SQL queries based on the following schema:

{schema}

Provide only the SQL query without any explanations. Alias columns appropriately to match the expected keys in the result. 



""".format(schema=schema)
    convert_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", "Question: {question}"),
        ]
    )
    llm = ChatOllama(model="qwen2.5-coder:3b", temperature=0)
    structured_llm = llm.with_structured_output(ConvertToSQL)
    sql_generator = convert_prompt | structured_llm
    result = sql_generator.invoke({"question": question})
    state["sql_query"] = result.sql_query
    print(f"Generated SQL query: {state['sql_query']}")
    return state

### Executing the SQL query

In [None]:
def execute_sql(state: AgentState):
    sql_query = state["sql_query"].strip()
    session = SessionLocal()
    print(f"Executing SQL query: {sql_query}")
    try:
        result = session.execute(text(sql_query))
        if sql_query.lower().startswith("select"):
            rows = result.fetchall()
            columns = result.keys()
            if rows:
                header = ", ".join(columns)
                state["query_rows"] = [dict(zip(columns, row)) for row in rows]
                print(f"Raw SQL Query Result: {state['query_rows']}")

                # Format the result for readability
                data_lines = []
                for row in state["query_rows"]:
                    line_data = ", ".join(
                        f"{key}: {value}" if isinstance(value, (str, int)) else value  
                        for key, value in row.items()
                    )
                    print(f"Line Data: {line_data}")  # Debugging line
                    data_lines.append(line_data)

                formatted_result = f"{header}\n" + "\n".join(data_lines)
                print(f"Formatted Result: {formatted_result}") # Debugging line
                print(f"state[query_rows]: {state['query_rows']}") # Debugging line
            else:
                state["query_rows"] = []
                formatted_result = "No results found."
            state["query_result"] = formatted_result
            state["sql_error"] = False
            print("SQL SELECT query executed successfully.")
        else:
            session.commit()
            state["query_result"] = "The action has been successfully completed."
            state["sql_error"] = False
            print("SQL command executed successfully.")
    except Exception as e:
        state["query_result"] = f"Error executing SQL query: {str(e)}"
        state["sql_error"] = True
        print(f"Error executing SQL query: {str(e)}")
    finally:
        session.close()
    return state

### Convert Retrieved Database data to Natural Language

In [None]:
def generate_human_readable_answer(state: AgentState):
    sql = state["sql_query"]
    result = state["query_result"]
    current_user = state["current_user"]
    query_rows = state.get("query_rows", [])
    sql_error = state.get("sql_error", False)
    print("Generating a human-readable answer.")
    system = """You are an assistant that converts SQL query results into clear, natural language responses without including any identifiers like order IDs. Start the response with a friendly greeting that includes the user's name.
    """
    if sql_error:
        # Directly relay the error message
        generate_prompt = ChatPromptTemplate.from_messages(
            [
                ("system", system),
                (
                    "human",
                    f"""SQL Query:
{sql}

Result:
{result}

Formulate a clear and understandable error message in a single sentence, starting with 'Hello {current_user},' informing them about the issue."""
                ),
            ]
        )
    elif sql.lower().startswith("select"):
        if not query_rows:
            # Handle cases with no records
            generate_prompt = ChatPromptTemplate.from_messages(
                [
                    ("system", system),
                    (
                        "human",
                        f"""SQL Query:
{sql}

Result:
{result}

Formulate a clear and understandable answer to the original question in a single sentence, starting with 'Hello,' and mention that there are no records found."""
                    ),
                ]
            )
        else:
            # Handle displaying records
            generate_prompt = ChatPromptTemplate.from_messages(
                [
                    ("system", system),
                    (
                        "human",
                        f"""SQL Query:
{sql}

Result:
{result}

Formulate a clear and understandable answer to the original question in a single sentence, starting with 'Hello,' and list each record found"""
                    ),
                ]
            )
    else:
        # Handle non-select queries
        generate_prompt = ChatPromptTemplate.from_messages(
            [
                ("system", system),
                (
                    "human",
                    f"""SQL Query:
{sql}

Result:
{result}

Formulate a clear and understandable confirmation message in a single sentence, starting with 'Hello,' confirming that your request has been successfully processed."""
                ),
            ]
        )

    llm = ChatOllama(model="qwen2.5-coder:3b", temperature=0)
    human_response = generate_prompt | llm | StrOutputParser()
    answer = human_response.invoke({})
    state["query_result"] = answer
    print("Generated human-readable answer.")
    return state

class RewrittenQuestion(BaseModel):
    question: str = Field(description="The rewritten question.")

def regenerate_query(state: AgentState):
    question = state["question"]
    print("Regenerating the SQL query by rewriting the question.")
    system = """You are an assistant that reformulates an original question to enable more precise SQL queries. Ensure that all necessary details, such as table joins, are preserved to retrieve complete and accurate data.
    """
    rewrite_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            (
                "human",
                f"Original Question: {question}\nReformulate the question to enable more precise SQL queries, ensuring all necessary details are preserved.",
            ),
        ]
    )
    llm = ChatOllama(model="qwen2.5-coder:3b", temperature=0)
    structured_llm = llm.with_structured_output(RewrittenQuestion)
    rewriter = rewrite_prompt | structured_llm
    rewritten = rewriter.invoke({})
    state["question"] = rewritten.question
    state["attempts"] += 1
    print(f"Rewritten question: {state['question']}")
    return state

def generate_funny_response(state: AgentState):
    print("Generating a funny response for an unrelated question.")
    system = """You are a charming and funny assistant who responds in a playful manner.
    """
    human_message = "I can not help with that, but doesn't asking questions make you hungry? You can always order something delicious."
    funny_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", human_message),
        ]
    )
    llm = ChatOllama(model="qwen2.5-coder:3b", temperature=0.7)
    funny_response = funny_prompt | llm | StrOutputParser()
    message = funny_response.invoke({})
    state["query_result"] = message
    print("Generated funny response.")
    return state

def end_max_iterations(state: AgentState):
    state["query_result"] = "Please try again."
    print("Maximum attempts reached. Ending the workflow.")
    return state

def relevance_router(state: AgentState):
    if state["relevance"].lower() == "relevant":
        return "convert_to_sql"
    else:
        return "generate_funny_response"

def check_attempts_router(state: AgentState):
    if state["attempts"] < 3:
        return "convert_to_sql"
    else:
        return "end_max_iterations"

def execute_sql_router(state: AgentState):
    if not state.get("sql_error", False):
        return "generate_human_readable_answer"
    else:
        return "regenerate_query"

workflow = StateGraph(AgentState)

workflow.add_node("get_current_user", get_current_user)
workflow.add_node("check_relevance", check_relevance)
workflow.add_node("convert_to_sql", convert_nl_to_sql)
workflow.add_node("execute_sql", execute_sql)
workflow.add_node("generate_human_readable_answer", generate_human_readable_answer)
workflow.add_node("regenerate_query", regenerate_query)
workflow.add_node("generate_funny_response", generate_funny_response)
workflow.add_node("end_max_iterations", end_max_iterations)

workflow.add_edge("get_current_user", "check_relevance")

workflow.add_conditional_edges(
    "check_relevance",
    relevance_router,
    {
        "convert_to_sql": "convert_to_sql",
        "generate_funny_response": "generate_funny_response",
    },
)

workflow.add_edge("convert_to_sql", "execute_sql")

workflow.add_conditional_edges(
    "execute_sql",
    execute_sql_router,
    {
        "generate_human_readable_answer": "generate_human_readable_answer",
        "regenerate_query": "regenerate_query",
    },
)

workflow.add_conditional_edges(
    "regenerate_query",
    check_attempts_router,
    {
        "convert_to_sql": "convert_to_sql",
        "max_iterations": "end_max_iterations",
    },
)

workflow.add_edge("generate_human_readable_answer", END)
workflow.add_edge("generate_funny_response", END)
workflow.add_edge("end_max_iterations", END)

workflow.set_entry_point("get_current_user")

app = workflow.compile()

In [None]:
from IPython.display import Image, display

try:
    display(Image(app.get_graph(xray=True).draw_mermaid_png()))
except:
    pass

## Testing with dummy user

In [None]:
fake_config = {"configurable": {"current_user_id": "2"}}
user_question_1 = "What is the city of the address with AddressID of 297?"
#user_question_1 = "What is the first name of the customer with the CustomerID of 30?"
result_1 = app.invoke({"question": user_question_1, "attempts": 0}, config=fake_config)
print("Result:", result_1["query_result"])

Retrieving the current user based on user ID.
Error retrieving user: name 'User' is not defined
Checking relevance of the question: What is the ProductDescription of customer 30113's order on 2008-06-13? Use SalesOrderHeader, SalesOrderDetail and ProductDescription tables
Relevance determined: relevant
Converting question to SQL for user : What is the ProductDescription of customer 30113's order on 2008-06-13? Use SalesOrderHeader, SalesOrderDetail and ProductDescription tables
Generated SQL query: SELECT T3.ProductDescription FROM SalesOrderHeader AS T1 INNER JOIN SalesOrderDetail AS T2 ON T1.SalesOrderID = T2.SalesOrderID INNER JOIN ProductDescription AS T3 ON T2.ProductID = T3.ProductID WHERE T1.CustomerID = 30113 AND T1.OrderDate = '2008-06-13'
Executing SQL query: SELECT T3.ProductDescription FROM SalesOrderHeader AS T1 INNER JOIN SalesOrderDetail AS T2 ON T1.SalesOrderID = T2.SalesOrderID INNER JOIN ProductDescription AS T3 ON T2.ProductID = T3.ProductID WHERE T1.CustomerID = 301

KeyError: 'end_max_iterations'

Retrieving the current user based on user ID.
Error retrieving user: name 'User' is not defined
Checking relevance of the question: What is the OrderID of customer 30113's order on 2008-06-13? Use either SalesOrderHeader or SalesOrderDetail tables
Relevance determined: relevant
Converting question to SQL for user : What is the OrderID of customer 30113's order on 2008-06-13? Use either SalesOrderHeader or SalesOrderDetail tables
Generated SQL query: SELECT T1.OrderID FROM SalesOrderHeader AS T1 INNER JOIN Customer AS T2 ON T1.CustomerID = T2.CustomerID WHERE T2.CustomerID = 30113 AND T1.OrderDate = '2008-06-13'
Executing SQL query: SELECT T1.OrderID FROM SalesOrderHeader AS T1 INNER JOIN Customer AS T2 ON T1.CustomerID = T2.CustomerID WHERE T2.CustomerID = 30113 AND T1.OrderDate = '2008-06-13'
Error executing SQL query: (pyodbc.ProgrammingError) ('42S22', "[42S22] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid column name 'OrderID'. (207) (SQLExecDirectW)")
[SQL: SELECT 

### Building the user interface

In [85]:
def display_text(input_text):
    result = app.invoke({"question": input_text, "attempts": 0}, config=fake_config)
    return {"query_result": result["query_result"]}

iface = gr.Interface(fn=display_text, inputs="text", outputs="text")

iface.launch()

* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




Retrieving the current user based on user ID.
Error retrieving user: name 'User' is not defined
Checking relevance of the question: What is the city of the address with AddressID of 297? Use the Address table
Relevance determined: relevant
Converting question to SQL for user : What is the city of the address with AddressID of 297? Use the Address table
Generated SQL query: SELECT City FROM Address WHERE AddressID = 297
Executing SQL query: SELECT City FROM Address WHERE AddressID = 297
Raw SQL Query Result: [{'City': 'Renton'}]
Formatted Result: City
None for $None
state[query_rows: [{'City': 'Renton'}]
SQL SELECT query executed successfully.
Generating a human-readable answer.
Generated human-readable answer.


Traceback (most recent call last):
  File "c:\Users\tyron\anaconda3\envs\tf\lib\site-packages\gradio\queueing.py", line 625, in process_events
    response = await route_utils.call_process_api(
  File "c:\Users\tyron\anaconda3\envs\tf\lib\site-packages\gradio\route_utils.py", line 322, in call_process_api
    output = await app.get_blocks().process_api(
  File "c:\Users\tyron\anaconda3\envs\tf\lib\site-packages\gradio\blocks.py", line 2098, in process_api
    result = await self.call_function(
  File "c:\Users\tyron\anaconda3\envs\tf\lib\site-packages\gradio\blocks.py", line 1645, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
  File "c:\Users\tyron\anaconda3\envs\tf\lib\site-packages\anyio\to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "c:\Users\tyron\anaconda3\envs\tf\lib\site-packages\anyio\_backends\_asyncio.py", line 2364, in run_sync_in_worker_thread
    return await future
  File "c:\U