<a href="https://colab.research.google.com/github/redswimmer/natural-language-to-sql/blob/main/Natural_Language_to_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Natural Language to SQL Demo

This example demonstrates how to query our SQL database using natural language by leveraging the LLM OpenAI and [LlamaIndex](https://www.llamaindex.ai/).  It was adapted from [this example by Jerry Liu](https://docs.llamaindex.ai/en/stable/examples/index_structs/struct_indices/SQLIndexDemo.html) to add support for multiple tables and foreign keys.

### Demo Video

I recorded a [demo video here](https://www.loom.com/share/1aa2b1f582784127bcb49f432c867818?sid=bdf2e506-0be2-4a6e-87e3-37ef6d6435d4) if you don't feel like reading the code.

## Setup

Install dependencies

In [None]:
!pip install llama-index

In [138]:
import os
import openai

In [139]:
import os
from google.colab import userdata

os.environ["OPENAI_API_KEY"] = userdata.get('OPENAI_API_KEY')
openai.api_key = os.environ["OPENAI_API_KEY"]

In [140]:
import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.ERROR, force=True)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

In [141]:
from IPython.display import Markdown, display

## Create Database Schema

Create our database using SQLAlchemy

In [142]:
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
    ForeignKey,
)

In [143]:
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

One table will store population details about cities and the other will store the mayor for each city.

In [144]:
# create city stats SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_id", Integer, primary_key=True),
    Column("city_name", String(32)),
    Column("population", Integer),
    Column("country", String(32), nullable=False),
)

# create city mayor SQL table
table_name = "city_mayor"
city_mayor_table = Table(
    table_name,
    metadata_obj,
    Column("mayor_id", Integer, primary_key=True),
    Column("city_id", Integer, ForeignKey('city_stats.city_id')),
    Column("mayor_name", String(64), nullable=False),
)
metadata_obj.create_all(engine)

## Define SQL Database

In [145]:
from llama_index import SQLDatabase, ServiceContext
from llama_index.llms import OpenAI
from sqlalchemy import insert

In [146]:
llm = OpenAI(temperature=0.1, model="gpt-3.5-turbo")
service_context = ServiceContext.from_defaults(llm=llm)

In [147]:
sql_database = SQLDatabase(engine, include_tables=["city_stats", "city_mayor"])

Populate tables with test data.

In [148]:
rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {"city_name": "Chicago", "population": 2679000, "country": "United States"},
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

In [149]:
rows = [
    {"city_id": 1, "mayor_name": "Jim Bob"},
    {"city_id": 2, "mayor_name": "Mary Smith"},
    {"city_id": 3, "mayor_name": "Jimmy Johnson"},
    {"city_id": 4, "mayor_name": "Andrew Savala"},
]
for row in rows:
    stmt = insert(city_mayor_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

In [150]:
list(sql_database._all_tables)

['city_stats', 'city_mayor']

## Query Index

Perform test queries to make sure we can retrive our test data.

In [151]:
from sqlalchemy import text

In [152]:
with engine.connect() as con:
    rows = con.execute(text("SELECT * from city_stats"))
    for row in rows:
        print(row)

(1, 'Toronto', 2930000, 'Canada')
(2, 'Tokyo', 13960000, 'Japan')
(3, 'Chicago', 2679000, 'United States')
(4, 'Seoul', 9776000, 'South Korea')


In [153]:
with engine.connect() as con:
    rows = con.execute(text("SELECT * from city_mayor"))
    for row in rows:
        print(row)

(1, 1, 'Jim Bob')
(2, 2, 'Mary Smith')
(3, 3, 'Jimmy Johnson')
(4, 4, 'Andrew Savala')


## Part 1: Text-to-SQL Query Engine

Once we have constructed our SQL database, we can use the NLSQLTableQueryEngine to construct natural language queries that are synthesized into SQL queries.

Note that we need to specify the tables we want to use with this query engine. If we don’t the query engine will pull all the schema context, which could overflow the context window of the LLM.

In [154]:
from llama_index.indices.struct_store.sql_query import NLSQLTableQueryEngine

query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database,
    tables=["city_stats", "city_mayor"],
)

In [155]:
query_str = "Which city has the highest population?"
response = query_engine.query(query_str)
display(Markdown(f"<b>{response}</b>"))

<b>The city with the highest population is Tokyo, with a population of 13,960,000.</b>

In [156]:
query_str = "Which cities have a mayor who's name contains the letter A?  Include the city and mayor name"
response = query_engine.query(query_str)
display(Markdown(f"<b>{response}</b>"))

<b>The cities with mayors whose names contain the letter A are Tokyo, with mayor Mary Smith, and Seoul, with mayor Andrew Savala.</b>

This query engine should be used in any case where you can specify the tables you want to query over beforehand, or the total size of all the table schema plus the rest of the prompt fits your context window.

## Part 2: Query-Time Retrieval of Tables for Text-to-SQL

If we don’t know ahead of time which table we would like to use, and the total size of the table schema overflows your context window size, we should store the table schema in an index so that during query time we can retrieve the right schema.

The way we can do this is using the SQLTableNodeMapping object, which takes in a SQLDatabase and produces a Node object for each SQLTableSchema object passed into the ObjectIndex constructor.

In [157]:
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, force=True)
from llama_index.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine,
)
from llama_index.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index import VectorStoreIndex

# manually set context text
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="city_stats", context_str="This table gives information regarding the population and country of a given city.")),
    (SQLTableSchema(table_name="city_mayor", context_str="This table gives information regarding the mayor of each city.")),
]

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=1)
)

In [158]:
response = query_engine.query("Which city has the highest population?")
display(Markdown(f"<b>{response}</b>"))

<b>The city with the highest population is Tokyo.</b>

In [159]:
response = query_engine.query("Which cities have a mayor who's name contains the letter A.  Include the cities and the full mayor's name?")
display(Markdown(f"<b>{response}</b>"))

<b>The cities with mayors whose names contain the letter A are Seoul, with mayor Andrew Savala, and Tokyo, with mayor Mary Smith.</b>

Lets try combining results from both tables

In [160]:
response = query_engine.query("Which cities have a population greater than 3 million and mayor who's name contains the letter A.  Include the cities and the full mayor's name?")
display(Markdown(f"<b>{response}</b>"))

<b>The cities with a population greater than 3 million and a mayor whose name contains the letter A are Seoul, with mayor Andrew Savala, and Tokyo, with mayor Mary Smith.</b>