In [None]:
from pyprojroot import here
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.prompts import PromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from operator import itemgetter
import os
from dotenv import load_dotenv
load_dotenv()

True

**Set the environment variables and load the LLM**

In [None]:
llm = ChatGoogleGenerativeAI(
    model=os.getenv("MODEL_NAME"),
    temperature=os.getenv("TEMPERATURE"),
    google_api_key=os.getenv("GEMINI_API_KEY")
)

**Load and test the sqlite db**

In [None]:
sqldb_directory = here("data/travel.sqlite")
db = SQLDatabase.from_uri(
    f"sqlite:///{sqldb_directory}")

print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM aircrafts_data LIMIT 10;")

sqlite
['aircrafts_data', 'airports_data', 'boarding_passes', 'bookings', 'car_rentals', 'flights', 'hotels', 'seats', 'ticket_flights', 'tickets', 'trip_recommendations']


"[('773', 'Boeing 777-300', 11100), ('763', 'Boeing 767-300', 7900), ('SU9', 'Sukhoi Superjet-100', 3000), ('320', 'Airbus A320-200', 5700), ('321', 'Airbus A321-200', 5600), ('319', 'Airbus A319-100', 6700), ('733', 'Boeing 737-300', 4200), ('CN1', 'Cessna 208 Caravan', 1200), ('CR2', 'Bombardier CRJ-200', 2700)]"

**Create the SQL agent chain and run a test query**

In [4]:
import re

def extract_sql_query(text: str) -> str:
    """
    Extract a clean SQL query from a possibly formatted or prefixed LLM output.
    Handles code blocks, 'SQLQuery:' prefix, and various SQL command types.
    """
    text = text.strip()

    # Step 1: Remove code block markers
    if text.startswith("```"):
        lines = text.splitlines()
        # Remove first/last lines if they're ``` or ```sql
        if lines[0].startswith("```") and lines[-1].startswith("```"):
            text = "\n".join(lines[1:-1]).strip()

    # Step 2: Remove known prefixes like "SQLQuery:"
    text = re.sub(r"^(SQLQuery:|Query:)\s*", "", text, flags=re.IGNORECASE)

    # Step 3: Extract only the actual SQL statement (any type)
    match = re.search(
        r"(SELECT|INSERT|UPDATE|DELETE|PRAGMA|CREATE|DROP|ALTER|DESCRIBE|SHOW)\s.+",
        text,
        flags=re.IGNORECASE | re.DOTALL,
    )
    if match:
        query = match.group(0).strip()
        # Remove trailing backticks or semicolons
        return query.rstrip(";`").strip()

    # Step 4: If nothing matched, return original (but cleaned)
    return text


In [None]:
system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
    Question: {question}\n
    SQL Query: {query}\n
    SQL Result: {result}\n
    Answer:
    """

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(
    llm, db)
clean_query_output = RunnableLambda(extract_sql_query)

answer_prompt = PromptTemplate.from_template(
    system_role)
answer = answer_prompt | llm | StrOutputParser()

chain = (
    RunnablePassthrough.assign(raw_query=write_query)
    .assign(query=itemgetter("raw_query") | clean_query_output)
    .assign(result=itemgetter("query") | execute_query)
    | answer
)

  execute_query = QuerySQLDataBaseTool(db=db)


In [6]:
message = "How many tables do I have in the database? and what are their names?"
response = chain.invoke({"question": message})
response

'You have 11 tables in the database. Their names are: aircrafts_data, airports_data, boarding_passes, bookings, flights, seats, ticket_flights, tickets, car_rentals, hotels, and trip_recommendations.'

**Travel SQL-agent Tool Design**

In [None]:
from langchain_core.tools import tool
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.prompts import PromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from operator import itemgetter
from .extract_sql_query import extract_sql_query
from agent_graph.load_tools_config import LoadToolsConfig

TOOLS_CFG = LoadToolsConfig()


class TravelSQLAgentTool:   
    """
    A LangChain-based tool for querying a travel-related SQL database using natural language.

    This agent uses a language model to convert user questions into SQL queries, executes those queries 
    against a SQLite database, and generates human-readable responses based on query results.

    Attributes:
        sql_agent_llm (ChatGoogleGenerativeAI): The LLM used for query generation and answering.
        system_role (str): Prompt template to guide the model in formatting final answers.
        db (SQLDatabase): SQLite database instance.
        chain (Runnable): LangChain pipeline that generates SQL, runs the query, and returns an answer.

    Methods:
        __init__: Initializes the TravelSQLAgentTool by setting up the language model, SQL database, and query-answering pipeline.
    """

    def __init__(self, llm: str, sqldb_directory: str, llm_temerature: float, llm_api_key: str) -> None:
        """
        Initialize the TravelSQLAgentTool with model and database configurations.

        Args:
            llm (str): Name of the language model to use (e.g., 'gemini-2.5-flash').
            sqldb_directory (str): Path to the SQLite database file.
            llm_temerature (float): Temperature setting for the model (controls randomness).
            llm_api_key (str): API key for the language model provider.
        """
        self.sql_agent_llm = ChatGoogleGenerativeAI(
            model=llm,
            temperature=llm_temerature,
            google_api_key=llm_api_key
        )
        self.system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
            Question: {question}\n
            SQL Query: {query}\n
            SQL Result: {result}\n
            Answer:
            """
        self.db = SQLDatabase.from_uri(
            f"sqlite:///{sqldb_directory}")
        print(self.db.get_usable_table_names())

        execute_query = QuerySQLDataBaseTool(db=self.db)
        write_query = create_sql_query_chain(
            self.sql_agent_llm, self.db)
        clean_query_output = RunnableLambda(extract_sql_query)
        answer_prompt = PromptTemplate.from_template(
            self.system_role)

        answer = answer_prompt | self.sql_agent_llm | StrOutputParser()
        self.chain = (
            RunnablePassthrough.assign(raw_query=write_query)
            .assign(query=itemgetter("raw_query") | clean_query_output)
            .assign(result=itemgetter("query") | execute_query)
            | answer
        )


@tool
def query_travel_sqldb(query: str) -> str:
    """Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
    agent = TravelSQLAgentTool(
        llm=TOOLS_CFG.travel_sqlagent_llm,
        sqldb_directory=TOOLS_CFG.travel_sqldb_directory,
        llm_temerature=TOOLS_CFG.travel_sqlagent_llm_temperature,
        llm_api_key=TOOLS_CFG.travel_sqlagent_api_key
    )
    response = agent.chain.invoke({"question": query})
    return response


In [None]:
from agent_graph.load_tools_config import LoadToolsConfig

TOOLS_CFG = LoadToolsConfig()

@tool
def query_travel_sqldb(query: str) -> str:
    """Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
    agent = TravelSQLAgentTool(
        llm=TOOLS_CFG.travel_sqlagent_llm,
        sqldb_directory=TOOLS_CFG.travel_sqldb_directory,
        llm_temerature=TOOLS_CFG.travel_sqlagent_llm_temperature
    )
    response = agent.chain.invoke({"question": query})
    return response