In [1]:
%%capture
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

# We have to check which Torch version for Xformers (2.3 -> 0.0.27)
from torch import __version__; from packaging.version import Version as V
xformers = "xformers==0.0.27" if V(__version__) < V("2.4.0") else "xformers"
!pip install --no-deps {xformers} trl peft accelerate bitsandbytes triton
!pip install Gradio
!pip install langchain_community
!pip install sqlalchemy
!pip install sqlite3
!pip install requests


In [2]:
import re
import requests
import sqlite3
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
from langchain.sql_database import SQLDatabase
import gradio as gr
from unsloth import FastLanguageModel
import pandas as pd
import torch

# Global variables
db = None
connection = None
loaded_model = None
loaded_tokenizer = None
hf_api_key = "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"  # Replace with your actual API key

def connectDatabase(url):
    global db, connection
    response = requests.get(url)
    sql_script = response.text

    # Create an in-memory SQLite database connection
    connection = sqlite3.connect(":memory:", check_same_thread=False)
    cursor = connection.cursor()

    try:
        # Execute the SQL script
        cursor.executescript(sql_script)
        connection.commit()  # Commit the changes to ensure tables are created
        engine = create_engine(
            "sqlite://",
            creator=lambda: connection,
            poolclass=StaticPool,
            connect_args={"check_same_thread": False})
        db = SQLDatabase(engine)
        return "Successfully connected to the database."
    except Exception as e:
        return f"Error connecting to the database: {str(e)}"

def executeQuery(query):
    global connection
    if connection:
        try:
            cursor = connection.cursor()
            cursor.execute(query)
            results = cursor.fetchall()
            columns = [description[0] for description in cursor.description]  # Get column names
            return pd.DataFrame(results, columns=columns)
        except Exception as e:
            return pd.DataFrame({'Error': [f'Error executing query: {str(e)}']})
    else:
        return pd.DataFrame({'Error': ['Please connect to the database first.']})

def getDatabaseSchema():
    if db:
        return db.get_table_info()
    else:
        return "Please connect to the database first."

def extract_sql_statement(response_text):
    pattern = r'### Response:\n(.*?);'
    match = re.search(pattern, response_text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return ""

def getQueryFromLLM(question):
    global loaded_model, loaded_tokenizer
    if not loaded_model or not loaded_tokenizer:
        return None, "Model not loaded. Please load the model first."

    schema = getDatabaseSchema()
    input_prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Generate an SQL query to answer the following question: {question}

### Input:
Database schema:
{schema}

### Response:
"""

    inputs = loaded_tokenizer([input_prompt], return_tensors="pt").to("cuda")
    outputs = loaded_model.generate(**inputs, max_new_tokens=64, use_cache=True)
    generated_response = loaded_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

    print(f"Generated Response: {generated_response}")  # Debug output

    query = extract_sql_statement(generated_response)
    print(f"Extracted Query: {query}")  # Debug output

    if query:
        result_df = executeQuery(query)
        return query, result_df
    return None, pd.DataFrame({'Error': ['Failed to generate a valid SQL query.']})

def load_model(model_name):
    global loaded_model, loaded_tokenizer
    try:
        print(f"Loading model: {model_name}...")
        max_seq_length = 2048
        dtype = None
        load_in_4bit = True

        loaded_model, loaded_tokenizer = FastLanguageModel.from_pretrained(
            model_name,
            max_seq_length=max_seq_length,
            dtype=dtype,
            load_in_4bit=load_in_4bit,
        )
        FastLanguageModel.for_inference(loaded_model)
        print("Model loaded successfully.")
        return "Model loaded successfully."
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return f"Error loading model: {str(e)}"

def chat_with_sql(question, database):
    if not database:
        return [None, "Please connect to a database first."]

    if not loaded_model or not loaded_tokenizer:
        return [None, "Please load the model first."]

    try:
        print(f"Generating query from LLM for question: {question}")
        query, result = getQueryFromLLM(question)
        print(f"Generated query: {query}")

        if query:
            print("Generating response based on query result...")
            response_df = result
            return [question, response_df]
        else:
            return [question, pd.DataFrame({'Error': ['Failed to generate a valid SQL query.']})]

    except Exception as e:
        error_message = f"An error occurred: {str(e)}"
        print(error_message)
        return [question, pd.DataFrame({'Error': [error_message]})]

def clear_history():
    return "", "", []

# Initialize Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# Chat with SQL DB 🤖")

    with gr.Row():
        database = gr.Textbox(label="Database", placeholder="ex: https://raw.githubusercontent.com/....sql")
        model = gr.Dropdown(choices=["PlatoisHere/8B_58"], label="Model", value="PlatoisHere/8B_58")

    with gr.Row():
        load_model_btn = gr.Button("Load Model")
        connect_btn = gr.Button("Connect to Database")
        clear_btn = gr.Button("Clear")

    question = gr.Textbox(label="Chat with an SQL database", placeholder="Enter your question here...")
    submit_btn = gr.Button("Submit Question")
    chat_output = gr.Chatbot(height=400)
    dataframe_output = gr.Dataframe(interactive=True)

    def load_model_callback(model_name):
        result = load_model(model_name)
        return [[None, result]], None

    def connect_callback(database):
        result = connectDatabase(database)
        return [[None, result]], None

    def submit_callback(question, database, history):
        response = chat_with_sql(question, database)
        if isinstance(response[1], pd.DataFrame):
            return history + [[question, "Query executed successfully."]], response[1]
        else:
            return history + [response], None

    def clear_history():
        return "", "", [], None

    load_model_btn.click(load_model_callback, inputs=[model], outputs=[chat_output, dataframe_output])
    connect_btn.click(connect_callback, inputs=[database], outputs=[chat_output, dataframe_output])
    submit_btn.click(submit_callback, inputs=[question, database, chat_output], outputs=[chat_output, dataframe_output])
    clear_btn.click(clear_history, inputs=None, outputs=[question, database, chat_output, dataframe_output])

if __name__ == "__main__":
    demo.launch(debug=True)


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://0fd19f6be587b2f367.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


Loading model: PlatoisHere/8B_58...
==((====))==  Unsloth 2024.8: Fast Llama patching. Transformers = 4.44.2.
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.4.0+cu121. CUDA = 7.5. CUDA Toolkit = 12.1.
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.27.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/5.70G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/230 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/345 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/2.68G [00:00<?, ?B/s]

Unsloth 2024.8 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


Model loaded successfully.
Generating query from LLM for question: Give me the list of all artists from artist table
Generated Response: Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Generate an SQL query to answer the following question: Give me the list of all artists from artist table

### Input:
Database schema:

CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Cust