In [78]:
import os
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine
from langchain_community.agent_toolkits.sql.base import create_sql_agent
import pandas as pd
from langchain_community.chat_message_histories import RedisChatMessageHistory
from langchain_fireworks import ChatFireworks
from langchain.prompts import PromptTemplate
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.memory import ChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate
import tkinter as tk
from tkinter import scrolledtext, messagebox
from langchain.memory import ConversationBufferMemory
import config
from langsmith import Client
from langchain_core.prompts.chat import MessagesPlaceholder
from langchain.memory import ConversationBufferWindowMemory
from langchain_ai21 import AI21LLM
from langchain_experimental.sql.base import SQLDatabaseChain

In [79]:
from langchain.agents.agent_toolkits import create_sql_agent,SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.sql_database import SQLDatabase
from langchain.chat_models.openai import ChatOpenAI
from langchain.memory import ConversationBufferMemory

In [80]:
os.environ["AI21_API_KEY"] = config.AI21_API_KEY
os.environ["GOOGLE_API_KEY"] = config.GOOGLE_API_KEY
os.environ["OPENAI_API_KEY"] = config.OPENAI_API_KEY
os.environ["NVIDIA_API_KEY"] = config.NVIDIA_API_KEY
os.environ["HUGGINGFACEHUB_API_TOKEN"] = config.HUGGINGFACEHUB_API_TOKEN
os.environ["FIREWORKS_API_KEY"] = config.FIREWORKS_API_KEY
os.environ["LANGCHAIN_API_KEY"] = config.LANGCHAIN_API_KEY

In [81]:
os.environ["LANGCHAIN_TRACING_V2"]="true"
os.environ["LANGCHAIN_ENDPOINT"]="https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"]=config.LANGCHAIN_API_KEY
os.environ["LANGCHAIN_PROJECT"]="pr-giving-defeat-64"

In [82]:
client = Client()

In [6]:
csv_files = os.listdir("./csv")

In [None]:
dataframe = []

for csv_file in csv_files:
    df = pd.read_csv(os.path.join("./csv", csv_file))
    for index, row in df.iterrows():
        if type(row["Gender"]) == str:
            gender = row["Gender"]
        else:
            row["Gender"] = gender 

        if type(row["Age"]) == str:
            age = float(row["Age"].split(' ')[0])
            row["Age"] = age 
        else:
            row["Age"] = age 
        
        if type(row["File Start"]) == str:
            file_start = row["File Start"]
            del row["File Start"]
            row["file_start"] = file_start
        else:
            row["file_start"] = file_start
            del row['File Start']

        row['start_time'] = row['Start time']
        del row['Start time']

        row['end_time'] = row['End time']
        del row['End time']

        row['channel_names'] = row['Channel names']
        del row['Channel names']

        row['file_name'] = csv_file

        dataframe.append(row)

df = pd.DataFrame(dataframe)

In [83]:
engine = create_engine("sqlite:///eeg.db")
# df.to_sql("eeg", engine, index=False)

In [84]:
db = SQLDatabase(engine=engine)
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM eeg WHERE start_time == '13:14:29:034';")

sqlite
['eeg']


"[('Male', 12.0, 'spike and wave', '13:10:31', '13:14:29:034', '13:14:31', 'FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5 T6 FZ PZ CZ', '1004.csv')]"

In [87]:
llm = ChatFireworks(api_key = "fw_3ZYESBpJiXVuqDjbxK34KxDw", model="accounts/fireworks/models/llama-v3p1-405b-instruct")

In [35]:
default_prompt = "You are a database agent for writing queries and returning answers. The table is eeg, containing CSV records of EEG recordings from different patients. Each file (identified by file_name) can have multiple entries with consistent gender, age, and file_start values. Events are single entries with a start_time, end_time, comment, and channel_names. If channel_names includes 'FP1', it implies the event was observed in FP1, and similarly for other channels. Ages are floats type. Subjects are characterized by age and gender, not filenames. File names referred to by an integer should be assumed to have a .csv extension. Timestamps in hh:mm:ss:ms format need to be cast to integers for comparisons. Use DISTINCT with file_name to handle files with multiple entries. You should also look for message history to be aware of the context completely."

In [112]:
memory = ChatMessageHistory(session_id="test-session")

In [100]:
chat_template = """ Based on the schema given {info} write an executable query for the user input. 
Execute it in the database and get sql results. Make a response to user from sql results based on 
the question. 
Input: "user input"
SQL query: "SQL Query here"
"""

In [101]:
chat_prompt = ChatPromptTemplate.from_messages([
    ('system', chat_template),
    MessagesPlaceholder(variable_name='history'),
    ('human', "{input}")
])

In [115]:
# Assuming you have a list of previous Question-Answer pairs
history_list = [
    "Q: Name a file that has the recording of a male patient. A: 1004.csv",
]

# Convert history_list to a single string
history_str = "\n".join(history_list)


In [147]:
template = '''Answer the following questions as best you can. You have access to the following tools:

{tools}

These are the relevant pieces of the previous conversations, use them if needed.
{history}

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

Begin!

Question: {input}
Thought:{agent_scratchpad}'''

prompt = PromptTemplate.from_template(template)

In [124]:
toolkit = SQLDatabaseToolkit(llm = llm, db = db)

In [141]:
toolkit.dict()

{}

In [136]:
tools = toolkit.get_tools()

In [148]:
tool_names = [tool.name for tool in tools]

In [149]:
tool_names

['sql_db_query', 'sql_db_schema', 'sql_db_list_tables', 'sql_db_query_checker']

