In [None]:

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
import os
import pandas as pd
import json
from typing import Dict

# Environment variables

PASS = os.getenv("PASS")
DATABASE = os.getenv("DATABASE")
USER = os.getenv("USER")
HOST = os.getenv("HOST")
DATABASE_CLIENT = os.getenv("DATABASE_CLIENT")
PORT = os.getenv("PORT")

In [2]:
# Database connection
def create_db_connection(password):
    DATABASE_URL = f"{DATABASE_CLIENT}://{USER}:{password}@{HOST}:{PORT}/{DATABASE}"
    engine = create_engine(DATABASE_URL)
    Session = sessionmaker(bind=engine)
    return engine, Session()


In [3]:
engine, session = create_db_connection(PASS)

In [5]:
def parse_nan_values(dataframe: pd.DataFrame) -> pd.DataFrame:
    return dataframe.fillna("")


def parse_numeric_values(dataframe: pd.DataFrame) -> pd.DataFrame:
    for column in dataframe.select_dtypes(include=["number"]).columns:
        dataframe[column] = dataframe[column].round(2)
    return dataframe


In [6]:
def execute_query(query: str) -> Dict:
    try:
        df_result = pd.read_sql_query(query, engine)
        df_result = parse_nan_values(df_result)
        df_result = parse_numeric_values(df_result)
        return df_result.to_dict(orient="records")
    except Exception as e:
        raise e


In [None]:
def dump_to_json(data: Dict) -> str:
    return json.dumps(data, indent=4)

In [35]:
res = execute_query("""SELECT a.airline, COUNT(*) AS total_flights, AVG(distance) AS avg_distance
FROM flights f
JOIN airlines a ON f.airline = a.iata_code
WHERE f.year = 2015
GROUP BY a.airline
ORDER BY a.airline;


""")

In [None]:
data = dump_to_json(res)
data

'[\n    {\n        "airline": "Alaska Airlines Inc.",\n        "total_flights": 167,\n        "avg_distance": 1258.01\n    },\n    {\n        "airline": "American Airlines Inc.",\n        "total_flights": 632,\n        "avg_distance": 1035.42\n    },\n    {\n        "airline": "American Eagle Airlines Inc.",\n        "total_flights": 260,\n        "avg_distance": 433.22\n    },\n    {\n        "airline": "Atlantic Southeast Airlines",\n        "total_flights": 512,\n        "avg_distance": 453.05\n    },\n    {\n        "airline": "Delta Air Lines Inc.",\n        "total_flights": 747,\n        "avg_distance": 852.72\n    },\n    {\n        "airline": "Frontier Airlines Inc.",\n        "total_flights": 66,\n        "avg_distance": 979.48\n    },\n    {\n        "airline": "Hawaiian Airlines Inc.",\n        "total_flights": 64,\n        "avg_distance": 664.92\n    },\n    {\n        "airline": "JetBlue Airways",\n        "total_flights": 235,\n        "avg_distance": 1041.18\n    },\n   

In [4]:
eval_df = pd.read_csv("../data/eval.csv")

In [5]:
eval_df.head()

Unnamed: 0,sql_query,result,description
0,"SELECT a.airline, COUNT(*) AS total_flights\nF...","'[\n {\n ""airline"": ""Southwest Airli...",Which airline operated the most flights in 2015?
1,SELECT COUNT(*) AS total_flights\nFROM flights...,"'[\n {\n ""total_flights"": 191\n }\n]'",How many flights departed from Dallas/Fort Wor...
2,SELECT *\nFROM airports\nWHERE state = 'CA';,"'[\n {\n ""iata_code"": ""ACV"",\n ""airport"": ""Arc...",List all airports located in California.
3,SELECT AVG(departure_delay) AS avg_departure_d...,"'[\n {\n ""avg_departure_delay"": 10.61\n }\n]'",What is the average departure delay for flight...
4,SELECT COUNT(*) AS cancelled_flights\nFROM fli...,"'[\n {\n ""cancelled_flights"": 63\n }\n]'",How many flights were cancelled in 2015?


In [None]:
def parse_to_json(data: str) -> dict:
    data = data.replace("'", "\"").replace('\\n', '')
    return json.loads(f"[{data}]")[0]

