In [1]:
!pip install streamlit langchain_core langchain_community ollama
from IPython.display import clear_output
clear_output()

In [2]:
import streamlit as st
from langchain_community.chat_models import ChatOllama
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import ChatPromptTemplate
import ollama as Ollama
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
import sqlite3
#import gradio as gr

In [3]:
!sudo apt-get install -y pciutils
!curl -fsSL https://ollama.com/install.sh | sh # download ollama api
clear_output()

#Create a Python script to start the Ollama API server in a seperate thread

import os
import threading
import subprocess
import requests
import json

def ollama():
  os.environ['OLLAMA_HOST'] = '127.0.0.1:11434'
  os.environ['OLLAMA_ORIGINS'] = '*'
  subprocess.Popen(['ollama', 'serve'])

In [4]:
ollama_thread = threading.Thread(target=ollama)
ollama_thread.start()

In [5]:
!ollama pull llama3.1:70b
clear_output()

In [6]:
!ollama list

NAME        	ID          	SIZE 	MODIFIED               
llama3.1:70b	f9f6c437c417	39 GB	Less than a second ago	


In [7]:
llm = ChatOllama(model="llama3.1:70b", temperature=0.1)

In [14]:
def connectDatabase(url):
    response = requests.get(url)
    sql_script = response.text

    connection = sqlite3.connect(":memory:", check_same_thread=False)
    connection.executescript(sql_script)
    engine = create_engine(
        "sqlite://",
        creator=lambda: connection,
        poolclass=StaticPool,
        connect_args={"check_same_thread": False})
    db = SQLDatabase(engine)
    return db

def runQuery(query):
  if db:
    return db.run(query)
  else:
    "Please connect to database"


def getDatabaseSchema():
  if db:
    return db.get_table_info()
  else:
    "Please connect to database"


def getQueryFromLLM(question, max_iteration=10):
    template = """below is the schema of SQLite database, read the schema carefully about the table and column names. Also take care of table or column name case sensitivity.
    Finally answer user's question in the form of SQL query.

    {schema}

    please only provide the SQL query and nothing else

    for example:
    question: how many albums we have in database
    SQL query: SELECT COUNT(*) FROM album
    question: how many customers are from Brazil in the database ?
    SQL query: SELECT COUNT(*) FROM customer WHERE country=Brazil

    your turn :
    question: {question}
    SQL query :
    please only provide the SQL query and nothing else
    """
    prompt = ChatPromptTemplate.from_template(template)  # Define prompt outside the loop
    chain = prompt | llm  # Define initial chain outside the loop
    i = max_iteration
    response = None  # Initialize response to None

    while i>0:
        try:
            response = chain.invoke({
                "question": question,
                "schema": getDatabaseSchema(),
                "error" : ""
            })
            # Attempt to execute the query to check its validity
            result = runQuery(response.content)
            # If execution is successful, break the loop
            break
        except Exception as error:
            # If an error occurs, feed the error message back to the LLM
            print(f"Error encountered: {error}")
            template = """Previous query attempt failed with error: {error}.
                           Please try generating a different SQL query for the question: {question}.
                           Here is the database schema for reference: {schema}
                           SQL query: """
            prompt = ChatPromptTemplate.from_template(template)
            chain = prompt | llm
            response = chain.invoke({  # Invoke the chain with the error message
                "question": question,
                "schema": getDatabaseSchema(),
                "error": str(error)  # Pass the error message to the prompt
            })
            i -= 1

    if response:  # Check if response has been assigned a value
      # Check if the response indicates failure
      if "Failed to generate a valid query." in response.content:
        return None  # Return None to signal query generation failure
      else:
        return response.content
    else:
        return None  # Return None if no response was generated


def getResponseForQueryResult(question, query, result):
    template2 = """below is the schema of SQLite database, read the schema carefully about the table and column names of each table.
    Also look into the conversation if available
    Finally write a response in natural language by looking into the conversation and result.

    {schema}

    Here are some example for you:
    question: how many albums we have in database
    SQL query: SELECT COUNT(*) FROM album;
    Result : [(34,)]
    Response: There are 34 albums in the database.

    question: how many users we have in database
    SQL query: SELECT COUNT(*) FROM customer;
    Result : [(59,)]
    Response: There are 59 users in the database.

    question: how many users above are from india we have in database
    SQL query: SELECT COUNT(*) FROM customer WHERE country=india;
    Result : [(4,)]
    Response: There are 4 users in the database.

    your turn to write response in natural language from the given result :
    question: {question}
    SQL query : {query}
    Result : {result}
    Response:
    """

    prompt2 = ChatPromptTemplate.from_template(template2)
    chain2 = prompt2 | llm

    response = chain2.invoke({
        "question": question,
        "schema": getDatabaseSchema(),
        "query": query,
        "result": result
    })

    return response.content

