In [1]:
from google import genai
from dotenv import load_dotenv
import os
load_dotenv()

GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
DB_PASS = os.getenv('DB_PASS')

client = genai.Client(api_key=GOOGLE_API_KEY)

In [18]:
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

def create_db_connection(password):
    # Create connection string
    DATABASE_URL = f"postgresql://postgres:{password}@localhost:5432/postgres"
    
    # Create engine
    engine = create_engine(DATABASE_URL)
    
    # Create session factory
    Session = sessionmaker(bind=engine)
    
    return engine, Session()

In [20]:
from sqlalchemy import inspect, Table, Column, Integer, String, MetaData,DateTime

# Create MetaData instance
metadata = MetaData()

# Get inspector to check existing tables
inspector = inspect(engine)

In [21]:
# Get list of existing tables
existing_tables = inspector.get_table_names()

if 'users' not in existing_tables:
    # Define a sample table if it doesn't exist
    # Define a sample table if it doesn't exist
    users = Table('users', metadata,
        Column('id', Integer, primary_key=True),
        Column('name', String(50)),
        Column('email', String(120)),
        Column('department', String(50)),
        Column('salary', Integer),
        Column('date_of_joining', DateTime),
        extend_existing=True
    )
    # Create the table
    metadata.create_all(engine)
    print("Users table created")
else:
    # If table exists, get it from metadata
    users = Table('users', metadata, autoload_with=engine)
    print("Users table already exists")

# Print all available tables

Users table created


In [22]:
users = Table('users', metadata, autoload_with=engine)

In [23]:
from datetime import datetime, timedelta
import random

# First, drop existing data

# Sample data
sample_departments = ['Engineering', 'Marketing', 'Sales', 'HR', 'Finance']
sample_data = [
    {'name': 'John Smith', 'email': 'john.smith@company.com', 'department': 'Engineering', 'salary': 85000},
    {'name': 'Emma Wilson', 'email': 'emma.wilson@company.com', 'department': 'Marketing', 'salary': 75000},
    {'name': 'Michael Brown', 'email': 'michael.brown@company.com', 'department': 'Sales', 'salary': 90000},
    {'name': 'Sarah Davis', 'email': 'sarah.davis@company.com', 'department': 'HR', 'salary': 65000},
    {'name': 'James Johnson', 'email': 'james.johnson@company.com', 'department': 'Finance', 'salary': 95000},
    {'name': 'Lisa Anderson', 'email': 'lisa.anderson@company.com', 'department': 'Engineering', 'salary': 88000},
    {'name': 'David Martinez', 'email': 'david.martinez@company.com', 'department': 'Sales', 'salary': 82000},
    {'name': 'Jennifer Taylor', 'email': 'jennifer.taylor@company.com', 'department': 'Marketing', 'salary': 72000},
    {'name': 'Robert Wilson', 'email': 'robert.wilson@company.com', 'department': 'Engineering', 'salary': 91000},
    {'name': 'Emily White', 'email': 'emily.white@company.com', 'department': 'HR', 'salary': 68000},
    {'name': 'Daniel Lee', 'email': 'daniel.lee@company.com', 'department': 'Finance', 'salary': 92000},
    {'name': 'Maria Garcia', 'email': 'maria.garcia@company.com', 'department': 'Sales', 'salary': 78000},
    {'name': 'William Turner', 'email': 'william.turner@company.com', 'department': 'Engineering', 'salary': 86000},
    {'name': 'Jessica Brown', 'email': 'jessica.brown@company.com', 'department': 'Marketing', 'salary': 71000},
    {'name': 'Thomas Clark', 'email': 'thomas.clark@company.com', 'department': 'Finance', 'salary': 89000},
    {'name': 'Amanda Rodriguez', 'email': 'amanda.rodriguez@company.com', 'department': 'HR', 'salary': 67000},
    {'name': 'Kevin Miller', 'email': 'kevin.miller@company.com', 'department': 'Engineering', 'salary': 84000},
    {'name': 'Michelle Lee', 'email': 'michelle.lee@company.com', 'department': 'Sales', 'salary': 79000},
    {'name': 'Christopher Davis', 'email': 'chris.davis@company.com', 'department': 'Marketing', 'salary': 73000},
    {'name': 'Rachel Green', 'email': 'rachel.green@company.com', 'department': 'Finance', 'salary': 93000}
]

