In [None]:
from pydantic import BaseModel, Field
from typing import List, Optional


class RelevantTables(BaseModel):
    tables: List[str] = Field(
        ...,
        description="List of table names deemed relevant for answering the user's question.",
    )


class SQLQuery(BaseModel):
    sql_query: str = Field(
        ...,
        description="A syntactically correct SQLite query to address the user's question based on the provided schema.",
    )


class UserIntent(BaseModel):
    instruction: Optional[str] = Field(
        default=None,
        description="Concise English instruction for a database query agent, generated if the user input is a database query.",
    )
    response: Optional[str] = Field(
        default=None,
        description="Polite response in Arabic, generated if the user input is a greeting or not a database query.",
    )

In [2]:
import logging
from typing import Optional

from langchain_ollama.chat_models import ChatOllama
from langgraph.graph import MessagesState

DB_PATH = r"lms.db"
LLM_MODEL = "qwen3:4b"
MAX_QUERY_RETRIES = 2
LOG_LEVEL = logging.INFO

logging.basicConfig(level=LOG_LEVEL, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)


class State(MessagesState):
    instruction: Optional[str] = None
    response: Optional[str] = None
    schema: Optional[str] = None
    error_count: int = 0
    last_query: Optional[str] = None
    final_answer: Optional[str] = None


LLM = ChatOllama(model=LLM_MODEL, temperature=0.0)

  from pydantic.v1.fields import FieldInfo as FieldInfoV1


In [3]:
import sqlite3
from typing import Dict, List, Tuple


def get_db_connection() -> sqlite3.Connection:
    try:
        return sqlite3.connect(DB_PATH)
    except sqlite3.Error as e:
        raise


def get_table_and_column_names() -> Dict[str, List[str]]:
    tables_columns = {}
    try:
        with get_db_connection() as connection:
            cursor = connection.cursor()
            cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
            tables = [table[0] for table in cursor.fetchall()]
            for table in tables:
                try:
                    cursor.execute(f"PRAGMA table_info(`{table}`);")
                    columns = [col[1] for col in cursor.fetchall()]
                    tables_columns[table] = columns
                except sqlite3.Error as e:
                    logger.warning(f"Could not get columns for table '{table}': {e}")
            return tables_columns
    except sqlite3.Error as e:
        return {}


def _execute_query(query: str) -> Tuple[str, List[tuple]]:
    if "SELECT" not in query.strip().upper():
        raise PermissionError("Error: Only SELECT queries are permitted for execution.")

    try:
        with get_db_connection() as connection:
            cursor = connection.cursor()
            cursor.execute(query)
            results = cursor.fetchall()
            column_names = (
                [desc[0] for desc in cursor.description] if cursor.description else []
            )
            return column_names, results
    except sqlite3.Error as e:
        raise sqlite3.Error(f"Error: Query failed. Details: {e}") from e

In [4]:
import sqlite3
from typing import List

from langchain.tools import tool


@tool("get_tables_info")
def get_tables_info(tables: List[str]) -> str:
    """
    Retrieve the schema (columns, types, constraints, FKs) and the first 3 rows
    for each of the specified tables.

    Args:
        tables: A list of table names to retrieve information for.
    """
    if not tables:
        return "No tables specified."

    tables_info_parts = []
    try:
        with get_db_connection() as connection:
            cursor = connection.cursor()
            for table in tables:
                table_info_str = f"Table: `{table}`\n"
                try:
                    cursor.execute(f"PRAGMA table_info(`{table}`);")
                    schema_info = cursor.fetchall()
                    if not schema_info:
                        tables_info_parts.append(
                            f"Warning: Table '{table}' not found or has no columns."
                        )
                        continue

                    schema_details = [
                        {
                            "name": col[1],
                            "type": col[2],
                            "notnull": bool(col[3]),
                            "pk": bool(col[5]),
                        }
                        for col in schema_info
                    ]
                    table_info_str += f"Schema: {schema_details}\n"

                    cursor.execute(f"PRAGMA foreign_key_list(`{table}`);")
                    fk_info = cursor.fetchall()
                    if fk_info:
                        fk_details = [
                            {
                                "from_column": fk[3],
                                "to_table": fk[2],
                                "to_column": fk[4],
                            }
                            for fk in fk_info
                        ]
                        table_info_str += f"Foreign Keys: {fk_details}\n"

                    cursor.execute(f"SELECT * FROM `{table}` LIMIT 3;")
                    rows = cursor.fetchall()
                    column_names = [desc[0] for desc in cursor.description]
                    table_info_str += f"Sample Rows (Columns: {column_names}):\n{rows}"

                    tables_info_parts.append(table_info_str)

                except sqlite3.Error as e:
                    tables_info_parts.append(
                        f"Error retrieving info for table '{table}': {e}"
                    )

        return (
            "\n\n".join(tables_info_parts)
            if tables_info_parts
            else "Could not retrieve info for any specified table."
        )

    except sqlite3.Error as e:
        return f"Error: Could not connect to the database. {e}"

