In [1]:
import os
import pandas as pd
import sqlite3
import google.generativeai as genai

In [2]:
genai.configure(api_key="Insert Key here")
model_name = "gemini-1.5-flash-latest"
model = genai.GenerativeModel(model_name)

In [3]:
def read_csvs_from_directory(directory_path):
    dataframes = {}
    for filename in os.listdir(directory_path):
        if filename.endswith(".csv"):
            path = os.path.join(directory_path, filename)
            df_name = os.path.splitext(filename)[0]
            dataframes[df_name] = pd.read_csv(path)
    return dataframes

In [4]:
def create_sqlite_db_from_dfs(dataframes):
    conn = sqlite3.connect(":memory:")
    for table_name, df in dataframes.items():
        df.to_sql(table_name, conn, index=False, if_exists='replace')
    return conn

In [5]:
def get_metadata_from_dfs(dataframes):
    metadata_list = []
    for table_name, df in dataframes.items():
        cols = [f"- {col} ({str(dtype)})" for col, dtype in df.dtypes.items()]
        metadata = f"Table: {table_name}\nColumns:\n" + "\n".join(cols)
        metadata_list.append(metadata)
    return "\n\n".join(metadata_list)

In [6]:
def generate_sql_query(question, metadata, model_name):
    model = genai.GenerativeModel(model_name)
    prompt = (
        f"You are an expert data analyst. Based on the following table schemas:\n\n"
        f"{metadata}\n\n"
        f"Write a syntactically correct SQL query (MySQL) for this question:\n"
        f"{question}"
        f"Give the query only, nothing else"
        f"When calculating ratios in SQL, multiply the numerator by 1.0 or use CAST(... AS FLOAT) to force float division and avoid integer division"
        f"When giving the final result, always order it in a way you see fit"
    )
    response = model.generate_content(prompt)
    return response.text.strip("`sql\n").strip("`")

In [7]:
def run_sql_query(query, conn):
    try:
        result = pd.read_sql_query(query, conn)
        return result
    except Exception as e:
        return f"Error running SQL query: {e}"

In [8]:
directory_path = "./Datasets/ben10"
question = "For the top 5 aliens that have fought most battles, what are their battle count, wins, win to battle ratio, strength and speed?"

dataframes = read_csvs_from_directory(directory_path)
conn = create_sqlite_db_from_dfs(dataframes)
metadata = get_metadata_from_dfs(dataframes)

print(metadata)

Table: ben10_aliens
Columns:
- alien_id (int64)
- alien_name (object)
- species (object)
- home_planet (object)
- strength_level (int64)
- speed_level (int64)
- intelligence (int64)

Table: ben10_enemies
Columns:
- enemy_id (int64)
- alien_name (object)
- enemy_name (object)

Table: ben10_battles
Columns:
- battle_id (int64)
- alien_name (object)
- enemy_name (object)
- battle_date (object)
- winner (object)


In [9]:
sql_query = generate_sql_query(question, metadata, model_name)
print("Generated SQL Query:\n", sql_query)

Generated SQL Query:
 SELECT
  a.alien_name,
  COUNT(b.battle_id) AS battle_count,
  SUM(CASE WHEN b.winner = a.alien_name THEN 1 ELSE 0 END) AS wins,
  (
    SUM(CASE WHEN b.winner = a.alien_name THEN 1 ELSE 0 END) * 1.0 / COUNT(b.battle_id)
  ) AS win_ratio,
  a.strength_level,
  a.speed_level
FROM ben10_aliens AS a
JOIN ben10_battles AS b
  ON a.alien_name = b.alien_name
GROUP BY
  a.alien_name
ORDER BY
  battle_count DESC
LIMIT 5;


In [10]:
result = run_sql_query(sql_query, conn)
print("Query Result:\n", result)

Query Result:
     alien_name  battle_count  wins  win_ratio  strength_level  speed_level
0    Big Chill             9     4   0.444444               9            5
1    Terraspin             8     4   0.500000              10            6
2      Upchuck             7     4   0.571429               7           10
3   Glitch Ben             7     3   0.428571               8            4
4  Astrodactyl             7     3   0.428571              10            7
