In [1]:
!pip install gradio langchain_core langchain_community ollama
from IPython.display import clear_output
clear_output()

In [2]:
import gradio as gr
from langchain_community.chat_models import ChatOllama
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import ChatPromptTemplate
import ollama as Ollama
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
import sqlite3

In [3]:
!sudo apt-get install -y pciutils
!curl -fsSL https://ollama.com/install.sh | sh # download ollama api
clear_output()

#Create a Python script to start the Ollama API server in a seperate thread

import os
import threading
import subprocess
import requests
import json

def ollama():
  os.environ['OLLAMA_HOST'] = '127.0.0.1:11434'
  os.environ['OLLAMA_ORIGINS'] = '*'
  subprocess.Popen(['ollama', 'serve'])

In [4]:
ollama_thread = threading.Thread(target=ollama)
ollama_thread.start()

In [5]:
!ollama pull llama3.1:70b
clear_output()

In [6]:
!ollama pull llama3.1:8b
clear_output()

In [23]:
# Global variable to store the database connection object
db = None

# Function to connect to the database
def connectDatabase(url):
    response = requests.get(url)
    sql_script = response.text

    connection = sqlite3.connect(":memory:", check_same_thread=False)
    connection.executescript(sql_script)
    engine = create_engine(
        "sqlite://",
        creator=lambda: connection,
        poolclass=StaticPool,
        connect_args={"check_same_thread": False})
    db = SQLDatabase(engine)


# Function to run a query on the database
def runQuery(query):
  if db:
    return db.run(query)
  else:
    return "Please connect to the database first."


# Function to get the database schema
def getDatabaseSchema():
  if db:
    return db.get_table_info()
  else:
    return "Please connect to the database first."


def getQueryFromLLM(llm, question, max_iteration=10):
    template = """below is the schema of SQLite database, read the schema carefully about the table and column names. Also take care of table or column name case sensitivity.
    Finally answer user's question in the form of SQL query.

    {schema}

    please only provide the SQL query and nothing else

    for example:
    question: how many albums we have in database
    SQL query: SELECT COUNT(*) FROM album
    question: how many customers are from Brazil in the database ?
    SQL query: SELECT COUNT(*) FROM customer WHERE country=Brazil

    your turn :
    question: {question}
    SQL query :
    please only provide the SQL query and nothing else
    """
    prompt = ChatPromptTemplate.from_template(template)  # Define prompt outside the loop
    chain = prompt | llm  # Define initial chain outside the loop
    i = max_iteration
    response = None  # Initialize response to None

    while i>0:
        try:
            response = chain.invoke({
                "question": question,
                "schema": getDatabaseSchema(),
                "error" : ""
            })
            # Attempt to execute the query to check its validity
            result = runQuery(response.content)
            # If execution is successful, break the loop
            break
        except Exception as error:
            # If an error occurs, feed the error message back to the LLM
            print(f"Error encountered: {error}")
            template = """Previous query attempt failed with error: {error}.
                           Please try generating a different SQL query for the question: {question}.
                           Here is the database schema for reference: {schema}
                           SQL query: """
            prompt = ChatPromptTemplate.from_template(template)
            chain = prompt | llm
            response = chain.invoke({  # Invoke the chain with the error message
                "question": question,
                "schema": getDatabaseSchema(),
                "error": str(error)  # Pass the error message to the prompt
            })
            i -= 1

    if response:  # Check if response has been assigned a value
      # Check if the response indicates failure
      if "Failed to generate a valid query." in response.content:
        return None  # Return None to signal query generation failure
      else:
        return response.content
    else:
        return None  # Return None if no response was generated


def getResponseForQueryResult(llm, question, query, result):
    template2 = """below is the schema of SQLite database, read the schema carefully about the table and column names of each table.
    Also look into the conversation if available
    Finally write a response in natural language by looking into the conversation and result.

    {schema}

    Here are some example for you:
    question: how many albums we have in database
    SQL query: SELECT COUNT(*) FROM album;
    Result : [(34,)]
    Response: There are 34 albums in the database.

    question: how many users we have in database
    SQL query: SELECT COUNT(*) FROM customer;
    Result : [(59,)]
    Response: There are 59 users in the database.

    question: how many users above are from india we have in database
    SQL query: SELECT COUNT(*) FROM customer WHERE country=india;
    Result : [(4,)]
    Response: There are 4 users in the database.

    your turn to write response in natural language from the given result :
    question: {question}
    SQL query : {query}
    Result : {result}
    Response:
    """

    prompt2 = ChatPromptTemplate.from_template(template2)
    chain2 = prompt2 | llm

    response = chain2.invoke({
        "question": question,
        "schema": getDatabaseSchema(),
        "query": query,
        "result": result
    })

    return response.content

In [24]:
def clear_history():
    return [], "", "", ""

def chat_with_sql(question, database, model, chat_history):
    if not database:
        return chat_history, "", "Please connect to a database first."

    chat_history.append(("user", question))

    try:
        print("Connecting to the database...")
        connectDatabase(url=database)
        print("Database connected.")

        print(f"Loading model: {model}...")
        llm = ChatOllama(model=model, temperature=0.1)
        print("Model loaded.")

        print(f"Generating query from LLM for question: {question}")
        query = getQueryFromLLM(llm, question)
        print(f"Generated query: {query}")

        print("Running query on the database...")
        result = runQuery(query)
        print(f"Query result: {result}")

        print("Generating response based on query result...")
        response = getResponseForQueryResult(llm, question, query, result)
        print(f"Response: {response}")

        chat_history.append(("assistant", response))

    except Exception as e:
        error_message = f"An error occurred: {str(e)}"
        print(error_message)
        chat_history.append(("assistant", error_message))

    return chat_history, "", ""

models = [model["name"] for model in Ollama.list()["models"]]

with gr.Blocks() as demo:
    gr.Markdown("# Chat with SQL DB 🤖")

    with gr.Row():
        database = gr.Textbox(label="Database", placeholder="ex: https://raw.githubusercontent.com/....sql")
        model = gr.Dropdown(choices=models, label="Model", value=models[0])

    chat_history = []

    question = gr.Textbox(label="Chat with an SQL database", placeholder="Enter your question here...")

    with gr.Row():
        connect_btn = gr.Button("Connect")
        clear_btn = gr.Button("Clear message history")

    chat_output = gr.Chatbot(height=400)

    def submit_callback(question, database, model):
        return chat_with_sql(question, database, model, chat_history)

    question.submit(submit_callback, inputs=[question, database, model], outputs=[chat_output, question, database])
    connect_btn.click(submit_callback, inputs=[question, database, model], outputs=[chat_output, question, database])
    clear_btn.click(clear_history, inputs=None, outputs=[chat_output, question, database, model])


In [25]:
demo.launch()

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://8a188e2cc0f5574dd8.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


