In [33]:
import os
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine
from langchain.agents.agent_toolkits import create_sql_agent,SQLDatabaseToolkit
import pandas as pd
from langchain_fireworks import ChatFireworks
from langchain.prompts import PromptTemplate
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.memory import ChatMessageHistory
import tkinter as tk
from tkinter import scrolledtext, messagebox
import config
from langsmith import Client
from datetime import datetime, date

In [4]:
os.environ["FIREWORKS_API_KEY"] = config.FIREWORKS_API_KEY
os.environ["LANGCHAIN_TRACING_V2"]="true"
os.environ["LANGCHAIN_ENDPOINT"]="https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"]=config.LANGCHAIN_API_KEY
os.environ["LANGCHAIN_PROJECT"]="pr-giving-defeat-64"

In [5]:
client = Client()

In [6]:
csv_files = os.listdir("./csv")

In [7]:
def parse_timestamp(timestamp):
    # If timestamp does not contain milliseconds, append ':00'
    if ':' in timestamp and len(timestamp.split(':')) == 3:
        timestamp += ':00'
    
    # Split timestamp into hours, minutes, seconds, and milliseconds
    hours, minutes, seconds, _ = map(int, timestamp.split(':'))

    _, _, _, milliseconds = map(str, timestamp.split(':'))
    
    # Create datetime.time object
    time_obj = datetime.strptime(f"{hours}:{minutes}:{seconds}.{milliseconds}", "%H:%M:%S.%f").time()

    return datetime.combine(date.today(), time_obj)
    
    # return datetime.combine(date.today(), time_obj)

In [15]:
dataframe = []

for csv_file in csv_files:
    df = pd.read_csv(os.path.join("./csv", csv_file))
    for index, row in df.iterrows():
        if type(row["Gender"]) == str:
            gender = row["Gender"].lower()
            row["Gender"] = gender
        else:
            row["Gender"] = gender 

        if type(row["Age"]) == str:
            age = float(row["Age"].split(' ')[0])
            row["Age"] = age 
        else:
            row["Age"] = age 
        
        if type(row["File Start"]) == str:
            file_start = row["File Start"]
            del row["File Start"]
            row["file_start"] = file_start
        else:
            row["file_start"] = file_start
            del row['File Start']

        row['start_time'] = row['Start time']
        del row['Start time']

        row['end_time'] = row['End time']
        del row['End time']

        row['channel_names'] = row['Channel names']
        del row['Channel names']

        row['file_name'] = csv_file

        dataframe.append(row)

df = pd.DataFrame(dataframe)

df['duration'] = (df['end_time'].apply(parse_timestamp) - df['start_time'].apply(parse_timestamp)).dt.total_seconds()


In [8]:
engine = create_engine("sqlite:///eeg.db")
# df.to_sql("eeg", engine, index=False)

In [9]:
db = SQLDatabase(engine=engine)
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT start_time, end_time, duration FROM eeg WHERE gender == 'female' LIMIT 5;")
# db.run("exec sp_help 'eeg';")

sqlite
['eeg']


"[('13:21:33:869', '13:21:35:965', 2.096), ('09:24:38:011', '09:24:44:960', 6.949), ('09:24:45:083', '09:24:45:758', 0.675), ('09:24:45:240', '09:24:45:593', 0.353), ('09:24:45:059', '09:24:45:774', 0.715)]"

In [10]:
llm = ChatFireworks(api_key = "fw_3ZYESBpJiXVuqDjbxK34KxDw", model="accounts/fireworks/models/llama-v3p1-405b-instruct")

In [20]:
default_prompt = "You are a database agent for writing queries and returning answers. The table is eeg, containing CSV records of EEG recordings from different patients. Each file (identified by file_name) can have multiple entries with consistent gender, age, and file_start values. Events are single entries with a start_time, end_time, comment, and channel_names. If channel_names includes 'FP1', it implies the event was observed in FP1, and similarly for other channels. Ages are floats type. Subjects are characterized by age and gender, not filenames. File names referred to by an integer should be assumed to have a .csv extension. Timestamps in hh:mm:ss:ms format need to be cast to integers for comparisons. Use DISTINCT with file_name to handle files with multiple entries. You should also look for message history to be aware of the context completely."

In [11]:
memory = ChatMessageHistory(session_id="test-session")

