# Sampling Similar NL-to-SQL Pairs for Enhanced Prompting

In [14]:
from langchain.sql_database import SQLDatabase

from langchain_openai import ChatOpenAI
from langchain.chains import create_sql_query_chain
from urllib.parse import quote  
from langchain.callbacks import get_openai_callback
import time
from dotenv import load_dotenv
import os
import sys
import json
load_dotenv()

import warnings
warnings.filterwarnings("ignore", category=SyntaxWarning)

experiment_path = '..\..'
path = os.path.abspath('')
module_path = os.path.join(path, experiment_path)
if module_path not in sys.path:
    sys.path.append(module_path+"\\functions")

from sqldatabase_langchain_utils import SQLDatabaseLangchainUtils


## Schema

In [15]:
SCHEMA = 'mondial_gpt'
PREFIX = 'mondial'

FILE_NAME_RESULT = f"sql_queries_chatgpt_few_shot_50_{SCHEMA}.json"

def save_queries(queries):
    data = {"queries":queries}
    with open(FILE_NAME_RESULT, "w") as arquivo_json:
        json.dump(data, arquivo_json, indent=4) 

def read_queries():
    with open(FILE_NAME_RESULT, encoding='utf-8', errors='ignore') as json_data:
        data = json.load(json_data, strict=False)
    queries = data["queries"]
    return queries

## Connection

In [16]:
json_file_path = f"../../datasets/{SCHEMA}_db_connection.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    db_connection = json.load(json_data, strict=False)

## Database Information

In [17]:
from sqlalchemy.types import UserDefinedType

class GeoCoordType(UserDefinedType):
    def get_col_spec(self):
        return "GEOCOORD"

    def bind_processor(self, dialect):
        def process(value):
            return value
        return process

    def result_processor(self, dialect, coltype):
        def process(value):
            return value
        return process

from sqlalchemy.dialects.oracle import base as oracle_base
oracle_base.ischema_names["GEOCOORD"] = GeoCoordType

In [18]:
db = SQLDatabaseLangchainUtils(db_connection=db_connection)

exclude = [
    f"{SCHEMA}_tmdp",
    f"{SCHEMA}_tmdpmap",
    f"{SCHEMA}_tmds",
    f"{SCHEMA}_tmjmap",
    f"{SCHEMA}_tpv",
    f"{SCHEMA}_tmdc",
    f"{SCHEMA}_tmdcmap",
    f"{SCHEMA}_tmdej",
    f"{SCHEMA}_log_action",
    f"{SCHEMA}_log_error",
    f"{SCHEMA}_favorite_item", 
    f"{SCHEMA}_favorite_query",
    f"{SCHEMA}_favorite_tag",
    f"{SCHEMA}_favorite_tag_item",
    f"{SCHEMA}_favorite_visualization",
    f"{SCHEMA}_dashboard",
    f"{SCHEMA}_history",
    "teste_cliente",
    "teste_fornecedor",
    "teste_funcionario"
]

include_tables = [s for s in db.get_table_names() if not s.startswith(PREFIX) and s not in exclude]
db = SQLDatabaseLangchainUtils(db_connection=db_connection, include_tables=include_tables)

## Importing Dataset

In [19]:
json_file_path = f"../../datasets/{PREFIX}/{PREFIX}_dataset.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    dataset = json.load(json_data, strict=False)
dataset = dataset['dataset']
filtered_dataset = [{"id": entry["id"], "question": entry["question"], "query": entry["query"]} for entry in dataset]
filtered_dataset