In [5]:
import sqlite3

from langchain.tools import tool


@tool("execute_sql_query")
def db_query_tool(query: str) -> str:
    """
    Executes a **SELECT** SQL query against the SQLite database and returns the result
    including column names.
    Input should be a valid SQLite SELECT query string.
    If the query is incorrect or disallowed (not SELECT), an error message is returned.
    If an error occurs, check the schema, rewrite the query, and try again.
    **Security Note**: Only SELECT queries are permitted.
    """
    try:
        column_names, result_rows = _execute_query(query)
        if not result_rows:
            return f"Query executed successfully, but returned no results.\nColumns: {column_names}"
        else:
            return f"Query executed successfully.\nColumns: {column_names}\nResults:\n{str(result_rows)}"
    except (sqlite3.Error, PermissionError) as e:
        return f"{e}"

In [6]:
from langchain_core.messages import SystemMessage, AIMessage, ToolMessage


def generate_sql_query(state: State) -> State:
    # Getting state keys
    schema = state["schema"]
    instruction = state["instruction"]
    messages = state["messages"]
    last_query = state.get("last_query")

    # Check tool back with error so correction mode
    last_message = messages[-1]
    correction_mode = isinstance(last_message, ToolMessage)

    # First common part with schema which is required for both generator and corrector
    system_prompt_parts = [
        "You are an expert SQLite database engineer.",
        f"Database Schema and Sample Data:\n{schema}\n",
    ]

    # If coccrection mode add correction schema else add generation schema
    if correction_mode:
        error_details = last_message.content
        system_prompt_parts.extend(
            [
                f"The previous SQLite query attempt failed.",
                f"{error_details}",
                f"The failed query was:\n```sql\n{last_query}\n```",
                "Carefully analyze the schema, the error message, and the failed query.",
                "Provide a corrected SQLite query to address the original user instruction:",
                f"User Instruction: {instruction}\n",
                "Output *only* the corrected SQLite query using the 'SQLQuery' format.",
            ]
        )
    else:
        system_prompt_parts.extend(
            [
                "Write a precise SQLite query to answer the user's question based *only* on the provided schema and sample data.",
                f"User Question (Instruction): {instruction}\n",
                "Ensure the query is valid SQLite syntax. Use backticks (`) for table/column names if necessary.",
                "Use the 'SQLQuery' format.",
            ]
        )

    # Structure llm for produce correct string sql query
    system_prompt = "\n".join(system_prompt_parts)
    llm_sql = LLM.with_structured_output(SQLQuery)

    res = llm_sql.invoke([SystemMessage(content=system_prompt)])
    sql_query = res.sql_query

    # Add query as an arg to 'execute_sql_query'
    ai_message_with_tool_call = AIMessage(
        content="Generated SQL query, attempting execution.",
        tool_calls=[
            {
                "id": f"tool_call_{hash(sql_query)}",
                "name": "execute_sql_query",
                "args": {"query": sql_query},
            }
        ],
    )

    # Update with last query for possible correction later
    # Update error counter for max retrials
    return {
        "messages": [ai_message_with_tool_call],
        "last_query": sql_query,
        "error_count": state.get("error_count", 0) + 1,
    }

In [7]:
from langchain_core.messages import SystemMessage