In [32]:
template = '''You are a database agent for writing queries and returning answers. The table is eeg, containing CSV records of 
EEG recordings from different patients. 

Its schema is {schema}

Each file (identified by file_name) can have multiple entries with consistent gender, age, 
and file_start values. Events/abnormalities are single entries with a start_time, end_time, comment, and channel_names. The event/abnormality type
defined by the comment. If channel_names includes 'FP1', it implies the event was observed in FP1, and similarly for other channels. Ages are floats type. 
Subjects are characterized by age and gender, not filenames. File names referred to by an integer should be assumed to have a .csv extension. 
Timestamps in hh:mm:ss:ms format need to be cast to integers for comparisons. Use DISTINCT with file_name to handle files with multiple entries. 
You should also look for message history to be aware of the context completely."

Answer the following questions as best you can. You have access to the following tools:

{tools}

These are the relevant pieces of the previous conversations, use them if needed.
{history}

If there are mutltiple files mentioned in the previous conversations and the input refers to a file but doesn't specify which one, you should always 
assume that it is the most recently mentioned (either by human or by you) file being talked about

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

Begin!

Question: {input}
Thought:{agent_scratchpad}'''

prompt = PromptTemplate.from_template(template)

In [13]:
toolkit = SQLDatabaseToolkit(llm = llm, db = db)

In [14]:
tools = toolkit.get_tools()

In [15]:
tool_names = [tool.name for tool in tools]

In [16]:
schema = db.table_info

In [23]:
agent_executor = create_sql_agent(llm, db=db, agent_type="zero-shot-react-description", verbose=True, use_history = True, prompt = prompt)

In [24]:
agent_with_chat_history = RunnableWithMessageHistory(
    agent_executor,
    # This is needed because in most real world scenarios, a session id is needed
    # It isn't really used here because we are using a simple in memory ChatMessageHistory
    lambda session_id: memory,
    input_messages_key="input",
    history_messages_key="chat_history",
)

In [25]:
def execute_query():
    query = query_text.get("1.0", tk.END).strip()
    if not query:
        messagebox.showwarning("Input Error", "Please enter a query.")
        return
    
    try:
        result = agent_with_chat_history.invoke(
            {"input": query, "tools": tools, "tool_names": tool_names, "agent_scratchpad": "", "history": memory.messages, 'schema': schema},
            config={"configurable": {"session_id": "<foo>"}})
        output = result.get("output", "No result found")
        result_text.config(state=tk.NORMAL)
        result_text.delete("1.0", tk.END)
        result_text.insert(tk.END, output)
        result_text.config(state=tk.DISABLED)
    except Exception as e:
        messagebox.showerror("Execution Error", str(e))

In [30]:
# Create the main window
root = tk.Tk()
root.title("SQL Agent GUI")

# Create and place widgets
query_label = tk.Label(root, text="Enter SQL Query:")
query_label.pack(pady=5)

query_text = scrolledtext.ScrolledText(root, width=80, height=10)
query_text.pack(pady=5)

execute_button = tk.Button(root, text="Execute Query", command=execute_query)
execute_button.pack(pady=5)

result_label = tk.Label(root, text="Query Result:")
result_label.pack(pady=5)

result_text = scrolledtext.ScrolledText(root, width=80, height=15, state=tk.DISABLED)
result_text.pack(pady=5)


In [31]:
root.mainloop()



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mTo find the gender of the patient in this file, we need to look at the 'Gender' column in the eeg table for the most recently mentioned file.

Action: sql_db_query
Action Input: SELECT DISTINCT Gender FROM eeg WHERE file_name = '1029.csv'[0m[36;1m[1;3m[('female',)][0m[32;1m[1;3mI now know the final answer
Final Answer: The gender of the patient in this file is female.[0m

[1m> Finished chain.[0m


[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mThought: The question is asking for the age of the patient in the most recently mentioned file. Since the most recently mentioned file is not explicitly stated, I need to refer to the previous conversations to determine which file is being referred to.

Action: Review previous conversations
Action Input: None[0mReview previous conversations is not a valid tool, try one of [sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker].[32;1m[1;3

In [28]:
memory.messages

[HumanMessage(content='what is the average duration of all the events in the file 1004?'),
 AIMessage(content='The average duration of all the events in the file 1004 is 1.701857142857143 seconds.'),
 HumanMessage(content='what are the different types of events in this file ?'),
 AIMessage(content="There is only one type of event in this file, which is 'spike and wave'."),
 HumanMessage(content='what is the gender of the patient in this file ?'),
 AIMessage(content='The gender of the patient in this file is male.'),
 HumanMessage(content='how many events are there in this file ?'),
 AIMessage(content='There are 7 events in this file.'),
 HumanMessage(content='what is the age of the patient ?'),
 AIMessage(content='The age of the patient is 12.0 years old.'),
 HumanMessage(content='can you name a file that has the recording of a female patient of age 20 years or more ?'),
 AIMessage(content='1029.csv is a file that has the recording of a female patient of age 20 years or more.'),
 Human