# Query Pipeline over Pandas DataFrames

This is a simple example that builds a query pipeline that can perform structured operations over a Pandas DataFrame to satisfy a user query, using LLMs to infer the set of operations.

This can be treated as the "from-scratch" version of our `PandasQueryEngine`.

In [None]:
%poetry add llama-index==0.9.45.post1 arize-phoenix==2.2.1 pyvis

: 

In [1]:
from llama_index.core.query_pipeline import (
    QueryPipeline as QP,
    Link,
    InputComponent,
)
from llama_index.core.query_engine.pandas import PandasInstructionParser
from llama_index.llms.openai import OpenAI
from llama_index.core.prompts import PromptTemplate

In [2]:
import os
from dotenv import load_dotenv
load_dotenv()

True

## Download Data

Here we load the Titanic CSV dataset.

In [None]:
!wget 'https://raw.githubusercontent.com/jerryjliu/llama_index/main/docs/examples/data/csv/titanic_train.csv' -O 'titanic_train.csv'

: 

In [None]:
import pandas as pd

df = pd.read_csv("./titanic_train.csv")
df

: 

## Define Modules

Here we define the set of modules:
1. Pandas prompt to infer pandas instructions from user query
2. Pandas output parser to execute pandas instructions on dataframe, get back dataframe
3. Response synthesis prompt to synthesize a final response given the dataframe
4. LLM

The pandas output parser specifically is designed to safely execute Python code. It includes a lot of safety checks that may be annoying to write from scratch. This includes only importing from a set of approved modules (e.g. no modules that would alter the file system like `os`), and also making sure that no private/dunder methods are being called.

In [None]:
instruction_str = (
    "1. Convert the query to executable Python code using Pandas.\n"
    "2. The final line of code should be a Python expression that can be called with the `eval()` function.\n"
    "3. The code should represent a solution to the query.\n"
    "4. PRINT ONLY THE EXPRESSION.\n"
    "5. Do not quote the expression.\n"
)

pandas_prompt_str = (
    "You are working with a pandas dataframe in Python.\n"
    "The name of the dataframe is `df`.\n"
    "This is the result of `print(df.head())`:\n"
    "{df_str}\n\n"
    "Follow these instructions:\n"
    "{instruction_str}\n"
    "Query: {query_str}\n\n"
    "Expression:"
)
response_synthesis_prompt_str = (
    "Given an input question, synthesize a response from the query results.\n"
    "Query: {query_str}\n\n"
    "Pandas Instructions (optional):\n{pandas_instructions}\n\n"
    "Pandas Output: {pandas_output}\n\n"
    "Response: "
)

pandas_prompt = PromptTemplate(pandas_prompt_str).partial_format(
    instruction_str=instruction_str, df_str=df.head(5)
)
pandas_output_parser = PandasInstructionParser(df)
response_synthesis_prompt = PromptTemplate(response_synthesis_prompt_str)
llm = OpenAI(model="gpt-3.5-turbo")

: 

## Build Query Pipeline

Looks like this:
input query_str -> pandas_prompt -> llm1 -> pandas_output_parser -> response_synthesis_prompt -> llm2

Additional connections to response_synthesis_prompt: llm1 -> pandas_instructions, and pandas_output_parser -> pandas_output.

In [None]:
qp = QP(
    modules={
        "input": InputComponent(),
        "pandas_prompt": pandas_prompt,
        "llm1": llm,
        "pandas_output_parser": pandas_output_parser,
        "response_synthesis_prompt": response_synthesis_prompt,
        "llm2": llm,
    },
    verbose=True,
)
qp.add_chain(["input", "pandas_prompt", "llm1", "pandas_output_parser"])
qp.add_links(
    [
        Link("input", "response_synthesis_prompt", dest_key="query_str"),
        Link(
            "llm1", "response_synthesis_prompt", dest_key="pandas_instructions"
        ),
        Link(
            "pandas_output_parser",
            "response_synthesis_prompt",
            dest_key="pandas_output",
        ),
    ]
)
# add link from response synthesis prompt to llm2
qp.add_link("response_synthesis_prompt", "llm2")

: 

In [None]:
from pyvis.network import Network

net = Network(notebook=True, cdn_resources="in_line", directed=True)
net.from_nx(qp.dag)
net.show("text2sql_dag.html")

: 

## Run Query

In [None]:
response = qp.run(
    query_str="What is the correlation between survival and age?",
)