def identify_relevant_tables(state: State) -> State:
    # Handle where no instructions
    user_instruction = state["instruction"]
    if not user_instruction:
        return {"schema": "Error: Cannot identify tables without user instruction."}

    try:
        # Get tables names and columns and handle cases
        tables_with_columns = get_table_and_column_names()

        if not tables_with_columns:
            error_message = "Error: Could not retrieve table/column names from the database. Cannot proceed."
            return {"schema": error_message}

        # Prepare tables information
        tables_columns_str = "\n".join(
            [
                f"- Table: `{name}`, Columns: {', '.join(cols)}"
                for name, cols in tables_with_columns.items()
            ]
        )

        system_prompt = (
            "You are an expert database analyst. Your task is to identify which tables "
            "from the available list are strictly necessary to answer the user's question.\n"
            "Consider the table names and their columns provided below.\n\n"
            f"Available tables and their columns:\n{tables_columns_str}\n\n"
            f"User Question (Instruction): {user_instruction}\n\n"
            "Based *only* on the table names, column names, and the user question, list the tables "
            "that seem most relevant. If multiple tables are needed (e.g., for joins), include them all. "
            "Use the 'RelevantTables' format."
        )

        # Structured output -> [str, str, ..] as tables names
        llm_tables = LLM.with_structured_output(RelevantTables)
        relevant_tables = llm_tables.invoke([SystemMessage(content=system_prompt)])

        if not relevant_tables.tables:
            return {"schema": "No relevant tables were identified for the question."}

        schema = get_tables_info.invoke({"tables": relevant_tables.tables})
        return {"schema": schema}

    except Exception as e:
        return {
            "schema": f"Error: Failed to identify relevant tables or fetch schema. Details: {e}"
        }

In [8]:
from langchain_core.messages import SystemMessage, HumanMessage


def manager(state: State) -> State:
    # Get messages and bool flage to check if this is the first invoke
    messages = state["messages"]
    is_initial_call = not state.get("instruction") and not state.get("final_answer")

    # When the first invoke prepare prompts and structured output:
    # 1. fill the final_result state with a suitable response for an out of scope question
    # --OR--
    # 2. fill the instruction state with a consise english sql instruction for sql agency
    if is_initial_call and messages and isinstance(messages[0], HumanMessage):

        system_prompt = (
            "You are an intelligent assistant processing user input for a database query system.\n"
            "Analyze the user's Arabic input. Determine if it is a database-related query or something else (like a greeting).\n\n"
            "1.  **If the input IS a database query:** Translate it accurately into a concise English instruction suitable for a database query agent. Populate the 'instruction' field.\n"
            "2.  **If the input is a greeting (e.g., 'hi', 'hello', 'السلام عليكم', 'مرحبا') OR any other non-database query:** Generate a polite, brief response *in Arabic* acknowledging the user or explaining your purpose (database queries). Populate the 'response' field.\n\n"
            f"User's Arabic Input: '{messages[0].content}'\n\n"
            "Respond using the 'UserIntent' format. Populate *only one* field: 'instruction' (for queries) or 'response' (for non-queries/greetings)."
        )

        llm_intent = LLM.with_structured_output(UserIntent)
        res = llm_intent.invoke([SystemMessage(content=system_prompt)] + messages)

        if res.instruction:
            return {"instruction": res.instruction}
        elif res.response:
            return {"final_answer": res.response}
        else:
            return {"final_answer": "عذراً، حدث خطأ غير متوقع أثناء تحليل طلبك."}

    # If not initial invoke then we are to response
    else:
        original_user_query = messages[0].content
        sql_results = messages[-1].content

        # If tool message dont has error respond with arabic answer
        if not sql_results.startswith("Error:"):
            system_prompt = (
                f"You are a helpful assistant. The user asked in Arabic: '{original_user_query}'\n\n"
                f"The following data was retrieved from the database:\n---\n{sql_results}\n---\n\n"
                "Based *strictly* on the retrieved data, provide a clear, concise, and helpful answer to the user's question in Arabic.\n"
                "Do not add any information not present in the data.\n"
                "If the data is empty ('returned no results') or doesn't directly answer the question, state that politely in Arabic.\n"
                "Answer:"
            )

        # If error say that you can not answer
        else:
            system_prompt = (
                f"You are a helpful assistant communicating in Arabic.\n"
                f"The user asked: '{original_user_query}'\n"
                f"Unfortunately, there was a problem trying to answer the question using the database.\n"
                f"The final status was: {sql_results}\n\n"
                "Politely inform the user in Arabic that their request could not be completed due to this issue. Do not suggest code."
                "Just state the problem clearly and concisely."
                "Answer:"
            )

        res = LLM.invoke([SystemMessage(content=system_prompt)])

        # return to final_answer so the router decide END
        return {"final_answer": res.content}

