In [53]:
import json
import os

In [54]:
def create_table(db, start, end, id):
  table_name = db['table_names_original'][id]
  table_description = db['table_names'][id]
  all_columns = db['column_names_original']
  all_columns_description = db['column_names']
  database_name = db['db_id']
  table_columns = list(map(lambda column: column[1], all_columns[start:end+1]))
  table_column_descriptions = list(map(lambda column: column[1], all_columns_description[start:end+1]))

  table_column_types = db['column_types'][start:end+1]
  # print(db['primary_keys'][id], db['db_id'], id)
  primary_key = all_columns[int(db['primary_keys'][id])][1] if len(db['primary_keys']) > id else None
  return Table(id, database_name, table_name, table_description, table_columns, \
               table_column_descriptions, table_column_types, primary_key)


def build_database(db):
  columns = db['column_names_original']
  tables = []
  id,start = 0,1
  for index, column in enumerate(columns[1:]):
    #accumlate all the columns related to a particular table
    if column[0] != id:
      table = create_table(db, start, index, id)
      tables.append(table)
      start = index+1
      id += 1

  tables.append(create_table(db, start, len(columns)-1, id))
  return Database(db['db_id'], tables, columns, db['foreign_keys'])

In [55]:
class Database:
  def __init__(self, name, tables, columns, foreign_keys):
    self.name = name
    self.tables = tables
    self.columns = columns
    self.foreign_keys = foreign_keys

  def __str__(self):
    """Returns a string representation of the Database object."""
    return f"Database(name='{self.name}', tables={self.tables}), columns={self.columns}), foreign_keys={self.foreign_keys})"


class Table:
  def __init__(self, id, database_name ,name, description, columns, columns_description,column_types, primary_key):
    self.id = id
    self.database_name = database_name
    self.name = name
    self.description = description
    self.columns = columns
    self.columns_description = columns_description
    self.column_types = column_types
    self.primary_key = primary_key

  def __str__(self):
    """Returns a string representation of the Table object."""
    return f"Table(id={self.id}, name='{self.name}', description='{self.description}', columns={self.columns}, columns_description={self.columns_description}, column_types={self.column_types}, primary_key='{self.primary_key}')"


In [56]:
len(databases)

166

In [63]:
from elasticsearch import Elasticsearch

es = Elasticsearch('https://localhost:9200', http_auth=('elastic', 'ZXJxoo-rx_YWk1kbMs*b'), verify_certs=False)

  es = Elasticsearch('https://localhost:9200', http_auth=('elastic', 'ZXJxoo-rx_YWk1kbMs*b'), verify_certs=False)


In [74]:
# Create an index (use a different index name if needed)
index_name = "database_catalogue_index"
es.indices.create(index=index_name)



ObjectApiResponse({'acknowledged': True, 'shards_acknowledged': True, 'index': 'database_catalogue_index'})

In [59]:
class DatabaseCatalogueEncoder(json.JSONEncoder):
  def default(self, obj):
    if isinstance(obj, Table):
      # Override default behavior for Book objects
      return {
        "name" : obj.name,
        "description" : obj.description,
        # "columns" : obj.columns,
        "columns_description" : obj.columns_description,
        "db_name": obj.database_name,
        # "column_types" : obj.column_types,
        # "primary_key" : obj.primary_key
      }
    return super().default(obj)

In [156]:
tables_file_path = 'Documents/spider/tables.json'
index_name= "database"
SEPARATOR_STR = " "
def build_document_content(table):
    content = table.database_name + SEPARATOR_STR + table.description
    for column in table.columns_description:
        content += SEPARATOR_STR + column
    return content
# file = open('./catalogue.txt', 'w')   
with open(tables_file_path) as tables_file:
  db_list = json.load(tables_file)
  databases = []
  # with open("catalogue.txt", "w") as file:
  for ind,db in enumerate(db_list):
      database = build_database(db)
      # Index the document
      for index, table in enumerate(database.tables):
        document_body = {
            "content": build_document_content(table)
        }
          # if index == 0:
        # file.write(json.dumps(table, cls=DatabaseCatalogueEncoder))
        # file.write("\n")
                  # print(json.dumps(table, cls=DatabaseCatalogueEncoder)) json.dumps(table, cls=DatabaseCatalogueEncoder)
        es.index(index=index_name, id=database.name +'_'+ str(index), body=document_body)
      databases.append(database)
# file.close()
# print(databases[0])



In [142]:
class Prompt:

    def __init__(self):
        self.prompt_template = \
        f"""Given the following table schema:
{{}}

Write a SQL query for the english query:
Dont provide any explaination just give the response query as a string:
{{}}
"""

    def generate_prompt(self, table_schema, question):
        return self.prompt_template.format(table_schema, question)

    def generate_default_prompt(self, question):
        return f"""Write a SQL query for the english query:
Dont provide any explaination just give the response query as a string in a single line:
{question}"""
        

