In [None]:
%pip install sqlalchemy
%pip install python-dotenv
%pip install databricks-sql-connector 

In [7]:
import os
import requests
from IPython.display import Markdown, display
from dotenv import load_dotenv
from sqlalchemy import create_engine, Table, Column, Integer, String, MetaData, insert , select
from sqlalchemy.schema import CreateTable
from sqlalchemy import text

In [8]:
load_dotenv()

GPT4V_KEY = os.getenv("AOAI_API_KEY")
GPT4V_ENDPOINT = os.getenv("AOAI_ENDPOINT")
OPENAI_API_VERSION = os.getenv("AOAI_MODEL_API_VERSION_4o") 

## Using Azure Databricks

In [13]:
from databricks import sql  
  
# Connection details  
server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME")  
http_path = os.getenv("DATABRICKS_HTTP_PATH")
access_token = os.getenv("DATABRICKS_ACCESS_TOKEN")
  
# Establish connection  
connection = sql.connect(  
  server_hostname=server_hostname,  
  http_path=http_path,  
  access_token=access_token  
)  

In [14]:
# Create a cursor object  
cursor = connection.cursor()  
  
# Define the SQL CREATE TABLE statement as a string  
# We cannot enfore the primary key constraint in Delta Lake we would need to enable the Unity catalog
create_table_sql = """  
CREATE TABLE IF NOT EXISTS city_stats (  
    city_name STRING,  
    population INT,  
    country STRING  
) USING DELTA;  
"""  
  
# Execute the SQL command to create the table  
cursor.execute(create_table_sql)  

# Insert sample data into the table  
cursor.execute("""  
INSERT INTO city_stats VALUES  
('Toronto', 2930000, 'Canada'),  
('Tokyo', 13960000, 'Japan'),  
('Chicago', 2679000, 'United States'),  
('Seoul', 9776000, 'South Korea')
""")  
  
# Commit the transaction  
connection.commit()  
  
# Query the table and fetch results  
cursor.execute("SELECT * FROM city_stats")  
result = cursor.fetchall()  
  
# Print the results  
for row in result:  
    print(row)  

Row(city_name='Toronto', population=2930000, country='Canada')
Row(city_name='Tokyo', population=13960000, country='Japan')
Row(city_name='Chicago', population=2679000, country='United States')
Row(city_name='Seoul', population=9776000, country='South Korea')


## Azure OpenAI text-to-sql into Azure Databricks

In [None]:
# The prompting schema is used based on this paper https://arxiv.org/abs/2308.15363v2

In [17]:
schema_info = "/* Given the following database schema : */\n"+ create_table_sql

In [23]:
text_info = """/* Answer the following: Sum the total amount of population in all cities containing a o in its name?  
↳ */  
SELECT"""  

In [24]:
examples = """  
/* Some example questions and corresponding SQL queries   
are provided based on similar problems: */  
/* Answer the following: How many authors are there? */  
SELECT count(*) FROM authors  
  
/* Answer the following: How many farms are there? */  
SELECT count(*) FROM farm  
"""  


In [25]:
# Combine the variables into a single text input  
text_input = f"{schema_info} {text_info} {examples}"

In [26]:
import requests
headers = {
    "Content-Type": "application/json",
    "api-key": GPT4V_KEY,
}
# Payload for the request
payload = {
  "messages": [
    {
      "role": "system",
      "content": [
        {
          "type": "text",
          "text": "You are an AI assistant that helps to compile text to SQL statements. Just pure SQL nothing more no need to say '''sql '''. You use the Databricks SQL Dialect"
        }
      ]
    },
    {
      "role": "user",
      "content": [
        {
          "type": "text",
          "text": text_input
        }
      ]
    }
  ],
  "temperature": 0.1,
  "top_p": 0.95,
  "max_tokens": 800
}

# Send request
try:
    response = requests.post(GPT4V_ENDPOINT, headers=headers, json=payload)
    response.raise_for_status()  # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
except requests.RequestException as e:
    raise SystemExit(f"Failed to make the request. Error: {e}")

# Handle the response as needed (e.g., print or process)
response_data = response.json()  
chat_output = response_data['choices'][0]['message']['content']  
print(chat_output)

SELECT SUM(population) 
FROM city_stats 
WHERE city_name LIKE '%o%';


In [27]:
# Query the table and fetch results  
cursor.execute(chat_output)  
result = cursor.fetchall()  
  
# Print the results  
for row in result:  
    print(row)  

Row(sum(population)=29345000)


In [6]:
# SQL command to drop the table  
drop_table_sql = "DROP TABLE IF EXISTS city_stats;"  
  
# Execute the drop table command  
cursor.execute(drop_table_sql)  
  
# Close the cursor and connection  
cursor.close()  
connection.close()  

## Optional: With local DB 

In [None]:
# Define your database engine  
engine = create_engine('sqlite:///example.db')  # Example using SQLite  
  
# Define metadata  
metadata = MetaData()  
  
# Define the city_stats table  
city_stats_table = Table('city_stats_table', metadata,   
    Column('city_name', String, primary_key=True ),  
    Column('population', Integer),  
    Column('country', String),  
)  
  
# Create the table in the database  
metadata.create_all(engine)  
  
# Data to be inserted  
rows = [  
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},  
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},  
    {"city_name": "Chicago", "population": 2679000, "country": "United States"},  
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},  
]  
  
# Insert data  
for row in rows:  
    stmt = insert(city_stats_table).values(**row)  
    with engine.begin() as connection:  
        cursor = connection.execute(stmt)  

In [None]:
# Generate the CREATE TABLE statement for later usage during prompting
create_table_statement = str(CreateTable(city_stats_table).compile(engine))

# Print the CREATE TABLE statement
print(create_table_statement)


CREATE TABLE city_stats_table (
	city_name VARCHAR NOT NULL, 
	population INTEGER, 
	country VARCHAR, 
	PRIMARY KEY (city_name)
)




In [None]:
# view current table
stmt = select(
    city_stats_table.c.city_name,
    city_stats_table.c.population,
    city_stats_table.c.country,
).select_from(city_stats_table)

with engine.connect() as connection:
    results = connection.execute(stmt).fetchall()
    print(results)

[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]


In [None]:
with engine.connect() as con:
    rows = con.execute(text("SELECT city_name from city_stats_table"))
    for row in rows:
        print(row)

('Chicago',)
('Seoul',)
('Tokyo',)
('Toronto',)


In [None]:
print(create_table_statement)


CREATE TABLE city_stats (
	city_name VARCHAR NOT NULL, 
	population INTEGER, 
	country VARCHAR, 
	PRIMARY KEY (city_name)
)


