## Using Text2SQL to generate queries

In [None]:
import sqlite3
from sqlite3 import Error, OperationalError
import sqlparse
import torch
import sqlparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

torch.cuda.is_available()

## Set up SQLite database
dbname = 'testdb'
con = sqlite3.connect(f"{dbname}.sqlite")
cur = con.cursor()
con.close()

table_ddl_statements = """
CREATE TABLE Hospitals (
    hospital_id INTEGER PRIMARY KEY,
    hospital_name TEXT NOT NULL,
    address TEXT NOT NULL,
    city TEXT NOT NULL,
    state TEXT NOT NULL,
    zip_code TEXT NOT NULL,
    phone_number TEXT,
    hospital_type TEXT,
    bed_count INTEGER,
    latitude REAL,
    longitude REAL
);

CREATE TABLE InsuranceProviders (
    insurance_id INTEGER PRIMARY KEY,
    provider_name TEXT NOT NULL,
    provider_contact TEXT,
    provider_address TEXT,
    city TEXT,
    state TEXT,
    zip_code TEXT,
    hospital_id INTEGER,
    FOREIGN KEY (hospital_id) REFERENCES Hospitals(hospital_id)
);

CREATE TABLE InsurancePlans (
    plan_id INTEGER PRIMARY KEY,
    plan_name TEXT NOT NULL,
    plan_type TEXT NOT NULL,
    monthly_premium REAL,
    deductible REAL,
    coverage_percentage INTEGER,
    insurance_id INTEGER,
    FOREIGN KEY (insurance_id) REFERENCES InsuranceProviders(insurance_id)
);
"""

table_insert_statements = """
INSERT INTO Hospitals (hospital_id, hospital_name, address, city, state, zip_code, phone_number, hospital_type, bed_count, latitude, longitude)
VALUES
(1, 'Cedars-Sinai Medical Center', '8700 Beverly Blvd', 'Los Angeles', 'CA', '90048', '555-9876', 'General', 886, 34.0755, -118.3802),
(2, 'Mayo Clinic', '200 First St SW', 'Rochester', 'MN', '55905', '555-4321', 'Specialized', 1265, 44.0221, -92.4668),
(3, 'Mount Sinai Hospital', '1468 Madison Ave', 'New York', 'NY', '10029', '555-3456', 'General', 1185, 40.7892, -73.9525),
(4, 'Houston Methodist Hospital', '6565 Fannin St', 'Houston', 'TX', '77030', '555-7891', 'General', 933, 29.7071, -95.3985),
(5, 'Phoenix Children’s Hospital', '1919 E Thomas Rd', 'Phoenix', 'AZ', '85016', '555-8765', 'Children', 750, 33.4811, -112.0426),
(6, 'Baptist Health South Florida', '8900 N Kendall Dr', 'Miami', 'FL', '33176', '555-5432', 'General', 680, 25.6876, -80.3349),
(7, 'Parkland Health', '5201 Harry Hines Blvd', 'Dallas', 'TX', '75235', '555-2345', 'General', 992, 32.8177, -96.8437),
(8, 'UCSF Medical Center', '505 Parnassus Ave', 'San Francisco', 'CA', '94143', '555-6541', 'Specialized', 745, 37.7631, -122.4586),
(9, 'UW Medical Center', '1959 NE Pacific St', 'Seattle', 'WA', '98195', '555-8763', 'General', 912, 47.6490, -122.3094),
(10, 'Temple University Hospital', '3401 N Broad St', 'Philadelphia', 'PA', '19140', '555-0987', 'Veterans', 728, 39.9932, -75.1555);

INSERT INTO InsuranceProviders (insurance_id, provider_name, provider_contact, provider_address, city, state, zip_code, hospital_id)
VALUES
(1, 'Blue Cross Blue Shield', '555-0001', '500 Main St', 'Chicago', 'IL', '60601', 1),
(2, 'UnitedHealthcare', '555-0002', '600 Market St', 'Dallas', 'TX', '75001', 2),
(3, 'Cigna', '555-0003', '700 Commerce St', 'Philadelphia', 'PA', '19101', 3),
(4, 'Aetna', '555-0004', '800 Park Ave', 'New York', 'NY', '10001', 4),
(5, 'Kaiser Permanente', '555-0005', '900 State St', 'Los Angeles', 'CA', '90001', 1),
(6, 'Humana', '555-0006', '1000 Broadway St', 'Miami', 'FL', '33101', 6),
(7, 'Anthem', '555-0007', '1100 4th Ave', 'Houston', 'TX', '77002', 4),
(8, 'Molina Healthcare', '555-0008', '1200 Elm St', 'Phoenix', 'AZ', '85001', 5),
(9, 'Centene', '555-0009', '1300 Oak St', 'San Francisco', 'CA', '94101', 8),
(10, 'WellCare', '555-0010', '1400 Pine St', 'Seattle', 'WA', '98101', 9);

INSERT INTO InsurancePlans (plan_id, plan_name, plan_type, monthly_premium, deductible, coverage_percentage, insurance_id)
VALUES
(1, 'Blue Shield Platinum Plan', 'PPO', 450.00, 500.00, 90, 1),
(2, 'UnitedHealthcare Silver Plan', 'HMO', 300.00, 1000.00, 80, 2),
(3, 'Cigna Gold Plan', 'EPO', 400.00, 750.00, 85, 3),
(4, 'Aetna Bronze Plan', 'PPO', 200.00, 1500.00, 70, 4),
(5, 'Kaiser Family Health Plan', 'HMO', 350.00, 900.00, 80, 5),
(6, 'Humana Premier Plan', 'POS', 375.00, 800.00, 85, 6),
(7, 'Anthem Essential Plan', 'EPO', 325.00, 1200.00, 75, 7),
(8, 'Molina Basic Plan', 'HMO', 250.00, 1300.00, 70, 8),
(9, 'Centene Advantage Plan', 'PPO', 425.00, 600.00, 90, 9),
(10, 'WellCare Senior Plan', 'POS', 290.00, 1100.00, 80, 10);
"""

