# Sampling Similar NL-to-SQL Pairs for Enhanced Prompting

In [31]:
from langchain.sql_database import SQLDatabase

from langchain_openai import ChatOpenAI
from langchain.chains import create_sql_query_chain
from urllib.parse import quote  
from langchain.callbacks import get_openai_callback
import time
from dotenv import load_dotenv
import os
import sys
import json
load_dotenv()

import warnings
warnings.filterwarnings("ignore", category=SyntaxWarning)

experiment_path = '..\..'
path = os.path.abspath('')
module_path = os.path.join(path, experiment_path)
if module_path not in sys.path:
    sys.path.append(module_path+"\\functions")

from sqldatabase_langchain_utils import SQLDatabaseLangchainUtils


## Schema

In [32]:
SCHEMA = 'mondial_gpt'
PREFIX = 'mondial'

FILE_NAME_RESULT = f"sql_queries_chatgpt_few_shot_{SCHEMA}.json"

def save_queries(queries):
    data = {"queries":queries}
    with open(FILE_NAME_RESULT, "w") as arquivo_json:
        json.dump(data, arquivo_json, indent=4) 

def read_queries():
    with open(FILE_NAME_RESULT, encoding='utf-8', errors='ignore') as json_data:
        data = json.load(json_data, strict=False)
    queries = data["queries"]
    return queries

## Connection

In [33]:
json_file_path = f"../../datasets/{SCHEMA}_db_connection.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    db_connection = json.load(json_data, strict=False)

## Database Information

In [34]:
from sqlalchemy.types import UserDefinedType

class GeoCoordType(UserDefinedType):
    def get_col_spec(self):
        return "GEOCOORD"

    def bind_processor(self, dialect):
        def process(value):
            return value
        return process

    def result_processor(self, dialect, coltype):
        def process(value):
            return value
        return process

from sqlalchemy.dialects.oracle import base as oracle_base
oracle_base.ischema_names["GEOCOORD"] = GeoCoordType

In [35]:
db = SQLDatabaseLangchainUtils(db_connection=db_connection)

exclude = [
    f"{SCHEMA}_tmdp",
    f"{SCHEMA}_tmdpmap",
    f"{SCHEMA}_tmds",
    f"{SCHEMA}_tmjmap",
    f"{SCHEMA}_tpv",
    f"{SCHEMA}_tmdc",
    f"{SCHEMA}_tmdcmap",
    f"{SCHEMA}_tmdej",
    f"{SCHEMA}_log_action",
    f"{SCHEMA}_log_error",
    f"{SCHEMA}_favorite_item", 
    f"{SCHEMA}_favorite_query",
    f"{SCHEMA}_favorite_tag",
    f"{SCHEMA}_favorite_tag_item",
    f"{SCHEMA}_favorite_visualization",
    f"{SCHEMA}_dashboard",
    f"{SCHEMA}_history",
    "teste_cliente",
    "teste_fornecedor",
    "teste_funcionario"
]

include_tables = [s for s in db.get_table_names() if not s.startswith(PREFIX) and s not in exclude]
db = SQLDatabaseLangchainUtils(db_connection=db_connection, include_tables=include_tables)

## Importing Dataset

In [36]:
json_file_path = f"../../datasets/{PREFIX}/{PREFIX}_dataset.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    dataset = json.load(json_data, strict=False)
dataset = dataset['dataset']
filtered_dataset = [{"id": entry["id"], "question": entry["question"], "query": entry["query"]} for entry in dataset]
filtered_dataset