In [89]:
def build_table_schema(database, table, include_reference_tables):
    column_types = table.column_types
    query = f"CREATE TABLE {table.name} (\n"
    for index, column in enumerate(table.columns):
        primary_key = " PRIMARY KEY" if column == table.primary_key else ""
        query = query + f"{column} {column_types[index]}{primary_key},\n"
    all_columns = database.columns
    reference_tables = []
    for fk in database.foreign_keys:
        if all_columns[fk[0]][0] == table.id:
            query += f"FOREIGN KEY({all_columns[fk[0]][1]}) REFERENCES "
            reference_table_ind = all_columns[fk[1]][0]
            reference_tables.append(reference_table_ind)
            reference_table = database.tables[reference_table_ind]
            query += f"{reference_table.name}({all_columns[fk[1]][1]})\n"

    #end of the table
    query += ");\n"

    #add the reference tables to the query string 
    if include_reference_tables and reference_tables:
        for table_ind in reference_tables:
            query += "\n\n"
            query += build_table_schema(database,  database.tables[table_ind], False)
    return query

In [157]:

# print(search_results)
def generate_prompt(eng_query):  
    search_results = es.search(index="database", body={"query": {"match": {"content": eng_query}}})

    hits = search_results['hits']['hits']
    if len(hits) == 0:
        return Prompt().generate_default_prompt(eng_query)
    
    id = hits[0]['_id']

    words = id.split('_')
    db_id, table_id = "_".join(words[:-1]), int(words[-1])
    table_schema  = ""
    for database in databases:
        if database.name == db_id:
            table = database.tables[table_id]
            table_schema = build_table_schema(database, table, True)
            break
    
    basicPrompt = Prompt()
    prompt = basicPrompt.generate_prompt(table_schema, eng_query)
    print(prompt)
    return prompt

In [117]:
for hit in search_results['hits']['hits']:
    print(f"Title: {hit['_id']}")
    print("-" * 20)

Title: geo_3
--------------------
Title: party_host_2
--------------------
Title: college_3_7
--------------------
Title: gas_company_2
--------------------
Title: concert_singer_0
--------------------
Title: world_1_3
--------------------
Title: wedding_0
--------------------
Title: employee_hire_evaluation_2
--------------------
Title: company_employee_1
--------------------
Title: concert_singer_1
--------------------


geo_6


In [159]:
response = es.count(index="database")



In [161]:
response