eval_df['parsed_result'] = eval_df['result'].apply(parse_to_json)


                                           sql_query  \
0  SELECT a.airline, COUNT(*) AS total_flights\nF...   
1  SELECT COUNT(*) AS total_flights\nFROM flights...   
2       SELECT *\nFROM airports\nWHERE state = 'CA';   
3  SELECT AVG(departure_delay) AS avg_departure_d...   
4  SELECT COUNT(*) AS cancelled_flights\nFROM fli...   

                                              result  \
0      {        "airline": "Southwest Airlines Co...   
1                           { "total_flights": 191 }   
2   { "iata_code": "ACV", "airport": "Arcata Airp...   
3                   { "avg_departure_delay": 10.61 }   
4                        { "cancelled_flights": 63 }   

                                         description  \
0   Which airline operated the most flights in 2015?   
1  How many flights departed from Dallas/Fort Wor...   
2           List all airports located in California.   
3  What is the average departure delay for flight...   
4           How many flights were cancelled in

In [24]:
eval_df.parsed_result[0]

{'airline': 'Southwest Airlines Co.', 'total_flights': 1026}

In [25]:
listed_eval_sets = eval_df.parsed_result.to_list()

In [27]:
listed_eval_sets

[{'airline': 'Southwest Airlines Co.', 'total_flights': 1026},
 {'total_flights': 191},
 {'iata_code': 'ACV',
  'airport': 'Arcata Airport',
  'city': 'Arcata/Eureka',
  'state': 'CA',
  'country': 'USA',
  'latitude': 40.98,
  'longitude': -124.11},
 {'avg_departure_delay': 10.61},
 {'cancelled_flights': 63},
 {'airline': 'Hawaiian Airlines Inc.', 'on_time_percentage': 90.63},
 {'destination_airport': 'SFO', 'flight_count': 10},
 {'airline': 'Frontier Airlines Inc.', 'avg_arrival_delay': 15.56},
 {'diverted_flights': 8},
 {'origin_airport': 'ATL', 'departure_count': 280},
 {'airline': 'Atlantic Southeast Airlines', 'cancellations': 18},
 {'airline': 'Alaska Airlines Inc.',
  'total_flights': 167,
  'avg_distance': 1258.01}]

In [28]:
# you would need to define a function that sends the query to the LLM and returns the result
# this is just a placeholder
def send_query_to_llm(query: str):
    # Replace this with your actual LLM call
    # This is just a dummy implementation
    return "LLM Response for: " + query

llm_responses = []
for index, row in eval_df.iterrows():
    query = row['sql_query']
    description = row['description']
    llm_response = send_query_to_llm(description) # or query, depending on what you want to send
    llm_responses.append(llm_response)

eval_df['llm_response'] = llm_responses

In [83]:

eval_df.parsed_result

0     {'airline': 'Southwest Airlines Co.', 'total_f...
1                                {'total_flights': 191}
2     {'iata_code': 'ACV', 'airport': 'Arcata Airpor...
3                        {'avg_departure_delay': 10.61}
4                             {'cancelled_flights': 63}
5     {'airline': 'Hawaiian Airlines Inc.', 'on_time...
6     {'destination_airport': 'SFO', 'flight_count':...
7     {'airline': 'Frontier Airlines Inc.', 'avg_arr...
8                               {'diverted_flights': 8}
9     {'origin_airport': 'ATL', 'departure_count': 280}
10    {'airline': 'Atlantic Southeast Airlines', 'ca...
11    {'airline': 'Alaska Airlines Inc.', 'total_fli...
Name: parsed_result, dtype: object

In [44]:
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..")))

from src.utils.database_utils import DatabaseConnector
from src.utils.common_utils import  clean_generation_result
from src.core.llm import GenerativeModelWrapper

In [45]:
llm = GenerativeModelWrapper()


In [46]:
description_set = eval_df.description.to_list()

In [47]:
description_set

['Which airline operated the most flights in 2015?',
 'How many flights departed from Dallas/Fort Worth International Airport (DFW) in 2015?',
 'List all airports located in California.',
 'What is the average departure delay for flights departing from Los Angeles International Airport (LAX) in 2015?',
 'How many flights were cancelled in 2015?',
 'Which airline had the highest on-time departure rate in 2015 (considering flights with a departure delay of 15 minutes or less as on time)?',
 'Which destination airport received the most flights from John F. Kennedy International Airport (JFK) in 2015?',
 'List the top 5 airlines with the highest average arrival delay in 2015.',
 'How many flights were diverted in 2015?',
 'Which airport had the most departures in 2015?',
 'Which airline had the most cancellations in 2015?',
 'For each airline, what was the total number of flights and the average distance flown in 2015?']

In [63]:
generated_llm_response = {}
for set in description_set:
    res = await llm.generate_sql(set)
    cleaned_response = clean_generation_result(res)
    parsed = json.loads(cleaned_response)
    generated_llm_response[set] = parsed
    
    

In [50]:
generated_llm_response

{'Which airline operated the most flights in 2015?': {'sql': 'SELECT a.airline, COUNT(f.flight_number) AS num_flights FROM flights AS f JOIN airlines AS a ON f.airline = a.iata_code WHERE f.year = 2015 GROUP BY a.airline ORDER BY num_flights DESC LIMIT 1'},
 'How many flights departed from Dallas/Fort Worth International Airport (DFW) in 2015?': {'sql': "SELECT count(*) AS num_flights FROM flights AS f JOIN airports AS a ON lower(f.origin_airport) = lower(a.iata_code) WHERE lower(a.airport) = lower('Dallas/Fort Worth International Airport') AND f.year = 2015"},
 'List all airports located in California.': {'sql': "SELECT iata_code, airport FROM airports WHERE lower(state) = lower('California')"},
 'What is the average departure delay for flights departing from Los Angeles International Airport (LAX) in 2015?': {'sql': "SELECT AVG(flights.departure_delay) AS average_departure_delay FROM flights JOIN airports ON flights.origin_airport = airports.iata_code WHERE lower(airports.airport) = 

In [64]:
db = DatabaseConnector()

In [65]:
len(db.execute_query(list(generated_llm_response.values())[0].get('sql'))[0])


2

In [66]:
eval_generated_llm_respone = {}

for i, j in generated_llm_response.items():
    exectued_query = db.execute_query(j.get('sql'))
    eval_generated_llm_respone[i] = (j, exectued_query)


In [67]:
eval_generated_llm_respone

{'Which airline operated the most flights in 2015?': ({'sql': 'SELECT a.airline, count(*) AS total_flights FROM flights AS f JOIN airlines AS a ON f.airline = a.iata_code WHERE f.year = 2015 GROUP BY a.airline ORDER BY total_flights DESC LIMIT 1'},
  [{'airline': 'Southwest Airlines Co.', 'total_flights': 1026}]),
 'How many flights departed from Dallas/Fort Worth International Airport (DFW) in 2015?': ({'sql': "SELECT count(*) AS number_of_flights FROM flights AS f JOIN airports AS a ON lower(f.origin_airport) = lower(a.iata_code) WHERE lower(a.airport) = lower('Dallas/Fort Worth International Airport') AND f.year = 2015"},
  [{'number_of_flights': 191}]),
 'List all airports located in California.': ({'sql': "SELECT airport, city FROM airports WHERE lower(state) = lower('California')"},
  []),
 'What is the average departure delay for flights departing from Los Angeles International Airport (LAX) in 2015?': ({'sql': "SELECT AVG(flights.departure_delay) AS average_departure_delay FROM

In [79]:
accuracy = [
    True,
    True,
    False,
    True,
    True,
    True,
    False,
    True,
    True,
    True,
    True,
    True,
]

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

y_true = [True] * len(accuracy)
y_pred = accuracy 

# Calculate metrics
accuracy_val = accuracy_score(y_true, y_pred)
precision_val = precision_score(y_true, y_pred)
recall_val = recall_score(y_true, y_pred)
f1_val = f1_score(y_true, y_pred)

print(f"Accuracy: {accuracy_val}")
print(f"Precision: {precision_val}")
print(f"Recall: {recall_val}")
print(f"F1 Score: {f1_val}")

Accuracy: 0.8333333333333334
Precision: 1.0
Recall: 0.8333333333333334
F1 Score: 0.9090909090909091


  from ..externals._packaging.version import parse as parse_version