user_input_entity_relationships = """
sales.product_id can be joined with products.product_id
Patients.insurance_id can be joined with InsuranceProviders.insurance_id
InsuranceProviders.hospital_id can be joined with Hospitals.hospital_id
InsurancePlans.insurance_id can be joined with InsuranceProviders.insurance_id
"""

## Creating tables and inserting data
delimiter = ';'
table_ddl_statements = table_ddl_statements.split(delimiter)
con = sqlite3.connect(f"{dbname}.sqlite")
cur = con.cursor()
for ddl in table_ddl_statements:
  try:
    result = cur.execute(ddl+delimiter)
  except Error as e:
    if 'already exists' in str(e):
      pass
    else:
      raise OperationalError('Something else has gone wrong:',e)
con.close()

table_insert_statements = table_insert_statements.split(delimiter)
con = sqlite3.connect(f"{dbname}.sqlite")
cur = con.cursor()
for query in table_insert_statements:
  try:
    result = cur.execute(query+delimiter)
    con.commit()
  except Error as e:
    if 'UNIQUE' in str(e):
      pass
    else:
      raise OperationalError('Something else has gone wrong:',e)
con.close()

print('Database auto-modelling complete.')

## Summarizing helper functions
def summarize_database(dbname):
  """Dynamically generate a list of table DDLs for all tables in the database, to use in prompting."""
  delimiter = ';'
  final_ddl_concatenated_string = []
  con = sqlite3.connect(f"{dbname}.sqlite")
  cur = con.cursor()
  cur.execute(f"SELECT * FROM sqlite_master WHERE type='table'")
  result = cur.fetchall()
  print(len(result), 'table(s) parsed.')
  for row in result:
    ddl = sqlparse.format(row[-1])
    final_ddl_concatenated_string.append(ddl+delimiter)
  return final_ddl_concatenated_string