ObjectApiResponse({'count': 876, '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0}})

In [162]:
es.indices.delete(index="database_index")



ObjectApiResponse({'acknowledged': True})

In [73]:
for index in es.indices.get(index='*'):
  print(index)



In [82]:
id = search_results['hits']['hits'][0]['_id']
words = id.split('_')
db_id, table_id = "_".join(words[:-1]), int(words[-1])
print(db_id, table_id)

geo 6


In [108]:
%pip install openai

Collecting openai
  Downloading openai-1.23.2-py3-none-any.whl.metadata (21 kB)
Collecting pydantic<3,>=1.9.0 (from openai)
  Downloading pydantic-2.7.0-py3-none-any.whl.metadata (103 kB)
     ---------------------------------------- 0.0/103.4 kB ? eta -:--:--
     ----------- --------------------------- 30.7/103.4 kB 1.3 MB/s eta 0:00:01
     -------------------------------------- 103.4/103.4 kB 1.5 MB/s eta 0:00:00
Collecting annotated-types>=0.4.0 (from pydantic<3,>=1.9.0->openai)
  Downloading annotated_types-0.6.0-py3-none-any.whl.metadata (12 kB)
Collecting pydantic-core==2.18.1 (from pydantic<3,>=1.9.0->openai)
Note: you may need to restart the kernel to use updated packages.  Downloading pydantic_core-2.18.1-cp312-none-win_amd64.whl.metadata (6.7 kB)

Downloading openai-1.23.2-py3-none-any.whl (311 kB)
   ---------------------------------------- 0.0/311.2 kB ? eta -:--:--
   -------------------------- ------------- 204.8/311.2 kB 4.1 MB/s eta 0:00:01
   ------------------------

In [147]:
from openai import OpenAI

client = OpenAI(
  base_url = "https://integrate.api.nvidia.com/v1",
  api_key = "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"
)

def generate_sql_query(eng_query):
    completion = client.chat.completions.create(
      model="mistralai/mixtral-8x22b-instruct-v0.1",
      messages=[{"role":"user","content": generate_prompt(eng_query)}],
      temperature=0.5,
      top_p=1,
      max_tokens=1024,
      stream=False
    )

    return completion.choices[0].message.content

In [166]:
with open('./Documents/spider/dev.json') as f:
    train_data = json.load(f)
    actual_query_file = open("actual_query.txt", "w")
    generated_query_file = open("generated_query.txt", "w")

    count = 0 
    for query in train_data:
        count += 1
        actual_sql_query = query['query']
        eng_query = query['question']
        generated_sql_query = generate_sql_query(eng_query)
        
        #truncate entire query into single line
        generated_sql_query = generated_sql_query.replace("\n", "")
        actual_query_file.write(actual_sql_query+'\n')
        generated_query_file.write(generated_sql_query+'\n')

        if count == 3:
            break;
        
    actual_query_file.close()
    generated_query_file.close()
       



Given the following table schema:
CREATE TABLE Tourist_Attractions (
Tourist_Attraction_ID number PRIMARY KEY,
Attraction_Type_Code text,
Location_ID number,
How_to_Get_There text,
Name text,
Description text,
Opening_Hours text,
Other_Details text,
FOREIGN KEY(Attraction_Type_Code) REFERENCES Ref_Attraction_Types(Attraction_Type_Code)
FOREIGN KEY(Location_ID) REFERENCES Locations(Location_ID)


CREATE TABLE Ref_Attraction_Types (
Attraction_Type_Code text PRIMARY KEY,
Attraction_Type_Description text,


CREATE TABLE Locations (
Location_ID number PRIMARY KEY,
Location_Name text,
Address text,
Other_Details text,


Write a SQL query for the english query:
Dont provide any explaination just give the response query as a string:
How many singers do we have?





Given the following table schema:
CREATE TABLE station_company (
Station_ID number PRIMARY KEY,
Company_ID number,
Rank_of_the_Year number,
FOREIGN KEY(Company_ID) REFERENCES company(Company_ID)
FOREIGN KEY(Station_ID) REFERENCES gas_station(Station_ID)


CREATE TABLE company (
Company_ID number PRIMARY KEY,
Rank number,
Company text,
Headquarters text,
Main_Industry text,
Sales_billion number,
Profits_billion number,
Assets_billion number,
Market_Value number,


CREATE TABLE gas_station (
Station_ID number PRIMARY KEY,
Open_Year number,
Location text,
Manager_Name text,
Vice_Manager_Name text,
Representative_Name text,


Write a SQL query for the english query:
Dont provide any explaination just give the response query as a string:
What is the total number of singers?





Given the following table schema:
CREATE TABLE station_company (
Station_ID number PRIMARY KEY,
Company_ID number,
Rank_of_the_Year number,
FOREIGN KEY(Company_ID) REFERENCES company(Company_ID)
FOREIGN KEY(Station_ID) REFERENCES gas_station(Station_ID)


CREATE TABLE company (
Company_ID number PRIMARY KEY,
Rank number,
Company text,
Headquarters text,
Main_Industry text,
Sales_billion number,
Profits_billion number,
Assets_billion number,
Market_Value number,


CREATE TABLE gas_station (
Station_ID number PRIMARY KEY,
Open_Year number,
Location text,
Manager_Name text,
Vice_Manager_Name text,
Representative_Name text,


Write a SQL query for the english query:
Dont provide any explaination just give the response query as a string:
Show name, country, age for all singers ordered by age from the oldest to the youngest.



In [126]:
generate_prompt(eng_query)



'Given the following table schema:\nCREATE TABLE party_host (\nParty_ID number PRIMARY KEY,\nHost_ID number,\nIs_Main_in_Charge others,\nFOREIGN KEY(Party_ID) REFERENCES party(Party_ID)\nFOREIGN KEY(Host_ID) REFERENCES host(Host_ID)\n\n\nCREATE TABLE party (\nParty_ID number PRIMARY KEY,\nParty_Theme text,\nLocation text,\nFirst_year text,\nLast_year text,\nNumber_of_hosts number,\n\n\nCREATE TABLE host (\nHost_ID number PRIMARY KEY,\nName text,\nNationality text,\nAge text,\n\n\nWrite a SQL query for the english query:\nDont provide any explaination just give the response query as a string:\nwhat is the lowest point in iowa\n'

In [165]:
##find the accuracy metrics  


However, if you have a table named 'singers' with columns 'name', 'country', and 'age', the SQL query would look like this:```sqlSELECT name, country, ageFROM singersORDER BY age DESC;```
However, if you have a table named 'singers' with columns 'name', 'country', and 'age', the SQL query would look like this:

```sql
SELECT name, country, age
FROM singers
ORDER BY age DESC;
```

