# Install dependencies

In [1]:
%pip install transformers torch "sqlglot[rs]" duckdb google-generativeai

Note: you may need to restart the kernel to use updated packages.


# Import Libraries

In [2]:
import os
import random
import sqlite3
import json
import google.generativeai as genai
import sqlglot
import time
import pandas as pd
from pathlib import Path

In [3]:
# Configure Gemini API Key
genai.configure(api_key='API_KEY')

# Load and parse the train.json file
with open('D:/Downloads/Text-to-SQL/train/train.json', 'r') as file:
    data = json.load(file)

# Extract schema from sqlite

In [4]:
# extract schema from the .sqlite file
def connect_to_database(db_path):
    """Establish a connection to the SQLite database."""
    try:
        conn = sqlite3.connect(db_path)
        return conn
    except sqlite3.Error as e:
        print(f"Error connecting to database: {e}")
        return None

def get_table_names(conn):
    """Retrieve the names of all tables in the database."""
    try:
        cur = conn.cursor()
        cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = cur.fetchall()
        return [table[0] for table in tables]  # Extract table names from tuples
    except sqlite3.Error as e:
        print(f"Error retrieving table names: {e}")
        return []

def get_column_names(conn, table_name):
    """Retrieve the column names for a given table."""
    try:
        cur = conn.cursor()
        # cur.execute(f"PRAGMA table_info({table_name});")
        cur.execute(f"PRAGMA table_info(\"{table_name}\");")
        schema = cur.fetchall()
        return [column[1] for column in schema]  # Extract column names from tuples
    except sqlite3.Error as e:
        print(f"Error retrieving columns for table {table_name}: {e}")
        return []

def extract_schema(conn):
    """Extract the schema of all tables and store it in a dictionary."""
    schema_dict = {}
    table_names = get_table_names(conn)
    
    for table_name in table_names:
        columns = get_column_names(conn, table_name)
        schema_dict[table_name] = columns
    
    return schema_dict


## Validate the SQL syntax (sqlglot)

In [5]:
def validate_sql(sql_query):
    """Validate the SQL query using sqlglot."""
    try:
        sqlglot.parse_one(sql_query)
        return True
    except sqlglot.errors.ParseError as e:
        print(f"SQL syntax error: {e}")
        return False

## Extract data from SQL query

In [6]:
def execute_sql_query(sql_query, db_path):
    """Execute the SQL query against the SQLite database."""
    
    conn = sqlite3.connect(db_path)
    # print(conn)
    cursor = conn.cursor()
    
    try:
        cursor.execute(sql_query)        
        # rows = cursor.fetchall()    
        rows = cursor.fetchmany(5)  

        return rows
    
    except sqlite3.Error as e:
        print(f"Database error: {e}")
        return None
    
    finally:
        conn.close()


# Prompt Template

In [7]:
# prompt_template = """
# "Database schema will be given in dict format where key is the table name and value is the list of columns in the table
# {schema}
# Example: dict 'tablename': ['col1', 'col2', 'col3']

# Here are a few examples, "Q" represents the question and "A" represents the corresponding SQL-query :
# Q: List out the account numbers of female clients who are oldest and has lowest average salary, calculate the gap between this lowest average salary with the highest average salary?
# A: SELECT T1.account_id , ( SELECT MAX(A11) - MIN(A11) FROM district ) FROM account AS T1 INNER JOIN district AS T2 ON T1.district_id = T2.district_id WHERE T2.district_id = ( SELECT district_id FROM client WHERE gender = 'F' ORDER BY birth_date ASC LIMIT 1 ) ORDER BY T2.A11 DESC LIMIT 1

# Q: For the branch which located in the south Bohemia with biggest number of inhabitants, what is the percentage of the male clients?
# A: SELECT CAST(SUM(T1.gender = 'M') AS REAL) * 100 / COUNT(T1.client_id) FROM client AS T1 INNER JOIN district AS T2 ON T1.district_id = T2.district_id WHERE T2.A3 = 'south Bohemia' GROUP BY T2.A4 ORDER BY T2.A4 DESC LIMIT 1

# Q: "For the client who first applied the loan in 1993/7/5, what is the increase rate of his/her account balance from 1993/3/22 to 1998/12/27?
# A: SELECT CAST((SUM(IIF(T3.date = '1998-12-27', T3.balance, 0)) - SUM(IIF(T3.date = '1993-03-22', T3.balance, 0))) AS REAL) * 100 / SUM(IIF(T3.date = '1993-03-22', T3.balance, 0)) FROM loan AS T1 INNER JOIN account AS T2 ON T1.account_id = T2.account_id INNER JOIN trans AS T3 ON T3.account_id = T2.account_id WHERE T1.date = '1993-07-05'

# Using valid SQL, answer the following question based on the tables provided above.
# It is important to use qualified column names in the SQL-query, meaning the form "SELECT table_name.column_name FROM table_name;