[{'id': '1',
  'question': 'What is the area of Thailand?',
  'query': "SELECT max(area) FROM mondial_country where name  = 'Thailand'"},
 {'id': '2',
  'question': 'What are the provinces with an area greater than 10000?',
  'query': 'SELECT name FROM mondial_province where area > 10000'},
 {'id': '3',
  'question': 'What are the languages spoken in Poland?',
  'query': "SELECT mondial_language.name FROM mondial_language INNER JOIN mondial_country ON mondial_language.country = mondial_country.code WHERE mondial_country.name = 'Poland'"},
 {'id': '4',
  'question': 'How deep is Lake Kariba?',
  'query': "SELECT depth FROM mondial_lake where name LIKE '%Lake Kariba%'"},
 {'id': '5',
  'question': 'What is the total of provinces of Netherlands?',
  'query': "SELECT count(*) FROM mondial_province p INNER JOIN mondial_country c ON p.country = c.code WHERE c.name = 'Netherlands'"},
 {'id': '6',
  'question': 'What is the percentage of religious people are hindu in thailand?',
  'query': "SE

In [20]:
base_pairs = filtered_dataset[:50]
test_pairs = filtered_dataset[50:]

## Functions for Similarity

In [21]:
from sentence_transformers import SentenceTransformer, util
import torch

embedder = SentenceTransformer('all-MiniLM-L6-v2')
base_nl_queries = [pair['question'] for pair in base_pairs]
base_embeddings = embedder.encode(base_nl_queries, convert_to_tensor=True)

def get_similar_examples(new_query, base_pairs, base_embeddings, top_k=5):
    new_embedding = embedder.encode(new_query, convert_to_tensor=True)
    cosine_scores = util.pytorch_cos_sim(new_embedding, base_embeddings)[0]
    top_results = torch.topk(cosine_scores, k=top_k)
    similar_examples = [base_pairs[idx] for idx in top_results.indices]
    return similar_examples

new_query = "What is the total area of the provinces in Canada"
similar_examples = get_similar_examples(new_query, base_pairs, base_embeddings, top_k=5)
similar_examples


[{'id': '2',
  'question': 'What are the provinces with an area greater than 10000?',
  'query': 'SELECT name FROM mondial_province where area > 10000'},
 {'id': '5',
  'question': 'What is the total of provinces of Netherlands?',
  'query': "SELECT count(*) FROM mondial_province p INNER JOIN mondial_country c ON p.country = c.code WHERE c.name = 'Netherlands'"},
 {'id': '19',
  'question': 'List all provinces in Germany',
  'query': "SELECT p.name FROM mondial_province p INNER JOIN mondial_country c ON p.country = c.code  WHERE c.name = 'Germany'"},
 {'id': '7',
  'question': 'List the number of provinces each river flows through.',
  'query': 'SELECT river.name, COUNT(DISTINCT geo_river.province) FROM river JOIN geo_river ON river.name = geo_river.river GROUP BY river.name'},
 {'id': '11',
  'question': 'How many provinces have areas greater than 1000 in Niger?',
  'query': "SELECT count(p.name) FROM mondial_province p INNER JOIN mondial_country c on p.country = c.code where c.name =

## Creating the prompt

In [22]:
from langchain.prompts.prompt import PromptTemplate

f = open(f"prompt.txt", "r")
prompt_template = f.read()
f.close()

prompt_template += "\n\nBelow are some examples of natural language queries and their corresponding SQL queries:\n\n{few-shot}"

PROMPT = PromptTemplate(
    input_variables=["input", "table_info", "top_k", "few-shot"], template=prompt_template
)

print(PROMPT)

input_variables=['few-shot', 'input', 'table_info', 'top_k'] input_types={} partial_variables={} template='You are an Oracle SQL expert. Given an input question, first create a syntactically correct Oracle SQL query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, don\'t query for at {top_k} most results or any using the FETCH FIRST n ROWS ONLY clause as per Oracle SQL. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use TRUNC(SYSDATE) function to get the current date, if the question involves "today". \n\nSome hints:

## SQLQueryChain

In [23]:
query_chain  = create_sql_query_chain(ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo-16k'), db.db, prompt=PROMPT)

### Example 1

In [24]:
with get_openai_callback() as cb:
    sql_query = query_chain.invoke({"question":"What is the total area of the provinces in Canada", "few-shot": similar_examples})
    
    print(cb.total_tokens)
    print(cb.prompt_tokens)
    print(cb.completion_tokens)
    print(cb.total_cost)
sql_query

6287
6272
15
0.018876


"SQLQuery: SELECT SUM(area) FROM province WHERE country = 'CAN'"

In [25]:
json_file_path = f"../../datasets/{PREFIX}/queries_{PREFIX}.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    queries = json.load(json_data, strict=False)
queries = queries['queries']
queries = queries[50:]
queries

[{'id': '51',
  'question': 'List the name of capital cities for which we do not have data about the city in located',
  'query_string': '',
  'type': 'complex'},
 {'id': '52',
  'question': 'What is the capital of the provice Andalucía?',
  'query_string': '',
  'type': 'simple'},
 {'id': '53',
  'question': 'What is the abbreviation for the organization whose name is "General Confederation of Trade Unions"?',
  'query_string': '',
  'type': 'simple'},
 {'id': '54',
  'question': 'In which city is the organization with the abbreviation "UPU" ?',
  'query_string': '',
  'type': 'simple'},
 {'id': '55',
  'question': 'List the names of countries which are members of only one organization',
  'query_string': '',
  'type': 'complex'},
 {'id': '56',
  'question': 'List the name of countries which are not a member of NATO.',
  'query_string': '',
  'type': 'complex'},
 {'id': '57',
  'question': 'List the name of countries which are a member of NATO.',
  'query_string': '',
  'type': 'mediu

In [26]:
import warnings
from sqlalchemy import exc

# Suppress SAWarning
warnings.filterwarnings("ignore", category=exc.SAWarning, message=".*Cannot correctly sort tables.*")

number_of_queries_to_delay = 25
count = 0
for instance in queries:
    if count == number_of_queries_to_delay:
        count = 0
        time.sleep(10)
    with get_openai_callback() as cb:
        start_time = time.time()
        similar_examples = get_similar_examples(instance["question"], base_pairs, base_embeddings, top_k=3)
        print(similar_examples)
        sql_query = query_chain.invoke({"question":instance["question"], "few-shot": similar_examples})
        end_time = time.time()
        instance["query_string"] = sql_query
        instance['total_tokens'] = cb.total_tokens
        instance['prompt_tokens'] = cb.prompt_tokens
        instance['completion_tokens'] = cb.completion_tokens
        instance['total_cost'] = cb.total_cost
        instance['time'] = end_time - start_time
        print(instance['id'], instance['question'], instance["query_string"], instance['time'], instance['total_cost'])
    save_queries(queries)
    count += 1
queries

[{'id': '47', 'question': 'List all the capitals of European countries.', 'query': "SELECT c.capital FROM mondial_country c\nINNER JOIN mondial_encompasses e ON e.country = c.code\nINNER JOIN mondial_continent co ON co.name = e.continent\nWHERE co.name = 'Europe'"}, {'id': '16', 'question': 'Show the cities with longitude between 0 and 6.', 'query': 'SELECT name FROM mondial_city WHERE longitude BETWEEN 0 AND 6'}, {'id': '21', 'question': 'How many cities have populations less than 1000?', 'query': 'SELECT count(*) FROM mondial_city WHERE population < 1000'}]
51 List the name of capital cities for which we do not have data about the city in located SQLQuery: 
SELECT c.capital 
FROM country c
LEFT JOIN located l ON c.capital = l.city
WHERE l.city IS NULL 1.664595127105713 0.018622
[{'id': '10', 'question': 'What is the capital of Georgia?', 'query': "SELECT capital FROM mondial_country WHERE name = 'Georgia'"}, {'id': '47', 'question': 'List all the capitals of European countries.', 'qu

[{'id': '51',
  'question': 'List the name of capital cities for which we do not have data about the city in located',
  'query_string': 'SQLQuery: \nSELECT c.capital \nFROM country c\nLEFT JOIN located l ON c.capital = l.city\nWHERE l.city IS NULL',
  'type': 'complex',
  'total_tokens': 6197,
  'prompt_tokens': 6166,
  'completion_tokens': 31,
  'total_cost': 0.018622,
  'time': 1.664595127105713},
 {'id': '52',
  'question': 'What is the capital of the provice Andalucía?',
  'query_string': "SQLQuery: SELECT capital FROM province WHERE name = 'Andalucía'",
  'type': 'simple',
  'total_tokens': 6178,
  'prompt_tokens': 6162,
  'completion_tokens': 16,
  'total_cost': 0.01855,
  'time': 0.6058299541473389},
 {'id': '53',
  'question': 'What is the abbreviation for the organization whose name is "General Confederation of Trade Unions"?',
  'query_string': "SQLQuery: SELECT abbreviation FROM organization WHERE name = 'General Confederation of Trade Unions'",
  'type': 'simple',
  'total

In [27]:
with open(FILE_NAME_RESULT, encoding='utf-8', errors='ignore') as json_data:
    queries = json.load(json_data, strict=False)
queries = queries['queries']
for instance in queries:
    instance['query_string'] = instance['query_string'].lstrip("SQLQuery: \n").strip()
save_queries(queries)

In [28]:
with open(FILE_NAME_RESULT, encoding='utf-8', errors='ignore') as json_data:
    queries = json.load(json_data, strict=False)
queries = queries['queries']
for instance in queries:
    if instance['query_string'][0] != "S":
        instance['query_string'] = "S"+ instance['query_string']
save_queries(queries)

In [30]:
to_fix = ['89', '98']
for instance in queries:
    if instance["id"] in to_fix:
        if count == number_of_queries_to_delay:
            count = 0
            time.sleep(10)
        with get_openai_callback() as cb:
            start_time = time.time()
            similar_examples = get_similar_examples(instance["question"], base_pairs, base_embeddings, top_k=3)
            print(similar_examples)
            sql_query = query_chain.invoke({"question":instance["question"], "few-shot": similar_examples})
            end_time = time.time()
            instance["query_string"] = sql_query
            instance['total_tokens'] = cb.total_tokens
            instance['prompt_tokens'] = cb.prompt_tokens
            instance['completion_tokens'] = cb.completion_tokens
            instance['total_cost'] = cb.total_cost
            instance['time'] = end_time - start_time
            print(instance['id'], instance['question'], instance["query_string"], instance['time'], instance['total_cost'])
        save_queries(queries)
        count += 1
queries

[{'id': '17', 'question': 'Select cities whose population is greater than 100000, altitude greater than 2500, and the country you belong to has population growth greater than 1.', 'query': 'SELECT ci.name FROM mondial_city ci\nINNER JOIN mondial_country c ON c.code = ci.country\nINNER JOIN mondial_population p ON p.country = c.code \nWHERE ci.population > 100000 AND ci.elevation > 2500 AND \np.population_growth > 1'}, {'id': '25', 'question': 'For all countries, give the sum of the population of all its neighbors', 'query': 'SELECT c.name AS country_name, SUM(c1.population) AS total_neighbor_population\nFROM mondial_Country c\nJOIN mondial_Borders b ON c.code = b.country1\nJOIN mondial_Country c1 ON b.country2 = c1.code\nGROUP BY c.name;'}, {'id': '21', 'question': 'How many cities have populations less than 1000?', 'query': 'SELECT count(*) FROM mondial_city WHERE population < 1000'}]
89 Find the countries whose name starts with the letter "B" and have a population greater than 10 mil

[{'id': '51',
  'question': 'List the name of capital cities for which we do not have data about the city in located',
  'query_string': 'SELECT c.capital \nFROM country c\nLEFT JOIN located l ON c.capital = l.city\nWHERE l.city IS NULL',
  'type': 'complex',
  'total_tokens': 6197,
  'prompt_tokens': 6166,
  'completion_tokens': 31,
  'total_cost': 0.018622,
  'time': 1.664595127105713},
 {'id': '52',
  'question': 'What is the capital of the provice Andalucía?',
  'query_string': "SELECT capital FROM province WHERE name = 'Andalucía'",
  'type': 'simple',
  'total_tokens': 6178,
  'prompt_tokens': 6162,
  'completion_tokens': 16,
  'total_cost': 0.01855,
  'time': 0.6058299541473389},
 {'id': '53',
  'question': 'What is the abbreviation for the organization whose name is "General Confederation of Trade Unions"?',
  'query_string': "SELECT abbreviation FROM organization WHERE name = 'General Confederation of Trade Unions'",
  'type': 'simple',
  'total_tokens': 6215,
  'prompt_tokens