In [None]:
import os
from dotenv import load_dotenv
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.utilities import SQLDatabase

load_dotenv()

True

**Set the environment variable and load the LLM**

In [None]:
llm = ChatGoogleGenerativeAI(
    model=os.getenv("MODEL_NAME"), 
    temperature=os.getenv("TEMPERATURE"),
    google_api_key=os.getenv("GEMINI_API_KEY"),
)

**Load and test the sqlite db**

In [4]:
sqldb_directory = here("data/travel.sqlite")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM aircrafts_data LIMIT 10;")

sqlite
['aircrafts_data', 'airports_data', 'boarding_passes', 'bookings', 'car_rentals', 'flights', 'hotels', 'seats', 'ticket_flights', 'tickets', 'trip_recommendations']


"[('773', 'Boeing 777-300', 11100), ('763', 'Boeing 767-300', 7900), ('SU9', 'Sukhoi Superjet-100', 3000), ('320', 'Airbus A320-200', 5700), ('321', 'Airbus A321-200', 5600), ('319', 'Airbus A319-100', 6700), ('733', 'Boeing 737-300', 4200), ('CN1', 'Cessna 208 Caravan', 1200), ('CR2', 'Bombardier CRJ-200', 2700)]"

**Create the SQL agent and run a test query**

In [5]:
import re

chain = create_sql_query_chain(llm, db)

def extract_sql_query(text: str) -> str:
    """
    Extract a clean SQL query from a possibly formatted or prefixed LLM output.
    Handles code blocks, 'SQLQuery:' prefix, and various SQL command types.
    """
    text = text.strip()

    # Step 1: Remove code block markers
    if text.startswith("```"):
        lines = text.splitlines()
        # Remove first/last lines if they're ``` or ```sql
        if lines[0].startswith("```") and lines[-1].startswith("```"):
            text = "\n".join(lines[1:-1]).strip()

    # Step 2: Remove known prefixes like "SQLQuery:"
    text = re.sub(r"^(SQLQuery:|Query:)\s*", "", text, flags=re.IGNORECASE)

    # Step 3: Extract only the actual SQL statement (any type)
    match = re.search(
        r"(SELECT|INSERT|UPDATE|DELETE|PRAGMA|CREATE|DROP|ALTER|DESCRIBE|SHOW)\s.+",
        text,
        flags=re.IGNORECASE | re.DOTALL,
    )
    if match:
        query = match.group(0).strip()
        # Remove trailing backticks or semicolons
        return query.rstrip(";`").strip()

    # Step 4: If nothing matched, return original (but cleaned)
    return text


response = chain.invoke({"question": "How many rows are there in the aircrafts_data table?"})
sql_query = extract_sql_query(response)
print(sql_query)

SELECT COUNT(*) FROM aircrafts_data


In [6]:
db.run(sql_query)

'[(9,)]'