In [1]:
import psycopg2
from psycopg2 import sql
import os

# PostgreSQL URL
os.environ["DATABASE_URL"] = "postgresql://postgres:password@localhost:5432"
postgres_url = os.environ.get("DATABASE_URL")
db_name = "rasa_prod"

# # Connect to PostgreSQL and create the new database
# conn = psycopg2.connect(postgres_url)
# conn.autocommit = True
# cursor = conn.cursor()

# # Create a new database
# cursor.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(db_name)))

# # Close the initial connection
# cursor.close()
# conn.close()

# # Connect to the newly created database
# conn = psycopg2.connect(f"{postgres_url}/{db_name}")
# cursor = conn.cursor()

# # Create the CARD_PROD table
# cursor.execute("""
#     CREATE TABLE CARD_PROD (
#         Card_Prod_ID VARCHAR(3) PRIMARY KEY,
#         Cust_Face_Prod_NM VARCHAR(45)
#     )
# """)

# # Create the CARD_PROD_FETR table
# cursor.execute("""
#     CREATE TABLE CARD_PROD_FETR (
#         Card_Prod_ID VARCHAR(3),
#         Card_Prod_FETR_CD VARCHAR(40),
#         Card_Prod_FETR_Type VARCHAR(50),
#         Card_Prod_FETR_Desc VARCHAR(600),
#         FOREIGN KEY (Card_Prod_ID) REFERENCES CARD_PROD(Card_Prod_ID)
#     )
# """)

# # Insert data into CARD_PROD table
# card_prod_data = [
#     ('001', 'Disney'),
#     ('002', 'Freedom'),
#     ('003', 'Sapphire')
# ]

# cursor.executemany("""
#     INSERT INTO CARD_PROD (Card_Prod_ID, Cust_Face_Prod_NM)
#     VALUES (%s, %s)
# """, card_prod_data)

# # Insert data into CARD_PROD_FETR table
# card_prod_fetr_data = [
#     ('001', 'Annual_Fee', 'Optional_Feature', 'Annual Fee Charged on this card is 25 USD annually'),
#     ('001', 'Cash_Back', 'Mandatory_Feature', 'Cash back of 2% on grocery purchases,5% on Retail'),
#     ('001', 'Purchase_Protection', 'Complimentary_Benefit', 'Purchase protection for purchases above 500USD'),
#     ('001', 'BuyNowPayLater', 'Complimentary_Benefit', 'Payment Plan for any purchase above 100USD'),
#     ('001', 'ApplyByPhone', 'Optional_Feature', 'Card Onboarding and activation by phone'),
#     ('002', 'Annual_Fee', 'Optional_Feature', 'No Annual Fee'),
#     ('002', 'Cash_Back', 'Mandatory_Feature', 'Cash back of 2% on grocery purchases,5% on Retail'),
#     ('003', 'Annual_Fee', 'Optional_Feature', 'Annual Fee Charged on this card is 625 USD annually'),
#     ('003', 'Cash_Back', 'Mandatory_Feature', 'Cash back of 5% on grocery purchases,5% on Retail,10% on Airline Ticket Purchase'),
#     ('003', 'Purchase_Protection', 'Complimentary_Benefit', 'Purchase protection for purchases above 500USD'),
#     ('003', 'BuyNowPayLater', 'Complimentary_Benefit', 'Payment Plan for any purchase above 100USD'),
#     ('003', 'ApplyByPhone', 'Optional_Feature', 'Card Onboarding and activation by phone'),
#     ('003', 'AirlineMile', 'Complimentary_Benefit', 'Statement Point to be converted to Airline Miles'),
#     ('003', 'StatementCredit', 'Complimentary_Benefit', 'TSA pre-check credit per year upto 100'),
#     ('003', 'Travel_Lounge', 'Complimentary_Benefit', 'Free Access to Lounges across the globe'),
#     ('003', 'PayByPhone', 'Optional_Feature', 'Card Payment by phone')
# ]

# cursor.executemany("""
#     INSERT INTO CARD_PROD_FETR (Card_Prod_ID, Card_Prod_FETR_CD, Card_Prod_FETR_Type, Card_Prod_FETR_Desc)
#     VALUES (%s, %s, %s, %s)
# """, card_prod_fetr_data)

# # Commit the transactions
# conn.commit()

# # Close the connection
# cursor.close()
# conn.close()


In [1]:
import os
from sqlalchemy import create_engine
from llama_index.llms.openai import OpenAI
from llama_index.core import SQLDatabase, ServiceContext
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
from llama_index.core.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine,
)
from llama_index.core import VectorStoreIndex
import logging

from llama_index.core.output_parsers import LangchainOutputParser
from langchain_core.output_parsers import JsonOutputParser
from dotenv import load_dotenv

load_dotenv()

# PostgreSQL URL
os.environ["DATABASE_URL"] = "postgresql://postgres:password@localhost:5432"
postgres_url = os.environ.get("DATABASE_URL")
db_name = "rasa_prod"

