In [None]:
ROOT_PATH = 'your/root/path'

In [2]:
def _load_password_and_api_key(key_file_path):
    """
    """
    with open(key_file_path, 'r') as f:
        api_key = f.read().strip()
    return api_key

In [3]:
import certifi
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi

username = 's322796'
password = _load_password_and_api_key(ROOT_PATH + 'Data/Auth/mongodb.atlas.clusters/cluster0.key')

uri = f"mongodb+srv://{username}:{password}@cluster0.ixr2wwl.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"

# Create a new client and connect to the server
client = MongoClient(uri, server_api=ServerApi('1'), tlsCAFile=certifi.where())

# Send a ping to confirm a successful connection
try:
    client.admin.command('ping')
    print("Pinged your deployment. You successfully connected to MongoDB!")
except Exception as e:
    print(e)

Pinged your deployment. You successfully connected to MongoDB!


In [4]:
import os
import logging
from tqdm import tqdm
from typing import Dict, List
import json
import random
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob

# API keys for models
# anthropic_api_key = load_api_key(ROOT_PATH + 'Data/anthropic.api.key/text2sql.key')
# together_api_key = load_api_key(ROOT_PATH + 'Data/together.ai.api.key/API.key')

In [5]:
# Split data and load training data into RAG
data_dirs = [ROOT_PATH + 'DataSampling/data/enriched_dataset/entriched_full_dataset_1/']

def load_data(data_dirs: List[str]) -> pd.DataFrame:

    data = {}
    for data_dir in data_dirs:
        data_list = glob.glob(data_dir + '**/instance_*.json', recursive=True)

        for data_file in data_list:
            
            with open(data_file, 'r') as f:
                data_json = json.load(f)
            
                if data_json['id'] not in data:
                    data[data_json['id']] = []
                data[data_json['id']].append(data_json)
    
    data_list = []
    for ket,values in data.items():

        base_instance = values[0]
        new_instance = base_instance.copy()
        new_instance['inference_results'] = []

        for instance in values:
            new_instance['inference_results'].append(instance['inference_results'])
        
        data_list.append(new_instance)

    df = pd.DataFrame(data_list)
    # df = df.set_index('id')
    return df

In [6]:
df = load_data(data_dirs)

df.head()

Unnamed: 0,id,dataset,database,schemas,question,sql,evidence,difficulty,question_analysis,sql_analysis,inference_results
0,976,spider,"{'name': 'dog_kennels', 'path': ['spider_strat...","{'name': 'dog_kennels', 'path': ['spider_strat...",How much does the most recent treatment cost?,SELECT cost_of_treatment FROM Treatments ORDER...,,simple,"{'char_length': 45, 'word_length': 8, 'entitie...","{'char_length': 80, 'tables_count': 1, 'tables...","[{'has_prediction': True, 'model': {'model_nam..."
1,833,spider,"{'name': 'orchestra', 'path': ['spider_stratif...","{'name': 'orchestra', 'path': ['spider_stratif...",Return the maximum and minimum shares for perf...,"SELECT max(SHARE) , min(SHARE) FROM performan...",,simple,"{'char_length': 94, 'word_length': 16, 'entiti...","{'char_length': 75, 'tables_count': 1, 'tables...","[{'has_prediction': True, 'model': {'model_nam..."
2,130,spider,"{'name': 'car_1', 'path': ['spider_stratified_...","{'name': 'car_1', 'path': ['spider_stratified_...",What are the names of all European countries w...,SELECT T1.CountryName FROM COUNTRIES AS T1 JOI...,,simple,"{'char_length': 75, 'word_length': 13, 'entiti...","{'char_length': 227, 'tables_count': 3, 'table...","[{'has_prediction': True, 'model': {'model_nam..."
3,649,spider,"{'name': 'poker_player', 'path': ['spider_stra...","{'name': 'poker_player', 'path': ['spider_stra...",List the earnings of poker players in descendi...,SELECT Earnings FROM poker_player ORDER BY Ear...,,simple,"{'char_length': 55, 'word_length': 9, 'entitie...","{'char_length': 56, 'tables_count': 1, 'tables...","[{'has_prediction': True, 'model': {'model_nam..."
4,188,bird,"{'name': 'financial', 'path': ['stratified_out...","{'name': 'financial', 'path': ['stratified_out...",Among the accounts who have loan validity more...,SELECT T1.account_id FROM loan AS T1 INNER JOI...,,moderate,"{'char_length': 164, 'word_length': 28, 'entit...","{'char_length': 185, 'tables_count': 2, 'table...","[{'has_prediction': True, 'model': {'model_nam..."


In [7]:
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

len(train_df), len(test_df)

(385, 97)

In [None]:
from vector_storage import Text2SQLVectorDB

