<a href="https://colab.research.google.com/github/ztor2/text2sql_for_postgres_demo/blob/main/text2sql_for_postgres_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/sqlgen_agens

/content/drive/MyDrive/sqlgen_agens


In [None]:
# !wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin
# !sudo mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600
# !wget https://developer.download.nvidia.com/compute/cuda/12.3.1/local_installers/cuda-repo-ubuntu2204-12-3-local_12.3.1-545.23.08-1_amd64.deb
!sudo dpkg -i cuda-repo-ubuntu2204-12-3-local_12.3.1-545.23.08-1_amd64.deb
!sudo cp /var/cuda-repo-ubuntu2204-12-3-local/cuda-*-keyring.gpg /usr/share/keyrings/
!sudo apt-get update
!sudo apt-get -y install cuda-toolkit-12-3;

In [None]:
!pip install ctransformers
!pip install gradio

In [None]:
import torch
# from transformers import AutoModelForCausalLM
from ctransformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import pipeline
import gradio as gr
import psycopg2
import psycopg2.extras

In [None]:
model = AutoModelForCausalLM.from_pretrained("TheBloke/neural-chat-7B-v3-3-GGUF", model_file="neural-chat-7b-v3-3.Q5_K_M.gguf", model_type="mistral", gpu_layers=100, hf=True)
# model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-beta-GGUF", model_file="zephyr-7b-beta.Q8_0.gguf", model_type="mistral", gpu_layers=50, hf=True)
tokenizer = AutoTokenizer.from_pretrained("Intel/neural-chat-7b-v3-3")
# tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")

In [None]:
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    )

In [None]:
system = "Generate a PostgreSQL query or provide proper information to answer the following question based on database schema: "
warn = " If the user's question is not related to SQL generation, tell them to submit a SQL-related question. "

In [None]:
def format_message(message: str, history: list, memory_limit: int=3) -> str:

    if len(history) > memory_limit:
        history = history[-memory_limit:]

    global completed_prompt
    completed_prompt = "### System: " + f"{system}" + f"{dbinfo}" + f" {warn}"
    for i, [user_msg, model_answer] in enumerate(history):
       completed_prompt += f"### User: {user_msg} ### Assistant: {model_answer}"
    completed_prompt += " ### User: " + f"{message}"

    return completed_prompt

In [None]:
def get_response(message: str, history: list) -> str:

    query = format_message(message, history)
    response = ""
    sequences = pipe(
        query,
        max_new_tokens=400,
        do_sample=True,
        temperature=0.5,
        top_k=25,
        top_p=0.95,
        use_cache=False)

    generated_text = sequences[0]['generated_text']
    response = generated_text[len(query):]
    print("Chatbot:", response.strip())

    return response.strip()

In [None]:
def tab1_click(a, b, c, d, e, f, g, h):

    try:
        conn = psycopg2.connect(
        host=a,
        port=b,
        user=c,
        password=d,
        database=e
        )
        cursor = conn.cursor(cursor_factory = psycopg2.extras.RealDictCursor)
        if f == '' and g == '':
            cursor.execute(f"""SELECT table_schema_name AS schema_name, table_name, column_name
                               FROM information_schema.columns;
                               """)

        elif f != '' and g == '':
            cursor.execute(f"""SELECT table_schema AS schema_name, table_name, column_namee
                               FROM information_schema.columns
                               WHERE table_schema = '{f}';
                               """)

        elif f == '' and g != '':
            cursor.execute(f"""SELECT table_schema AS schema_name, table_name, column_name
                               FROM information_schema.columns
                               WHERE table_name ='{g}';
                               """)

        elif f != '' and g != '':
            cursor.execute(f"""SELECT table_schema AS schema_name, table_name, column_name
                               FROM information_schema.columns
                               WHERE table_schema = '{f}'
                               AND table_name ='{g}';
                               """)

        global dbinfo
        dbinfo = cursor.fetchall()
        dbinfo = [dict(row) for row in dbinfo]
        dbinfo = str(dbinfo)
        dbinfo += f" additional schema description: {h}"
    except:
        return "Error occurred"

    return "Successfully connected to " + a + " / "+ b + " / "+ c

In [None]:
def tab2_click(a):

    try:
        global dbinfo
        dbinfo = f"{a}"
    except:
        return "Error occurred"

    return "Successfully updated database information."

In [None]:
# torch.cuda.empty_cache()
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
    with gr.Row():
        gr.Markdown("# 🐘 SQL Generator for AgensSQL&PostgreSQL")
    with gr.Row():
        gr.Markdown("""- Convert natural language queries into SQL queries that match your DB.
                    \n - Connect to a DB by entering the DB info of AgensSQL or PostgreSQL family in '**Connect to DB**'.
                    \n - For DBs that don't support connections, you can enter info directly under "**Enter DB info manually**".
                    """)
    with gr.Row():
        with gr.Column(scale=1):
            with gr.Tab("Connect to DB"):
                host = gr.Textbox(value="175.125.**.***", label="Host", lines=1)
                port = gr.Textbox(value="5432", label="Port", lines=1)
                user = gr.Textbox(value="postgres", label="User", lines=1)
                pwd = gr.Textbox(value="password", label="Password", lines=1, type="password")
                db = gr.Textbox(value="agchat", label="Database", lines=1)
                scm = gr.Textbox(value="agchat", label="Schema(optional)", lines=1)
                tbl = gr.Textbox(value="tb_chat", label="Table(optional)", lines=1)
                add = gr.Textbox(value="This is the log table of AgensDesk app.", label="Description(optional)", lines=1)
                info = gr.Textbox(value="", label="Status", lines=2)
                btn = gr.Button(value="Connect")
                btn.click(tab1_click, inputs=[host, port, user, pwd, db, scm, tbl, add], outputs=[info])
            with gr.Tab("Enter DB info manually"):
                input_dbinfo = gr.Textbox(value="I have table named 'test', it has 3 columns, 'student', 'subect', 'score'", label="Description", lines=16)
                info = gr.Textbox(value="", label="Status", lines=2)
                btn2 = gr.Button(value="Enter")
                btn2.click(tab2_click, inputs=[input_dbinfo], outputs=[info])
        with gr.Column(scale=3):
            chat = gr.ChatInterface(get_response,
                                    retry_btn='Retry',
                                    undo_btn='Undo',
                                    clear_btn='Clear ',
                                    submit_btn='Submit ')
demo.launch(share=True)
# demo.launch()