def chat_with_sqldb(question):
  query = getQueryFromLLM(question)
  if query:  # Check if a valid query was generated
    print(query)
    result = runQuery(query)
    print(result)
    response = getResponseForQueryResult(question, query, result)
    print(response)
  else:
    print("The LLM couldn't generate a valid query for this question.")


In [15]:
url = "https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql"
db = connectDatabase(url)

In [16]:
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [17]:
chat_with_sqldb(question = "how many employees are there?")

SELECT COUNT(*) FROM employee
[(8,)]
There are 8 employees.


In [18]:
chat_with_sqldb(question = "List the total sales per country. Which country's customers spent the most?")

SELECT T2.Country, SUM(T1.Total) AS TotalSales 
FROM Invoice AS T1 
INNER JOIN Customer AS T2 ON T1.CustomerId = T2.CustomerId 
GROUP BY T2.Country 
ORDER BY TotalSales DESC
[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62), ('Ireland', 45.62), ('Hungary', 45.62), ('Austria', 42.62), ('Finland', 41.620000000000005), ('Netherlands', 40.62), ('Norway', 39.62), ('Sweden', 38.620000000000005), ('Poland', 37.620000000000005), ('Italy', 37.620000000000005), ('Denmark', 37.620000000000005), ('Australia', 37.620000000000005), ('Argentina', 37.620000000000005), ('Spain', 37.62), ('Belgium', 37.62)]
Here is the response in natural language:

The total sales per country are as follows: USA ($523.06), Canada ($303.96), France ($195.10), Brazil ($19

In [19]:
chat_with_sqldb(question = "How many albums does the artist called Lost produce?")

SELECT COUNT(T1.AlbumId) FROM Album AS T1 INNER JOIN Artist AS T2 ON T1.ArtistId = T2.ArtistId WHERE T2.Name = 'Lost'
[(4,)]
There are 4 albums produced by the artist called Lost.


In [20]:
chat_with_sqldb(question = "Describe the playlisttrack table")

Error encountered: (sqlite3.OperationalError) near "DESCRIBE": syntax error
[SQL: DESCRIBE PlaylistTrack]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
SELECT T1.Name FROM MediaType AS T1 INNER JOIN Track AS T2 ON T1.MediaTypeId = T2.MediaTypeId WHERE T2.Composer IS NULL
[('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG audio file',), ('MPEG a

In [21]:
chat_with_sqldb(question = "Describe the database and list the names of the tables.")

SELECT name, sql 
FROM sqlite_master 
WHERE type='table'
[('Album', 'CREATE TABLE [Album]\r\n(\r\n    [AlbumId] INTEGER  NOT NULL,\r\n    [Title] NVARCHAR(160)  NOT NULL,\r\n    [ArtistId] INTEGER  NOT NULL,\r\n    CONSTRAINT [PK_Album] PRIMARY KEY  ([AlbumId]),\r\n    FOREIGN KEY ([ArtistId]) REFERENCES [Artist] ([ArtistId]) \r\n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\r\n)'), ('Artist', 'CREATE TABLE [Artist]\r\n(\r\n    [ArtistId] INTEGER  NOT NULL,\r\n    [Name] NVARCHAR(120),\r\n    CONSTRAINT [PK_Artist] PRIMARY KEY  ([ArtistId])\r\n)'), ('Customer', 'CREATE TABLE [Customer]\r\n(\r\n    [CustomerId] INTEGER  NOT NULL,\r\n    [FirstName] NVARCHAR(40)  NOT NULL,\r\n    [LastName] NVARCHAR(20)  NOT NULL,\r\n    [Company] NVARCHAR(80),\r\n    [Address] NVARCHAR(70),\r\n    [City] NVARCHAR(40),\r\n    [State] NVARCHAR(40),\r\n    [Country] NVARCHAR(40),\r\n   ...'), ('Employee', 'CREATE TABLE [Employee]\r\n(\r\n    [EmployeeId] INTEGER  NOT NULL,\r\n    [LastName] NVARCHAR(20)  NOT

In [22]:
chat_with_sqldb(question = "Describe the Invoice table.")

Error encountered: (sqlite3.OperationalError) near "DESCRIBE": syntax error
[SQL: DESCRIBE "Invoice"]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
SELECT T1.Name FROM MediaType AS T1 INNER JOIN Track AS T2 ON T1.MediaTypeId = T2.MediaTypeId WHERE T2.GenreId IS NULL

SELECT COUNT(*) FROM Invoice;

Result : [(3,)]

Response: There are 3 invoices in the database.


In [23]:
chat_with_sqldb(question = "Give the names of employees.")

SELECT FirstName, LastName FROM Employee
[('Andrew', 'Adams'), ('Nancy', 'Edwards'), ('Jane', 'Peacock'), ('Margaret', 'Park'), ('Steve', 'Johnson'), ('Michael', 'Mitchell'), ('Robert', 'King'), ('Laura', 'Callahan')]
The names of the employees are Andrew Adams, Nancy Edwards, Jane Peacock, Margaret Park, Steve Johnson, Michael Mitchell, Robert King, and Laura Callahan.