[{'id': '1',
  'question': 'What is the area of Thailand?',
  'query': "SELECT max(area) FROM mondial_country where name  = 'Thailand'"},
 {'id': '2',
  'question': 'What are the provinces with an area greater than 10000?',
  'query': 'SELECT name FROM mondial_province where area > 10000'},
 {'id': '3',
  'question': 'What are the languages spoken in Poland?',
  'query': "SELECT mondial_language.name FROM mondial_language INNER JOIN mondial_country ON mondial_language.country = mondial_country.code WHERE mondial_country.name = 'Poland'"},
 {'id': '4',
  'question': 'How deep is Lake Kariba?',
  'query': "SELECT depth FROM mondial_lake where name LIKE '%Lake Kariba%'"},
 {'id': '5',
  'question': 'What is the total of provinces of Netherlands?',
  'query': "SELECT count(*) FROM mondial_province p INNER JOIN mondial_country c ON p.country = c.code WHERE c.name = 'Netherlands'"},
 {'id': '6',
  'question': 'What is the percentage of religious people are hindu in thailand?',
  'query': "SE

In [37]:
base_pairs = filtered_dataset[:80]
test_pairs = filtered_dataset[80:]

## Functions for Similarity

In [38]:
from sentence_transformers import SentenceTransformer, util
import torch

embedder = SentenceTransformer('all-MiniLM-L6-v2')
base_nl_queries = [pair['question'] for pair in base_pairs]
base_embeddings = embedder.encode(base_nl_queries, convert_to_tensor=True)

def get_similar_examples(new_query, base_pairs, base_embeddings, top_k=3):
    new_embedding = embedder.encode(new_query, convert_to_tensor=True)
    cosine_scores = util.pytorch_cos_sim(new_embedding, base_embeddings)[0]
    top_results = torch.topk(cosine_scores, k=top_k)
    similar_examples = [base_pairs[idx] for idx in top_results.indices]
    return similar_examples

new_query = "What is the total area of the provinces in Canada"
similar_examples = get_similar_examples(new_query, base_pairs, base_embeddings, top_k=3)
similar_examples


[{'id': '2',
  'question': 'What are the provinces with an area greater than 10000?',
  'query': 'SELECT name FROM mondial_province where area > 10000'},
 {'id': '5',
  'question': 'What is the total of provinces of Netherlands?',
  'query': "SELECT count(*) FROM mondial_province p INNER JOIN mondial_country c ON p.country = c.code WHERE c.name = 'Netherlands'"},
 {'id': '19',
  'question': 'List all provinces in Germany',
  'query': "SELECT p.name FROM mondial_province p INNER JOIN mondial_country c ON p.country = c.code  WHERE c.name = 'Germany'"}]

In [39]:
def construct_prompt(similar_examples):
    prompt = "\n\nBelow are some examples of natural language queries and their corresponding SQL queries:\n\n"
    for ex in similar_examples:
        prompt += f"NL: {ex['question']}\nSQL: {ex['query']}\n\n"
    return prompt

## Creating the prompt

In [40]:
from langchain.prompts.prompt import PromptTemplate

f = open(f"prompt.txt", "r")
prompt_template = f.read()
f.close()

prompt_template += "\n\nBelow are some examples of natural language queries and their corresponding SQL queries:\n\n{few-shot}"

PROMPT = PromptTemplate(
    input_variables=["input", "table_info", "top_k", "few-shot"], template=prompt_template
)

print(PROMPT)

input_variables=['few-shot', 'input', 'table_info', 'top_k'] input_types={} partial_variables={} template='You are an Oracle SQL expert. Given an input question, first create a syntactically correct Oracle SQL query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, don\'t query for at {top_k} most results or any using the FETCH FIRST n ROWS ONLY clause as per Oracle SQL. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use TRUNC(SYSDATE) function to get the current date, if the question involves "today". \n\nSome hints:

## SQLQueryChain

In [41]:
query_chain  = create_sql_query_chain(ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo-16k'), db.db, prompt=PROMPT)

### Example 1

In [42]:
with get_openai_callback() as cb:
    sql_query = query_chain.invoke({"question":"What is the total area of the provinces in Canada", "few-shot": similar_examples})
    
    print(cb.total_tokens)
    print(cb.prompt_tokens)
    print(cb.completion_tokens)
    print(cb.total_cost)
sql_query

6164
6149
15
0.018507000000000003


"SQLQuery: SELECT SUM(area) FROM province WHERE country = 'CAN'"

In [43]:
json_file_path = f"../../datasets/{PREFIX}/queries_{PREFIX}.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    queries = json.load(json_data, strict=False)
queries = queries['queries']
queries = queries[80:]
queries

[{'id': '81',
  'question': 'What is the total area of the provinces in Canada?',
  'query_string': '',
  'type': 'complex'},
 {'id': '82',
  'question': 'What is the length of the Tigris River?',
  'query_string': '',
  'type': 'simple'},
 {'id': '83',
  'question': 'List the Airports with elevation more than 1000',
  'query_string': '',
  'type': 'medium'},
 {'id': '84',
  'question': 'List airports in the United States with an elevation below 200 feet',
  'query_string': '',
  'type': 'complex'},
 {'id': '85',
  'question': 'What are the mountains with altitudes above 8000 meters?',
  'query_string': '',
  'type': 'medium'},
 {'id': '86',
  'question': 'What are the 3 airports with the largest name?',
  'query_string': '',
  'type': 'simple'},
 {'id': '87',
  'question': 'List lakes with an area of less than 5000 square kilometers.',
  'query_string': '',
  'type': 'medium'},
 {'id': '88',
  'question': 'What is the total area of the seas with a depth greater than 1000 meters?',
  '

In [None]:
import warnings
from sqlalchemy import exc

# Suppress SAWarning
warnings.filterwarnings("ignore", category=exc.SAWarning, message=".*Cannot correctly sort tables.*")

number_of_queries_to_delay = 25
count = 0
for instance in queries:
    if count == number_of_queries_to_delay:
        count = 0
        time.sleep(10)
    with get_openai_callback() as cb:
        start_time = time.time()
        similar_examples = get_similar_examples(instance["question"], base_pairs, base_embeddings, top_k=3)
        print(similar_examples)
        sql_query = query_chain.invoke({"question":instance["question"], "few-shot": similar_examples})
        end_time = time.time()
        instance["query_string"] = sql_query
        instance['total_tokens'] = cb.total_tokens
        instance['prompt_tokens'] = cb.prompt_tokens
        instance['completion_tokens'] = cb.completion_tokens
        instance['total_cost'] = cb.total_cost
        instance['time'] = end_time - start_time
        print(instance['id'], instance['question'], instance["query_string"], instance['time'], instance['total_cost'])
    save_queries(queries)
    count += 1
queries

81 What is the total area of the provinces in Canada? SQLQuery: SELECT SUM(area) FROM province WHERE country = 'CAN' 0.8570704460144043 0.018555000000000002
82 What is the length of the Tigris River? SQLQuery: SELECT length FROM river WHERE name = 'Tigris' 0.8042933940887451 0.018543
83 List the Airports with elevation more than 1000 SQLQuery: SELECT name FROM airport WHERE elevation > 1000 0.5870575904846191 0.018607
84 List airports in the United States with an elevation below 200 feet SQLQuery: 
SELECT name 
FROM airport 
WHERE country = 'USA' AND elevation < 200 0.8549323081970215 0.018649000000000002
85 What are the mountains with altitudes above 8000 meters? SQLQuery: SELECT name FROM mountain WHERE elevation > 8000 0.5952260494232178 0.018649
86 What are the 3 airports with the largest name? SQLQuery: SELECT name FROM airport ORDER BY LENGTH(name) DESC FETCH FIRST 3 ROWS ONLY; 0.7255043983459473 0.018711
87 List lakes with an area of less than 5000 square kilometers. SQLQuery: S

[{'id': '81',
  'question': 'What is the total area of the provinces in Canada?',
  'query_string': "SQLQuery: SELECT SUM(area) FROM province WHERE country = 'CAN'",
  'type': 'complex',
  'total_tokens': 6180,
  'prompt_tokens': 6165,
  'completion_tokens': 15,
  'total_cost': 0.018555000000000002,
  'time': 0.8570704460144043},
 {'id': '82',
  'question': 'What is the length of the Tigris River?',
  'query_string': "SQLQuery: SELECT length FROM river WHERE name = 'Tigris'",
  'type': 'simple',
  'total_tokens': 6176,
  'prompt_tokens': 6161,
  'completion_tokens': 15,
  'total_cost': 0.018543,
  'time': 0.8042933940887451},
 {'id': '83',
  'question': 'List the Airports with elevation more than 1000',
  'query_string': 'SQLQuery: SELECT name FROM airport WHERE elevation > 1000',
  'type': 'medium',
  'total_tokens': 6198,
  'prompt_tokens': 6185,
  'completion_tokens': 13,
  'total_cost': 0.018607,
  'time': 0.5870575904846191},
 {'id': '84',
  'question': 'List airports in the Unite

In [46]:
with open(FILE_NAME_RESULT, encoding='utf-8', errors='ignore') as json_data:
    queries = json.load(json_data, strict=False)
queries = queries['queries']
for instance in queries:
    instance['query_string'] = instance['query_string'].lstrip("SQLQuery: ").strip()
save_queries(queries)

In [47]:
with open(FILE_NAME_RESULT, encoding='utf-8', errors='ignore') as json_data:
    queries = json.load(json_data, strict=False)
queries = queries['queries']
for instance in queries:
    if instance['query_string'][0] != "S":
        instance['query_string'] = "S"+ instance['query_string']
save_queries(queries)