In [48]:
import os
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine
from langchain_community.agent_toolkits.sql.base import create_sql_agent
import pandas as pd
from langchain_community.chat_message_histories import RedisChatMessageHistory
from langchain_fireworks import ChatFireworks
from langchain.prompts import PromptTemplate
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.memory import ChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate
import tkinter as tk
from tkinter import scrolledtext, messagebox
from langchain.memory import ConversationBufferMemory
import config
from langsmith import Client

In [49]:
os.environ["AI21_API_KEY"] = config.AI21_API_KEY
os.environ["GOOGLE_API_KEY"] = config.GOOGLE_API_KEY
os.environ["OPENAI_API_KEY"] = config.OPENAI_API_KEY
os.environ["NVIDIA_API_KEY"] = config.NVIDIA_API_KEY
os.environ["HUGGINGFACEHUB_API_TOKEN"] = config.HUGGINGFACEHUB_API_TOKEN
os.environ["FIREWORKS_API_KEY"] = config.FIREWORKS_API_KEY
os.environ["LANGCHAIN_API_KEY"] = config.LANGCHAIN_API_KEY

In [50]:
os.environ["LANGCHAIN_TRACING_V2"]="true"
os.environ["LANGCHAIN_ENDPOINT"]="https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"]=config.LANGCHAIN_API_KEY
os.environ["LANGCHAIN_PROJECT"]="pr-giving-defeat-64"

In [51]:
client = Client()

In [37]:
csv_files = os.listdir("./csv")

In [None]:
dataframe = []

for csv_file in csv_files:
    df = pd.read_csv(os.path.join("./csv", csv_file))
    for index, row in df.iterrows():
        if type(row["Gender"]) == str:
            gender = row["Gender"]
        else:
            row["Gender"] = gender 

        if type(row["Age"]) == str:
            age = float(row["Age"].split(' ')[0])
            row["Age"] = age 
        else:
            row["Age"] = age 
        
        if type(row["File Start"]) == str:
            file_start = row["File Start"]
            del row["File Start"]
            row["file_start"] = file_start
        else:
            row["file_start"] = file_start
            del row['File Start']

        row['start_time'] = row['Start time']
        del row['Start time']

        row['end_time'] = row['End time']
        del row['End time']

        row['channel_names'] = row['Channel names']
        del row['Channel names']

        row['file_name'] = csv_file

        dataframe.append(row)

df = pd.DataFrame(dataframe)

In [52]:
engine = create_engine("sqlite:///eeg.db")
# df.to_sql("eeg", engine, index=False)

In [53]:
db = SQLDatabase(engine=engine)
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM eeg WHERE start_time == '13:14:29:034';")

sqlite
['eeg']


"[('Male', 12.0, 'spike and wave', '13:10:31', '13:14:29:034', '13:14:31', 'FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5 T6 FZ PZ CZ', '1004.csv')]"

In [54]:
llm = ChatFireworks(model="accounts/fireworks/models/llama-v3p1-405b-instruct")

In [41]:
default_prompt = "You are a database agent for writing queries and returning answers. The table is eeg, containing CSV records of EEG recordings from different patients. Each file (identified by file_name) can have multiple entries with consistent gender, age, and file_start values. Events are single entries with a start_time, end_time, comment, and channel_names. If channel_names includes 'FP1', it implies the event was observed in FP1, and similarly for other channels. Ages are floats type. Subjects are characterized by age and gender, not filenames. File names referred to by an integer should be assumed to have a .csv extension. Timestamps in hh:mm:ss:ms format need to be cast to integers for comparisons. Use DISTINCT with file_name to handle files with multiple entries. You should also look for message history to be aware of the context completely."

In [55]:
memory = ChatMessageHistory(session_id="test-session")

In [61]:
agent_executor = create_sql_agent(llm, db=db, agent_type="zero-shot-react-description", verbose=True, use_history = True)

In [62]:
agent_with_chat_history = RunnableWithMessageHistory(
    agent_executor,
    # This is needed because in most real world scenarios, a session id is needed
    # It isn't really used here because we are using a simple in memory ChatMessageHistory
    lambda session_id: memory,
    input_messages_key="input",
    history_messages_key="chat_history",
)

In [63]:
def execute_query():
    query = query_text.get("1.0", tk.END).strip()
    if not query:
        messagebox.showwarning("Input Error", "Please enter a query.")
        return
    
    try:
        result = agent_executor.invoke(
            {"input": f"{query}"},
            config={"configurable": {"session_id": "<foo>"}})
        output = result.get("output", "No result found")
        result_text.config(state=tk.NORMAL)
        result_text.delete("1.0", tk.END)
        result_text.insert(tk.END, output)
        result_text.config(state=tk.DISABLED)
    except Exception as e:
        messagebox.showerror("Execution Error", str(e))

In [64]:
# Create the main window
root = tk.Tk()
root.title("SQL Agent GUI")

# Create and place widgets
query_label = tk.Label(root, text="Enter SQL Query:")
query_label.pack(pady=5)

query_text = scrolledtext.ScrolledText(root, width=80, height=10)
query_text.pack(pady=5)

execute_button = tk.Button(root, text="Execute Query", command=execute_query)
execute_button.pack(pady=5)

result_label = tk.Label(root, text="Query Result:")
result_label.pack(pady=5)

result_text = scrolledtext.ScrolledText(root, width=80, height=15, state=tk.DISABLED)
result_text.pack(pady=5)


In [65]:
root.mainloop()



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mHere's my response:

Action: sql_db_list_tables
Action Input: [0m[38;5;200m[1;3meeg[0m[32;1m[1;3mWith the list of tables, I can now query the schema of the most relevant table, which is "eeg".

Action: sql_db_schema
Action Input: eeg[0m[33;1m[1;3m
CREATE TABLE eeg (
	"Gender" TEXT, 
	"Age" FLOAT, 
	"Comment" TEXT, 
	file_start TEXT, 
	start_time TEXT, 
	end_time TEXT, 
	channel_names TEXT, 
	file_name TEXT
)

/*
3 rows from eeg table:
Gender	Age	Comment	file_start	start_time	end_time	channel_names	file_name
Male	12.0	spike and wave	13:10:31	13:11:36:739	13:11:38:478	FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5 T6 FZ PZ CZ	1004.csv
Male	12.0	spike and wave	13:10:31	13:14:29:034	13:14:31	FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5 T6 FZ PZ CZ	1004.csv
Male	12.0	spike and wave	13:10:31	13:16:20	13:16:21:982	FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5 T6 FZ PZ CZ	1004.csv
*/[0m[32;1m[1;3mI can now constru