# Install dependencies

In [None]:
!pip install transformers torch datasets duckdb google-generativeai

Collecting datasets
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-2.21.0-py3-none-any.whl (527 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[2

In [None]:
!pip install --upgrade bitsandbytes accelerate

Collecting bitsandbytes
  Downloading bitsandbytes-0.43.3-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)
Collecting transformers
  Downloading transformers-4.44.2-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Collecting accelerate
  Downloading accelerate-0.33.0-py3-none-any.whl.metadata (18 kB)
Downloading bitsandbytes-0.43.3-py3-none-manylinux_2_24_x86_64.whl (137.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.5/137.5 MB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading transformers-4.44.2-py3-none-any.whl (9.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.5/9.5 MB[0m [31m38.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading accelerate-0.33.0-py3-none-any.whl (315 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m315.1/315.1 kB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected 

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Import Libraries

In [None]:
import sqlite3
import sqlglot
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig
import torch
import bitsandbytes as bnb
import google.generativeai as genai
import pandas as pd
import json
import random

# Load PICARD model

In [None]:
# Load model
model_name = "tscholak/cxmefzzi"
tokenizer = AutoTokenizer.from_pretrained(model_name)

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

# Load the model with 4-bit quantization
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map="auto"  # map to available devices (CPU/GPU)
)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.89k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.44k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/27.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/11.4G [00:00<?, ?B/s]

#Extract schema

**Description**:

This code extracts the schema from an SQLite database,
formatting the input according to a specific template, and then using a pre-trained model to generate the SQL query. Below is a detailed explanation of each function:

1. `get_schema`
This function extracts the schema of an SQLite database. The code snippet `f"PRAGMA table_info(\"{table_name}\");"` ensures that even table names containing special characters are handled correctly by enclosing the table name in double quotes.

2. `format_input` formats input according to the model's input

3. `translate_text_to_sql` this function uses a pre-trained model to translate the formatted input (question, database ID, and schema) into an SQL query.


In [None]:
def get_schema(cursor):
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") # provides table name
    tables = cursor.fetchall()

    schema = []
    for table_name in tables:
        table_name = table_name[0]
        cursor.execute(f"PRAGMA table_info(\"{table_name}\");")
        columns = cursor.fetchall()

        table_schema = f"{table_name} : "
        column_info = []
        for col in columns:
            col_name = col[1]
            # Formatting column name as the model takes input in specific format
            quoted_col_name = f'"{col_name}"'
            cursor.execute(f"SELECT DISTINCT {quoted_col_name} FROM {table_name} LIMIT 1;")
            contents = cursor.fetchall()
            contents_str = ', '.join([str(c[0]) for c in contents])
            column_info.append(f"{col_name} ({contents_str})")

        table_schema += " , ".join(column_info)
        schema.append(table_schema)

    return " | ".join(schema)

# required input format
#[question] | [db_id] | [table] : [column] ( [content] , [content] ) , [column] ( ... ) , [...] | [table] : ... | ...

def format_input(question, db_id, schema):
    return f"{question} | {db_id} | {schema}"

def translate_text_to_sql(formatted_input):
    inputs = tokenizer(formatted_input, return_tensors="pt")
    outputs = model.generate(**inputs)
    sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return sql_query


**Description**:

1. This code provides a set of utility functions to interact with an SQLite database.
2. The functions enable connecting to the database, retrieving the schema (tables and columns), validating SQL queries using sqlglot, and executing SQL queries with a limitation on the number of rows returned.

In [None]:
# extract schema from the .sqlite file
def connect_to_database(db_path):
    """Establish a connection to the SQLite database."""
    try:
        conn = sqlite3.connect(db_path)
        return conn
    except sqlite3.Error as e:
        print(f"Error connecting to database: {e}")
        return None

def get_table_names(conn):
    """Retrieve the names of all tables in the database."""
    try:
        cur = conn.cursor()
        cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = cur.fetchall()
        return [table[0] for table in tables]  # Extract table names from tuples
    except sqlite3.Error as e:
        print(f"Error retrieving table names: {e}")
        return []

def get_column_names(conn, table_name):
    """Retrieve the column names for a given table."""
    try:
        cur = conn.cursor()
        # cur.execute(f"PRAGMA table_info({table_name});")
        cur.execute(f"PRAGMA table_info(\"{table_name}\");")
        schema = cur.fetchall()
        return [column[1] for column in schema]  # Extract column names from tuples
    except sqlite3.Error as e:
        print(f"Error retrieving columns for table {table_name}: {e}")
        return []

def extract_schema(conn):
    """Extract the schema of all tables and store it in a dictionary."""
    schema_dict = {}
    table_names = get_table_names(conn)

    for table_name in table_names:
        columns = get_column_names(conn, table_name)
        schema_dict[table_name] = columns

    return schema_dict

def validate_sql(sql_query):
    """Validate the SQL query using sqlglot."""
    try:
        sqlglot.parse_one(sql_query)
        return True
    except sqlglot.errors.ParseError as e:
        print(f"SQL syntax error: {e}")
        return False

def execute_sql_query(sql_query, db_path):
    """Execute the SQL query against the SQLite database."""

    conn = sqlite3.connect(db_path)
    # print(conn)
    cursor = conn.cursor()

    try:
        cursor.execute(sql_query)
        # rows = cursor.fetchall()
        rows = cursor.fetchmany(5)
        return rows

    except sqlite3.Error as e:
        print(f"Database error: {e}")
        return None

    finally:
        conn.close()


# Prompt template


Get your api key from https://ai.google.dev/gemini-api/docs/api-key

In [None]:
# prompt_template = """
# Here are a few examples, you can use them to correct the SQL query. Given the question and SQL query, you need to verify and correct the SQL query if necessary.

# 1. Check if the SQL query correctly matches the question. In the question there are some hints provided.
# 2. Ensure that all table names and column names in the query are correct according to the schema.
# 3. If the query is correct, return it as is. If not, modify the query to correct any errors.

# "question": "Rank schools by their average score in Writing where the score is greater than 499, showing their charter numbers. Valid charter number means the number is not null",
# "SQL": "SELECT CharterNum, AvgScrWrite, RANK() OVER (ORDER BY AvgScrWrite DESC) AS WritingScoreRank FROM schools AS T1  INNER JOIN satscores AS T2 ON T1.CDSCode = T2.cds WHERE T2.AvgScrWrite > 499 AND CharterNum is not null",

# "question": "Among the weekly issuance accounts, how many have a loan of under 200000?  "frequency = 'POPLATEK TYDNE' stands for weekly issuance",
# "SQL": "SELECT COUNT(T1.account_id) FROM loan AS T1 INNER JOIN account AS T2 ON T1.account_id = T2.account_id WHERE T2.frequency = 'POPLATEK TYDNE' AND T1.amount < 200000",

# Evidence helps you to write the correct SQL query.
# Question: {question}
# Do not include any additional information or explanations—return only the corrected SQL query. DO NOT return anything else except the SQL query. Make sure the SQL query is compatible with SQLite syntax. No need to mentio ```sql"
# """

# def get_gemini_response(question, SQL_query):
#     """Generate SQL query using Gemini Pro."""

#     prompt = prompt_template.format(question=question, SQL_query=SQL_query)
#     model = genai.GenerativeModel('gemini-pro')
#     response = model.generate_content([prompt])
#     return response.text.strip()

genai.configure(api_key='API_KEY')

In [None]:
prompt_template2 = """
Given the question, schema, and SQL query, verify and correct the SQL query if necessary.
If no SQL query is provided, generate a correct SQL query based on the question and schema.

Instructions:
1. Check if the SQL query correctly matches the question according to the provided schema.
2. Ensure that all table names and column names in the query are accurate according to the schema.
3. If the query is correct, return it as is. If incorrect, modify the query to correct any errors.

Return only the corrected SQL query—no additional information or explanations.

Database schema:
{schema}
[question] | [db_id] | [table] : [column] ( [content] , [content] ) , [column] ( ... ) , [...] | [table] : ... | ...

Question: {question}
Query: {query}

Return only the SQL query. Ensure the SQL query is compatible with SQLite syntax.
"""


def get_gemini_response2(question, schema, query):
    """Generate SQL query using Gemini Pro."""

    prompt = prompt_template2.format(schema=schema, question=question, query=query)
    model = genai.GenerativeModel('gemini-pro')
    response = model.generate_content([prompt])
    return response.text.strip()

#Text-to-SQL

**Description:**


1. Random samples are taken from `train.json` file downloaded from BIRD-Benchmark
2. Picard model will generate the SQL query from given schema, question and evidence. This query is sent to gemini pro to verify the SQL query, table names and column names provided the schema, question, evidence and query.
3. Query generated from gemini pro is validated through sqlglot library.
4. Finally, after validation, first 5 rows of the data is retrieved from the database using the final sql query (validated query)
5. All the results obtained from `generated query` and the `actual query` are stored along with their query in a csv.

(Due to the tornado/excess requests to gemini-pro an error might occur. Thus, 3 random samples are taken at one time.)

In [None]:
# loop

with open('/content/drive/MyDrive/train_databases/train.json', 'r') as file:
    data = json.load(file)

selected_instances = random.sample(data, 2)
results_list = []

for idx, instance in enumerate(selected_instances):
    print(f"Instance {idx + 1}")

    db_id = instance['db_id']
    print(db_id)
    question = instance['question']
    print(question)
    evidence = instance['evidence']
    question += evidence
    SQL_query = instance["SQL"]

    db_path = '/content/drive/MyDrive/train_databases/'+ db_id + '/' + db_id +'.sqlite'
    # print(db_path)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    schema = get_schema(cursor)
    formatted_input = format_input(question, db_id, schema)

    try:
      sql_query = translate_text_to_sql(formatted_input)
      sql_gen = sql_query.split("| ")[1]
    except:
      sql_gen = None


    # ---------------------------------------------
    conn = connect_to_database(db_path)
    if conn:
        schema_dict = extract_schema(conn)
    new_query = get_gemini_response2(question=question, schema=formatted_input, query=sql_gen)
    converted_query = new_query
    converted_query = converted_query.replace("```sql ", "").replace("```", "").replace('sql', "")

    validation_successful = validate_sql(converted_query)

    if validation_successful:
          results_gen = execute_sql_query(converted_query, db_path)
    else:
          results_gen = None

    results_actual = execute_sql_query(SQL_query, db_path)

    # Store results
    results_list.append({
        "gen_query": sql_gen,
        "actual_query": SQL_query,
        "gen_output": results_gen,
        "actual_output": results_actual,
        "validation": "yes" if validation_successful else "no"
    })

    print('-'*50)

# Convert the results list to a DataFrame and save it as a CSV file
df = pd.DataFrame(results_list)
df.to_csv('/content/drive/MyDrive/train_databases/PICARD/picard18.csv', index=False)

Instance 1
computer_student
Calculate the percentage of high-level undergraduate course.
--------------------------------------------------
Instance 2
retail_world
Find the total payment of the orders by customers from San Francisco.
SQL syntax error: Invalid expression / Unexpected token. Line 1, Col: 7.
  `[4mSELECT[0m SUM(UnitPrice * Quantity * (1 - Discount)) FROM Orders JOIN Customers ON Orders.CustomerID = Custom
Database error: no such table: Order Details
Database error: near "`SELECT SUM(UnitPrice * Quantity * (1 - Discount)) FROM Orders JOIN Customers ON Orders.CustomerID = Customers.CustomerID WHERE Customers.City = 'San Francisco';`": syntax error
--------------------------------------------------


Merged all the csv files and calculated success rate

In [6]:
paths = '/content/drive/MyDrive/train_databases/PICARD/'

for i in range(8,19):
    path1 = paths + 'picard' + str(i) + '.csv'
    df = pd.read_csv(path1)
    if i == 8:
        final = df
    else:
        final = pd.concat([final, df])

print("before removing duplicates",final.shape)
final.drop_duplicates(inplace=True)
print("after removing duplicates",final.shape)

def correct(row):
    if row['gen_output'] == row['actual_output']:
        return 1
    else:
        return 0
final['correct_output'] = final.apply(correct, axis=1)
final['correct_output'].value_counts()
final.to_csv('picard_gemini-pro-model.csv', index=False)

before removing duplicates (31, 5)
after removing duplicates (31, 5)


Unnamed: 0_level_0,count
correct_output,Unnamed: 1_level_1
0,22
1,9


#Success rate

In [7]:
# Success rate
print("Success rate:", final['correct_output'].value_counts()[1] / final.shape[0])

Success rate: 0.2903225806451613


Although this is not the correct way of deciding the success rate, due to time-constraint and gemini-pro request limits, I took random sampling once and calculated success rate. Ideally at least the sampling should be repeated 100 times and take the average success rate to represent the model's performance.