output_parser = LangchainOutputParser(JsonOutputParser())

# os.environ["OPENAI_API_KEY"] = ""

engine = create_engine(f"{postgres_url}/{db_name}")

# Choose LLM and configure ServiceContext
llm = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), model="gpt-4o-mini", output_parser=output_parser)
from llama_index.llms.ollama import Ollama

# llm = Ollama(model="llama2:7b-chat", request_timeout=60.0)

service_context = ServiceContext.from_defaults(llm=llm)#, embed_model="local")

# Define the tables and create SQLDatabase object
tables = [
    {
        "table_name": "card_prod", 
        "context": "List of card products, contains product ID and customer-facing product name."
    },
    {
        "table_name": "card_prod_fetr", 
        "context": "List of product features associated with card products from(card_prod), contains product ID, feature code, feature type, and feature description."
    }
]


sql_database = SQLDatabase(
    engine, include_tables=[table["table_name"] for table in tables]
)

# Create table node mapping and object index
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    SQLTableSchema(table_name=table["table_name"], context_str=table["context"])
    for table in tables
]

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)


  service_context = ServiceContext.from_defaults(llm=llm, embed_model="local")
  from .autonotebook import tqdm as notebook_tqdm


In [None]:
obj_retriever = obj_index.as_retriever(similarity_top_k=10)

from llama_index.core.retrievers import SQLRetriever
from typing import List
from llama_index.core.query_pipeline import FnComponent

sql_retriever = SQLRetriever(sql_database)

def get_table_context_str(table_schema_objs: List[SQLTableSchema]):
    """Get table context string."""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        table_info = sql_database.get_single_table_info(
            table_schema_obj.table_name
        )
        if table_schema_obj.context_str:
            table_opt_context = " The table description is: "
            table_opt_context += table_schema_obj.context_str
            table_info += table_opt_context

        context_strs.append(table_info)
    return "\n\n".join(context_strs)


table_parser_component = FnComponent(fn=get_table_context_str)

In [None]:
from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from llama_index.core import PromptTemplate
from llama_index.core.query_pipeline import FnComponent
from llama_index.core.llms import ChatResponse


def parse_response_to_sql(response: ChatResponse) -> str:
    """Parse response to SQL."""
    response = response.message.content
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        # TODO: move to removeprefix after Python 3.9+
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()


sql_parser_component = FnComponent(fn=parse_response_to_sql)

text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
    dialect=engine.dialect.name
)
# print(text2sql_prompt.template)

In [None]:
response_synthesis_prompt_str = (
    "Given an input question, synthesize a response from the query results.\n"
    "Query: {query_str}\n"
    "SQL: {sql_query}\n"
    "SQL Response: {context_str}\n"
    "Response: "
)
response_synthesis_prompt = PromptTemplate(
    response_synthesis_prompt_str,
)

In [None]:
# from llama_index.core.query_pipeline import (
#     QueryPipeline as QP,
#     Link,
#     InputComponent,
#     CustomQueryComponent,
# )

# qp = QP(
#     modules={
#         "input": InputComponent(),
#         "table_retriever": obj_retriever,
#         "table_output_parser": table_parser_component,
#         "text2sql_prompt": text2sql_prompt,
#         "text2sql_llm": llm,
#         "sql_output_parser": sql_parser_component,
#         "sql_retriever": sql_retriever,
#         "response_synthesis_prompt": response_synthesis_prompt,
#         "response_synthesis_llm": llm,
#     },
#     verbose=True,
# )

# qp.add_chain(["input", "table_retriever", "table_output_parser"])
# qp.add_link("input", "text2sql_prompt", dest_key="query_str")
# qp.add_link("table_output_parser", "text2sql_prompt", dest_key="schema")
# qp.add_chain(
#     ["text2sql_prompt", "text2sql_llm", "sql_output_parser", "sql_retriever"]
# )
# qp.add_link(
#     "sql_output_parser", "response_synthesis_prompt", dest_key="sql_query"
# )
# qp.add_link(
#     "sql_retriever", "response_synthesis_prompt", dest_key="context_str"
# )
# qp.add_link("input", "response_synthesis_prompt", dest_key="query_str")
# qp.add_link("response_synthesis_prompt", "response_synthesis_llm")

In [None]:
# response = qp.run(
#     query="look for the word cashback or cash or discount"
# )
# print(str(response))

In [None]:
from llama_index.core import VectorStoreIndex, load_index_from_storage
from sqlalchemy import text
from llama_index.core.schema import TextNode
from llama_index.core import StorageContext
import os
from pathlib import Path
from typing import Dict