In [9]:
from langchain_core.runnables import RunnableLambda
from langchain_core.messages import ToolMessage
from langgraph.prebuilt import ToolNode


# When a tool error the state key 'error' will contain thre error
# so we run the handler to get that error and formate it and add it to the messages for correction
def handle_tool_error(state: State) -> State:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ],
    }


sql_execution_node = ToolNode([db_query_tool]).with_fallbacks(
    [RunnableLambda(handle_tool_error)], exception_key="error"
)

In [None]:
from langchain_core.messages import ToolMessage
from langgraph.graph import END


def should_regenerate_query(state: State) -> str:
    messages = state["messages"]
    error_count = state.get("error_count", 0)
    last_message = messages[-1]

    # If last message is tool with error and in the valid retries range
    # go to correction else manager.
    if (
        isinstance(last_message, ToolMessage)
        and last_message.content.startswith("Error:")
        and error_count <= MAX_QUERY_RETRIES
    ):
        return "generate_sql_query"
    else:
        return "manager"


def route_after_manager(state: State) -> str:
    if state.get("final_answer"):
        return END
    elif state.get("instruction"):
        return "identify_relevant_tables"


from langchain_core.runnables import RunnableLambda
from langchain_core.messages import ToolMessage
from langgraph.prebuilt import ToolNode


# When a tool error the state key 'error' will contain thre error
# so we run the handler to get that error and formate it and add it to the messages for correction
def handle_tool_error(state: State) -> State:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ],
    }


sql_execution_node = ToolNode([db_query_tool]).with_fallbacks(
    [RunnableLambda(handle_tool_error)], exception_key="error"
)

In [11]:
from langgraph.graph import StateGraph, END


workflow = StateGraph(State)

workflow.add_node("manager", manager)
workflow.add_node("identify_relevant_tables", identify_relevant_tables)
workflow.add_node("generate_sql_query", generate_sql_query)
workflow.add_node("execute_sql", sql_execution_node)

workflow.set_entry_point("manager")

workflow.add_conditional_edges(
    "manager",
    route_after_manager,
    {
        "identify_relevant_tables": "identify_relevant_tables",
        END: END,
    },
)

workflow.add_edge("identify_relevant_tables", "generate_sql_query")
workflow.add_edge("generate_sql_query", "execute_sql")

workflow.add_conditional_edges(
    "execute_sql",
    should_regenerate_query,
    {
        "generate_sql_query": "generate_sql_query",
        "manager": "manager",
    },
)

GRAPH = workflow.compile()

In [12]:
initial_state = State(messages=[HumanMessage(content="من هم الطلاب في قسم علوم الحاسب؟")])
for event in GRAPH.stream(initial_state, {"recursion_limit": 50}):
    print(event)

{'manager': {'instruction': 'List all students in the Computer Science department.'}}
{'identify_relevant_tables': {'schema': "Table: `departments`\nSchema: [{'name': 'department_id', 'type': 'INTEGER', 'notnull': False, 'pk': True}, {'name': 'name', 'type': 'TEXT', 'notnull': True, 'pk': False}]\nSample Rows (Columns: ['department_id', 'name']):\n[(1, 'Computer Science'), (2, 'Mathematics'), (3, 'History')]\n\nTable: `students`\nSchema: [{'name': 'student_id', 'type': 'INTEGER', 'notnull': False, 'pk': True}, {'name': 'name', 'type': 'TEXT', 'notnull': True, 'pk': False}, {'name': 'department_id', 'type': 'INTEGER', 'notnull': False, 'pk': False}]\nForeign Keys: [{'from_column': 'department_id', 'to_table': 'departments', 'to_column': 'department_id'}]\nSample Rows (Columns: ['student_id', 'name', 'department_id']):\n[(1, 'John Doe', 1), (2, 'Jane Roe', 2), (3, 'Mark Spencer', 1)]"}}
{'generate_sql_query': {'messages': [AIMessage(content='Generated SQL query, attempting execution.', a