Adapted for open-source & private AI from https://python.langchain.com/docs/how_to/sql_large_db/

In [1]:
import os
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
# from langchain_ollama.llms import OllamaLLM
from langchain_ollama.chat_models import ChatOllama
from langchain_ollama.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import OpenSearchVectorSearch

from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field

# Setup Database and Ollama Connections

In [2]:
OLLAMA_BASE_URL = "http://host.docker.internal:11434"
OLLAMA_MODEL_ID = "llama3.1"
OLLAMA_EMBEDDING_MODEL_ID = "nomic-embed-text"

SQLITE_DATABASE_URI = "sqlite:///Chinook.db"

OPENSEARCH_URL = "http://host.docker.internal:9200"

In [3]:
db = SQLDatabase.from_uri(SQLITE_DATABASE_URI)
print(f"Available tables are {db.get_usable_table_names()}")

Available tables are ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


In [4]:
# Test db connection
db.run("SELECT * FROM Artist LIMIT 10;")

"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [5]:
llm = ChatOllama(model=OLLAMA_MODEL_ID, temperature=0, base_url=OLLAMA_BASE_URL, verbose=True)

# Initial Example - Use LLM to Choose Relevant Tables Based on User Query

In [6]:
class Table(BaseModel):
    """Table in SQL database."""

    name: str = Field(description="Name of the table in SQL database.")

table_names = "\n".join(db.get_usable_table_names())
system = f"""Below are the available table names:

{table_names}

Return the names of ALL the SQL tables that MIGHT be relevant to the user question. Respond with a list of Table objects.
Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{input}"),
    ]
)
llm_with_tools = llm.bind_tools([Table])
output_parser = PydanticToolsParser(tools=[Table])

table_chain = prompt | llm_with_tools | output_parser

table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Genre'),
 Table(name='Track'),
 Table(name='Artist'),
 Table(name='Album')]

# Simplify Model Table Selection By Grouping Tables and Asking LLM To Choose Category (Schema)

In [7]:
system = """Return the names of any SQL tables that are relevant to the user question.
The tables are:

Music
Business
"""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{input}"),
    ]
)

category_chain = prompt | llm_with_tools | output_parser
category_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Music')]

In [8]:
from typing import List


def get_tables(categories: list[Table]) -> List[str]:
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables


table_chain = category_chain | get_tables
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs?"})

['Album', 'Artist', 'Genre', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']

# Create a Full SQL Query Chain, Including Passing Relevant Table Names During Prompt Generation, To Generate SQL Statements

In [9]:
from operator import itemgetter

from langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough

def postprocessor(text:str)->str:
    processed = text.replace("`","")
    return processed

query_chain = create_sql_query_chain(llm, db)
# Convert "question" key to the "input" key expected by current table_chain.
table_chain = {"input": itemgetter("question")} | table_chain
# Set table_names_to_use using table_chain.
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain | postprocessor

In [10]:
query = full_chain.invoke(
    {"question": "What are all the genres of Alanis Morisette songs"}
)
print(query)

Since there is no information about Alanis Morissette in the provided tables, I will assume that you want to know the genres of tracks from a specific artist.

However, since we don't have any data about Alanis Morissette's songs in our database, I'll provide an answer based on the available data. Let's say we want to find the genres of AC/DC's songs instead.

Question: What are all the genres of Alanis Morisette songs
SQLQuery: SELECT T.GenreId, G.Name FROM Track AS T INNER JOIN Genre AS G ON T.GenreId = G.GenreId WHERE T.AlbumId IN (SELECT A.AlbumId FROM Album AS A INNER JOIN Artist AS Ar ON A.ArtistId = Ar.ArtistId WHERE Ar.Name = 'AC/DC')


# Handle High Cardinality Columns

In [12]:
import ast
import re


def query_as_list(db, query):
    res = db.run(query)
    res = [el for sub in ast.literal_eval(res) for el in sub if el]
    res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
    return res


proper_nouns = query_as_list(db, "SELECT Name FROM Artist")
proper_nouns += query_as_list(db, "SELECT Title FROM Album")
proper_nouns += query_as_list(db, "SELECT Name FROM Genre")
proper_nouns = list(set(proper_nouns))
len(proper_nouns)
proper_nouns[:5]

['Academy of St. Martin in the Fields & Sir Neville Marriner',
 'Live At Donington  (Disc )',
 'Machine Head',
 'Comedy',
 'Xis']

In [13]:
embedding = OllamaEmbeddings(model=OLLAMA_EMBEDDING_MODEL_ID, base_url=OLLAMA_BASE_URL)

vector_db = OpenSearchVectorSearch.from_texts(
    texts = proper_nouns, 
    embedding = embedding,
    opensearch_url=OPENSEARCH_URL,
    bulk_size=10000
    )
retriever = vector_db.as_retriever(search_kwargs={"k": 15})

ImportError: Could not import OpenSearch. Please install it with `pip install opensearch-py`.

In [None]:
from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough

system = """You are a SQLite expert. Given an input question, create a syntactically
correct SQLite query to run. Unless otherwise specificed, do not return more than
{top_k} rows.

Only return the SQL query with no markup or explanation.

Here is the relevant table info: {table_info}

Here is a non-exhaustive list of possible feature values. If your query includes a 'WHERE' clause, make sure to check the spelling of the predicate value against this list of proper noun spellings first:

{proper_nouns}
"""

prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{input}")])

table_chain = category_chain | get_tables
table_chain = {"input": itemgetter("question")} | table_chain

query_chain = create_sql_query_chain(llm, db, prompt=prompt)
retriever_chain = (
    itemgetter("question")
    | retriever
    | (lambda docs: "\n".join(doc.page_content for doc in docs))
)
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain | postprocessor
chain = RunnablePassthrough.assign(proper_nouns=retriever_chain) | full_chain

In [27]:
## Example without using retrieval
query = full_chain.invoke(
    {"question": "What are all the genres of Alania Morisqete songs", "proper_nouns": ""}
)
print(query)
db.run(query)

SELECT T3.Name FROM Album AS T1 JOIN Artist AS T2 ON T1.ArtistId = T2.ArtistId JOIN Track AS T3 ON T1.AlbumId = T3.AlbumId WHERE T2.Name = 'Alania Morisqete' LIMIT 5


''

In [30]:
# With retrieval
query = chain.invoke({"question": "What are all the genres of Alania Morisqete songs"})
print(query)
db.run(query)

SELECT T3.Name FROM Track AS T1 JOIN Album AS T2 ON T1.AlbumId = T2.AlbumId JOIN Genre AS T3 ON T1.GenreId = T3.GenreId WHERE T2.ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alanis Morissette') LIMIT 5


"[('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',)]"