In [73]:
import os
import csv
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import CharacterTextSplitter
from langchain_ai21 import AI21Embeddings
from langchain_community.docstore.document import Document
from langchain_community.vectorstores import Chroma
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain_ai21 import AI21LLM
from langchain_openai import ChatOpenAI
import pandas as pd
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents import AgentExecutor
from langchain_fireworks import ChatFireworks
from langchain.prompts import PromptTemplate
from langchain_core.prompts.base import BasePromptTemplate
from math import isnan
import config

In [74]:
os.environ["AI21_API_KEY"] = config.AI21_API_KEY
os.environ["GOOGLE_API_KEY"] = config.GOOGLE_API_KEY
os.environ["OPENAI_API_KEY"] = config.OPENAI_API_KEY
os.environ["NVIDIA_API_KEY"] = config.NVIDIA_API_KEY
os.environ["HUGGINGFACEHUB_API_TOKEN"] = config.HUGGINGFACEHUB_API_TOKEN
os.environ["FIREWORKS_API_KEY"] = config.FIREWORKS_API_KEY

In [3]:
columns_to_metadata = ["Gender", "Age", "File Start", "Start time", "End time", "Channel names", "Comment"]

In [4]:
csv_files = os.listdir("./csv")

In [5]:
dataframe = []

for csv_file in csv_files:
    df = pd.read_csv(os.path.join("./csv", csv_file))
    for index, row in df.iterrows():
        if type(row["Gender"]) == str:
            gender = row["Gender"]
        else:
            row["Gender"] = gender 

        if type(row["Age"]) == str:
            age = row["Age"]
        else:
            row["Age"] = age 
        
        if type(row["File Start"]) == str:
            file_start = row["File Start"]
            del row["File Start"]
            row["file_start"] = file_start
        else:
            row["file_start"] = file_start
            del row['File Start']

        row['start_time'] = row['Start time']
        del row['Start time']

        row['end_time'] = row['End time']
        del row['End time']

        row['channel_names'] = row['Channel names']
        del row['Channel names']

        row['file_name'] = csv_file

        dataframe.append(row)

df = pd.DataFrame(dataframe)

In [6]:
df.head()

Unnamed: 0,Gender,Age,Comment,file_start,start_time,end_time,channel_names,file_name
0,Male,12 years,spike and wave,13:10:31,13:11:36:739,13:11:38:478,FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5...,1004.csv
1,Male,12 years,spike and wave,13:10:31,13:14:29:034,13:14:31,FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5...,1004.csv
2,Male,12 years,spike and wave,13:10:31,13:16:20,13:16:21:982,FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5...,1004.csv
3,Male,12 years,spike and wave,13:10:31,13:17:19:034,13:17:19:965,FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5...,1004.csv
4,Male,12 years,spike and wave,13:10:31,13:17:51:026,13:17:53:295,FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5...,1004.csv


In [7]:
engine = create_engine("sqlite:///eeg.db")
df.to_sql("eeg", engine, index=False)

9893

In [8]:
db = SQLDatabase(engine=engine)
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM eeg WHERE start_time == '13:14:29:034';")

sqlite
['eeg']


"[('Male', '12 years', 'spike and wave', '13:10:31', '13:14:29:034', '13:14:31', 'FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5 T6 FZ PZ CZ', '1004.csv')]"

In [90]:
llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest")

In [17]:
llm = AI21LLM(model="j2-ultra")

In [88]:
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)


In [28]:
llm = ChatNVIDIA(model="meta/llama3-70b-instruct")

In [None]:
llm=HuggingFaceEndpoint(repo_id="google/flan-t5-xxl")


In [9]:
llm = ChatFireworks(model="accounts/fireworks/models/llama-v3p1-70b-instruct")

In [10]:
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

In [11]:
tools = toolkit.get_tools()

In [62]:
default_prompt = "You are database agent to help write queries and return answers. The table that we are concerned with is eeg and it contains records of csv files comprising of eeg recordings of different patients. Each file is characterized the by the file_name column and a file can have multiple entries in the database table. But in all of those entries of a single fie, the gender, age and file_start columns will have the same value because the recording is of the same patient. Each patient is characterized by their gender and age. An event is a single entry (row) in the database and can only belong to one file but a file can contain multiple events. An event begins at start_time and ends at end_time. Comment shows the type of the event and channel names show the channels in which the event was observed. If there are multiple names in the channel_names column but one of them is 'FP1' then its safe to say that the event was observed in FP1, same goes for all other channels. The age attribute is not an integer so you can't do comparisons straightaway instead you will have to manipulate the column to get the age value and then compare. If a file is referred to by only an integer, you should assume that .csv is there in the extension of the file name. The timestamps are strings, so to compare them in terms of hour, minute, or seconds, or milliseconds, or as a whole, you should take information from them accordingly, and for comparison you should cast them to integers. They are in the format hh:mm:ss:ms. Unless very necessary, you should not use LIMIT in your queries, especially when the results ask for 'all'. When asked for file names or instances, you should almost always use DISTINCT in your queries with file_name because a file can contain more than one entries and may appear more than once in the database"

