In [None]:
!pip install langchain predictionguard lancedb html2text

Collecting predictionguard
  Downloading predictionguard-2.5.0-py2.py3-none-any.whl.metadata (872 bytes)
Collecting lancedb
  Downloading lancedb-0.15.0-cp38-abi3-manylinux_2_28_x86_64.whl.metadata (4.8 kB)
Collecting html2text
  Downloading html2text-2024.2.26.tar.gz (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.5/56.5 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting deprecation (from lancedb)
  Downloading deprecation-2.1.0-py2.py3-none-any.whl.metadata (4.6 kB)
Collecting pylance==0.19.1 (from lancedb)
  Downloading pylance-0.19.1-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (7.4 kB)
Collecting overrides>=0.7 (from lancedb)
  Downloading overrides-7.7.0-py3-none-any.whl.metadata (5.8 kB)
Downloading predictionguard-2.5.0-py2.py3-none-any.whl (18 kB)
Downloading lancedb-0.15.0-cp38-abi3-manylinux_2_28_x86_64.whl (27.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m27.1/2

# Imports and authentication

In [None]:
import time
import os
import re
import urllib
import shutil

import html2text
import sqlite3
from langchain import PromptTemplate
import lancedb
from predictionguard import PredictionGuard
import pandas as pd
from getpass import getpass
import numpy as np

In [None]:
pg_access_token = getpass('Enter your Prediction Guard access api key: ')
os.environ['PREDICTIONGUARD_API_KEY'] = pg_access_token

Enter your Prediction Guard access api key: ··········


In [None]:
client = PredictionGuard()

# Create a sqlite database

We will create a local SQLite database for this example, but a similar approach could be used with any remote Postgres, MySQL, etc. database. We will load an example [movie rental database called Sakila](https://dev.mysql.com/doc/sakila/en/sakila-structure.html). Sakila models a database for a chain of video rental stores. It contains a vast amount of information about:

- movie titles
- actors, genres, etc.
- what stores have what films in inventory
- transactions and payments
- customers
- staff

![](https://raw.githubusercontent.com/bradleygrant/sakila-sqlite3/main/sakila.png)

In [None]:
# Pull the example database
! git clone https://github.com/bradleygrant/sakila-sqlite3.git

Cloning into 'sakila-sqlite3'...
remote: Enumerating objects: 18, done.[K
remote: Counting objects: 100% (18/18), done.[K
remote: Compressing objects: 100% (16/16), done.[K
remote: Total 18 (delta 4), reused 12 (delta 2), pack-reused 0 (from 0)[K
Receiving objects: 100% (18/18), 2.39 MiB | 9.06 MiB/s, done.
Resolving deltas: 100% (4/4), done.


In [None]:
# Establish connection to the SQLite database
db_path = 'sakila-sqlite3/sakila_master.db'
conn = sqlite3.connect(db_path)

In [None]:
# Execute a SQL squery passed in a string argument
def execute_sql_query(query):
  cursor = conn.cursor()
  cursor.execute(query)
  result = cursor.fetchall()
  cursor.close()
  return result

In [None]:
# Try querying the database
results = execute_sql_query("SELECT * FROM film LIMIT 3;")
results

[(1,
  'ACADEMY DINOSAUR',
  'A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies',
  '2006',
  1,
  None,
  6,
  0.99,
  86,
  20.99,
  'PG',
  'Deleted Scenes,Behind the Scenes',
  '2020-12-23 07:12:31'),
 (2,
  'ACE GOLDFINGER',
  'A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China',
  '2006',
  1,
  None,
  3,
  4.99,
  48,
  12.99,
  'G',
  'Trailers,Deleted Scenes',
  '2020-12-23 07:12:31'),
 (3,
  'ADAPTATION HOLES',
  'A Astounding Reflection of a Lumberjack And a Car who must Sink a Lumberjack in A Baloon Factory',
  '2006',
  1,
  None,
  7,
  2.99,
  50,
  18.99,
  'NC-17',
  'Trailers,Deleted Scenes',
  '2020-12-23 07:12:31')]

In [None]:
# Try querying the database
results = execute_sql_query("SELECT * FROM customer LIMIT 3;")
results

[(1,
  1,
  'MARY',
  'SMITH',
  'MARY.SMITH@sakilacustomer.org',
  5,
  '1',
  '2006-02-14 22:04:36.000',
  '2020-12-23 07:15:11'),
 (2,
  1,
  'PATRICIA',
  'JOHNSON',
  'PATRICIA.JOHNSON@sakilacustomer.org',
  6,
  '1',
  '2006-02-14 22:04:36.000',
  '2020-12-23 07:15:11'),
 (3,
  1,
  'LINDA',
  'WILLIAMS',
  'LINDA.WILLIAMS@sakilacustomer.org',
  7,
  '1',
  '2006-02-14 22:04:36.000',
  '2020-12-23 07:15:11')]

In [None]:
# Try querying the database
results = execute_sql_query('SELECT * FROM rental LIMIT 3;')
results

[(1,
  '2005-05-24 22:53:30.000',
  367,
  130,
  '2005-05-26 22:04:30.000',
  1,
  '2020-12-23 07:15:20'),
 (2,
  '2005-05-24 22:54:33.000',
  1525,
  459,
  '2005-05-28 19:40:33.000',
  1,
  '2020-12-23 07:15:20'),
 (3,
  '2005-05-24 23:03:39.000',
  1711,
  408,
  '2005-06-01 22:12:39.000',
  1,
  '2020-12-23 07:15:20')]

# Simple approach with static schema information

Generally, our approach to SQL generation involves asking the LLM to generate a relevant SQL query and injecting the schema information in the prompt for context. The problem in this case is that the schema information is quite long. In reality, production DBs might have 100's of tables, views, etc. All of this schema information addeded into the prompt creates issues with:

- Model context windows
- Model performance

As such, one "naive" thing we could try is generating a summary of the schema information that fits into the context window of the given model. This will only scale to a certain point, and it may introduce weirdness because of lack of relevant context. However, it might be enough for your use case.


## Prepare descriptive static schema information

In [None]:
schema_description = []
query = "SELECT name FROM sqlite_master WHERE type='table';"

# Assuming 'conn' is your SQLite connection object and has been defined earlier
cursor = conn.cursor()
cursor.execute(query)
tables = cursor.fetchall()

for (table,) in tables:

    # Use double quotes around the table name to avoid syntax error with reserved keywords
    cursor.execute(f'PRAGMA table_info("{table}")')
    columns = cursor.fetchall()
    column_descriptions = ", ".join([f"{col[1]}" for col in columns])
    schema_description.append(f"- {table}: includes {column_descriptions}")

cursor.close()
static_schema_description = "\n".join(schema_description)
print(static_schema_description)

- actor: includes actor_id, first_name, last_name, last_update
- country: includes country_id, country, last_update
- city: includes city_id, city, country_id, last_update
- address: includes address_id, address, address2, district, city_id, postal_code, phone, last_update
- language: includes language_id, name, last_update
- category: includes category_id, name, last_update
- customer: includes customer_id, store_id, first_name, last_name, email, address_id, active, create_date, last_update
- film: includes film_id, title, description, release_year, language_id, original_language_id, rental_duration, rental_rate, length, replacement_cost, rating, special_features, last_update
- film_actor: includes actor_id, film_id, last_update
- film_category: includes film_id, category_id, last_update
- film_text: includes film_id, title, description
- inventory: includes inventory_id, film_id, store_id, last_update
- staff: includes staff_id, first_name, last_name, address_id, picture, email, stor

## Define prompt Templates

We will define two prompt templates:

(1) **Text-to-SQL** - This prompt will be used (as you might guess) to general a SQL query based on the user input.

(2) **SQL results to natural language** - This prompt will take the raw results of a SQL query and create a natural language response for the user.

In [None]:
sql_template = """
Generate a SQL query to answer this question: "{question}"

DDL statements:
{schema_description}

The following SQL query best answers the question "{question}":
"""
sql_prompt = PromptTemplate(template=sql_template, input_variables=["question", "schema_description"])

qa_template = """You are a data analytics assistant who answers user questions.
To answer these questions you will need the data provided, which is a result of executed the given SQL query.
Give a short and crisp response that answers the answer.
Don't add any notes or any extra information after your response.

Question: {question}

SQL Query: {sql_query}

Data: {data}

Answer:
"""
qa_prompt = PromptTemplate(template=qa_template,input_variables=["question", "sql_query", "data"])

## Create some utilities to generate and refine the SQL query

In [None]:
# This function generates the SQL query
def generate_sql_query(question, injected_schema):
  prompt_filled = sql_prompt.format(question=question, schema_description=injected_schema)
  result = client.chat.completions.create(
      model="Hermes-3-Llama-3.1-8B",
      messages=[{"role": "user", "content": prompt_filled}],
      max_tokens=300,
      temperature=0.1
  )
  sql_query = result['choices'][0]['message']['content']
  return sql_query

# This will then allow us to clean up the generated text
def extract_and_refine_sql_query(sql_query):

  # Extract SQL query using a regular expression
  match = re.search(r"(SELECT.*?);", sql_query, re.DOTALL)
  if match:

      refined_query = match.group(1)

      # Check for and remove any text after a colon
      colon_index = refined_query.find(':')
      if colon_index != -1:
          refined_query = refined_query[:colon_index]

      # Ensure the query ends with a semicolon
      if not refined_query.endswith(';'):
          refined_query += ';'
      return refined_query

  else:
      return ""

# Finally, we have a convenience function to generate the final results.
def get_answer_from_sql(question, injected_schema):
    sql_query = generate_sql_query(question, injected_schema)
    sql_query = extract_and_refine_sql_query(sql_query)

    try:
        cursor = conn.cursor()
        cursor.execute(sql_query)
        result = cursor.fetchall()
        cursor.close()
        return result, sql_query

    except sqlite3.Error as e:
        print(f"Error executing SQL query: {e}")
        return "There was an error executing the SQL query.", sql_query

# Create a way to return a natural language answer

In [None]:
def get_answer(question, data, sql_query):

  prompt_filled = qa_prompt.format(question=question, data=data, sql_query=sql_query)

  # Respond to the user
  output = client.chat.completions.create(
      model="Hermes-3-Llama-3.1-8B",
      messages=[{"role": "user", "content": prompt_filled}],
      max_tokens=500,
      temperature=0.1
  )
  completion = output['choices'][0]['message']['content']

  return completion

## Try out the simple approach!

In [None]:
question = "What are the three most rented movies?"

In [None]:
print('Question:')
print('------------------------')
print(question)
print('')

context, sql_query = get_answer_from_sql(question, static_schema_description)
print('Generated SQL Query:')
print('------------------------')
print(sql_query)
print('')
print('SQL result:')
print('------------------------')
print(context)
print('')

# Convert context and answer to string if they are not already
answer = get_answer(question, context, sql_query)
context_str = ', '.join([str(item) for item in context]) if isinstance(context, list) else str(context)
answer_str = str(answer)
print('Generate NL answer:')
print('------------------------')
print(answer)

Question:
------------------------
What are the three most rented movies?

Generated SQL Query:
------------------------
SELECT f.title, COUNT(r.rental_id) AS rental_count
FROM film f
JOIN inventory i ON f.film_id = i.film_id
JOIN rental r ON i.inventory_id = r.inventory_id
GROUP BY f.title
ORDER BY rental_count DESC
LIMIT 3;

SQL result:
------------------------
[('BUCKET BROTHERHOOD', 34), ('ROCKETEER MOTHER', 33), ('SCALAWAG DUCK', 32)]

Generate NL answer:
------------------------
The three most rented movies are 'BUCKET BROTHERHOOD' with 34 rentals, 'ROCKETEER MOTHER' with 33 rentals, and 'SCALAWAG DUCK' with 32 rentals.


In [None]:
question = "How many rentals last longer than 3 days?"

In [None]:
print('Question:')
print('------------------------')
print(question)
print('')

context, sql_query = get_answer_from_sql(question, static_schema_description)
print('Generated SQL Query:')
print('------------------------')
print(sql_query)
print('')
print('SQL result:')
print('------------------------')
print(context)
print('')

# Convert context and answer to string if they are not already
answer = get_answer(question, context, sql_query)
context_str = ', '.join([str(item) for item in context]) if isinstance(context, list) else str(context)
answer_str = str(answer)
print('Generate NL answer:')
print('------------------------')
print(answer)

Question:
------------------------
How many rentals last longer than 3 days?

Error executing SQL query: no such column: day
Generated SQL Query:
------------------------
SELECT COUNT(*) AS rentals_longer_than_3_days
FROM rental
WHERE DATEDIFF(day, rental_date, return_date) > 3;

SQL result:
------------------------
There was an error executing the SQL query.

Generate NL answer:
------------------------
The number of rentals longer than 3 days cannot be determined due to an error executing the SQL query.


In [None]:
question = "What actor is featured in the movie that has the been rented most recently by the customer with the most rentals?"

In [None]:
print('Question:')
print('------------------------')
print(question)
print('')

context, sql_query = get_answer_from_sql(question, static_schema_description)
print('Generated SQL Query:')
print('------------------------')
print(sql_query)
print('')
print('SQL result:')
print('------------------------')
print(context)
print('')

# Convert context and answer to string if they are not already
answer = get_answer(question, context, sql_query)
context_str = ', '.join([str(item) for item in context]) if isinstance(context, list) else str(context)
answer_str = str(answer)
print('Generate NL answer:')
print('------------------------')
print(answer)

Question:
------------------------
What actor is featured in the movie that has the been rented most recently by the customer with the most rentals?

Generated SQL Query:
------------------------
SELECT a.actor_id, a.first_name, a.last_name
FROM actor a
JOIN film_actor fa ON a.actor_id = fa.actor_id
JOIN rental r ON fa.film_id = r.inventory_id
JOIN customer c ON r.customer_id = c.customer_id
WHERE r.rental_date = (
    SELECT MAX(r2.rental_date)
    FROM rental r2
    JOIN customer c2 ON r2.customer_id = c2.customer_id
    GROUP BY c2.customer_id
    ORDER BY COUNT(r2.rental_id) DESC
    LIMIT 1
)
GROUP BY a.actor_id, a.first_name, a.last_name
ORDER BY COUNT(fa.film_id) DESC
LIMIT 1;

SQL result:
------------------------
[]

Generate NL answer:
------------------------
The actor featured in the movie that has been rented most recently by the customer with the most rentals is not available in the provided data.


# More advanced retrieval approach with dynamic schema information

Assuming that your database fits one of the following scenarios:
- has many tables
- has tables with many columns
- includes fields with "unexpected" formats for values
- includes columns with non-semantically meaningful names
- etc.

We need to go beyond the simple, naive SQL generation method. We will now integrate a vector database to store schema information along with data dictionary descriptions of tables and columns. The column description will also include example field values for extra context.

We will retrieve the relevant information to answer a question on-the-fly and inject it into the prompt. We will also include "special instructions" in the prompt to deal with database quirks.

## Get schema info

In [None]:
# Thankfully a data dictionary exists already for all the tables and columns in this database.
# We will pull these pages off the Internet and use the information.
def get_table_info(table):

  # Let's get the html off of a website with the data dictionary.
  fp = urllib.request.urlopen("https://dev.mysql.com/doc/sakila/en/sakila-structure-tables-" + table + ".html")
  mybytes = fp.read()
  html = mybytes.decode("utf8")
  fp.close()

  # And convert it to text.
  h = html2text.HTML2Text()
  h.ignore_links = True
  text = h.handle(html)

  # Pull out table info
  table_info = text.split('####')[1]
  table_info = table_info.replace('\n\n', '\n').strip()

  # Get the column info
  column_descriptions = {}
  column_info = text.split('####')[2].split('HOME')[0]
  for line in column_info.split('\n'):
    if '*' in line:
      column_name = line.split('`')[1].strip()
      column_descriptions[column_name] = line.split(':')[-1].strip()

  return table_info, column_descriptions

In [None]:
# Format the information for all tables
table_descriptions = {}
for line in static_schema_description.split('\n'):
  if '-' in line:
    table_name = line.split(':')[0].split(' ')[-1]
    table_info, column_descriptions = get_table_info(table_name)
    table_descriptions[table_name] = {
        "table_description": table_info,
        "column_descriptions": column_descriptions
    }

In [None]:
table_descriptions

{'actor': {'table_description': '5.1.1 The actor Table\nThe `actor` table lists information for all actors.\nThe `actor` table is joined to the `film` table by means of the `film_actor`\ntable.',
  'column_descriptions': {'actor_id': 'A surrogate primary key used to uniquely identify each actor in the table.',
   'first_name': 'The actor first name.',
   'last_name': 'The actor last name.',
   'last_update': 'When the row was created or most recently updated.'}},
 'country': {'table_description': '5.1.5 The country Table\nThe `country` table contains a list of countries.\nThe `country` table is referred to by a foreign key in the `city` table.',
  'column_descriptions': {'country_id': 'A surrogate primary key used to uniquely identify each country in the table.',
   'country': 'The name of the country.',
   'last_update': 'When the row was created or most recently updated.'}},
 'city': {'table_description': '5.1.4 The city Table\nThe `city` table contains a list of cities.\nThe `city` 

In [None]:
# Create a query that will return some example values from each column
values_query_template = """SELECT
  {column_name},
  COUNT({column_name}) AS `value_occurrence`

FROM
  {my_table}

GROUP BY
  {column_name}

ORDER BY
  `value_occurrence` DESC

LIMIT 3;"""

values_query = PromptTemplate(template=values_query_template,
                              input_variables=["column_name", "my_table"])

In [None]:
# Add the example values to the column descriptions
for table in table_descriptions:
  for column in table_descriptions[table]['column_descriptions']:
    try:
      query = values_query.format(my_table=table, column_name=column)
      results = execute_sql_query(query)
      table_descriptions[table]['column_descriptions'][column] = table_descriptions[table]['column_descriptions'][column] + table_descriptions[table]['column_descriptions'][column] + ' Example values are ' + ', '.join([str(c[0]) for c in results]) + '.'
    except:
      pass

In [None]:
table_descriptions

{'actor': {'table_description': '5.1.1 The actor Table\nThe `actor` table lists information for all actors.\nThe `actor` table is joined to the `film` table by means of the `film_actor`\ntable.',
  'column_descriptions': {'actor_id': 'A surrogate primary key used to uniquely identify each actor in the table.A surrogate primary key used to uniquely identify each actor in the table. Example values are 1, 2, 3.',
   'first_name': 'The actor first name.The actor first name. Example values are PENELOPE, KENNETH, JULIA.',
   'last_name': 'The actor last name.The actor last name. Example values are KILMER, NOLTE, TEMPLE.',
   'last_update': 'When the row was created or most recently updated.When the row was created or most recently updated. Example values are 2020-12-23 07:12:30, 2020-12-23 07:12:29, 2020-12-23 07:12:31.'}},
 'country': {'table_description': '5.1.5 The country Table\nThe `country` table contains a list of countries.\nThe `country` table is referred to by a foreign key in the 

## Prepare the vector DB for retrieval

We will use LanceDB as our vector database, and query the database using a kind of heirarchical retrieval. That is, we will first match to the tables relevant to the query, and then we will subsequently pull in information about the columns relevant to the query. To this end, we will create two tables in the database (a "tables" table and a "columns" table).

In [None]:
# Grab the original schema
with open('sakila-sqlite3/source/sqlite-sakila-schema.sql') as f:
  original_schema = f.read()

In [None]:
shutil.rmtree('.lancedb', ignore_errors=True)

In [None]:
# Create information about the DB tables to push into the vector DB
table_data = []
for table in table_descriptions:
  table_data.append([
      table,
      table_descriptions[table]['table_description']
  ])

table_df = pd.DataFrame(table_data, columns=['table', 'text'])
table_df.head()

Unnamed: 0,table,text
0,actor,5.1.1 The actor Table\nThe `actor` table lists...
1,country,5.1.5 The country Table\nThe `country` table c...
2,city,5.1.4 The city Table\nThe `city` table contain...
3,address,5.1.2 The address Table\nThe `address` table c...
4,language,5.1.12 The language Table\nThe `language` tabl...


In [None]:
# Format information about the columns to put in the vector DB
column_data = []
for table in table_descriptions:
  for column in table_descriptions[table]['column_descriptions']:

    # Get the column data type
    create_statement = original_schema.split('CREATE TABLE ' + table)[1].split(';')[0]
    for line in create_statement.split('\n'):
      if column == line.strip().split(' ')[0].strip():
        data_type = ' '.join(line.strip().split(' ')[1:])

    column_data.append([
        column,
        table,
        data_type,
        table_descriptions[table]['column_descriptions'][column]
    ])

column_df = pd.DataFrame(column_data, columns=[
    "column",
    "table",
    "data_type",
    "text"
])
column_df.head()

Unnamed: 0,column,table,data_type,text
0,actor_id,actor,"numeric NOT NULL ,",A surrogate primary key used to uniquely ident...
1,first_name,actor,"VARCHAR(45) NOT NULL,",The actor first name.The actor first name. Exa...
2,last_name,actor,"VARCHAR(45) NOT NULL,",The actor last name.The actor last name. Examp...
3,last_update,actor,"TIMESTAMP NOT NULL,",When the row was created or most recently upda...
4,country_id,country,"SMALLINT NOT NULL,",A surrogate primary key used to uniquely ident...


In [None]:
# Function to embed the text in a DataFrame and store the vectors in a column named "vector"
def embed_text_column(df, text_column="text", batch_size=5):
    def embed_batch(batch):
        embeddings = []
        for i in range(0, len(batch), batch_size):
            chunk = batch[i:i+batch_size]
            response = client.embeddings.create(
                model="multilingual-e5-large-instruct",
                input=[text for text in chunk]
            )

            if "data" in response:
                chunk_embeddings = [item["embedding"] for item in response["data"] if "embedding" in item]
                embeddings.extend(chunk_embeddings)
            else:
                raise Exception(f"Error in embedding response: {response}")
        return [np.array(embedding) for embedding in embeddings]

    # Embed the documents in smaller batches
    embeddings = embed_batch(df[text_column].tolist())

    # Verify the number of embeddings matches the number of documents
    if len(embeddings) != len(df):
        raise ValueError("The number of embeddings does not match the number of documents")

    # Add embeddings to the dataframe as a single column named "vector"
    df["vector"] = embeddings

    return df

def embed(sentence):
    response = client.embeddings.create(
        model="multilingual-e5-large-instruct",
        input=sentence
    )
    return np.array(response["data"][0]["embedding"])

In [None]:
# LanceDB setup
uri = ".lancedb"
if os.path.exists(uri):
    shutil.rmtree(uri)
os.mkdir(uri)
db = lancedb.connect(uri)

In [None]:
table_vector_data=embed_text_column(table_df)
column_vector_data=embed_text_column(column_df)

In [None]:
# Ensure DataFrame structure
def ensure_correct_columns(df, required_columns):
    missing_columns = [col for col in required_columns if col not in df.columns]
    if missing_columns:
        raise KeyError(f"DataFrame is missing required columns: {missing_columns}")

# Example DataFrames
required_columns = ["column", "table", "data_type", "text"]

# Ensure the DataFrame has the correct columns
ensure_correct_columns(column_df, required_columns)

In [None]:
# Create the DB tables and add the records.
db.create_table("tables", data=table_vector_data)
db.create_table("columns", data=column_vector_data)

LanceTable(connection=LanceDBConnection(/content/.lancedb), name="columns")

In [None]:
# Let's try to match a query to one of our table.
table = db.open_table("tables")
message = "What actor is featured in the movie that has the been rented most recently by the customer with the most rentals?"
results = table.search(embed(message)).limit(5).to_df()
results.head()

  results = table.search(embed(message)).limit(5).to_df()


Unnamed: 0,table,text,vector,_distance
0,actor,5.1.1 The actor Table\nThe `actor` table lists...,"[0.06676504, 0.027252175, -0.024558537, -0.027...",0.367774
1,film_actor,5.1.8 The film_actor Table\nThe `film_actor` t...,"[0.059756048, 0.019669699, -0.05381601, -0.035...",0.375969
2,inventory,5.1.11 The inventory Table\nThe `inventory` ta...,"[0.0612023, 0.012456703, -0.018217133, -0.0070...",0.380201
3,customer,5.1.6 The customer Table\nThe `customer` table...,"[0.0472097, 0.039031174, -0.029149614, -0.0254...",0.391701
4,rental,5.1.14 The rental Table\nThe `rental` table co...,"[0.0377812, 0.037709307, -0.028794236, -0.0090...",0.392925


## Create dynamic prompt templates

In [None]:
sql_template = """You are a SQL expert and you only generate SQL queries which are executable. You provide no extra explanations.
You respond with a SQL query that answers the user question in the below instruction by querying a database with the following schema:

{schema_description}

--- DATEDIFF is not supported in this database. Do not use it.

### Instruction:
User question: \"{question}\"

### Response:
"""
sql_prompt = PromptTemplate(template=sql_template, input_variables=["question", "schema_description"])

In [None]:
# This formatted the matched fields into a SQL schema format, with comments
def format_fields(vals):
  out = ""
  for i, v in vals.iterrows():
    out += "-" + v['column'] + ": " + v['text'] + '\n'
  return out.strip()

def format_sql_create_statement(table_name, columns_df):
    """
    Format the column information into a SQL CREATE TABLE statement.

    :param table_name: The name of the table.
    :param columns_df: A DataFrame containing columns and their info for the table.
    :return: A formatted SQL CREATE TABLE statement as a string.
    """
    create_statement = f"CREATE TABLE {table_name} (\n"
    for _, row in columns_df.iterrows():
        column_line = f"  {row['column']} {row['data_type']}"
        if pd.notnull(row['text']):
            column_line += f" -- {row['text']}"
        column_line += ",\n"
        create_statement += column_line
    create_statement = create_statement.rstrip(',\n') + "\n);"
    return create_statement

In [None]:
# Here is the final function that dynamically generates our prompt
def fill_prompt(question):

  # Get the right tables from the vector DB
  table = db.open_table("tables")
  results = table.search(embed(question)).limit(4).to_pandas()
  results.sort_values(by="_distance", ascending=True)
  tables = results['table'].values.tolist()

  # Get the column context from the vector DB
  table = db.open_table("columns")
  results = table.search(embed(question)).limit(60).to_pandas()
  results = results[results['table'].isin(tables)]
  results.sort_values(by="_distance", ascending=True)

  # Format the column info
  table_info = ""
  for table_name in tables:
      table_columns_df = results[results['table'] == table_name].head(10)
      table_info += format_sql_create_statement(table_name, table_columns_df) + "\n\n"

  # Fill promptname,statefp,lat,lon
  filled_prompt = sql_prompt.format(
        schema_description=table_info.strip(),
        question=question,
    )

  return filled_prompt

In [None]:
print(fill_prompt('What actor is featured in the movie that has the been rented most recently by the customer with the most rentals?'))

You are a SQL expert and you only generate SQL queries which are executable. You provide no extra explanations.
You respond with a SQL query that answers the user question in the below instruction by querying a database with the following schema:

CREATE TABLE actor (
  last_name VARCHAR(45) NOT NULL, -- The actor last name.The actor last name. Example values are KILMER, NOLTE, TEMPLE.,
  first_name VARCHAR(45) NOT NULL, -- The actor first name.The actor first name. Example values are PENELOPE, KENNETH, JULIA.,
  last_update TIMESTAMP NOT NULL, -- When the row was created or most recently updated.When the row was created or most recently updated. Example values are 2020-12-23 07:12:30, 2020-12-23 07:12:29, 2020-12-23 07:12:31.,
  actor_id numeric NOT NULL , -- A surrogate primary key used to uniquely identify each actor in the table.A surrogate primary key used to uniquely identify each actor in the table. Example values are 1, 2, 3.
);

CREATE TABLE film_actor (
  actor_id INT NOT NUL

In [None]:
# Now we just need to redefine a couple of functions with slight modifications.

def generate_sql_query(question):

  prompt_filled = fill_prompt(question)

  result = client.chat.completions.create(
      model="Hermes-3-Llama-3.1-8B",
      messages=[{"role": "user", "content": prompt_filled}],
      max_tokens=500,
      temperature=0.1
  )
  sql_query = result["choices"][0]["message"]["content"]
  return sql_query


def get_answer_from_sql(question):
    sql_query = generate_sql_query(question)
    sql_query = extract_and_refine_sql_query(sql_query)

    try:
        cursor = conn.cursor()
        cursor.execute(sql_query)
        result = cursor.fetchall()
        cursor.close()
        return result, sql_query

    except sqlite3.Error as e:
        print(f"Error executing SQL query: {e}")
        return "There was an error executing the SQL query.", sql_query

## Try out the dynamic prompts

In [None]:
question = "How many film can be rented for longer than 3 days?"

In [None]:
print('Question:')
print('------------------------')
print(question)
print('')

context, sql_query = get_answer_from_sql(question)
print('Generated SQL Query:')
print('------------------------')
print(sql_query)
print('')
print('SQL result:')
print('------------------------')
print(context)
print('')

# Convert context and answer to string if they are not already
answer = get_answer(question, context, sql_query)
context_str = ', '.join([str(item) for item in context]) if isinstance(context, list) else str(context)
answer_str = str(answer)
print('Generate NL answer:')
print('------------------------')
print(answer)

Question:
------------------------
How many film can be rented for longer than 3 days?

Generated SQL Query:
------------------------
SELECT COUNT(*) 
FROM film 
WHERE rental_duration > 3;

SQL result:
------------------------
[(797,)]

Generate NL answer:
------------------------
797 films can be rented for longer than 3 days.