vec_db = Text2SQLVectorDB(
    mongodb_uri=uri,
    db_name="text2sql_vectordb",
    collection_name="enriched_instances_1",
    root_data_dir= ROOT_PATH + 'DataSampling/data/datasets'
)

  from .autonotebook import tqdm as notebook_tqdm


Collection 'enriched_instances_1' created.
New search index named vector_index is building.
Polling to check if the index is ready. This may take up to 60 seconds.
Index 'vector_index' is ready for querying.


In [9]:
for index, row in tqdm(train_df.iterrows(), total=len(train_df)):
    instance = row.to_dict()
    try:
        vec_db.store_instance(instance)
    except Exception as e:
        print(f"Error storing instance {instance['id']}: {e}")
        continue

100%|██████████| 385/385 [00:29<00:00, 13.10it/s]


In [18]:
sample_index = 10

sample_data = test_df.iloc[sample_index].to_dict()

sample_input_text = vec_db._create_instance_text(sample_data)

print("The sample instance id : ",sample_data['id'])

print("The sample dataset name : ",sample_data['dataset'])

print("The sample database name : ",sample_data['database']['name'])

print("The sample test difficulty level : ",sample_data['difficulty'])

print(f"\n {100 * '-'}\n The sample input prompt to be embedded : \n {100 * '-'}\n",sample_input_text)

The sample instance id :  1495
The sample dataset name :  bird
The sample database name :  debit_card_specializing
The sample test difficulty level :  simple

 ----------------------------------------------------------------------------------------------------
 The sample input prompt to be embedded : 
 ----------------------------------------------------------------------------------------------------
 Question: Which client ID consumed the most in September 2013?
Evidence: September 2013 refers to yearmonth.date = '201309'
Table: customers Description: nan DDL: CREATE TABLE customers
(
    CustomerID INTEGER UNIQUE     not null
        primary key,
    Segment    TEXT null,
    Currency   TEXT null
); Table: gasstations Description: nan DDL: CREATE TABLE gasstations
(
    GasStationID INTEGER    UNIQUE   not null
        primary key,
    ChainID      INTEGER          null,
    Country      TEXT null,
    Segment      TEXT null
); Table: products Description: nan DDL: CREATE TABLE pro

In [32]:
limit = 50

similar_examples = pd.DataFrame(vec_db.find_similar_instances(query_text=sample_input_text, limit=limit))

similar_examples

Unnamed: 0,original_instance,score
0,"{'id': 1474, 'dataset': 'bird', 'database': {'...",0.825282
1,"{'id': 1489, 'dataset': 'bird', 'database': {'...",0.79283
2,"{'id': 1479, 'dataset': 'bird', 'database': {'...",0.784701
3,"{'id': 1472, 'dataset': 'bird', 'database': {'...",0.777476
4,"{'id': 1527, 'dataset': 'bird', 'database': {'...",0.7771
5,"{'id': 1531, 'dataset': 'bird', 'database': {'...",0.738905
6,"{'id': 1513, 'dataset': 'bird', 'database': {'...",0.728154
7,"{'id': 1491, 'dataset': 'bird', 'database': {'...",0.720641
8,"{'id': 1516, 'dataset': 'bird', 'database': {'...",0.719535
9,"{'id': 1518, 'dataset': 'bird', 'database': {'...",0.690281


In [38]:
k = 15

score = similar_examples.iloc[k]['score']
the_similar = similar_examples.iloc[k]['original_instance']

print("The Similarity score : ",score)

print(f"The Similar {k} instance id : ",the_similar['id'])

print(f"The Similar {k} dataset name : ",the_similar['dataset'])

print(f"The Similar {k} database name : ",the_similar['database']['name'])

print(f"The Similar {k} test difficulty level : ",the_similar['difficulty'])

print(f"\n {100 * '-'}\n The Similar {k} input prompt to be embedded : \n {100 * '-'}\n",vec_db._create_instance_text(the_similar))


The Similarity score :  0.6543048024177551
The Similar 15 instance id :  1481
The Similar 15 dataset name :  bird
The Similar 15 database name :  debit_card_specializing
The Similar 15 test difficulty level :  challenging

 ----------------------------------------------------------------------------------------------------
 The Similar 15 input prompt to be embedded : 
 ----------------------------------------------------------------------------------------------------
 Question: What is the difference in the annual average consumption of the customers with the least amount of consumption paid in CZK for 2013 between SME and LAM, LAM and KAM, and KAM and SME?
Evidence: annual average consumption of customer with the lowest consumption in each segment = total consumption per year / the number of customer with lowest consumption in each segment; Difference in annual average = SME's annual average - LAM's annual average; Difference in annual average = LAM's annual average - KAM's annual a