# Add random join dates between 2020 and 2023
start_date = datetime(2020, 1, 1)
end_date = datetime(2023, 12, 31)

for entry in sample_data:
    days_range = (end_date - start_date).days
    random_days = random.randint(0, days_range)
    entry['date_of_joining'] = start_date + timedelta(days=random_days)

# Insert new data
insert_stmt = users.insert().values(sample_data)
session.execute(insert_stmt)
session.commit()


In [8]:
import pandas as pd
def execute_query(query):
    """
    Execute SQL Query
    
    Parameters:
    query (str): SQL query to execute
    
    Returns:
    pandas.DataFrame: Query results as DataFrame
    """
    try:
        # Execute query and load results into DataFrame
        df_result = pd.read_sql_query(query, engine)
        return df_result
        
    except Exception as e:
        print(f"Error executing query: {str(e)}")
        return None

In [9]:
TEXT_TO_SQL_RULES = """
#### SQL RULES ####
- ONLY USE SELECT statements, NO DELETE, UPDATE OR INSERT etc. statements that might change the data in the database.
- ONLY USE the tables and columns mentioned in the database schema.
- ONLY USE "*" if the user query asks for all the columns of a table.
- ONLY CHOOSE columns belong to the tables mentioned in the database schema.
- YOU MUST USE "JOIN" if you choose columns from multiple tables!
- ALWAYS QUALIFY column names with their table name or table alias to avoid ambiguity (e.g., orders.OrderId, o.OrderId)
- YOU MUST USE "lower(<table_name>.<column_name>) like lower(<value>)" function or "lower(<table_name>.<column_name>) = lower(<value>)" function for case-insensitive comparison!
    - Use "lower(<table_name>.<column_name>) LIKE lower(<value>)" when:
        - The user requests a pattern or partial match.
        - The value is not specific enough to be a single, exact value.
        - Wildcards (%) are needed to capture the pattern.
    - Use "lower(<table_name>.<column_name>) = lower(<value>)" when:
        - The user requests an exact, specific value.
        - There is no ambiguity or pattern in the value.
- ALWAYS CAST the date/time related field to "TIMESTAMP WITH TIME ZONE" type when using them in the query
    - example 1: CAST(properties_closedate AS TIMESTAMP WITH TIME ZONE)
    - example 2: CAST('2024-11-09 00:00:00' AS TIMESTAMP WITH TIME ZONE)
    - example 3: CAST(DATE_TRUNC('month', CURRENT_DATE - INTERVAL '1 month') AS TIMESTAMP WITH TIME ZONE)
- If the user asks for a specific date, please give the date range in SQL query
    - example: "What is the total revenue for the month of 2024-11-01?"
    - answer: "SELECT SUM(r.PriceSum) FROM Revenue r WHERE CAST(r.PurchaseTimestamp AS TIMESTAMP WITH TIME ZONE) >= CAST('2024-11-01 00:00:00' AS TIMESTAMP WITH TIME ZONE) AND CAST(r.PurchaseTimestamp AS TIMESTAMP WITH TIME ZONE) < CAST('2024-11-02 00:00:00' AS TIMESTAMP WITH TIME ZONE)"
- USE THE VIEW TO SIMPLIFY THE QUERY.
- DON'T MISUSE THE VIEW NAME. THE ACTUAL NAME IS FOLLOWING THE CREATE VIEW STATEMENT.
- MUST USE the value of alias from the comment section of the corresponding table or column in the DATABASE SCHEMA section for the column/table alias.
  - EXAMPLE
    DATABASE SCHEMA
    /* {"displayName":"_orders","description":"A model representing the orders data."} */
    CREATE TABLE orders (
      -- {"description":"A column that represents the timestamp when the order was approved.","alias":"_timestamp"}
      ApprovedTimestamp TIMESTAMP
    }

    SQL
    SELECT _orders.ApprovedTimestamp AS _timestamp FROM orders AS _orders;
- DON'T USE '.' in column/table alias, replace '.' with '_' in column/table alias.
- DON'T USE "FILTER(WHERE <expression>)" clause in the generated SQL query.
- DON'T USE "EXTRACT(EPOCH FROM <expression>)" clause in the generated SQL query.
- DON'T USE INTERVAL or generate INTERVAL-like expression in the generated SQL query.
- ONLY USE the following SQL keywords while generating SQL query:
  - Aggregation functions:
    - AVG
    - COUNT
    - MAX
    - MIN
    - SUM
    - ARRAY_AGG
    - BOOL_OR
  - Math functions:
    - ABS
    - CBRT
    - CEIL
    - EXP
    - FLOOR
    - LN
    - ROUND
    - SIGN
    - GREATEST
    - LEAST
    - MOD
    - POWER
  - String functions:
    - LENGTH
    - REVERSE
    - CHR
    - CONCAT
    - FORMAT
    - LOWER
    - LPAD
    - LTRIM
    - POSITION
    - REPLACE
    - RPAD
    - RTRIM
    - STRPOS
    - SUBSTR
    - SUBSTRING
    - TRANSLATE
    - TRIM
    - UPPER
  - Date and Time functions:
    - CURRENT_DATE
    - DATE_TRUNC
    - EXTRACT
  - operators:
    - `+`
    - `-`
    - `*`
    - `/`
    - `||`
    - `<`
    - `>`
    - `>=`
    - `<=`
    - `=`
    - `<>`
    - `!=`
- ONLY USE JSON_QUERY for querying fields if "json_type":"JSON" is identified in the columns comment, NOT the deprecated JSON_EXTRACT_SCALAR function.
    - DON'T USE CAST for JSON fields, ONLY USE the following funtions:
      - LAX_BOOL for boolean fields
      - LAX_FLOAT64 for double and float fields
      - LAX_INT64 for bigint fields
      - LAX_STRING for varchar fields
    - For Example:
      DATA SCHEMA:
        `/* {"displayName":"users","description":"A model representing the users data."} */
        CREATE TABLE users (
            -- {"alias":"address","description":"A JSON object that represents address information of this user.","json_type":"JSON","json_fields":{"json_type":"JSON","address.json.city":{"name":"city","type":"varchar","path":"$.city","properties":{"displayName":"city","description":"City Name."}},"address.json.state":{"name":"state","type":"varchar","path":"$.state","properties":{"displayName":"state","description":"ISO code or name of the state, province or district."}},"address.json.postcode":{"name":"postcode","type":"varchar","path":"$.postcode","properties":{"displayName":"postcode","description":"Postal code."}},"address.json.country":{"name":"country","type":"varchar","path":"$.country","properties":{"displayName":"country","description":"ISO code of the country."}}}}
            address JSON
        )`
      To get the city of address in user table use SQL:
      `SELECT LAX_STRING(JSON_QUERY(u.address, '$.city')) FROM user as u`
- ONLY USE JSON_QUERY_ARRAY for querying "json_type":"JSON_ARRAY" is identified in the comment of the column, NOT the deprecated JSON_EXTRACT_ARRAY.
    - USE UNNEST to analysis each item individually in the ARRAY. YOU MUST SELECT FROM the parent table ahead of the UNNEST ARRAY.
    - The alias of the UNNEST(ARRAY) should be in the format `unnest_table_alias(individual_item_alias)`
      - For Example: `SELECT item FROM UNNEST(ARRAY[1,2,3]) as my_unnested_table(item)`
    - If the items in the ARRAY are JSON objects, use JSON_QUERY to query the fields inside each JSON item.
      - For Example:
      DATA SCHEMA
        `/* {"displayName":"my_table","description":"A test my_table"} */
        CREATE TABLE my_table (
            -- {"alias":"elements","description":"elements column","json_type":"JSON_ARRAY","json_fields":{"json_type":"JSON_ARRAY","elements.json_array.id":{"name":"id","type":"bigint","path":"$.id","properties":{"displayName":"id","description":"data ID."}},"elements.json_array.key":{"name":"key","type":"varchar","path":"$.key","properties":{"displayName":"key","description":"data Key."}},"elements.json_array.value":{"name":"value","type":"varchar","path":"$.value","properties":{"displayName":"value","description":"data Value."}}}}
            elements JSON
        )`
        To get the number of elements in my_table table use SQL:
        `SELECT LAX_INT64(JSON_QUERY(element, '$.number')) FROM my_table as t, UNNEST(JSON_QUERY_ARRAY(elements)) AS my_unnested_table(element) WHERE LAX_FLOAT64(JSON_QUERY(element, '$.value')) > 3.5`
    - To JOIN ON the fields inside UNNEST(ARRAY), YOU MUST SELECT FROM the parent table ahead of the UNNEST syntax, and the alias of the UNNEST(ARRAY) SHOULD BE IN THE FORMAT unnest_table_alias(individual_item_alias)
      - For Example: `SELECT p.column_1, j.column_2 FROM parent_table AS p, join_table AS j JOIN UNNEST(p.array_column) AS unnested(array_item) ON j.id = array_item.id`
- DON'T USE JSON_QUERY and JSON_QUERY_ARRAY when "json_type":"".
- DON'T USE LAX_BOOL, LAX_FLOAT64, LAX_INT64, LAX_STRING when "json_type":"".
"""