In [69]:
default_prompt = "You are a database agent for writing queries and returning answers. The table is eeg, containing CSV records of EEG recordings from different patients. Each file (identified by file_name) can have multiple entries with consistent gender, age, and file_start values. Events are single entries with a start_time, end_time, comment, and channel_names. If channel_names includes 'FP1', it implies the event was observed in FP1, and similarly for other channels. age is a non-integer string requiring conversion for comparisons. Subjects are characterized by age and gender, not filenames. File names referred to by an integer should be assumed to have a .csv extension. Timestamps in hh:mm:ss:ms format need to be cast to integers for comparisons. Use DISTINCT with file_name to handle files with multiple entries. Avoid using LIMIT unless absolutely necessary."

In [None]:
prompt_template = BasePromptTemplate(
    template=(
        default_prompt + "\n\n"
        "Available tools: {tools}\n\n"
        "Available tools: {tool_names}\n\n"
        "{agent_scratchpad}\n\n"
        "Question: {input}\n\n"
        "Answer:"
    ),
    input_variables=["input", "tools", "tool_names", "agent_scratchpad"]
)

In [39]:
agent = create_sql_agent(llm, db=db, agent_type = 'zero-shot-react-description', verbose=True, handle_parsing_errors = True)

In [67]:
query = "Name all the subjects of the recordings"

In [71]:
result = agent.invoke({"input": f"{default_prompt} query: {query}"})
result["output"]



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mThought: I should look at the tables in the database to see what I can query.  Then I should query the schema of the most relevant tables.
Action: sql_db_list_tables
Action Input: [0m[38;5;200m[1;3meeg[0m[32;1m[1;3mAction: sql_db_schema
Action Input: eeg[0m[33;1m[1;3m
CREATE TABLE eeg (
	"Gender" TEXT, 
	"Age" TEXT, 
	"Comment" TEXT, 
	file_start TEXT, 
	start_time TEXT, 
	end_time TEXT, 
	channel_names TEXT, 
	file_name TEXT
)

/*
3 rows from eeg table:
Gender	Age	Comment	file_start	start_time	end_time	channel_names	file_name
Male	12 years	spike and wave	13:10:31	13:11:36:739	13:11:38:478	FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5 T6 FZ PZ CZ	1004.csv
Male	12 years	spike and wave	13:10:31	13:14:29:034	13:14:31	FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5 T6 FZ PZ CZ	1004.csv
Male	12 years	spike and wave	13:10:31	13:16:20	13:16:21:982	FP1 FP2 F3 F4 C3 C4 P3 P4 O1 O2 F7 F8 T3 T4 T5 T6 FZ PZ CZ	1004.csv
*/[

"The subjects of the recordings are: \n('Male', '12 years'), \n('Female', '11 years'), \n('Male', '1 year'), \n('Male', '17 Years'), \n('Male', '22 years'), \n('Female', '23 years'), \n('Male', '3 years'), \n('Male', '8 years'), \n('Female', '18 years'), \n('Male', '11 years'), \n('Male', '10 years'), \n('Male', '42 years'), \n('Male', '15 years'), \n('Male', '35 years'), \n('Female', '67 years'), \n('Female ', '5 years'), \n('Female', '28 years'), \n('Male', '3.5 years'), \n('Male', '7 years'), \n('Male', '21 years'), \n('Female', '16 years'), \n('Male', '5 years'), \n('Male', '6 years'), \n('Female', '28 Years'), \n('Male', '17 years'), \n('Male', '14 years'), \n('Male', '45 years'), \n('Male', '46 years'), \n('Female', '1 year'), \n('Male', '2 years'), \n('Male', '4.5 Years'), \n('Male', '39 years'), \n('Male', '31 years'), \n('Female', '44 years'), \n('Female', '50 years'), \n('Female', '7 years'), \n('Female', '5 years'), \n('Female', '5.5 years'), \n('Male', '4 years'), \n('Femal