# Hint helps you to write the correct SQL query.
# Question: {question}
# Hint: {evidence}
# DO NOT return anything else except the SQL query. Make sure the SQL query is compatible with SQLite syntax. No need to mentio ```sql"
# """

# def get_gemini_response(question, evidence, schema):
#     """Generate SQL query using Gemini Pro."""
   
#     prompt = prompt_template.format(schema=schema, question=question, evidence=evidence)
#     model = genai.GenerativeModel('gemini-pro')
#     response = model.generate_content([prompt])
#     return response.text.strip()

In [8]:
""" Pay close attention on which column is in which table. if context contains more than one tables then create a query by performing JOIN operation only using the column unitid for the tables.\
Follow these Instructions for creating syntactically correct SQL query:\
- Be sure not to query for columns that do not exist in the tables and use alias only where required.\
- Always use the column 'instnm' associated with the 'unitid' in the generated query.\
- Whenever asked for Institute Names, return the institute names using column 'instnm' associated with the 'unitid' in the generated query.\
- Likewise, when asked about the average (AVG function) or ratio, ensure the appropriate aggregation function is used.\
- Pay close attention to the filtering criteria mentioned in the question and incorporate them using the WHERE clause in your SQL query.\
- If the question involves multiple conditions, use logical operators such as AND, OR to combine them effectively.\
- When dealing with date or timestamp columns, use appropriate date functions (e.g., DATE_PART, EXTRACT) for extracting specific parts of the date or performing date arithmetic.\
- If the question involves grouping of data (e.g., finding totals or averages for different categories), use the GROUP BY clause along with appropriate aggregate functions.\
- Consider using aliases for tables and columns to improve readability of the query, especially in case of complex joins or subqueries.\
- If necessary, use subqueries or common table expressions (CTEs) to break down the problem into smaller, more manageable parts.
""" 

prompt_template = """
"Database schema will be given in dict format where key is the table name and value is the list of columns in the table
{schema}
Example: dict 'tablename': ['col1', 'col2', 'col3']

Here are a few examples, you can use them to generate the SQL query:

"question": "Rank schools by their average score in Writing where the score is greater than 499, showing their charter numbers.",
"evidence": "Valid charter number means the number is not null",
"SQL": "SELECT CharterNum, AvgScrWrite, RANK() OVER (ORDER BY AvgScrWrite DESC) AS WritingScoreRank FROM schools AS T1  INNER JOIN satscores AS T2 ON T1.CDSCode = T2.cds WHERE T2.AvgScrWrite > 499 AND CharterNum is not null",
   
"question": "Among the weekly issuance accounts, how many have a loan of under 200000?",
"evidence": "frequency = 'POPLATEK TYDNE' stands for weekly issuance",
"SQL": "SELECT COUNT(T1.account_id) FROM loan AS T1 INNER JOIN account AS T2 ON T1.account_id = T2.account_id WHERE T2.frequency = 'POPLATEK TYDNE' AND T1.amount < 200000",

Evidence helps you to write the correct SQL query.
Question: {question}
Evidence: {evidence}
DO NOT return anything else except the SQL query. Make sure the SQL query is compatible with SQLite syntax. No need to mentio ```sql"
"""

def get_gemini_response(question, evidence, schema):
    """Generate SQL query using Gemini Pro."""
   
    prompt = prompt_template.format(schema=schema, question=question, evidence=evidence)
    model = genai.GenerativeModel('gemini-pro')
    response = model.generate_content([prompt])
    return response.text.strip()


In [9]:
prompt_template2 = """
Given the question, evidence, schema, and SQL query, you need to verify and correct the SQL query if necessary. 

1. Check if the SQL query correctly matches the question based on the provided schema.
2. Ensure that all table names and column names in the query are correct according to the schema.
3. If the query is correct, return it as is. If not, modify the query to correct any errors.

Do not include any additional information or explanations—return only the corrected SQL query.

The database schema is provided in dict format, where each key is a table name and each value is a list of columns for that table:
{schema}
Example format: dict 'table_name': ['column1', 'column2', 'column3']

Question: {question}
Hint: {evidence}
Query: {query}

Only return the SQL query—nothing else. Make sure the SQL query is compatible with SQLite syntax.
"""

def get_gemini_response2(question, evidence, schema, query):
    """Generate SQL query using Gemini Pro."""
   
    prompt = prompt_template2.format(schema=schema, question=question, evidence=evidence, query=query)
    model = genai.GenerativeModel('gemini-pro')
    response = model.generate_content([prompt])
    return response.text.strip()

**Description**:

1. The provided code is designed to process multiple instances of data, where each instance involves generating and refining SQL queries based on a given question, evidence, and database schema. 
2. The generated SQL queries are then validated and executed against the actual SQL queries, and the results are stored in a CSV file. 
3. The process introduces a delay between each instance to control the processing rate.

