In [1]:
import os
import logging
import pandas as pd
from langchain_ollama.llms import OllamaLLM
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain_ollama import OllamaEmbeddings
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
import faiss


# Define database connection parameters

db_user = "root"
db_password = "q1w2e3r4"
db_host = "localhost"
db_name = "llm_asst"

# get primary key of table
tbl_pk = {"Schedules": "schedule_id", "RecurringSchedules": "recurring_schedule_id", "Memos": "memo_id"}

# models to load

llm = OllamaLLM(model="llama3.1")
embeddings = OllamaEmbeddings(model="nomic-embed-text",)
faiss_db = FAISS.load_local("faiss", embeddings, allow_dangerous_deserialization=True)

semantic_search_k = 3

In [2]:
input_query = "when is the team discussion?"

In [3]:
# Database connection

# db = SQLDatabase.from_uri("sqlite:///nba_roster.db", sample_rows_in_table_info=0)
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")

print(db.dialect)
print(db.get_usable_table_names())


mysql
['Memos', 'RecurringSchedules', 'Schedules', 'Users']


In [15]:
# intent identification

intent_examples = pd.read_csv('intent_fewshot.csv')
intent_list = ", ".join(intent_examples.intent.unique())

# few-shot prompt
fs_intent_template = """
Determine the operation type and table from the user queries below. Use only the following candidates: operation types (create, delete, update, query) and tables (Schedules, RecurringSchedules, Memos).
---
Example:

Query: 'Please remove my appointment from the schedule.'
Intent: delete Schedules
Query: 'Update the weekly team meeting to 11 AM.'
Intent: update RecurringSchedules
Query: 'Create a note for my presentation next week.'
Intent: create Memos
Query: 'Can you show my upcoming events?'
Intent: query Schedules

---
Reply ONE of the following: """+ intent_list + """
If no intent from the list is found, respond with: "Intent not found."
User Query: {input}
Intent:"

"""

intent_messages = [("system", fs_intent_template), ("human", "{input}")]
intent_prompt = ChatPromptTemplate.from_messages(intent_messages)
intent_chain = intent_prompt | llm

intent = intent_chain.invoke({"input": input_query})
intent

"Query: 'when is the team discussion?'\nIntent: query Schedules"

In [16]:
# process the intent

if "\nIntent:" in intent:
    intent = intent.split("\nIntent:")[1].strip()
elif "Intent not found." in intent:
    intent = "Intent not found."
else:
    intent = intent.strip()
    
intent_operation, intent_table = intent.split(" ")

In [17]:
intent_operation

'query'

In [18]:
# semantic search

retrieved_docs = faiss_db.similarity_search(input_query, 
                                            filter={"tbl_name": intent_table},
                                            k=semantic_search_k)

# get the primary key of the retrieved documents in the metadata
retrieved_id = [doc.metadata[tbl_pk[intent_table]] for doc in retrieved_docs]

In [21]:

sys_prompt = """Given an input question, convert it to an SQL query. No preamble. No explanation. Just the SQL query.

The database schema is as follows:
{schema}

"""

human_prompt = """

The original request from user is: {user_query}.
The operation type should be: {intent_operation}.
The table should be: {intent_table}.
The {pk} of the relevant records is: {retrieved_ids}.

Give a SQL statement:

"""

messages = [("system", sys_prompt), ("human", human_prompt)]
prompt = ChatPromptTemplate.from_messages(messages)
chain = prompt | llm

sql_statement = chain.invoke({"schema": db.get_table_info(),"user_query": input_query, 
                "intent_operation": intent_operation, "intent_table": intent_table, 
                "pk": tbl_pk[intent_table], "retrieved_ids": retrieved_id})
sql_statement

'SELECT * FROM `Schedules` WHERE schedule_id IN (5, 2, 1)'

In [22]:
response = db.run(sql_statement)
response

"[(1, 1, 'Project Kickoff', 'Kickoff meeting for the new project with stakeholders.', datetime.datetime(2024, 10, 30, 9, 0), datetime.datetime(2024, 10, 30, 10, 30), datetime.datetime(2024, 10, 23, 9, 43, 49)), (2, 1, 'Lunch with Client', 'Discuss project requirements over lunch.', datetime.datetime(2024, 10, 31, 12, 0), datetime.datetime(2024, 10, 31, 13, 0), datetime.datetime(2024, 10, 23, 9, 43, 49)), (5, 1, 'Weekly Team Meeting', 'Discuss project updates and blockers.', datetime.datetime(2024, 11, 2, 10, 0), datetime.datetime(2024, 11, 2, 11, 0), datetime.datetime(2024, 10, 23, 9, 43, 49))]"

In [23]:
sys_prompt2 = """

You are a personal assistant for task management. 
User's Question is: {question}

You have access to a database with the following schema:
    Table Schema: {schema}
You have checked the database by running the following SQL query:
    SQL Query: {query}
You have received the following response:
    SQL Response: {response}
    
Write a natural language answer to the user's question based on the information above.
(if there are multiple records, please provide a summary and introduce each record briefly)
"""

human_prompt2 = """
    User's Question is: {question}
"""

messages2 = [("system", sys_prompt2), ("human", human_prompt2)]
prompt2 = ChatPromptTemplate.from_messages(messages2)
chain2 = prompt2 | llm


final_response = chain2.invoke({"schema": db.get_table_info([intent_table]),"question": input_query, 
                "query": sql_statement, "response": response})
print(final_response)

The team discussion you're referring to is actually scheduled for this Friday. There is one record in our schedules database that matches your question. The "Weekly Team Meeting" is scheduled from 10am to 11am on November 2nd.

Would you like me to remind you about it or add any other tasks related to the team discussion?


In [28]:
# gradio for web interface(multi-turn conversation)

import gradio as gr

def chat(user_input, history):
    if not history:
        history = []
    else:
        history = history.split("\n")
    
    intent = intent_chain.invoke({"input": user_input})
    # intent = intent.lower()
    if "intent not found" in intent.lower():
        return llm.invoke(user_input)
        
    elif "\nIntent:" in intent:
        intent = intent.split("\nIntent:")[1].strip()
    else:
        intent = intent.strip()
        
    # intent_operation, intent_table = intent.split(" ")
    intent_operation, intent_table = "query", "Schedules"
    
    retrieved_docs = faiss_db.similarity_search(user_input, 
                                                filter={"tbl_name": intent_table},
                                                k=semantic_search_k)
    retrieved_id = [doc.metadata[tbl_pk[intent_table]] for doc in retrieved_docs]
    
    sql_statement = chain.invoke({"schema": db.get_table_info(),"user_query": user_input,
                                    "intent_operation": intent_operation, "intent_table": intent_table, 
                                    "pk": tbl_pk[intent_table], "retrieved_ids": retrieved_id})
    try:
        response = db.run(sql_statement)
    except:
        response = "No records found."
        
    final_response = chain2.invoke({"schema": db.get_table_info([intent_table]),"question": user_input,
                                    "query": sql_statement, "response": response})
    # history.append(f"User: {user_input}")
    # history.append(f"Assistant: {final_response}")
    
    return final_response
    
    # retrieved_docs = faiss_db.similarity_search(user_input,
    # return response

# chatbot = gr.Chatbot(height=300)
iface = gr.ChatInterface(fn=chat, type="messages")

iface.launch(share=True)

* Running on local URL:  http://127.0.0.1:7864
* Running on public URL: https://3350097a92043f8e5a.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


