In this notebook, I have explored the retrying mechanisms for the LLMs if it generates the wrong query.

In [None]:
import json
import re
from llm import GenerativeModelWrapper
from src.config import config
from src.db.database import create_db_connection
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

In [None]:
user_query = (
    "calculate the month-over-month percent change in the number of flights for 2015"
)
error_message = '500: Error executing query: (psycopg2.errors.UndefinedObject) type "double" does not exist\nLINE 1: ...hts ) SELECT  flight_month,  (CAST(num_flights AS DOUBLE) - ...\n                                                             ^\n\n[SQL: WITH MonthlyFlights AS (  SELECT  CAST(flights.month AS INT) AS flight_month,  COUNT(*) AS num_flights  FROM flights  WHERE  flights.year = 2015  GROUP BY  flights.month ), LaggedFlights AS (  SELECT  flight_month,  num_flights,  LAG(num_flights, 1, 0) OVER (ORDER BY flight_month) AS previous_month_flights  FROM MonthlyFlights ) SELECT  flight_month,  (CAST(num_flights AS DOUBLE) - CAST(previous_month_flights AS DOUBLE)) / CAST(previous_month_flights AS DOUBLE) AS month_over_month_change FROM LaggedFlights]\n(Background on this error at: https://sqlalche.me/e/20/f405)'
generated_sql = "SQL: WITH MonthlyFlights AS (  SELECT  CAST(flights.month AS INT) AS flight_month,  COUNT(*) AS num_flights  FROM flights  WHERE  flights.year = 2015  GROUP BY  flights.month ), LaggedFlights AS (  SELECT  flight_month,  num_flights,  LAG(num_flights, 1, 0) OVER (ORDER BY flight_month) AS previous_month_flights  FROM MonthlyFlights ) SELECT  flight_month,  (CAST(num_flights AS DOUBLE) - CAST(previous_month_flights AS DOUBLE)) / CAST(previous_month_flights AS DOUBLE) AS month_over_month_change FROM LaggedFlights"

In [None]:
sql_correction_prompt = """
### TASK ###
You are an ANSI SQL expert with exceptional logical thinking skills and debugging skills.

Now you are given syntactically incorrect ANSI SQL query and related error message, please generate the syntactically correct ANSI SQL query without changing original semantics.

### QUESTION ###
SQL query asked by user: {user_query}
GENERATED SQL: {generated_sql}
ERROR Message: {error_message}

### FINAL ANSWER FORMAT ###
The final answer must be a corrected SQL query in JSON format:

{{
    "sql": <CORRECTED_SQL_QUERY_STRING>
}}
"""


In [None]:
prompt = sql_correction_prompt.format(
    user_query=user_query, error_message=error_message, generated_sql=generated_sql
)

In [None]:
llm = GenerativeModelWrapper()

In [None]:
res = await llm.generate_sql(prompt=prompt)
res

In [None]:
def clean_generation_result(result: str) -> str:
    def _normalize_whitespace(s: str) -> str:
        return re.sub(r"\s+", " ", s).strip()

    return (
        _normalize_whitespace(result)
        .replace("\\n", " ")
        .replace("```sql", "")
        .replace("```json", "")
        .replace('"""', "")
        .replace("'''", "")
        .replace("```", "")
        .replace(";", "")
    )

In [None]:
sql_query = json.loads(clean_generation_result(res))
print(sql_query)

In [None]:
sql = sql_query.get("sql", "")
sql

In [None]:
from typing import Dict
import pandas as pd
from sqlalchemy import engine
import os
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker


def parse_nan_values(dataframe: pd.DataFrame) -> pd.DataFrame:
    return dataframe.fillna("")


# Environment variables
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
PASS = os.getenv("PASS")
DATABASE = os.getenv("DATABASE")
USER = os.getenv("USER")
HOST = os.getenv("HOST")
DATABASE_CLIENT = os.getenv("DATABASE_CLIENT")
PORT = os.getenv("PORT")


# Database connection
def create_db_connection(password):
    DATABASE_URL = f"{DATABASE_CLIENT}://{USER}:{password}@{HOST}:{PORT}/{DATABASE}"
    engine = create_engine(DATABASE_URL)
    Session = sessionmaker(bind=engine)
    return engine, Session()


In [None]:
engine, session = create_db_connection(PASS)


# from fastapi import HTTPException
def execute_query(query: str) -> Dict:
    try:
        df_result = pd.read_sql_query(query, engine)
        df_result = parse_nan_values(df_result)
        return df_result.to_dict(orient="records")
    except Exception as e:
        raise e

In [None]:
retry = 0
while retry < 3:
    try:
        res = execute_query(sql)
    except Exception as e:
        res = await llm.generate_sql(prompt=prompt)
        sql_query = json.loads(clean_generation_result(res))
        sql = sql_query.get("sql", "")
        res = execute_query(sql)
        retry += 1
        print("Exception Occured:", e)


In [None]:
res = await execute_with_retries(
    user_query=user_query,
    initial_generated_sql=sql_query,
    initial_error_message=error_message,
    engine=engine,
)