: 

In [None]:
print(response.message.content)

: 

# Query Pipeline for Advanced Text-to-SQL

In this guide we show you how to setup a text-to-SQL pipeline over your data with our [query pipeline](https://docs.llamaindex.ai/en/stable/module_guides/querying/pipeline/root.html) syntax.

This gives you flexibility to enhance text-to-SQL with additional techniques. We show these in the below sections:
1. **Query-Time Table Retrieval**: Dynamically retrieve relevant tables in the text-to-SQL prompt.
2. **Query-Time Sample Row retrieval**: Embed/Index each row, and dynamically retrieve example rows for each table in the text-to-SQL prompt.

Our out-of-the box pipelines include our `NLSQLTableQueryEngine` and `SQLTableRetrieverQueryEngine`. (if you want to check out our text-to-SQL guide using these modules, take a look [here](https://docs.llamaindex.ai/en/stable/examples/index_structs/struct_indices/SQLIndexDemo.html)). This guide implements an advanced version of those modules, giving you the utmost flexibility to apply this to your own setting.

## Load and Ingest Data


### Load Data
We use the [WikiTableQuestions dataset](https://ppasupat.github.io/WikiTableQuestions/) (Pasupat and Liang 2015) as our test dataset.

We go through all the csv's in one folder, store each in a sqlite database (we will then build an object index over each table schema).

In [9]:
!wget "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip" -O data.zip
!unzip -o data.zip

--2024-06-14 14:19:47--  https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/82109896/b9b6aeb6-f3c1-11e6-9167-57b997906244?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20240614%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240614T061947Z&X-Amz-Expires=300&X-Amz-Signature=678bb87c2bd8cc081e6adc61d2ded922bb6d8a24544ce1c33b697c6ff2a94730&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=82109896&response-content-disposition=attachment%3B%20filename%3DWikiTableQuestions-1.0.2-compact.zip&response-content-type=application%2Foctet-stream [following]
--2024-06-14 14:19:47--  https://objects.githubusercontent.com/github-production-release-asset-2e65b

In [3]:
import pandas as pd
from pathlib import Path

data_dir = Path("./WikiTableQuestions/csv/200-csv")
csv_files = sorted([f for f in data_dir.glob("*.csv")])
dfs = []
for csv_file in csv_files:
    print(f"processing file: {csv_file}")
    try:
        df = pd.read_csv(csv_file)
        dfs.append(df)
    except Exception as e:
        print(f"Error parsing {csv_file}: {str(e)}")

processing file: WikiTableQuestions/csv/200-csv/0.csv
processing file: WikiTableQuestions/csv/200-csv/1.csv
processing file: WikiTableQuestions/csv/200-csv/10.csv
processing file: WikiTableQuestions/csv/200-csv/11.csv
processing file: WikiTableQuestions/csv/200-csv/12.csv
processing file: WikiTableQuestions/csv/200-csv/14.csv
processing file: WikiTableQuestions/csv/200-csv/15.csv
Error parsing WikiTableQuestions/csv/200-csv/15.csv: Error tokenizing data. C error: Expected 4 fields in line 16, saw 5

processing file: WikiTableQuestions/csv/200-csv/17.csv
Error parsing WikiTableQuestions/csv/200-csv/17.csv: Error tokenizing data. C error: Expected 6 fields in line 5, saw 7

processing file: WikiTableQuestions/csv/200-csv/18.csv
processing file: WikiTableQuestions/csv/200-csv/20.csv
processing file: WikiTableQuestions/csv/200-csv/22.csv
processing file: WikiTableQuestions/csv/200-csv/24.csv
processing file: WikiTableQuestions/csv/200-csv/25.csv
processing file: WikiTableQuestions/csv/200-

### Extract Table Name and Summary from each Table

Here we use gpt-3.5 to extract a table name (with underscores) and summary from each table with our Pydantic program.

In [4]:
tableinfo_dir = "WikiTableQuestions_TableInfo"
!mkdir {tableinfo_dir}

mkdir: cannot create directory ‘WikiTableQuestions_TableInfo’: File exists


In [5]:
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.llms.openai import OpenAI


class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(
        ..., description="table name (must be underscores and NO spaces)"
    )
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )


prompt_str = """\
Give me a summary of the table with the following JSON format.

- The table name must be unique to the table and describe it while being concise.
- Do NOT output a generic table name (e.g. table, my_table).

Do NOT make the table name one of the following: {exclude_table_name_list}

Table:
{table_str}

Summary: """

program = LLMTextCompletionProgram.from_defaults(
    output_cls=TableInfo,
    llm=OpenAI(model="gpt-3.5-turbo", api_key=os.getenv("OPENAI_API_KEY")),
    prompt_template_str=prompt_str,
)

In [7]:
import json


def _get_tableinfo_with_index(idx: int) -> str | None:
    results_gen = Path(tableinfo_dir).glob(f"{idx}_*")
    results_list = list(results_gen)
    if len(results_list) == 0:
        return None
    elif len(results_list) == 1:
        path = results_list[0]
        res = str(TableInfo.parse_file(path))
        return res
    else:
        raise ValueError(
            f"More than one file matching index: {list(results_gen)}"
        )


table_names = set()
table_infos = []
for idx, df in enumerate(dfs):
    table_info = _get_tableinfo_with_index(idx)
    if table_info:
        table_infos.append(table_info)
    else:
        while True:
            df_str = df.head(10).to_csv()
            table_info = program(
                table_str=df_str,
                exclude_table_name_list=str(list(table_names)),
            )
            table_name = table_info.table_name
            print(f"Processed table: {table_name}")
            if table_name not in table_names:
                table_names.add(table_name)
                break
            else:
                # try again
                print(f"Table name {table_name} already exists, trying again.")
                pass

        out_file = f"{tableinfo_dir}/{idx}_{table_name}.json"
        json.dump(table_info.dict(), open(out_file, "w"))
    table_infos.append(table_info)

In [8]:
print(table_infos)

["table_name='Renaissance_Album_Chart_Positions' table_summary='Summary of chart positions for Renaissance albums in the UK, US, and NL from 1969 to 1981.'", "table_name='Renaissance_Album_Chart_Positions' table_summary='Summary of chart positions for Renaissance albums in the UK, US, and NL from 1969 to 1981.'", "table_name='Actress_Filmography' table_summary='List of films and roles for actress in various productions'", "table_name='Actress_Filmography' table_summary='List of films and roles for actress in various productions'", "table_name='Yearly_Deaths_and_Accidents' table_summary='Summary of yearly deaths and number of accidents from 2003 to 2012.'", "table_name='Yearly_Deaths_and_Accidents' table_summary='Summary of yearly deaths and number of accidents from 2003 to 2012.'", "table_name='Award_Nominations_and_Wins' table_summary='Table containing information on award nominations and wins for various categories and nominees.'", "table_name='Award_Nominations_and_Wins' table_summa

### Put Data in SQL Database

We use `sqlalchemy`, a popular SQL database toolkit, to load all the tables.

In [None]:
# put data into sqlite db
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
)
import re


# Function to create a sanitized column name
def sanitize_column_name(col_name):
    # Remove special characters and replace spaces with underscores
    return re.sub(r"\W+", "_", col_name)


# Function to create a table from a DataFrame using SQLAlchemy
def create_table_from_dataframe(
    df: pd.DataFrame, table_name: str, engine, metadata_obj
):
    # Sanitize column names
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    # Dynamically create columns based on DataFrame columns and data types
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]

    # Create a table with the defined columns
    table = Table(table_name, metadata_obj, *columns)

    # Create the table in the database
    metadata_obj.create_all(engine)

    # Insert data from DataFrame into the table
    with engine.connect() as conn:
        for _, row in df.iterrows():
            insert_stmt = table.insert().values(**row.to_dict())
            conn.execute(insert_stmt)
        conn.commit()


engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
for idx, df in enumerate(dfs):
    tableinfo = _get_tableinfo_with_index(idx)
    print(f"Creating table: {tableinfo.table_name}")
    create_table_from_dataframe(df, tableinfo.table_name, engine, metadata_obj)

: 

In [None]:
# setup Arize Phoenix for logging/observability
import phoenix as px
import llama_index

px.launch_app()
llama_index.set_global_handler("arize_phoenix")

: 

## Advanced Capability 1: Text-to-SQL with Query-Time Table Retrieval.

We now show you how to setup an e2e text-to-SQL with table retrieval.

### Define Modules

Here we define the core modules.
1. Object index + retriever to store table schemas
2. SQLDatabase object to connect to the above tables + SQLRetriever.
3. Text-to-SQL Prompt
4. Response synthesis Prompt
5. LLM

Object index, retriever, SQLDatabase

In [27]:
%pip install pymysql 

Collecting pymysql
  Downloading PyMySQL-1.1.1-py3-none-any.whl.metadata (4.4 kB)
Downloading PyMySQL-1.1.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.0/45.0 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: pymysql
Successfully installed pymysql-1.1.1
Note: you may need to restart the kernel to use updated packages.


In [11]:
import os
from llama_index.core import SQLDatabase
from pydantic import BaseModel, Field
from llmtext.llms.openai import OpenAILLM
import asyncio


class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(
        description="table name (must be underscores and NO spaces)"
    )
    table_summary: str = Field(
        description="short, concise summary/caption of the table in business perspective"
    )


async def arun(
    sql_database: SQLDatabase,
    api_key: str = os.getenv("OPENAI_API_KEY", ""),
    model: str = "gpt-3.5-turbo",
) -> list[TableInfo]:
    tables = sql_database.get_usable_table_names()

    llm = OpenAILLM(api_key=api_key, model=model)

    gather = []
    for table in tables:
        # retrieve table schema
        schema = sql_database.get_single_table_info(table_name=table)
        dialect = sql_database.dialect
        data = sql_database.run_sql(
            f"""SELECT *
FROM {table}
LIMIT 3;"""
        )
        prompt = f"""Let's think step by step.
Create a summary of the table 
{table}

Database is in dialect 
{dialect}

Here's the schema
{schema}

Here's the sample data
{data}
"""

        gather.append(llm.astructured_extraction(text=prompt, output_class=TableInfo))
    table_infos = await asyncio.gather(*gather)
    return table_infos


In [12]:
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import SQLDatabase, VectorStoreIndex
from sqlalchemy import create_engine

db_url = os.getenv("DATABASE_URL", "")

engine = create_engine(url=db_url, pool_recycle=3600, echo=True)

sql_database = SQLDatabase(
    engine,
    ignore_tables=["admin", "admin_block", "api_key", "refresh_token"],
)

table_node_mapping = SQLTableNodeMapping(sql_database)

db_tables = sql_database.get_usable_table_names()
print(db_tables)

table_infos = await arun(sql_database)
print(table_infos)

2024-06-14 18:33:26,635 INFO sqlalchemy.engine.Engine SELECT DATABASE()
2024-06-14 18:33:26,636 INFO sqlalchemy.engine.Engine [raw sql] {}
2024-06-14 18:33:26,856 INFO sqlalchemy.engine.Engine SELECT @@sql_mode
2024-06-14 18:33:26,857 INFO sqlalchemy.engine.Engine [raw sql] {}
2024-06-14 18:33:26,968 INFO sqlalchemy.engine.Engine SELECT @@lower_case_table_names
2024-06-14 18:33:26,968 INFO sqlalchemy.engine.Engine [raw sql] {}
2024-06-14 18:33:27,301 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-06-14 18:33:27,301 INFO sqlalchemy.engine.Engine SHOW FULL TABLES FROM `kepler`
2024-06-14 18:33:27,302 INFO sqlalchemy.engine.Engine [raw sql] {}
2024-06-14 18:33:27,414 INFO sqlalchemy.engine.Engine ROLLBACK
2024-06-14 18:33:27,635 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-06-14 18:33:27,635 INFO sqlalchemy.engine.Engine SHOW FULL TABLES FROM `kepler`
2024-06-14 18:33:27,636 INFO sqlalchemy.engine.Engine [raw sql] {}
2024-06-14 18:33:27,748 INFO sqlalchemy.engine.Engine SHOW C

In [13]:
table_schema_objs = [
    SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
    for t in table_infos
]  # add a SQLTableSchema for each table

print(table_schema_objs)
obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)

[SQLTableSchema(table_name='ca_change_name', context_str="Table 'ca_change_name' contains information about changes in company names including Japanese and English versions, stock codes, exchange information, disclosure IDs, and timestamps for creation and update."), SQLTableSchema(table_name='ca_delisting', context_str="Table 'ca_delisting' captures delisting information such as delisting date, stock details, company names, and exchange details."), SQLTableSchema(table_name='ca_finance', context_str='A table containing financial information including disclosure details and stock data such as stock code, name, dates, and exchange information.'), SQLTableSchema(table_name='ca_increase', context_str="Table 'ca_increase' contains information about increases in CA (Convertible Arbitrage) with columns for id, disclosure_ca_id, assignment_date, stock_code, name, kname, ename, exchg, kexchg, old, new, crash, created_at, and updated_at."), SQLTableSchema(table_name='ca_ipo', context_str='Table

In [15]:
print(obj_retriever.retrieve("stock info"))

[SQLTableSchema(table_name='stock_info', context_str='Table containing information about stocks including stock details, company information, and important dates related to the stock market.'), SQLTableSchema(table_name='finance_info', context_str="Table 'finance_info' contains financial information such as stock code, revenue, operating profit, gross profit, net profit, EPS, and other related data."), SQLTableSchema(table_name='dividend_info', context_str="Table 'dividend_info' contains information about dividends, including stock details, dividend amounts, and dates.")]


SQLRetriever + Table Parser

In [17]:
from llama_index.core.retrievers import SQLRetriever
from typing import List
from llama_index.core.query_pipeline import FnComponent

sql_retriever = SQLRetriever(sql_database)


def get_table_context_str(table_schema_objs: List[SQLTableSchema]):
    """Get table context string."""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        table_info = sql_database.get_single_table_info(
            table_schema_obj.table_name
        )
        if table_schema_obj.context_str:
            table_opt_context = " The table description is: "
            table_opt_context += table_schema_obj.context_str
            table_info += table_opt_context

        context_strs.append(table_info)
    return "\n\n".join(context_strs)


table_parser_component = FnComponent(fn=get_table_context_str)

Text-to-SQL Prompt + Output Parser

In [18]:
from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_pipeline import FnComponent
from llama_index.core.llms import ChatResponse


def parse_response_to_sql(res: ChatResponse) -> str:
    """Parse response to SQL."""
    if res.message is None:
        return ""
    if res.message.content is None:
        return ""
    response = str(res.message.content)
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        # TODO: move to removeprefix after Python 3.9+
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()


sql_parser_component = FnComponent(fn=parse_response_to_sql)

text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
    dialect=engine.dialect.name
)
print(text2sql_prompt.template)

Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use tables listed below.
{schema}

Question: {query_str}
SQLQuery: 


Response Synthesis Prompt

In [19]:
response_synthesis_prompt_str = (
    "Given an input question, synthesize a response from the query results.\n"
    "Query: {query_str}\n"
    "SQL: {sql_query}\n"
    "SQL Response: {context_str}\n"
    "Response: "
)
response_synthesis_prompt = PromptTemplate(
    response_synthesis_prompt_str,
)

In [20]:
llm = OpenAI(model="gpt-3.5-turbo")

### Define Query Pipeline

Now that the components are in place, let's define the query pipeline!

In [21]:
from llama_index.core.query_pipeline import (
    QueryPipeline as QP,
    Link,
    InputComponent,
    CustomQueryComponent,
)

qp = QP(
    modules={
        "input": InputComponent(),
        "table_retriever": obj_retriever,
        "table_output_parser": table_parser_component,
        "text2sql_prompt": text2sql_prompt,
        "text2sql_llm": llm,
        "sql_output_parser": sql_parser_component,
        "sql_retriever": sql_retriever,
        "response_synthesis_prompt": response_synthesis_prompt,
        "response_synthesis_llm": llm,
    },
    verbose=True,
)

In [22]:
qp.add_chain(["input", "table_retriever", "table_output_parser"])
qp.add_link("input", "text2sql_prompt", dest_key="query_str")
qp.add_link("table_output_parser", "text2sql_prompt", dest_key="schema")
qp.add_chain(
    ["text2sql_prompt", "text2sql_llm", "sql_output_parser", "sql_retriever"]
)
qp.add_link(
    "sql_output_parser", "response_synthesis_prompt", dest_key="sql_query"
)
qp.add_link(
    "sql_retriever", "response_synthesis_prompt", dest_key="context_str"
)
qp.add_link("input", "response_synthesis_prompt", dest_key="query_str")
qp.add_link("response_synthesis_prompt", "response_synthesis_llm")

### Visualize Query Pipeline

A really nice property of the query pipeline syntax is you can easily visualize it in a graph via networkx.

In [45]:
%pip install pyvis

Collecting pyvis
  Downloading pyvis-0.3.2-py3-none-any.whl.metadata (1.7 kB)
Downloading pyvis-0.3.2-py3-none-any.whl (756 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m756.0/756.0 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hInstalling collected packages: pyvis
Successfully installed pyvis-0.3.2
Note: you may need to restart the kernel to use updated packages.


In [46]:
from pyvis.network import Network

net = Network(notebook=True, cdn_resources="in_line", directed=True)
net.from_nx(qp.dag)
net.show("text2sql_dag.html")

text2sql_dag.html


### Run Some Queries!

Now we're ready to run some queries across this entire pipeline.

In [23]:
response = qp.run(
    query="how many stocks are there ?"
)
print(str(response))

[1;3;38;2;155;135;227m> Running module input with input: 
query: how many stocks are there ?

[0m[1;3;38;2;155;135;227m> Running module table_retriever with input: 
input: how many stocks are there ?

[0m[1;3;38;2;155;135;227m> Running module table_output_parser with input: 
table_schema_objs: [SQLTableSchema(table_name='theme_stocks', context_str="Table 'theme_stocks' contains information about stocks related to different themes. It has columns id (INTEGER), theme_info_id (INTEGER), stock_...

[0m[1;3;38;2;155;135;227m> Running module text2sql_prompt with input: 
query_str: how many stocks are there ?
schema: Table 'theme_stocks' has columns: id (INTEGER), theme_info_id (INTEGER), stock_code (VARCHAR(10)), related_degree (INTEGER), created_at (DATETIME), updated_at (DATETIME), deleted_at (DATETIME), and fo...

[0m[1;3;38;2;155;135;227m> Running module text2sql_llm with input: 
messages: Given an input question, first create a syntactically correct mysql query to run, then loo

In [50]:
response = qp.run(query="When was the most recent IPO ? And what stock ?")
print(str(response))

[1;3;38;2;155;135;227m> Running module input with input: 
query: When was the most recent IPO ? And what stock ?

[0m[1;3;38;2;155;135;227m> Running module table_retriever with input: 
input: When was the most recent IPO ? And what stock ?

[0m[1;3;38;2;155;135;227m> Running module table_output_parser with input: 
table_schema_objs: [SQLTableSchema(table_name='ca_ipo', context_str="Table 'ca_ipo' contains information about initial public offerings (IPOs) in Japan, including details such as ID, disclosure CA ID, exchange date, sto...

[0m[1;3;38;2;155;135;227m> Running module text2sql_prompt with input: 
query_str: When was the most recent IPO ? And what stock ?
schema: Table 'ca_ipo' has columns: id (INTEGER), disclosure_ca_id (INTEGER), exchange_date (DATE), stock_code (VARCHAR(10)), name (VARCHAR(100)), kname (VARCHAR(100)), ename (VARCHAR(100)), sexchg (VARCHAR(1...

[0m[1;3;38;2;155;135;227m> Running module text2sql_llm with input: 
messages: Given an input question, first

In [51]:
response = qp.run(query="list of delisted stocks in May 2024")
print(str(response))

[1;3;38;2;155;135;227m> Running module input with input: 
query: list of delisted stocks in May 2024

[0m[1;3;38;2;155;135;227m> Running module table_retriever with input: 
input: list of delisted stocks in May 2024

[0m[1;3;38;2;155;135;227m> Running module table_output_parser with input: 
table_schema_objs: [SQLTableSchema(table_name='ca_delisting', context_str='This table contains information about delisting events, including delisting date, stock details, company names in different languages, exchange ...

[0m[1;3;38;2;155;135;227m> Running module text2sql_prompt with input: 
query_str: list of delisted stocks in May 2024
schema: Table 'ca_delisting' has columns: id (INTEGER), disclosure_ca_id (INTEGER), delisting_date (DATE), stock_code (VARCHAR(10)), name (VARCHAR(100)), kname (VARCHAR(100)), ename (VARCHAR(100)), exchg (VAR...

[0m[1;3;38;2;155;135;227m> Running module text2sql_llm with input: 
messages: Given an input question, first create a syntactically correct mysq

In [52]:
response = qp.run(query="technical analysis of Toyota motor")
print(str(response))

[1;3;38;2;155;135;227m> Running module input with input: 
query: technical analysis of Toyota motor

[0m[1;3;38;2;155;135;227m> Running module table_retriever with input: 
input: technical analysis of Toyota motor

[0m[1;3;38;2;155;135;227m> Running module table_output_parser with input: 
table_schema_objs: [SQLTableSchema(table_name='jp_etf_info', context_str="Table 'jp_etf_info' contains information about Japanese ETFs, including stock codes, base dates, Japan indexes, ETF management company codes, rew...

[0m[1;3;38;2;155;135;227m> Running module text2sql_prompt with input: 
query_str: technical analysis of Toyota motor
schema: Table 'jp_etf_info' has columns: id (INTEGER), stock_code (CHAR(5)), base_date (DATE), japan_index (VARCHAR(150)), etf_manage_company_code (CHAR(5)), reward (DECIMAL(5, 4)), long_invest (CHAR(1)), mar...

[0m[1;3;38;2;155;135;227m> Running module text2sql_llm with input: 
messages: Given an input question, first create a syntactically correct mysql q

## 2. Advanced Capability 2: Text-to-SQL with Query-Time Row Retrieval (along with Table Retrieval)

One problem in the previous example is that if the user asks a query that asks for "The Notorious BIG" but the artist is stored as "The Notorious B.I.G", then the generated SELECT statement will likely not return any matches.

We can alleviate this problem by fetching a small number of example rows per table. A naive option would be to just take the first k rows. Instead, we embed, index, and retrieve k relevant rows given the user query to give the text-to-SQL LLM the most contextually relevant information for SQL generation.

We now extend our query pipeline.

In [53]:
from llama_index.core.query_pipeline import QueryPipeline as QP
from llama_index.core.service_context import ServiceContext

qp = QP(verbose=True)
# NOTE: service context will be deprecated in v0.10 (though will still be backwards compatible)
service_context = ServiceContext.from_defaults(callback_manager=qp.callback_manager)

  service_context = ServiceContext.from_defaults(callback_manager=qp.callback_manager)


### Index Each Table

We embed/index the rows of each table, resulting in one index per table.

In [55]:
from llama_index.core import VectorStoreIndex, load_index_from_storage
from sqlalchemy import text
from llama_index.core.schema import TextNode
from llama_index.core.storage import StorageContext
import os
from pathlib import Path
from typing import Dict


def index_all_tables(
    sql_database: SQLDatabase, table_index_dir: str = "table_index_dir"
) -> Dict[str, VectorStoreIndex]:
    """Index all tables."""
    if not Path(table_index_dir).exists():
        os.makedirs(table_index_dir)

    vector_index_dict = {}
    engine = sql_database.engine
    for table_name in sql_database.get_usable_table_names():
        print(f"Indexing rows in table: {table_name}")
        if not os.path.exists(f"{table_index_dir}/{table_name}"):
            # get all rows from table
            with engine.connect() as conn:
                cursor = conn.execute(text(f'SELECT * FROM {table_name}'))
                result = cursor.fetchall()
                row_tups = []
                for row in result:
                    row_tups.append(tuple(row))

            # index each row, put into vector store index
            nodes = [TextNode(text=str(t)) for t in row_tups]

            # put into vector store index (use OpenAIEmbeddings by default)
            index = VectorStoreIndex(nodes, service_context=service_context)

            # save index
            index.set_index_id("vector_index")
            index.storage_context.persist(f"{table_index_dir}/{table_name}")
        else:
            # rebuild storage context
            storage_context = StorageContext.from_defaults(
                persist_dir=f"{table_index_dir}/{table_name}"
            )
            # load index
            index = load_index_from_storage(
                storage_context, index_id="vector_index", service_context=service_context
            )
        vector_index_dict[table_name] = index

    return vector_index_dict


vector_index_dict = index_all_tables(sql_database)

Indexing rows in table: ca_change_name
2024-06-14 15:18:55,474 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-06-14 15:18:55,475 INFO sqlalchemy.engine.Engine SELECT * FROM ca_change_name
2024-06-14 15:18:55,475 INFO sqlalchemy.engine.Engine [generated in 0.00116s] {}
2024-06-14 15:18:56,172 INFO sqlalchemy.engine.Engine ROLLBACK
Indexing rows in table: ca_delisting
2024-06-14 15:19:14,563 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-06-14 15:19:14,564 INFO sqlalchemy.engine.Engine SELECT * FROM ca_delisting
2024-06-14 15:19:14,564 INFO sqlalchemy.engine.Engine [generated in 0.00084s] {}
2024-06-14 15:19:15,154 INFO sqlalchemy.engine.Engine ROLLBACK
Indexing rows in table: ca_finance
2024-06-14 15:19:30,366 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-06-14 15:19:30,366 INFO sqlalchemy.engine.Engine SELECT * FROM ca_finance
2024-06-14 15:19:30,367 INFO sqlalchemy.engine.Engine [generated in 0.00096s] {}
2024-06-14 15:19:32,508 INFO sqlalchemy.engine.Engine ROLLBACK
I

In [None]:
test_retriever = vector_index_dict["Grammy_Award_Nominations_and_Wins"].as_retriever(
    similarity_top_k=10
)
nodes = test_retriever.retrieve("1")
for node in nodes:
  print(node.get_text())

: 

### Define Expanded Table Parser Component

We expand the capability of our `table_parser_component` to not only return the relevant table schemas, but also return relevant rows per table schema.

It now takes in both `table_schema_objs` (output of table retriever), but also the original `query_str` which will then be used for vector retrieval of relevant rows.

In [None]:
from llama_index.retrievers import SQLRetriever
from typing import List
from llama_index.query_pipeline import FnComponent

sql_retriever = SQLRetriever(sql_database)


def get_table_context_and_rows_str(
    query_str: str, table_schema_objs: List[SQLTableSchema]
):
    """Get table context string."""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        # first append table info + additional context
        table_info = sql_database.get_single_table_info(
            table_schema_obj.table_name
        )
        if table_schema_obj.context_str:
            table_opt_context = " The table description is: "
            table_opt_context += table_schema_obj.context_str
            table_info += table_opt_context

        # also lookup vector index to return relevant table rows
        vector_retriever = vector_index_dict[
            table_schema_obj.table_name
        ].as_retriever(similarity_top_k=2)
        relevant_nodes = vector_retriever.retrieve(query_str)
        if len(relevant_nodes) > 0:
            table_row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
            for node in relevant_nodes:
                table_row_context += str(node.get_content()) + "\n"
            table_info += table_row_context

        context_strs.append(table_info)
    return "\n\n".join(context_strs)


table_parser_component = FnComponent(fn=get_table_context_and_rows_str)

: 

### Define Expanded Query Pipeline

This looks similar to the query pipeline in section 1, but with an upgraded table_parser_component.

In [None]:
from llama_index.query_pipeline import (
    QueryPipeline as QP,
    Link,
    InputComponent,
    CustomQueryComponent,
)

qp.add_modules({
    "input": InputComponent(),
    "table_retriever": obj_retriever,
    "table_output_parser": table_parser_component,
    "text2sql_prompt": text2sql_prompt,
    "text2sql_llm": llm,
    "sql_output_parser": sql_parser_component,
    "sql_retriever": sql_retriever,
    "response_synthesis_prompt": response_synthesis_prompt,
    "response_synthesis_llm": llm,
})

: 

In [None]:
qp.add_link("input", "table_retriever")
qp.add_link("input", "table_output_parser", dest_key="query_str")
qp.add_link(
    "table_retriever", "table_output_parser", dest_key="table_schema_objs"
)
qp.add_link("input", "text2sql_prompt", dest_key="query_str")
qp.add_link("table_output_parser", "text2sql_prompt", dest_key="schema")
qp.add_chain(
    ["text2sql_prompt", "text2sql_llm", "sql_output_parser", "sql_retriever"]
)
qp.add_link(
    "sql_output_parser", "response_synthesis_prompt", dest_key="sql_query"
)
qp.add_link(
    "sql_retriever", "response_synthesis_prompt", dest_key="context_str"
)
qp.add_link("input", "response_synthesis_prompt", dest_key="query_str")
qp.add_link("response_synthesis_prompt", "response_synthesis_llm")

: 

In [None]:
from pyvis.network import Network

net = Network(notebook=True, cdn_resources="in_line", directed=True)
net.from_nx(qp.dag)
net.show("text2sql_dag.html")

: 

### Run Some Queries

We can now ask about relevant entries even if it doesn't exactly match the entry in the database.

In [None]:
response = qp.run(
    query="list all the table names"
)
print(str(response))

: 

: 

In [None]:
from sqlagent.agent import SQLAgent

agent = SQLAgent(
    db_url=os.getenv("DATABASE_URL", ""),
    api_key=os.getenv("OPENAI_API_KEY", ""),
    object_index_dir="./object_index",
    model="gpt-3.5-turbo",
)