def index_all_tables(
    sql_database: SQLDatabase, table_index_dir: str = "table_index_dir"
) -> Dict[str, VectorStoreIndex]:
    """Index all tables."""
    if not Path(table_index_dir).exists():
        os.makedirs(table_index_dir)

    vector_index_dict = {}
    engine = sql_database.engine
    for table_name in sql_database.get_usable_table_names():
        print(f"Indexing rows in table: {table_name}")
        if not os.path.exists(f"{table_index_dir}/{table_name}"):
            # get all rows from table
            with engine.connect() as conn:
                cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
                result = cursor.fetchall()
                row_tups = []
                for row in result:
                    row_tups.append(tuple(row))

            # index each row, put into vector store index
            nodes = [TextNode(text=str(t)) for t in row_tups]

            # put into vector store index (use OpenAIEmbeddings by default)
            index = VectorStoreIndex(nodes)

            # save index
            index.set_index_id("vector_index")
            index.storage_context.persist(f"{table_index_dir}/{table_name}")
        else:
            # rebuild storage context
            storage_context = StorageContext.from_defaults(
                persist_dir=f"{table_index_dir}/{table_name}"
            )
            # load index
            index = load_index_from_storage(
                storage_context, index_id="vector_index"
            )
        vector_index_dict[table_name] = index

    return vector_index_dict


vector_index_dict = index_all_tables(sql_database)

Indexing rows in table: card_prod
Indexing rows in table: card_prod_fetr


In [None]:
from llama_index.core.retrievers import SQLRetriever
from typing import List
from llama_index.core.query_pipeline import FnComponent

sql_retriever = SQLRetriever(sql_database)


def get_table_context_and_rows_str(
    query_str: str, table_schema_objs: List[SQLTableSchema]
):
    """Get table context string."""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        # first append table info + additional context
        table_info = sql_database.get_single_table_info(
            table_schema_obj.table_name
        )
        if table_schema_obj.context_str:
            table_opt_context = " The table description is: "
            table_opt_context += table_schema_obj.context_str
            table_info += table_opt_context

        # also lookup vector index to return relevant table rows
        vector_retriever = vector_index_dict[
            table_schema_obj.table_name
        ].as_retriever(similarity_top_k=2)
        relevant_nodes = vector_retriever.retrieve(query_str)
        if len(relevant_nodes) > 0:
            table_row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
            for node in relevant_nodes:
                table_row_context += str(node.get_content()) + "\n"
            table_info += table_row_context

        context_strs.append(table_info)
    return "\n\n".join(context_strs)


table_parser_component = FnComponent(fn=get_table_context_and_rows_str)

In [None]:
from llama_index.core.query_pipeline import (
    QueryPipeline as QP,
    Link,
    InputComponent,
    CustomQueryComponent,
)

qp = QP(
    modules={
        "input": InputComponent(),
        "table_retriever": obj_retriever,
        "table_output_parser": table_parser_component,
        "text2sql_prompt": text2sql_prompt,
        "text2sql_llm": llm,
        "sql_output_parser": sql_parser_component,
        "sql_retriever": sql_retriever,
        "response_synthesis_prompt": response_synthesis_prompt,
        "response_synthesis_llm": llm,
    },
    verbose=True,
)

qp.add_link("input", "table_retriever")
qp.add_link("input", "table_output_parser", dest_key="query_str")
qp.add_link(
    "table_retriever", "table_output_parser", dest_key="table_schema_objs"
)
qp.add_link("input", "text2sql_prompt", dest_key="query_str")
qp.add_link("table_output_parser", "text2sql_prompt", dest_key="schema")
qp.add_chain(
    ["text2sql_prompt", "text2sql_llm", "sql_output_parser", "sql_retriever"]
)
qp.add_link(
    "sql_output_parser", "response_synthesis_prompt", dest_key="sql_query"
)
qp.add_link(
    "sql_retriever", "response_synthesis_prompt", dest_key="context_str"
)
qp.add_link("input", "response_synthesis_prompt", dest_key="query_str")
qp.add_link("response_synthesis_prompt", "response_synthesis_llm")

In [None]:
response = qp.run(
    query="show cashback option or retail discount"
)
print(str(response))

[1;3;38;2;155;135;227m> Running module input with input: 
query: show cashback option or retail discount

[0m[1;3;38;2;155;135;227m> Running module table_retriever with input: 
input: show cashback option or retail discount

[0m[1;3;38;2;155;135;227m> Running module table_output_parser with input: 
query_str: show cashback option or retail discount
table_schema_objs: [SQLTableSchema(table_name='card_prod', context_str='List of card products, contains product ID and customer-facing product name.'), SQLTableSchema(table_name='card_prod_fetr', context_str='List of pr...

[0m[1;3;38;2;155;135;227m> Running module text2sql_prompt with input: 
query_str: show cashback option or retail discount
schema: Table 'card_prod' has columns: card_prod_id (VARCHAR(3)), cust_face_prod_nm (VARCHAR(45)), and foreign keys: . The table description is: List of card products, contains product ID and customer-facing ...

[0m[1;3;38;2;155;135;227m> Running module text2sql_llm with input: 
messages: Giv