In [17]:
def generate_sql_query(question, evidence, schema_dict):
    """Generate SQL query based on question, evidence, and schema."""
    generated_sql = get_gemini_response(question=question, evidence=evidence, schema=schema_dict)
    converted_query = generated_sql.replace("```sql ", "").replace("```", "").replace('sql', "")
    return converted_query

def refine_sql_query(question, evidence, schema_dict, initial_query):
    """Refine the SQL query using another model."""
    new_query = get_gemini_response2(question=question, evidence=evidence, schema=schema_dict, query=initial_query)
    refined_query = new_query.replace("```sql ", "").replace("```", "").replace('sql', "")
    return refined_query

def process_instance(instance, db_base_path):
    """Process a single instance: generate SQL, validate, execute, and store results."""
    db_id = instance['db_id']
    question = instance['question']
    evidence = instance['evidence']
    SQL_query = instance["SQL"]

    # db_folder = f'{db_base_path}/{db_id}/database_description'
    sql_schema_path = f'{db_base_path}/{db_id}/{db_id}.sqlite'

    conn = connect_to_database(sql_schema_path)
    
    if conn:
        schema_dict = extract_schema(conn)

    # Generate and refine the SQL query
    generated_sql = generate_sql_query(question, evidence, schema_dict)
    refined_sql = refine_sql_query(question, evidence, schema_dict, generated_sql)

    # Validate the refined SQL
    validation_successful = validate_sql(refined_sql)

    # Execute the actual and generated SQL queries
    results_actual = execute_sql_query(SQL_query, sql_schema_path)
    results_gen = execute_sql_query(refined_sql, sql_schema_path)

    # Return the results as a dictionary
    return {
        "gen_query": refined_sql,
        "actual_query": SQL_query,
        "gen_output": results_gen,
        "actual_output": results_actual,
        "validation": "yes" if validation_successful else "no"
    }

def process_multiple_instances(data, n=10, db_base_path='D:/Downloads/Text-to-SQL/train/train_databases', delay=4):
    """Process multiple instances, introduce a delay, and store the results in a DataFrame."""
    selected_instances = random.sample(data, n)
    results_list = []

    for idx, instance in enumerate(selected_instances):
        print(f"Processing Instance {idx + 1}")
        result = process_instance(instance, db_base_path)
        results_list.append(result)
        print('-' * 50)
        
        # Introduce a delay between processing each instance
        time.sleep(delay)

    # Convert the results list to a DataFrame and save it as a CSV file
    df = pd.DataFrame(results_list)
    df.to_csv('submit4.csv', index=False)


In [18]:
process_multiple_instances(data)

Processing Instance 1
Database error: near "transaction": syntax error
--------------------------------------------------
Processing Instance 2
--------------------------------------------------
Processing Instance 3
--------------------------------------------------
Processing Instance 4
--------------------------------------------------
Processing Instance 5
Database error: aggregate functions are not allowed in the GROUP BY clause
--------------------------------------------------
Processing Instance 6


**Description**:

1. The code provided is a script designed to process multiple CSV files that contain SQL query results, validate the correctness of the generated SQL queries against actual SQL queries, and analyze the results. 
2. The script performs several key tasks, including reading and concatenating CSV files, removing duplicate entries, validating the generated SQL outputs, and counting the occurrences of correct outputs. The final results are printed

In [15]:

def read_and_concatenate_csv_files(base_path, file_prefix, num_files):
    """Read multiple CSV files, concatenate them into a single DataFrame."""
    df_final = None

    for i in range(1, num_files + 1):
        file_path = os.path.join(base_path, f'{file_prefix}{i}.csv')
        df = pd.read_csv(file_path)
        if i == 1:
            df_final = df
        else:
            df_final = pd.concat([df_final, df], ignore_index=True)
    df_final = df_final.drop_duplicates()
    return df_final

def validate_outputs(df):
    """Add a column to the DataFrame indicating whether the generated output matches the actual output."""
    df['correct_output'] = df.apply(lambda row: 1 if row['gen_output'] == row['actual_output'] else 0, axis=1)
    return df

def process_csv_files(base_path, file_prefix='submit', num_files=20):
    """Main function to process CSV files, remove duplicates, validate outputs, and count results."""
    df_final = read_and_concatenate_csv_files(base_path, file_prefix, num_files)

    df_final = validate_outputs(df_final)
    correct_output_counts = df_final['correct_output'].value_counts()
    return df_final, correct_output_counts

# Define the base path where the CSV files are located
base_path = 'D:/Downloads/Text-to-SQL/'

# Call the main processing function
df_final, correct_output_counts = process_csv_files(base_path, file_prefix='submit',num_files=3)

# Print the results
print("Final DataFrame shape:", df_final.shape)
print("Correct output counts:\n", correct_output_counts)


Final DataFrame shape: (30, 6)
Correct output counts:
 0    22
1     8
Name: correct_output, dtype: int64


In [16]:
# Success rate
print("Success rate:", correct_output_counts[1] / df_final.shape[0])

Success rate: 0.26666666666666666
