In [None]:
import json
import os
import re

import pandas as pd
from dotenv import load_dotenv
import google.generativeai as genai
from sqlalchemy import create_engine, inspect
from sqlalchemy.orm import sessionmaker


load_dotenv()

GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
PASS = os.getenv("PASS")
DATABASE = os.getenv("DATABASE")
USER = os.getenv("USER")
HOST = os.getenv("HOST")
DATABASE_CLIENT = os.getenv("DATABASE_CLIENT")
PORT = os.getenv("PORT")

In [2]:
def create_db_connection(password):
    # Create connection string
    DATABASE_URL = f"{DATABASE_CLIENT}://{USER}:{password}@{HOST}:{PORT}/{DATABASE}"
    engine = create_engine(DATABASE_URL)

    # Create session factory
    Session = sessionmaker(bind=engine)

    return engine, Session()

In [3]:
engine, session = create_db_connection(PASS)

In [4]:
inspector = inspect(engine)

In [5]:
existing_tables = inspector.get_table_names()
existing_tables

['flights', 'airlines', 'airports']

In [6]:
from prompts import (
  background_for_context,
  airlines_prompt,
  airports_prompt,
  flights_prompt,
  TEXT_TO_SQL_RULES
)

In [7]:
sql_generation_system_prompt = f"""
You are a helpful assistant that converts natural language queries into ANSI SQL queries.

###Your Background###

{background_for_context}

###Database Schema###

The database contains three tables: airlines, airports and flights. The detail are as follows:

{airlines_prompt}

{airports_prompt}

{flights_prompt}


Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step.
Also be aware when using where clause to values they might be case sensitive as well.

{TEXT_TO_SQL_RULES}

### FINAL ANSWER FORMAT ###
The final answer must be a ANSI SQL query in JSON format. Strictly provide the output, no explanation details or summarizing.

{{
    "sql": <SQL_QUERY_STRING>
}}
"""


In [8]:
print(sql_generation_system_prompt)


You are a helpful assistant that converts natural language queries into ANSI SQL queries.

###Your Background###


The U.S. Department of Transportation's (DOT) Bureau of Transportation Statistics tracks the on-time performance of domestic flights operated by large air carriers. 
Summary information on the number of on-time, delayed, canceled, and diverted flights is published in DOT's monthly Air Travel Consumer Report and in this dataset of 2015 flight delays and cancellations.


###Database Schema###

The database contains three tables: airlines, airports and flights. The detail are as follows:


airlines table: This table provides a information about IATA codes and their full names.
Columns:
iata_code: Two-letter IATA code assigned to the airline.
airline: Full name of the airline.



airports table: This table contains information about various airports, including their locations.
Columns:
iata_code: Three-letter IATA airport code.
airport: Name of the airport.
city: City where t

In [None]:
from typing import TypedDict

class SQL(TypedDict):
    results: str

In [None]:
model = genai.GenerativeModel(model_name="models/gemini-2.0-flash")

result = model.generate_content(
    "I want all of the information about airlines.",
    generation_config=genai.GenerationConfig(
        response_mime_type="application/json",
        response_schema={
            "type": "object",
            "properties": {
                "sql": {"type": "string"}
            },
            "required": ["sql"],
        },
    ),
    request_options={"timeout": 600},
)


In [23]:
result.text


'{\n  "sql": "SELECT * FROM airlines"\n}'

In [24]:
def execute_query(query):
    """
    Execute SQL Query

    Parameters:
    query (str): SQL query to execute

    Returns:
    pandas.DataFrame: Query results as DataFrame
    """
    try:
        # Execute query and load results into DataFrame
        df_result = pd.read_sql_query(query, engine)
        return df_result

    except Exception as e:
        print(f"Error executing query: {str(e)}")
        return None


In [25]:
def clean_generation_result(result: str) -> str:
    def _normalize_whitespace(s: str) -> str:
        return re.sub(r"\s+", " ", s).strip()

    return (
        _normalize_whitespace(result)
        .replace("\\n", " ")
        .replace("```sql", "")
        .replace("```json", "")
        .replace('"""', "")
        .replace("'''", "")
        .replace("```", "")
        .replace(";", "")
    )


In [27]:
sql_query = json.loads(clean_generation_result(result.text))
sql_query


{'sql': 'SELECT * FROM airlines'}

In [28]:
# Get the SQL query from the response
print("Generated SQL Query:")
print(sql_query.get("sql"))

# Execute the query
result = execute_query(sql_query.get("sql"))
result


Generated SQL Query:
SELECT * FROM airlines


Unnamed: 0,iata_code,airline
0,UA,United Air Lines Inc.
1,AA,American Airlines Inc.
2,US,US Airways Inc.
3,F9,Frontier Airlines Inc.
4,B6,JetBlue Airways
5,OO,Skywest Airlines Inc.
6,AS,Alaska Airlines Inc.
7,NK,Spirit Air Lines
8,WN,Southwest Airlines Co.
9,DL,Delta Air Lines Inc.


In [3]:
print("SELECT *\nFROM users\nWHERE CAST(date_of_joining AS TIMESTAMP WITH TIME ZONE) >= CAST(\'2023-01-01 00:00:00\' AS TIMESTAMP WITH TIME ZONE)\n  AND CAST(date_of_joining AS TIMESTAMP WITH TIME ZONE) < CAST(\'2024-01-01 00:00:00\' AS TIMESTAMP WITH TIME ZONE);")

SELECT *
FROM users
WHERE CAST(date_of_joining AS TIMESTAMP WITH TIME ZONE) >= CAST('2023-01-01 00:00:00' AS TIMESTAMP WITH TIME ZONE)
  AND CAST(date_of_joining AS TIMESTAMP WITH TIME ZONE) < CAST('2024-01-01 00:00:00' AS TIMESTAMP WITH TIME ZONE);
