## RAG-Driven SQL Query Generation from Natural Language Inputs

### Initiliaze Libraries

In [None]:
from vanna.vannadb import VannaDB_VectorStore
import pandas as pd
from vanna.hf import Hf

### VannaDB vector database and LLM initialization 

In [None]:
class Sales_Data_SQL(VannaDB_VectorStore, Hf):
    def __init__(self, config=None):
        MY_VANNA_MODEL = 'YOUR_VANNA_MODEL'
        VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key='YOUR_VANNA_API_KEY', config=config)
        Hf.__init__(self, config=config)

In [4]:
vn = Sales_Data_SQL(config={'model_name_or_path': 'sakshimahadik/nlp-model-mistral7b-instruct-16bit'}) #'quantization_config': {'quant_method': 'quanto'}, 'pre-quantized': True})

Loading checkpoint shards: 100%|██████████| 3/3 [00:08<00:00,  2.95s/it]
Some parameters are on the meta device because they were offloaded to the disk.


### SQL Database setup

In [9]:
%env DATABASE=DATABASE_NAME
%env USER=USER_NAME
%env HOST=HOST_NAME
%env PASSWORD=PASSWORD

env: DATABASE=DATABASE_NAME
env: USER=USER_NAME
env: HOST=HOST_NAME
env: PASSWORD=PASSWORD


In [None]:
vn.connect_to_mysql(port='MYSQL_PORT')

### Adding SQL Schema to RAG

In [39]:
vn.train(ddl="""
        CREATE TABLE sales(
        invoice_id VARCHAR(30) NOT NULL PRIMARY KEY,
        branch VARCHAR(5) NOT NULL,
        city VARCHAR(30) NOT NULL,
        customer_type VARCHAR(30) NOT NULL,
        gender VARCHAR(10) NOT NULL,
        product_line VARCHAR(100) NOT NULL,
        unit_price DECIMAL(10,2) NOT NULL,
        quantity INT(20) NOT NULL,
        vat FLOAT(6,4) NOT NULL,
        total DECIMAL(12, 4) NOT NULL,
        date DATETIME NOT NULL,
        time TIME NOT NULL,
        payment VARCHAR(15) NOT NULL,
        cogs DECIMAL(10,2) NOT NULL,
        gross_margin_pct FLOAT(11,9),
        gross_income DECIMAL(12, 4),
        rating FLOAT(2, 1)
        );
    """)

Adding ddl: 
        CREATE TABLE sales(
        invoice_id VARCHAR(30) NOT NULL PRIMARY KEY,
        branch VARCHAR(5) NOT NULL,
        city VARCHAR(30) NOT NULL,
        customer_type VARCHAR(30) NOT NULL,
        gender VARCHAR(10) NOT NULL,
        product_line VARCHAR(100) NOT NULL,
        unit_price DECIMAL(10,2) NOT NULL,
        quantity INT(20) NOT NULL,
        vat FLOAT(6,4) NOT NULL,
        total DECIMAL(12, 4) NOT NULL,
        date DATETIME NOT NULL,
        time TIME NOT NULL,
        payment VARCHAR(15) NOT NULL,
        cogs DECIMAL(10,2) NOT NULL,
        gross_margin_pct FLOAT(11,9),
        gross_income DECIMAL(12, 4),
        rating FLOAT(2, 1)
        );
    


'364424-ddl'

In [40]:
vn.train(documentation="We are analysing Walmart's sales data to identify high-performing branches and products, analyze the sales patterns of various products, and understand customer behavior. The primary objective is to enhance and optimize sales strategies.")

Adding documentation....


'2526847-doc'

In [41]:
import json

json_file = 'question_sql_rag.json'

# Read the JSON file and save it as a dictionary
with open(json_file, 'r') as file:
    data_dict = json.load(file)

# Now `data_dict` is a dictionary containing the JSON data
print(data_dict)

# Extract only the question and SQL pairs into a new dictionary
question_sql_pairs = []

for category, pairs in data_dict.items():
    for item in pairs:
        question_sql_pairs.append([item["question"], item["sql"]])

# The result is a list of dictionaries with only question-SQL pairs
print(question_sql_pairs)

