In [1]:
# import dotenv
# dotenv.load_dotenv(ROOT_DIR.parent.parent / '.env')
from typing import Dict
from llama_index import ListIndex
from langchain.embeddings import OllamaEmbeddings
from langchain.chat_models import ChatOpenAI
from llama_index.indices.struct_store.sql_query import NLSQLTableQueryEngine
from llama_index.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine, )
from llama_index.llms import Ollama
from llama_index.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index import (LLMPredictor, SQLDatabase, ServiceContext,
                         VectorStoreIndex, set_global_service_context)
from sqlalchemy import MetaData, Table, create_engine, inspect

from dbchat import ROOT_DIR

DATA_DIR = ROOT_DIR.parent.parent / "data"
db_path = str(DATA_DIR / "chinook.db")


## Initialise models

In [2]:
# Initialise the encoder
embedding_model = OllamaEmbeddings(model="llama2")

# Initialise the llm
llm = Ollama(model="llama2")
llm_predictor = LLMPredictor(llm=llm)

ctx = ServiceContext.from_defaults(embed_model=embedding_model,
                                    llm_predictor=llm_predictor)
set_global_service_context(ctx)

## Create the SQL database

In [3]:
db_path = db_path
kwargs = {}

"""Get the SQL database."""
engine = create_engine(f"sqlite:///{db_path}")
inspection = inspect(engine)
all_table_names = inspection.get_table_names()

metadata_obj = MetaData()

for table_name in all_table_names:
    table = Table(table_name, metadata_obj, autoload_with=engine)
metadata_obj.create_all(engine)

sql_database = SQLDatabase(engine,
                            include_tables=all_table_names,
                            **kwargs)
# return sql_database

## Construct Index & Retriever

In [4]:
def construct_retriever(sql_database: SQLDatabase, 
                        table_contexts: Dict[str,str] = {},
                        index_retriever_kwargs={'similarity_top_k': 4}):
    table_node_mapping = SQLTableNodeMapping(sql_database)
    context_str = ""  # BUG: retrieval in 0.9.8 requires context in the metadata
    table_schema_objs = [
        (SQLTableSchema(table_name=t, context_str=table_contexts.get(t, context_str)))
        for t in sql_database.get_usable_table_names()
    ]  # add a SQLTableSchema for each table
    obj_index = ObjectIndex.from_objects(
        table_schema_objs,
        table_node_mapping,
        VectorStoreIndex,
    )
    retriever = obj_index.as_retriever(similarity_top_k=4)

    query_engine = SQLTableRetrieverQueryEngine(sql_database,
                                                retriever,
                                                service_context=ctx)
    return query_engine
query_engine = construct_retriever(sql_database)

## Make a query

In [5]:
input_query = "How much money did Berlin make?"
response = query_engine.query(input_query)
print(response.metadata["sql_query"])
print(response)

Question: How much money did Berlin make?

SQL Query: SELECT SUM(unit_price * quantity) FROM invoice_items WHERE track_id = 'Berlin';

SQL Result: 34567.00

Answer: Berlin made $34,567.00.
To synthesize a response from the query results, we can use the following template:

"Berlin made {SUM(unit_price * quantity)}."

Where SUM(unit_price * quantity) is the result of the SQL query you provided.

So, in this case, the response would be:

"Berlin made $34,567.00."


## Inspect the results

In [6]:
retrieved_tables = query_engine.sql_retriever._get_tables(input_query)
print(f"With {input_query=};")
display(retrieved_tables)

With input_query='How much money did Berlin make?';


[SQLTableSchema(table_name='invoice_items', context_str=''),
 SQLTableSchema(table_name='employees', context_str=''),
 SQLTableSchema(table_name='customers', context_str=''),
 SQLTableSchema(table_name='playlist_track', context_str='')]

### Storing SQL tables with additional context info

In [7]:
from llama_index import download_loader
from sqlalchemy import create_engine
def load_metadata_from_sqllite():
    DatabaseReader = download_loader("DatabaseReader")

    engine = create_engine(f"sqlite:///{db_path}")
    reader = DatabaseReader(   engine = engine    )
    
    query = "SELECT DESCRIPTION FROM table_descriptions"
    documents = reader.load_data(query=query)
    
    query = "SELECT DOCUMENT_ID FROM table_descriptions"
    document_ids = reader.load_data(query=query)
    return documents, document_ids

documents, document_ids = load_metadata_from_sqllite()

In [8]:
table_contexts = {docu.text.partition('\n')[0].partition(': ')[2]: l.partition(': ')[2] for docu in documents for l in docu.text.split('\n') if 'table description:' in l}
query_engine = construct_retriever(sql_database, table_contexts=table_contexts)

In [None]:
print(table_contexts)

In [9]:
input_query = "How much money did Berlin make?"
response = query_engine.query(input_query)
print(response.metadata["sql_query"])
print(response)

SELECT SUM(UnitPrice * Quantity) FROM invoice_items WHERE TrackId = 'Berlin';
Berlin made $0 according to the query results.


In [11]:
table_contexts = {docu.text.partition('\n')[0].partition(': ')[2]: l.partition(': ')[2] for docu in documents for l in docu.text.split('\n') if 'table description:' in l}
table_contexts['invoices'] = 'contains the amount of money that has been invoiced for each invoice, and the details of where that invoice was created.'
query_engine = construct_retriever(sql_database, table_contexts=table_contexts)

In [12]:
input_query = "How much money did Berlin make?"
response = query_engine.query(input_query)
print(response.metadata["sql_query"])
print(response)

SELECT SUM(Total) FROM invoices WHERE BillingCity = 'Berlin';
The amount of money that Berlin made is $75.24.


# Retrieval Post-Processors

## LLM Reranker
- Asks LLM to score the retrieved embeddings, and rank them based on the scores

In [10]:
retriever = obj_index.as_retriever(similarity_top_k=4)
query_engine = index.as_query_engine(
    similarity_top_k=10,
    node_postprocessors=[cohere_rerank],
)

NameError: name 'obj_index' is not defined