In [1]:
from langchain_ollama import ChatOllama
from langchain_community.utilities import SQLDatabase
from langchain_ollama.embeddings import OllamaEmbeddings
from pydantic import BaseModel, Field
import sqlite3
from uuid import uuid4

In [2]:
OLLAMA_BASE_URL = "http://host.docker.internal:11434"
OLLAMA_MODEL_ID = "qwen2.5"
OLLAMA_EMBEDDING_MODEL_ID = "nomic-embed-text"

SQLITE_DATABASE_URI = "sqlite:///Chinook.db"

OPENSEARCH_URL = "http://host.docker.internal:9200"

In [3]:
db = SQLDatabase.from_uri(SQLITE_DATABASE_URI)
print(f"Available tables are {db.get_usable_table_names()}")

Available tables are ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


In [4]:
# Test db connection
db.run("SELECT * FROM Artist LIMIT 10;")

"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [5]:
llm = ChatOllama(model=OLLAMA_MODEL_ID, temperature=0, base_url=OLLAMA_BASE_URL, verbose=True)

In [6]:
def get_table_names() -> str:
    table_names = "\n".join(db.get_usable_table_names())
    return table_names

In [12]:
class Query(BaseModel):
    """Query response object"""
    query: str = Field(description="The query to be executed against the database")

def generate_query(input_text:str, num_attempts_at_valid_query=5, return_results=True):
    structured_llm = llm.with_structured_output(schema=Query)
    messages = [{
        "role":"user",
        "content":f"""\
Here are the available table names:
{get_table_names()}

Generate a SQL query for sqlite3 that can be executed to satisfy the users question. If an error is raised by your query, update the query to resolve the error and satisfy the users question.

{input_text}.
"""
    }]

    attempts = 0
    while attempts < num_attempts_at_valid_query:
        attempts += 1
        
        query = structured_llm.invoke(messages)
        query_text = query.query
        messages.append({"role":"assistant","content":query_text})
        try:
            results = db.run(query_text)
            if return_results:
                return query_text, results
            return query_text
        except Exception as e:
            error_message = f"Error raised - {e}"
            messages.append({"role":"tool","content":error_message, "tool_call_id":uuid4().hex})
            print(messages)

        if attempts == num_attempts_at_valid_query:
            raise Exception(f"Exhausted {num_attempts_at_valid_query} attempts to generate query")
            

In [13]:
structured_llm = llm.with_structured_output(schema=Query)
messages = [{
    "role":"user",
    "content":f"""\
Here are the available table names:
{get_table_names()}

Generate a SQL query for sqlite3 that can be executed to satisfy the users question. If an error is raised by your query, update the query to resolve the error and satisfy the users question.

What are all the genres of Alanis Morisette songs.
"""
}]

In [17]:
generated_query, results = generate_query("What are all the genres of Alanis Morisette songs")

In [18]:
generated_query

"SELECT DISTINCT g.Name AS GenreName FROM Track t JOIN Album a ON t.AlbumId = a.AlbumId JOIN Artist ar ON a.ArtistId = ar.ArtistId JOIN Genre g ON t.GenreId = g.GenreId WHERE ar.Name = 'Alanis Morissette'"

In [19]:
results

"[('Rock',)]"