{'Generic Questions': [{'question': 'How many unique branches are present in the dataset?', 'sql': 'SELECT COUNT(DISTINCT Branch) FROM sales;'}, {'question': 'How many unique products were sold in each branch?', 'sql': 'SELECT Branch, COUNT(DISTINCT Product_line) FROM sales GROUP BY Branch;'}, {'question': 'What is the total quantity of products sold across all branches?', 'sql': 'SELECT SUM(Quantity) AS total_quantity FROM sales;'}], 'Product Analysis': [{'question': 'What is the highest price per unit for any product?', 'sql': 'SELECT MAX(Unit_price) FROM sales;'}, {'question': 'What is the lowest price per unit for any product?', 'sql': 'SELECT MIN(Unit_price) FROM sales;'}, {'question': 'Which product line has the most transactions?', 'sql': 'SELECT Product_line, COUNT(Invoice_ID) AS transaction_count FROM sales GROUP BY Product_line ORDER BY transaction_count DESC LIMIT 1;'}], 'Sales Analysis': [{'question': 'What is the distribution of sales by time of day?', 'sql': 'SELECT HOUR(

In [42]:
for i in question_sql_pairs:
    vn.train(question=i[0], sql=i[1])

#### Verifying if SQL data is added to RAG

In [6]:
training_data = vn.get_training_data()
training_data.head()

Unnamed: 0,id,training_data_type,question,content
0,542372-sql,sql,What is the highest price per unit for any pro...,SELECT MAX(Unit_price) FROM sales;
1,542375-sql,sql,What is the distribution of sales by time of day?,"SELECT HOUR(Time) AS hour, COUNT(Invoice_ID) A..."
2,542381-sql,sql,What is the most common customer type for high...,"SELECT Customer_type, COUNT(*) AS customer_cou..."
3,542379-sql,sql,On which day of the month are the most transac...,"SELECT DAY(Date) AS day_of_month, COUNT(Invoic..."
4,542378-sql,sql,How many male and female customers are there i...,"SELECT Branch, Gender, COUNT(*) AS customer_co..."


In [7]:
# The information schema query may need some tweaking depending on your database. This is a good starting point.
df_sales = vn.run_sql("DESC sales;")
df_sales.head()

Unnamed: 0,Field,Type,Null,Key,Default,Extra
0,invoice_id,varchar(30),NO,PRI,,
1,branch,varchar(5),NO,,,
2,city,varchar(30),NO,,,
3,customer_type,varchar(30),NO,,,
4,gender,varchar(10),NO,,,


### Generating and Executing SQL Statement based on Query

In [11]:
response_sql = vn.generate_sql("'How many distinct cities are present in the dataset?'", allow_llm_to_see_data=False)

SQL Prompt: [{'role': 'system', 'content': "You are a SQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Tables \n\n        CREATE TABLE sales(\n        invoice_id VARCHAR(30) NOT NULL PRIMARY KEY,\n        branch VARCHAR(5) NOT NULL,\n        city VARCHAR(30) NOT NULL,\n        customer_type VARCHAR(30) NOT NULL,\n        gender VARCHAR(10) NOT NULL,\n        product_line VARCHAR(100) NOT NULL,\n        unit_price DECIMAL(10,2) NOT NULL,\n        quantity INT(20) NOT NULL,\n        vat FLOAT(6,4) NOT NULL,\n        total DECIMAL(12, 4) NOT NULL,\n        date DATETIME NOT NULL,\n        time TIME NOT NULL,\n        payment VARCHAR(15) NOT NULL,\n        cogs DECIMAL(10,2) NOT NULL,\n        gross_margin_pct FLOAT(11,9),\n        gross_income DECIMAL(12, 4),\n        rating FLOAT(2, 1)\n        );\n    \n\n\n===Additional Context \n\nWe are analy

In [12]:
print(response_sql)

SELECT COUNT(DISTINCT City) AS city_count FROM sales;


In [8]:
result_sql = vn.run_sql(response_sql)
result_sql

Unnamed: 0,city_count
0,3


### Evaluation

For Evaluation we are considering these metrics:
- Exact Matching of generated output and ground truth output.
- Similarity between generated SQL and ground truth SQL.
- Syntax Errors

SQLQueryEvaluator Class:
- Loads the evaluation data from json file.
- It evaluates the data based on the above metrics and adds results to the results array.
- Summarizes the results


In [10]:
from difflib import SequenceMatcher

def calculate_similarity(sql1, sql2):
    """
    Calculate similarity between two SQL queries using SequenceMatcher.
    """
    return SequenceMatcher(None, sql1.lower(), sql2.lower()).ratio()

def df_comparison_metric(df1, df2):
    if df1.shape != df2.shape:
        return "Not Same"
    df1_set = set([tuple(sorted(map(str, row))) for row in df1.values])
    df2_set = set([tuple(sorted(map(str, row))) for row in df2.values])
    return "Same" if df1_set == df2_set else "Not Same"

def classify_error(generated_output, ground_truth_output, gen_status, truth_status):
    """
    Classify the type of error based on outputs.
    """
    
    evaluation = df_comparison_metric(generated_output, ground_truth_output)

    if evaluation == "Same":  # Set similarity threshold
        return "Correct"
    elif gen_status == "Syntax Error":
        return "Syntax Error"
    elif truth_status == "Error":
        return "Ground Truth Error"
    else:
        return "Wrong Answer"

In [11]:
import pandas as pd
import json

class SQLQueryEvaluator:
    def __init__(self, vanna_instance, eval_file):
        self.vn = vanna_instance
        self.eval_file = eval_file
        self.results = []

    def load_eval_data(self):
        with open(self.eval_file, 'r') as file:
            self.eval_data = json.load(file)

    def evaluate(self):
        for category, questions in self.eval_data.items():
            for item in questions:
                question = item["question"]
                ground_truth_sql = item["sql"]
                gen_status = ""
                error_classification = "No Error"

                # Generate SQL using the model
                try:
                    generated_sql = self.vn.generate_sql(question, allow_llm_to_see_data=False)
                    try:
                        # Execute generated SQL
                        generated_output = self.vn.run_sql(generated_sql)
                        gen_status = "Success"
                    except Exception as e:
                        generated_output = str(e)
                        gen_status = "Syntax Error"

                    # Execute ground truth SQL
                    try:
                        ground_truth_output = self.vn.run_sql(ground_truth_sql)
                        truth_status = "Success"
                    except Exception as e:
                        ground_truth_output = str(e)
                        truth_status = "Error"

                    # Classify Error
                    error_classification = classify_error(
                        generated_output, ground_truth_output, gen_status, truth_status
                    )
                    
                    # Similarity between ground truth and generated sql queries
                    similarity = calculate_similarity(ground_truth_sql, generated_sql)
                    analysis = ""
                    if similarity >= 0.8:
                        analysis = "Almost Correct"
                        
                    else:
                        analysis = "Wrong Answer"

                    # Append results
                    self.results.append({
                        "Question": question,
                        "Generated SQL": generated_sql,
                        "Ground Truth SQL": ground_truth_sql,
                        "Generated SQL Output": generated_output,
                        "Error Classification": error_classification,
                        "Similarity Analysis": analysis
                    })

                except Exception as e:
                    self.results.append({
                        "Question": question,
                        "Generated SQL": None,
                        "Ground Truth SQL": ground_truth_sql,
                        "Generated SQL Output": str(e),
                        "Error Classification": "Unknown Error",
                        "Similarity Analysis": "Unknown Error"
                    })

    def save_results(self, output_file):
        df = pd.DataFrame(self.results)
        df.to_csv(output_file, index=False)
        print(f"Results saved to {output_file}")

    def summarize_results(self):
        df = pd.DataFrame(self.results)
        summary = df["Error Classification"].value_counts()
        total = len(df)
        accuracy = summary.get("Correct", 0) / total * 100
        syntax_error_rate = summary.get("Syntax Error", 0) / total * 100
        analysis = df["Similarity Analysis"].value_counts()
        almost_correct = analysis.get("Almost Correct", 0) / total * 100

        print(f"Evaluation Summary:")
        print(f"Total Questions Evaluated: {total}")
        print(f"Correct Answers: {summary.get('Correct', 0)}")
        print(f"Syntax Errors: {summary.get('Syntax Error', 0)}")
        print(f"Wrong Answers: {summary.get('Wrong Answer', 0)}")
        print(f"Almost Correct Queries: {analysis.get('Almost Correct', 0)}")
        print(f"Overall Accuracy: {accuracy:.2f}%")
        print(f"Syntax Error Rate: {syntax_error_rate:.2f}%")
        print(f"Almost Correct Queries Rate: {almost_correct:.2f}%")

In [12]:
# Evaluation file
json_file = 'question_sql_eval.json'
output_file = 'sql_evaluation_results_final.csv'

# Evaluation of Model
evaluator = SQLQueryEvaluator(vn, json_file)
evaluator.load_eval_data()
evaluator.evaluate()
evaluator.save_results(output_file)

SQL Prompt: [{'role': 'system', 'content': "You are a SQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Tables \n\n        CREATE TABLE sales(\n        invoice_id VARCHAR(30) NOT NULL PRIMARY KEY,\n        branch VARCHAR(5) NOT NULL,\n        city VARCHAR(30) NOT NULL,\n        customer_type VARCHAR(30) NOT NULL,\n        gender VARCHAR(10) NOT NULL,\n        product_line VARCHAR(100) NOT NULL,\n        unit_price DECIMAL(10,2) NOT NULL,\n        quantity INT(20) NOT NULL,\n        vat FLOAT(6,4) NOT NULL,\n        total DECIMAL(12, 4) NOT NULL,\n        date DATETIME NOT NULL,\n        time TIME NOT NULL,\n        payment VARCHAR(15) NOT NULL,\n        cogs DECIMAL(10,2) NOT NULL,\n        gross_margin_pct FLOAT(11,9),\n        gross_income DECIMAL(12, 4),\n        rating FLOAT(2, 1)\n        );\n    \n\n\n===Additional Context \n\nWe are analy

In [13]:
evaluator.summarize_results()

Evaluation Summary:
Total Questions Evaluated: 27
Correct Answers: 16
Syntax Errors: 0
Wrong Answers: 11
Almost Correct Queries: 19
Overall Accuracy: 59.26%
Syntax Error Rate: 0.00%
Almost Correct Queries Rate: 70.37%


### Feedback

In [14]:
import torch
import gc

del vn
gc.collect()

0

#### Initializing the LLM textgeneration pipeline and filtering data for clasification as 'Wrong Answer'

In [15]:
from transformers import pipeline
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feedback_llm = pipeline("text-generation", model="sakshimahadik/nlp-model-mistral7b-instruct-16bit")


Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.65s/it]
Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [16]:
# Load the CSV file
file_path = 'sql_evaluation_results_final.csv'
data = pd.read_csv(file_path)

# Filter the data for rows where 'Error Classification' == 'Wrong Answer'
filtered_data = data[data['Error Classification'] == 'Wrong Answer']

In [17]:
filtered_data.head()

Unnamed: 0,Question,Generated SQL,Ground Truth SQL,Generated SQL Output,Error Classification,Similarity Analysis
1,Which branch made the highest number of sales?,"SELECT Branch, COUNT(Invoice_ID) AS sales_coun...","SELECT Branch, SUM(quantity) AS sales_count FR...",Branch sales_count\n0 A 340,Wrong Answer,Almost Correct
4,Which gender purchases the most products?,"SELECT Gender, COUNT(Invoice_ID) AS purchases ...","SELECT Gender, SUM(Quantity) AS total_quantity...",Gender purchases\n0 Female 501\n1 ...,Wrong Answer,Wrong Answer
6,What is the most commonly purchased product li...,"SELECT City, Product_line, COUNT(*) AS product...","WITH ranked_sales AS ( SELECT city, product_li...",City Product_line product...,Wrong Answer,Wrong Answer
8,Calculate the revenue per unit for each produc...,"SELECT Product_line, AVG(Total) AS revenue_per...","SELECT Product_line, SUM(Total)/SUM(Quantity) ...",Product_line revenue_per_unit\n0 ...,Wrong Answer,Wrong Answer
12,Which time of day records the highest average ...,"SELECT HOUR(Time) AS hour, AVG(Rating) AS avg_...","SELECT TIME_FORMAT(Time, '%H:00') AS hour, AVG...",hour avg_rating\n0 12 7.3,Wrong Answer,Almost Correct


#### Generating feedback to correct wrong SQL queries

In [18]:
ddl_prompt = """
        ===Tables 

        CREATE TABLE sales(
        invoice_id VARCHAR(30) NOT NULL PRIMARY KEY,
        branch VARCHAR(5) NOT NULL,
        city VARCHAR(30) NOT NULL,
        customer_type VARCHAR(30) NOT NULL,
        gender VARCHAR(10) NOT NULL,
        product_line VARCHAR(100) NOT NULL,
        unit_price DECIMAL(10,2) NOT NULL,
        quantity INT(20) NOT NULL,
        vat FLOAT(6,4) NOT NULL,
        total DECIMAL(12, 4) NOT NULL,
        date DATETIME NOT NULL,
        time TIME NOT NULL,
        payment VARCHAR(15) NOT NULL,
        cogs DECIMAL(10,2) NOT NULL,
        gross_margin_pct FLOAT(11,9),
        gross_income DECIMAL(12, 4),
        rating FLOAT(2, 1)
        );
    """

In [19]:
# Initialize a list to store results
results = []
feedbacks = []

# Iterate over the filtered rows
for _, row in filtered_data.iterrows():
    # Extract the necessary columns
    question = row['Question']
    sql_query = row['Generated SQL']
    sql_output = row['Generated SQL Output']
    
    messages = [
        {
            "role": "system", 
            "content": f"""
                        You are a SQL advisor and expert.
                        Your task is to read the question properly and analyze a provided SQL query, identify any mistakes, and offer detailed feedback to improve the query.
                        The SQL query is almost correct but needs to be changed slightly. Follow the instructions in Question.
                        
                        You will be given:
                        1. A question describing the intended SQL task.
                        2. The corresponding SQL query that produces incorrect or unexpected results.
                        3. The output of the incorrect SQL query.
                        
                        Output Requirements:
                        1. Clearly explain why the SQL query is incorrect.
                        2. Provide a step-by-step explanation of the issues in the query.
                        3. Offer actionable suggestions to fix the query to meet the requirements of the question.
                        4. Do not give any SQL in the output.
                        
                        Refer this schema:
                        {ddl_prompt}
                    """
        },
        
        {"role": "user", "content": f"Question: {question}, SQL Query: {sql_query}, Output of wrong sql: {sql_output}"},
    ]
    output = feedback_llm(messages, max_new_tokens=300, temperature=0.00001, do_sample=True, top_p=5)
    feedback = output[0]['generated_text'][2]['content']
    
    # Append feedbacks to the list
    feedbacks.append({
        'question': question,
        'sql_query': sql_query,
        'sql_output': sql_output,
        'Feedback': feedback,
    })

In [20]:
# Convert feedbacks into a DataFrame
feedback_df = pd.DataFrame(feedbacks)
feedback_df.to_csv('feedback.csv', index=False)

In [23]:
feedbacks[1]['Feedback']

" The SQL query is almost correct, but it is selecting the gender column instead of the total number of products purchased by each gender. Here's a step-by-step explanation of the issues in the query:\n\n1. The SELECT statement should include the column that represents the total number of products purchased by each customer, which is not the gender column.\n2. To get the total number of products purchased by each gender, we need to aggregate the quantity column in the sales table.\n\nTo fix the query, we can modify it as follows:\n\nSELECT Gender, SUM(quantity) AS total_products FROM sales GROUP BY Gender ORDER BY total_products DESC;\n\nThis query will correctly select the gender and total number of products purchased by each gender, group them, and order the results in descending order based on the total number of products."

#### Passing Feedback to the original LLM to correct the SQL queries based on Feedback

In [27]:
# Iterate over the feedback rows
for _, row in feedback_df.iterrows():
    # Extract the necessary columns
    question = row['question']
    sql_query = row['sql_query']
    feedback = row['Feedback']
    
    ddl_list = vn.get_related_ddl(question)
    ddl_prompt = vn.add_ddl_to_prompt('', ddl_list, max_tokens=300)
    
    message_log = []
    message_log.append(vn.user_message(question))
    message_log.append(vn.assistant_message(sql_query))
    message_log.append(vn.user_message(f"The returned SQL Query is wrong. Feedback: {feedback}. Use the feedback and return corrected query. Refer this before writing the sql {ddl_prompt}"))
    
    response = vn.submit_prompt(message_log)
    
    # Replace "\_" with "_"
    response = response.replace("\\_", "_")

    response = response.replace("\\", "")

    extracted_sql = vn.extract_sql_query(response)
    
    # Append results to the list
    results.append({
        'question': question,
        'sql_query': sql_query,
        'ddl_list': ddl_prompt,
        'Feedback': feedback,
        'Extracted SQL': extracted_sql
    })

Info: SELECT Branch, COUNT(Invoice_ID) AS sales_count FROM sales GROUP BY Branch ORDER BY sales_count DESC LIMIT 1;

This query selects the Branch and the number of sales (count of Invoice_ID) for each branch, orders them in descending order by sales count, and returns the branch with the highest number of sales (limit 1). The corrected SQL query uses the appropriate syntax to find the branch with the highest number of sales.
Info: SELECT Gender, SUM(quantity) AS total_products FROM sales GROUP BY Gender ORDER BY total_products DESC;

This query will correctly select the gender and total number of products purchased by each gender, group them, and order the results in descending order based on the total number of products purchased.
Info: SELECT City, Product_line, COUNT(Product_line) AS product_count FROM sales GROUP BY City, Product_line ORDER BY product_count DESC;
Info: ```
SELECT product_line, AVG(Total) AS revenue_per_unit FROM sales GROUP BY product_line ORDER BY revenue_per_uni

In [29]:
print(results[1])

{'question': 'Which gender purchases the most products?', 'sql_query': 'SELECT Gender, COUNT(Invoice_ID) AS purchases FROM sales GROUP BY Gender ORDER BY purchases DESC;', 'ddl_list': '\n===Tables \n\n        CREATE TABLE sales(\n        invoice_id VARCHAR(30) NOT NULL PRIMARY KEY,\n        branch VARCHAR(5) NOT NULL,\n        city VARCHAR(30) NOT NULL,\n        customer_type VARCHAR(30) NOT NULL,\n        gender VARCHAR(10) NOT NULL,\n        product_line VARCHAR(100) NOT NULL,\n        unit_price DECIMAL(10,2) NOT NULL,\n        quantity INT(20) NOT NULL,\n        vat FLOAT(6,4) NOT NULL,\n        total DECIMAL(12, 4) NOT NULL,\n        date DATETIME NOT NULL,\n        time TIME NOT NULL,\n        payment VARCHAR(15) NOT NULL,\n        cogs DECIMAL(10,2) NOT NULL,\n        gross_margin_pct FLOAT(11,9),\n        gross_income DECIMAL(12, 4),\n        rating FLOAT(2, 1)\n        );\n    \n\n', 'Feedback': " The SQL query is almost correct, but it is selecting the gender column instead o

In [None]:
# Convert feedbacks into a DataFrame
results_df = pd.DataFrame(results)

In [31]:
results_df.head()

Unnamed: 0,question,sql_query,ddl_list,Feedback,Extracted SQL
0,Which branch made the highest number of sales?,"SELECT Branch, COUNT(Invoice_ID) AS sales_coun...",\n===Tables \n\n CREATE TABLE sales(\n ...,"The SQL query provided is almost correct, but...","SELECT Branch, COUNT(Invoice_ID) AS sales_coun..."
1,Which gender purchases the most products?,"SELECT Gender, COUNT(Invoice_ID) AS purchases ...",\n===Tables \n\n CREATE TABLE sales(\n ...,"The SQL query is almost correct, but it is se...","SELECT Gender, SUM(quantity) AS total_products..."
2,What is the most commonly purchased product li...,"SELECT City, Product_line, COUNT(*) AS product...",\n===Tables \n\n CREATE TABLE sales(\n ...,"The SQL query provided is almost correct, but...","SELECT City, Product_line, COUNT(Product_line)..."
3,Calculate the revenue per unit for each produc...,"SELECT Product_line, AVG(Total) AS revenue_per...",\n===Tables \n\n CREATE TABLE sales(\n ...,"The SQL query is almost correct, but there is...","SELECT product_line, AVG(Total) AS revenue_per..."
4,Which time of day records the highest average ...,"SELECT HOUR(Time) AS hour, AVG(Rating) AS avg_...",\n===Tables \n\n CREATE TABLE sales(\n ...,"The SQL query is almost correct, but it is in...","SELECT HOUR(Time) AS hour, AVG(Rating) AS avg_..."


In [33]:
combined_df = results_df.merge(data, left_on='question', right_on='Question', how='left')
combined_df = combined_df.drop(columns=['Question', 'Generated SQL', 'Generated SQL Output', 'Error Classification', 'Similarity Analysis', 'ddl_list'])
combined_df.head()

Unnamed: 0,question,sql_query,Feedback,Extracted SQL,Ground Truth SQL
0,Which branch made the highest number of sales?,"SELECT Branch, COUNT(Invoice_ID) AS sales_coun...","The SQL query provided is almost correct, but...","SELECT Branch, COUNT(Invoice_ID) AS sales_coun...","SELECT Branch, SUM(quantity) AS sales_count FR..."
1,Which gender purchases the most products?,"SELECT Gender, COUNT(Invoice_ID) AS purchases ...","The SQL query is almost correct, but it is se...","SELECT Gender, SUM(quantity) AS total_products...","SELECT Gender, SUM(Quantity) AS total_quantity..."
2,What is the most commonly purchased product li...,"SELECT City, Product_line, COUNT(*) AS product...","The SQL query provided is almost correct, but...","SELECT City, Product_line, COUNT(Product_line)...","WITH ranked_sales AS ( SELECT city, product_li..."
3,Calculate the revenue per unit for each produc...,"SELECT Product_line, AVG(Total) AS revenue_per...","The SQL query is almost correct, but there is...","SELECT product_line, AVG(Total) AS revenue_per...","SELECT Product_line, SUM(Total)/SUM(Quantity) ..."
4,Which time of day records the highest average ...,"SELECT HOUR(Time) AS hour, AVG(Rating) AS avg_...","The SQL query is almost correct, but it is in...","SELECT HOUR(Time) AS hour, AVG(Rating) AS avg_...","SELECT TIME_FORMAT(Time, '%H:00') AS hour, AVG..."


In [36]:
feedback_results = []

for _, row in combined_df.iterrows():
    generated_sql = row['Extracted SQL']
    ground_truth_sql = row['Ground Truth SQL']
    gen_status = ""
    truth_status = ""
    
    try:
        generated_output = vn.run_sql(generated_sql)

    except:
        gen_status = "Syntax Error"
        
    try:
        ground_truth_output = vn.run_sql(ground_truth_sql)
        
    except:
        truth_status = "Error"
    
    try:   
        error_classification = classify_error(generated_output, ground_truth_output, gen_status, truth_status)
        score = calculate_similarity(ground_truth_sql, generated_sql)
        
        if score >= 0.8:
            analysis = "Almost Correct"
        else:
            analysis = "Wrong Answer"
        
        feedback_results.append({
            'Question': row['question'],
            'Generated SQL': generated_sql,
            'Ground Truth SQL': ground_truth_sql,
            'Generated Output': generated_output,
            'Ground Truth Output': ground_truth_output,
            'Similarity Analysis': analysis
        })
        
    except:
        feedback_results.append({
            'Question': row['question'],
            'Generated SQL': generated_sql,
            'Ground Truth SQL': ground_truth_sql,
            'Generated Output': None,
            'Ground Truth Output': None,
            'Similarity Analysis': "Unknown Error"
        })

#### Saving results after feedback

In [37]:
feedback_results_df = pd.DataFrame(feedback_results)
feedback_results_df.to_csv("results_after_feedback.csv", index=False)

In [45]:
analysis = feedback_results_df["Similarity Analysis"].value_counts()
total = len(feedback_results_df)

print(f"Total Almost Correct Queries: {analysis.get('Almost Correct', 0)}")
print(f"Total Queries: {total}")

Total Almost Correct Queries: 6
Total Queries: 11


In [24]:
import torch
import gc

del feedback_llm
gc.collect()

12034

### Prompt Engineering

We have initialized three types of prompts:
1. Prompt1: Paraphrased original system prompt.
2. Prompt2: Paraphrased original system prompt + Additional Instructions
3. Prompt3: Paraphrased original system prompt + Additional Instructions

In [46]:
import torch
import gc

del vn
gc.collect()

0

#### Prompts Initilization

In [4]:

prompt1 = """You are a SQL expert. 
Generate a SQL query to answer the question accurately. 
Base your response solely on the provided context and strictly adhere to the given response guidelines and format instructions."""

prompt2 = """You are a SQL expert. 
Your goal is to generate an efficient and accurate SQL query to answer the given question. 
Firstly, understand the problem. Analyze the requirements of the question thoroughly. 
Break the task into logical components using grouping, filtering, ranking etc. 
Secondly, choose the right approach. 
Decide whether to use a Common Table Expression (CTE), subquery, or inline logic based on the complexity of the problem. 
Use a CTE for multi-step logic or when intermediate results are reused. 
Use a subquery for one-off intermediate calculations.
Use direct aggregation or filtering for simple problems. 
At last, write the SQL query. Ensure the query addresses all requirements of the problem."""

prompt3 = """You are a SQL expert. 
Your task is to generate an efficient and accurate SQL query to solve the given problem. 
Begin by analyzing the requirements, breaking the task into logical components like grouping, filtering, ranking, or aggregating. 
Decide on the best approach based on complexity: use a Common Table Expression (CTE) for multi-step logic or reusable intermediate results, a subquery for one-time calculations, or direct aggregation and filtering for simple tasks. 
Justify your choice of approach to ensure correctness and efficiency. 
Write the query to meet all requirements, optimize for performance and readability, and handle edge cases like missing data, ties, or ensuring all groups are included.
"""

#### Prompt 1:

You are a SQL expert. 
Generate a SQL query to answer the question accurately. 
Base your response solely on the provided context and strictly adhere to the given response guidelines and format instructions.

In [48]:
vn_prompt1 = Sales_Data_SQL(config={'model_name_or_path': 'sakshimahadik/nlp-model-mistral7b-instruct-16bit', 'initial_prompt': prompt1}) #'quantization_config': {'quant_method': 'quanto'}, 'pre-quantized': True})
vn_prompt1.connect_to_mysql(port=25038)

Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.31s/it]


In [49]:
# Evaluation file
json_file = 'question_sql_eval.json'
output_file = 'sql_evaluation_results_prompt1.csv'

# Evaluation of Model
evaluator1 = SQLQueryEvaluator(vn_prompt1, json_file)
evaluator1.load_eval_data()
evaluator1.evaluate()
evaluator1.save_results(output_file)

SQL Prompt: [{'role': 'system', 'content': "You are a SQL expert. \nGenerate a SQL query to answer the question accurately. \nBase your response solely on the provided context and strictly adhere to the given response guidelines and format instructions.\n===Tables \n\n        CREATE TABLE sales(\n        invoice_id VARCHAR(30) NOT NULL PRIMARY KEY,\n        branch VARCHAR(5) NOT NULL,\n        city VARCHAR(30) NOT NULL,\n        customer_type VARCHAR(30) NOT NULL,\n        gender VARCHAR(10) NOT NULL,\n        product_line VARCHAR(100) NOT NULL,\n        unit_price DECIMAL(10,2) NOT NULL,\n        quantity INT(20) NOT NULL,\n        vat FLOAT(6,4) NOT NULL,\n        total DECIMAL(12, 4) NOT NULL,\n        date DATETIME NOT NULL,\n        time TIME NOT NULL,\n        payment VARCHAR(15) NOT NULL,\n        cogs DECIMAL(10,2) NOT NULL,\n        gross_margin_pct FLOAT(11,9),\n        gross_income DECIMAL(12, 4),\n        rating FLOAT(2, 1)\n        );\n    \n\n\n===Additional Context \n\nW

In [50]:
evaluator1.summarize_results()

Evaluation Summary:
Total Questions Evaluated: 27
Correct Answers: 17
Syntax Errors: 0
Wrong Answers: 10
Almost Correct Queries: 21
Overall Accuracy: 62.96%
Syntax Error Rate: 0.00%
Almost Correct Queries Rate: 77.78%


#### Prompt 2:

You are a SQL expert. 
Your goal is to generate an efficient and accurate SQL query to answer the given question. 
Firstly, understand the problem. Analyze the requirements of the question thoroughly. 
Break the task into logical components using grouping, filtering, ranking etc. 
Secondly, choose the right approach. 
Decide whether to use a Common Table Expression (CTE), subquery, or inline logic based on the complexity of the problem. 
Use a CTE for multi-step logic or when intermediate results are reused. 
Use a subquery for one-off intermediate calculations.
Use direct aggregation or filtering for simple problems. 
At last, write the SQL query. Ensure the query addresses all requirements of the problem.

In [52]:
vn_prompt2 = Sales_Data_SQL(config={'model_name_or_path': 'sakshimahadik/nlp-model-mistral7b-instruct-16bit', 'initial_prompt': prompt2}) #'quantization_config': {'quant_method': 'quanto'}, 'pre-quantized': True})
vn_prompt2.connect_to_mysql(port=25038)

Loading checkpoint shards: 100%|██████████| 3/3 [00:09<00:00,  3.25s/it]


In [53]:
# Evaluation file
json_file = 'question_sql_eval.json'
output_file = 'sql_evaluation_results_prompt2.csv'

# Evaluation of Model
evaluator2 = SQLQueryEvaluator(vn_prompt2, json_file)
evaluator2.load_eval_data()
evaluator2.evaluate()
evaluator2.save_results(output_file)

SQL Prompt: [{'role': 'system', 'content': "You are a SQL expert. \nYour goal is to generate an efficient and accurate SQL query to answer the given question. \nFirstly, understand the problem. Analyze the requirements of the question thoroughly. \nBreak the task into logical components using grouping, filtering, ranking etc. \nSecondly, choose the right approach. \nDecide whether to use a Common Table Expression (CTE), subquery, or inline logic based on the complexity of the problem. \nUse a CTE for multi-step logic or when intermediate results are reused. \nUse a subquery for one-off intermediate calculations.\nUse direct aggregation or filtering for simple problems. \nAt last, write the SQL query. Ensure the query addresses all requirements of the problem.\n===Tables \n\n        CREATE TABLE sales(\n        invoice_id VARCHAR(30) NOT NULL PRIMARY KEY,\n        branch VARCHAR(5) NOT NULL,\n        city VARCHAR(30) NOT NULL,\n        customer_type VARCHAR(30) NOT NULL,\n        gender

In [54]:
evaluator2.summarize_results()

Evaluation Summary:
Total Questions Evaluated: 27
Correct Answers: 18
Syntax Errors: 0
Wrong Answers: 8
Almost Correct Queries: 21
Overall Accuracy: 66.67%
Syntax Error Rate: 0.00%
Almost Correct Queries Rate: 77.78%


#### Prompt 3:

You are a SQL expert. 
Your task is to generate an efficient and accurate SQL query to solve the given problem. 
Begin by analyzing the requirements, breaking the task into logical components like grouping, filtering, ranking, or aggregating. 
Decide on the best approach based on complexity: use a Common Table Expression (CTE) for multi-step logic or reusable intermediate results, a subquery for one-time calculations, or direct aggregation and filtering for simple tasks. 
Justify your choice of approach to ensure correctness and efficiency. 
Write the query to meet all requirements, optimize for performance and readability, and handle edge cases like missing data, ties, or ensuring all groups are included.

In [8]:
vn_prompt3 = Sales_Data_SQL(config={'model_name_or_path': 'sakshimahadik/nlp-model-mistral7b-instruct-16bit', 'initial_prompt': prompt3}) #'quantization_config': {'quant_method': 'quanto'}, 'pre-quantized': True})
vn_prompt3.connect_to_mysql(port=25038)

Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.46s/it]
Some parameters are on the meta device because they were offloaded to the cpu.


In [12]:
# Evaluation file
json_file = 'question_sql_eval.json'
output_file = 'sql_evaluation_results_prompt3.csv'

# Evaluation of Model
evaluator3 = SQLQueryEvaluator(vn_prompt3, json_file)
evaluator3.load_eval_data()
evaluator3.evaluate()
evaluator3.save_results(output_file)

SQL Prompt: [{'role': 'system', 'content': "You are a SQL expert. \nYour task is to generate an efficient and accurate SQL query to solve the given problem. \nBegin by analyzing the requirements, breaking the task into logical components like grouping, filtering, ranking, or aggregating. \nDecide on the best approach based on complexity: use a Common Table Expression (CTE) for multi-step logic or reusable intermediate results, a subquery for one-time calculations, or direct aggregation and filtering for simple tasks. \nJustify your choice of approach to ensure correctness and efficiency. \nWrite the query to meet all requirements, optimize for performance and readability, and handle edge cases like missing data, ties, or ensuring all groups are included.\n\n===Tables \n\n        CREATE TABLE sales(\n        invoice_id VARCHAR(30) NOT NULL PRIMARY KEY,\n        branch VARCHAR(5) NOT NULL,\n        city VARCHAR(30) NOT NULL,\n        customer_type VARCHAR(30) NOT NULL,\n        gender VA

In [13]:
evaluator3.summarize_results()

Evaluation Summary:
Total Questions Evaluated: 27
Correct Answers: 17
Syntax Errors: 0
Wrong Answers: 10
Almost Correct Queries: 21
Overall Accuracy: 62.96%
Syntax Error Rate: 0.00%
Almost Correct Queries Rate: 77.78%


In [7]:
import torch
import gc

del vn_prompt3
gc.collect()

794