In [150]:
# Pass this history into your prompt template
prompt.format(history = history_str, tools = toolkit.get_tools(), input = "What is the age of the patient in this file ?", agent_scratchpad = "", tool_names = tool_names)

'Answer the following questions as best you can. You have access to the following tools:\n\n[QuerySQLDataBaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column \'xxxx\' in \'field list\', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x0000020B05485910>), InfoSQLDatabaseTool(description=\'Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3\', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x0000020B05485910>), ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database

In [151]:
agent_executor = create_sql_agent(llm, db=db, agent_type="zero-shot-react-description", verbose=True, use_history = True, prompt = prompt)

In [152]:
agent_with_chat_history = RunnableWithMessageHistory(
    agent_executor,
    # This is needed because in most real world scenarios, a session id is needed
    # It isn't really used here because we are using a simple in memory ChatMessageHistory
    lambda session_id: memory,
    input_messages_key="input",
    history_messages_key="chat_history",
)

In [158]:
query = "What is the age of the patient in this file ?"

In [160]:
agent_with_chat_history.invoke({'input': query, 'tools': tools, 'tool_names': tool_names, 'agent_scratchpad': "", 'history': history_str},
            config={"configurable": {"session_id": "<foo>"}})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mTo find the age of the patient in the file "1004.csv", I need to first check if there is a table in the database that contains patient information, including age. 

Action: sql_db_list_tables
Action Input: None[0m[38;5;200m[1;3meeg[0m[32;1m[1;3mAction: sql_db_schema
Action Input: eeg[0m[33;1m[1;3m
CREATE TABLE eeg (
	"Gender" TEXT, 
	"Age" FLOAT, 
	"Comment" TEXT, 
	file_start TEXT, 
	start_time TEXT, 
	end_time TEXT, 
	channel_names TEXT, 
	file_name TEXT
)

/*
3 rows from eeg table:
Gender	Age	Comment	file_start	start_time	end_time	channel_names	file_name
Male	12.0	spike and wave	13:10:31	13:11:36:739	13:11:38:478	FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5 T6 FZ PZ CZ	1004.csv
Male	12.0	spike and wave	13:10:31	13:14:29:034	13:14:31	FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5 T6 FZ PZ CZ	1004.csv
Male	12.0	spike and wave	13:10:31	13:16:20	13:16:21:982	FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5 T6 FZ 

{'input': 'What is the age of the patient in this file ?',
 'tools': [QuerySQLDataBaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x0000020B05485910>),
  InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x0000020B05485910>),
  ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x

In [14]:
memory = ConversationBufferMemory(memory_key = 'history' , input_key = 'input')

In [69]:
suffix = """Begin!

Relevant pieces of previous conversation:
{history}
(You do not need to use these pieces of information if not relevant)

Question: {input}
Thought: I should look at the tables in the database to see what I can query.  Then I should query the schema of the most relevant tables.
{agent_scratchpad}
"""

In [77]:
executor = create_sql_agent(
        llm = llm,
        toolkit = SQLDatabaseToolkit(db=db, llm=llm),
        agent_type = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
        input_variables = ["input", "agent_scratchpad","history"],
        suffix = suffix,
        agent_executor_kwargs = {'memory':memory},
        handle_parsing_errors = True
    )

In [153]:
def execute_query():
    query = query_text.get("1.0", tk.END).strip()
    if not query:
        messagebox.showwarning("Input Error", "Please enter a query.")
        return
    
    try:
        result = agent_with_chat_history.invoke(
            {"input": f"{query}"},
            config={"configurable": {"session_id": "<foo>"}})
        output = result.get("output", "No result found")
        result_text.config(state=tk.NORMAL)
        result_text.delete("1.0", tk.END)
        result_text.insert(tk.END, output)
        result_text.config(state=tk.DISABLED)
    except Exception as e:
        messagebox.showerror("Execution Error", str(e))

In [85]:
db = SQLDatabase(engine=engine)

In [20]:
table_info = db.table_info


In [88]:
m1 = ConversationBufferWindowMemory(k=4,return_messages=True)

In [90]:
db_chain = SQLDatabaseChain.from_llm(llm, db,verbose = True)

In [None]:
while True:
    query = input('human:')
    if query != '':
        chat = m1.load_memory_variables({})['history']
        prompt = chat_prompt.format(info=table_info, history=chat, input=query)
        response = db_chain.run(prompt)
        m1.save_context({'input': query}, {'output': response})
    else:
        break

In [1]:
from langchain.chains.llm import LLMChain

In [None]:
memory = ConversationBufferMemory(input_key='input', memory_key="history")
dbchain = SQLDatabaseChain(
        llm_chain=LLMChain(llm=llm, prompt=chat_prompt, memory=memory),
        database=db, 
        verbose=True
    )

In [22]:
query = "How many files are there ?"

In [None]:
memory.load_memory_variables({})['history']

In [None]:
chat = memory.load_memory_variables({})['history']
prompt = chat_prompt.format(info=table_info, history=chat, input=query)

In [None]:
dbchain.run("how many files are there ?")

In [154]:
# Create the main window
root = tk.Tk()
root.title("SQL Agent GUI")

# Create and place widgets
query_label = tk.Label(root, text="Enter SQL Query:")
query_label.pack(pady=5)

query_text = scrolledtext.ScrolledText(root, width=80, height=10)
query_text.pack(pady=5)

execute_button = tk.Button(root, text="Execute Query", command=execute_query)
execute_button.pack(pady=5)

result_label = tk.Label(root, text="Query Result:")
result_label.pack(pady=5)

result_text = scrolledtext.ScrolledText(root, width=80, height=15, state=tk.DISABLED)
result_text.pack(pady=5)


In [155]:
root.mainloop()