<a href="https://colab.research.google.com/github/sudarshan-koirala/youtube-stuffs/blob/main/llamaindex/text_to_sql.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Llama-Index Text-To-SQL

In [None]:
%%capture
!pip install duckdb duckdb-engine llama-index

In [None]:
from llama_index import SQLDatabase, SimpleDirectoryReader, WikipediaReader, Document
from llama_index.indices.struct_store import (
    NLSQLTableQueryEngine,
    SQLTableRetrieverQueryEngine,
)

In [None]:
from IPython.display import Markdown, display

## Basic Text-to-SQL with `NLSQLTableQueryEngine`


In [None]:
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
    column,
)

In [None]:
# creates a SQLAlchemy engine object that connects to an in-memory DuckDB database.
engine = create_engine("duckdb:///:memory:") # https://duckdb.org/
metadata_obj = MetaData()

In [None]:
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)

metadata_obj.create_all(engine)

In [None]:
# print tables
metadata_obj.tables.keys()

We introduce some test data into the `city_stats` table

In [None]:
from sqlalchemy import insert

rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {"city_name": "Chicago", "population": 2679000, "country": "United States"},
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

In [None]:
with engine.connect() as connection:
    cursor = connection.exec_driver_sql("SELECT * FROM city_stats")
    print(cursor.fetchall())

### Create SQLDatabase Object

In [None]:
from llama_index import SQLDatabase

In [None]:
sql_database = SQLDatabase(engine, include_tables=["city_stats"])

### Query Index
- An Index is a data structure that allows us to quickly retrieve relevant context for a user query
- We are going to use the `NLSQLTableQueryEngine` as an query engine and run queries against it.

#### Using OpenAI model

In [None]:
import os
# https://platform.openai.com/account/api-keys
os.environ["OPENAI_API_KEY"] = "OPENAI_API_KEY"

In [None]:
#define LLM
from llama_index import ServiceContext, set_global_service_context
from llama_index.llms import OpenAI

llm = OpenAI(model="gpt-3.5-turbo")

# configure service context
service_context = ServiceContext.from_defaults(llm=llm)

In [None]:
#query_engine_openai = NLSQLTableQueryEngine(sql_database)
query_engine_openai = NLSQLTableQueryEngine(sql_database, service_context=service_context)

In [None]:
response = query_engine_openai.query("Which city has the highest population?")

In [None]:
response.response

In [None]:
response.metadata

In [None]:
response_with_population = query_engine_openai.query("Which city has the highest population. Also provide the population?")

In [None]:
response_with_population.response

## Advanced Text-to-SQL with `SQLTableRetrieverQueryEngine`

- Let's assume that you have a large number of tables in your database, and putting all the table schemas into the prompt may overflow the text-to-SQL prompt.

- We first index the schemas with our ObjectIndex, and then use our SQLTableRetrieverQueryEngine abstraction on top.

In [None]:
# creates a SQLAlchemy engine object that connects to an in-memory DuckDB database.
engine = create_engine("duckdb:///:memory:")
metadata_obj = MetaData()

In [None]:
# create city_stats SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)

all_table_names = ["city_stats"]

# create a ton of dummy tables
n = 100
for i in range(n):
    tmp_table_name = f"tmp_table_{i}"
    tmp_table = Table(
        tmp_table_name,
        metadata_obj,
        Column(f"tmp_field_{i}_1", String(16), primary_key=True),
        Column(f"tmp_field_{i}_2", Integer),
        Column(f"tmp_field_{i}_3", String(16), nullable=False),
    )
    all_table_names.append(f"tmp_table_{i}")

metadata_obj.create_all(engine)

In [None]:
all_table_names

In [None]:
with engine.connect() as connection:
    cursor = connection.exec_driver_sql("SELECT * FROM city_stats")
    print(cursor.fetchall())

In [None]:
# insert dummy data
from sqlalchemy import insert

rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {"city_name": "Chicago", "population": 2679000, "country": "United States"},
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

In [None]:
with engine.connect() as connection:
    cursor = connection.exec_driver_sql("SELECT * FROM city_stats")
    print(cursor.fetchall())

In [None]:
with engine.connect() as connection:
    cursor = connection.exec_driver_sql("SELECT * FROM tmp_table_99")
    print(cursor.fetchall())

In [None]:
sql_database = SQLDatabase(engine, include_tables=["city_stats"])

### Construct Object Index

In [None]:
from llama_index.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
from llama_index import VectorStoreIndex

In [None]:
table_node_mapping = SQLTableNodeMapping(sql_database)

table_schema_objs = []
for table_name in all_table_names:
    table_schema_objs.append(SQLTableSchema(table_name=table_name))

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)

In [None]:
table_schema_objs

### Query Index with `SQLTableRetrieverQueryEngine`


In [None]:
query_engine = SQLTableRetrieverQueryEngine(
    sql_database,
    obj_index.as_retriever(similarity_top_k=1),
    #service_context=service_context
)

In [None]:
response = query_engine.query("Which city has the highest population?")

In [None]:
response

In [None]:
response.response

In [None]:
response.metadata