def summarize_relationships(user_input):
  entity_relationships = []
  delimiter = '--'
  for relationship in user_input.lstrip().rstrip().split("\n"):
    formatted_relationship = delimiter + ' ' + relationship
    entity_relationships.append(formatted_relationship)
  print(len(entity_relationships), 'relationship(s) parsed.')
  return '\n'.join(entity_relationships)

Database auto-modelling complete.


In [None]:
con = sqlite3.connect(f"{dbname}.sqlite")
cur = con.cursor()
cur.execute("""SELECT DISTINCT h.hospital_name
FROM Hospitals h
JOIN InsuranceProviders i ON h.hospital_id = i.hospital_id
WHERE
   h.state = 'CA'
ORDER BY h.hospital_name NULLS LAST; """)
results = cur.fetchall()
for row in results:
  print(row)

('Cedars-Sinai Medical Center',)
('UCSF Medical Center',)


In [None]:
## Set up SQLCoder model
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# # Verify MPS
# torch.set_default_device("mps")
# if torch.backends.mps.is_available():
#     mps_device = torch.device("mps")
#     x = torch.ones(1, device=mps_device)
#     print (x)
# else:
#     print ("MPS device not found.")

# Set flag manually
run_16bit_flag = True
torch.set_default_device("cuda")

if run_16bit_flag:
    # if you have atleast 15GB of GPU memory, run load the model in float16
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
    )
else:
    # else, load in 8 bits – this is a bit slower
    print("8-bit load")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        # torch_dtype=torch.float16,
        load_in_8bit=False,
        device_map="auto",
        use_cache=True,
    )

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
## Inferencing using SQLCoder
def format_prompt_db_specifics(dbname, user_input_entity_relationships):

  db_schema_string = summarize_database(dbname)
  entity_relationships_string = summarize_relationships(user_input=user_input_entity_relationships)
  print('Prompt generation complete.\n')

  prompt = """### Task
  Generate a SQL query to answer [QUESTION]{{question}}[/QUESTION]

  ### Instructions
  - If you cannot answer the question with the available database schema, return 'I do not know'
  - Use only direct string comparison

  ### Database Schema
  This query will run on a database whose schema is represented in this string:
  {db_schema_string}

  {entity_relationships}

  ### Answer
  Given the database schema, here is the SQL query that answers [QUESTION]{{question}}[/QUESTION]
  [SQL]"""

  return prompt.format(
      db_schema_string=db_schema_string,
      entity_relationships=entity_relationships_string
)

def generate_query(dbname, user_input_entity_relationships, question):
    print('Analyzing database to build prompt..')
    prompt = format_prompt_db_specifics(dbname, user_input_entity_relationships)
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

question = "Give me a list of hospitals in CA and their latitude and longitude values that have Blue Cross Blue Shield insurance."
generated_sql = generate_query(dbname, user_input_entity_relationships, question)

print(f"The question was: '{question}' \n")

print("Corresponding SQL query to execute against the underlying database:")
generated_sql = generated_sql.replace('ilike', 'like').replace('ILIKE', 'like').replace('CA%','CA').replace('%CA','CA')
print(generated_sql, '\n')
results = cur.fetchall()

print("Results from query:\n")

dbname = 'testdb'
con = sqlite3.connect(f"{dbname}.sqlite")
cur = con.cursor()
cur.execute(generated_sql)
results = cur.fetchall()
for row in results:
  print(row)

Analyzing database to build prompt..
3 table(s) parsed.
4 relationship(s) parsed.
Prompt generation complete.

The question was: 'Give me a list of hospitals in CA and their latitude and longitude values that have Blue Cross Blue Shield insurance.' 

Corresponding SQL query to execute against the underlying database:

SELECT h.hospital_name,
       h.latitude,
       h.longitude
FROM Hospitals h
JOIN InsuranceProviders i ON h.hospital_id = i.hospital_id
WHERE i.provider_name like '%Blue%Cross%Blue%Shield%'
  AND h.state = 'CA'
ORDER BY h.hospital_name NULLS LAST; 

Results from query:

('Cedars-Sinai Medical Center', 34.0755, -118.3802)