In [10]:
sql_generation_system_prompt = f"""
You are a helpful assistant that converts natural language queries into ANSI SQL queries.

Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step.
Also be aware when using where clause to values they might be case sensitive as well.

{TEXT_TO_SQL_RULES}

### FINAL ANSWER FORMAT ###
The final answer must be a ANSI SQL query in JSON format.

{{
    "sql": <SQL_QUERY_STRING>
}}
"""

In [11]:
from google.genai import types

# Use strong types.
config = types.GenerateContentConfig(
    system_instruction=sql_generation_system_prompt
    
)

In [12]:
client = genai.Client(api_key=GOOGLE_API_KEY)

# Generate directly with generate_content.
response = client.models.generate_content(
    model='gemini-2.0-flash',
    config=config,
    contents=[
        "I have a table name called users which has schema id, name, email, department, salary, and date_of_joining. I want check people from sales department whose salary is greater than 80000"
    ],
)

In [13]:
response.text

'```json\n{\n    "sql": "SELECT * FROM users WHERE lower(department) = lower(\'sales\') AND salary > 80000"\n}\n```'

In [14]:
import re
def clean_generation_result(result: str) -> str:
    def _normalize_whitespace(s: str) -> str:
        return re.sub(r"\s+", " ", s).strip()

    return (
        _normalize_whitespace(result)
        .replace("\\n", " ")
        .replace("```sql", "")
        .replace("```json", "")
        .replace('"""', "")
        .replace("'''", "")
        .replace("```", "")
        .replace(";", "")
    )

In [15]:
import json
sql_query = json.loads(clean_generation_result(response.text))
sql_query

{'sql': "SELECT * FROM users WHERE lower(department) = lower('sales') AND salary > 80000"}

In [16]:
sql_query.get('sql')

"SELECT * FROM users WHERE lower(department) = lower('sales') AND salary > 80000"

In [24]:
# Get the SQL query from the response
print("Generated SQL Query:")
print(sql_query.get('sql'))

# Execute the query
result = execute_query(sql_query.get('sql'))
result

Generated SQL Query:
SELECT * FROM users WHERE lower(department) = lower('sales') AND salary > 80000


Unnamed: 0,id,name,email,department,salary,date_of_joining
0,3,Michael Brown,michael.brown@company.com,Sales,90000,2023-12-29
1,7,David Martinez,david.martinez@company.com,Sales,82000,2022-09-09
