In [1]:
import os
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine
from langchain.agents.agent_toolkits import create_sql_agent,SQLDatabaseToolkit
import pandas as pd
from langchain_fireworks import ChatFireworks
from langchain.prompts import PromptTemplate
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.memory import ChatMessageHistory
import tkinter as tk
from tkinter import scrolledtext, messagebox
import config
from langsmith import Client


In [2]:
os.environ["AI21_API_KEY"] = config.AI21_API_KEY
os.environ["GOOGLE_API_KEY"] = config.GOOGLE_API_KEY
os.environ["OPENAI_API_KEY"] = config.OPENAI_API_KEY
os.environ["NVIDIA_API_KEY"] = config.NVIDIA_API_KEY
os.environ["HUGGINGFACEHUB_API_TOKEN"] = config.HUGGINGFACEHUB_API_TOKEN
os.environ["FIREWORKS_API_KEY"] = config.FIREWORKS_API_KEY
os.environ["LANGCHAIN_API_KEY"] = config.LANGCHAIN_API_KEY

In [3]:
os.environ["LANGCHAIN_TRACING_V2"]="true"
os.environ["LANGCHAIN_ENDPOINT"]="https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"]=config.LANGCHAIN_API_KEY
os.environ["LANGCHAIN_PROJECT"]="pr-giving-defeat-64"

In [4]:
client = Client()

In [5]:
csv_files = os.listdir("./csv")

In [6]:
dataframe = []

for csv_file in csv_files:
    df = pd.read_csv(os.path.join("./csv", csv_file))
    for index, row in df.iterrows():
        if type(row["Gender"]) == str:
            gender = row["Gender"].lower()
            row["Gender"] = gender
        else:
            row["Gender"] = gender 

        if type(row["Age"]) == str:
            age = float(row["Age"].split(' ')[0])
            row["Age"] = age 
        else:
            row["Age"] = age 
        
        if type(row["File Start"]) == str:
            file_start = row["File Start"]
            del row["File Start"]
            row["file_start"] = file_start
        else:
            row["file_start"] = file_start
            del row['File Start']

        row['start_time'] = row['Start time']
        del row['Start time']

        row['end_time'] = row['End time']
        del row['End time']

        row['channel_names'] = row['Channel names']
        del row['Channel names']

        row['file_name'] = csv_file

        dataframe.append(row)

df = pd.DataFrame(dataframe)

In [7]:
engine = create_engine("sqlite:///eeg.db")
df.to_sql("eeg", engine, index=False)

9893

In [9]:
db = SQLDatabase(engine=engine)
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT DISTINCT file_name FROM eeg WHERE gender == 'female';")

sqlite
['eeg']


"[('1018.csv',), ('1029.csv',), ('1064.csv',), ('1090.csv',), ('127.csv',), ('243.csv',), ('250.csv',), ('266.csv',), ('289.csv',), ('353.csv',), ('354.csv',), ('472.csv',), ('494.csv',), ('507.csv',), ('510.csv',), ('545.csv',), ('573.csv',), ('577.csv',), ('586.csv',), ('594.csv',), ('598.csv',), ('599.csv',), ('620.csv',), ('621.csv',), ('623.csv',), ('624.csv',), ('627.csv',), ('676.csv',), ('743.csv',), ('764.csv',), ('765.csv',), ('8.csv',), ('828.csv',), ('875.csv',), ('884.csv',), ('908.csv',)]"

In [10]:
llm = ChatFireworks(api_key = "fw_3ZYESBpJiXVuqDjbxK34KxDw", model="accounts/fireworks/models/llama-v3p1-405b-instruct")

In [11]:
default_prompt = "You are a database agent for writing queries and returning answers. The table is eeg, containing CSV records of EEG recordings from different patients. Each file (identified by file_name) can have multiple entries with consistent gender, age, and file_start values. Events are single entries with a start_time, end_time, comment, and channel_names. If channel_names includes 'FP1', it implies the event was observed in FP1, and similarly for other channels. Ages are floats type. Subjects are characterized by age and gender, not filenames. File names referred to by an integer should be assumed to have a .csv extension. Timestamps in hh:mm:ss:ms format need to be cast to integers for comparisons. Use DISTINCT with file_name to handle files with multiple entries. You should also look for message history to be aware of the context completely."

In [12]:
memory = ChatMessageHistory(session_id="test-session")

In [13]:
template = '''You are a database agent for writing queries and returning answers. The table is eeg, containing CSV records of 
EEG recordings from different patients. Each file (identified by file_name) can have multiple entries with consistent gender, age, 
and file_start values. Events are single entries with a start_time, end_time, comment, and channel_names. 
If channel_names includes 'FP1', it implies the event was observed in FP1, and similarly for other channels. Ages are floats type. 
Subjects are characterized by age and gender, not filenames. File names referred to by an integer should be assumed to have a .csv extension. 
Timestamps in hh:mm:ss:ms format need to be cast to integers for comparisons. Use DISTINCT with file_name to handle files with multiple entries. 
You should also look for message history to be aware of the context completely."

Answer the following questions as best you can. You have access to the following tools:

{tools}

These are the relevant pieces of the previous conversations, use them if needed.
{history}

If there are mutltiple files mentioned in the previous conversations and the input refers to a file but doesn't specify which one, you should always 
assume that it is the most recently mentioned file being talked about

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

Begin!

Question: {input}
Thought:{agent_scratchpad}'''

prompt = PromptTemplate.from_template(template)

In [14]:
toolkit = SQLDatabaseToolkit(llm = llm, db = db)

In [15]:
tools = toolkit.get_tools()

In [16]:
tool_names = [tool.name for tool in tools]

In [17]:
agent_executor = create_sql_agent(llm, db=db, agent_type="zero-shot-react-description", verbose=True, use_history = True, prompt = prompt)

In [18]:
agent_with_chat_history = RunnableWithMessageHistory(
    agent_executor,
    # This is needed because in most real world scenarios, a session id is needed
    # It isn't really used here because we are using a simple in memory ChatMessageHistory
    lambda session_id: memory,
    input_messages_key="input",
    history_messages_key="chat_history",
)

In [19]:
def execute_query():
    query = query_text.get("1.0", tk.END).strip()
    if not query:
        messagebox.showwarning("Input Error", "Please enter a query.")
        return
    
    try:
        result = agent_with_chat_history.invoke(
            {"input": query, "tools": tools, "tool_names": tool_names, "agent_scratchpad": "", "history": memory.messages},
            config={"configurable": {"session_id": "<foo>"}})
        output = result.get("output", "No result found")
        result_text.config(state=tk.NORMAL)
        result_text.delete("1.0", tk.END)
        result_text.insert(tk.END, output)
        result_text.config(state=tk.DISABLED)
    except Exception as e:
        messagebox.showerror("Execution Error", str(e))

In [20]:
# Create the main window
root = tk.Tk()
root.title("SQL Agent GUI")

# Create and place widgets
query_label = tk.Label(root, text="Enter SQL Query:")
query_label.pack(pady=5)

query_text = scrolledtext.ScrolledText(root, width=80, height=10)
query_text.pack(pady=5)

execute_button = tk.Button(root, text="Execute Query", command=execute_query)
execute_button.pack(pady=5)

result_label = tk.Label(root, text="Query Result:")
result_label.pack(pady=5)

result_text = scrolledtext.ScrolledText(root, width=80, height=15, state=tk.DISABLED)
result_text.pack(pady=5)


In [None]:
root.mainloop()