In [1]:
## Load environment variables
import os
from dotenv import load_dotenv

load_dotenv()
assert os.environ['LANGCHAIN_API_KEY'], "Please set the LANGCHAIN_API_KEY environment variable"
assert os.environ['GROQ_API_KEY'], "Please set the GROQ_API_KEY environment variable"
assert os.environ['OPENAI_API_KEY'], "Please set the OPENAI_API_KEY environment variable"

DATA_DIR = "./data"
DATA_CSV_PATH = DATA_DIR + "/Fixed_Synthetic_Data.csv"

In [2]:
## Read CSV file with LangChain
from langchain_community.document_loaders.csv_loader import CSVLoader

loader = CSVLoader(file_path=DATA_CSV_PATH)

data = loader.load()

In [17]:
## Convert CSV file to SQL file
import pandas as pd
from sqlalchemy import create_engine, MetaData, Table, text

sqlite_db_path = DATA_DIR + "/orders.db"
engine = create_engine(f"sqlite:///{sqlite_db_path}")
df = pd.read_csv(DATA_CSV_PATH)

with engine.connect() as conn:
    rows = df.to_sql(name="orders", con=engine, if_exists="replace", index=False)
    print(f"Inserted {rows} rows into the orders table")

    ## Check database
    table = Table('orders', MetaData(), autoload_with=engine)
    print(f"Columns in table '{table.name}':")
    print(table.columns.values())

    rows = conn.execute(text("SELECT * FROM orders LIMIT 5")).fetchall()
    print(f"Sample rows in table '{table.name}':")
    for row in rows:
        print(row)

engine.dispose()

Inserted 5000 rows into the orders table
Columns in table 'orders':
[Column('Instance', TEXT(), table=<orders>), Column('OrderNo', BIGINT(), table=<orders>), Column('ParentOrderNo', BIGINT(), table=<orders>), Column('RootParentOrderNo', BIGINT(), table=<orders>), Column('CreateDate', TEXT(), table=<orders>), Column('DeleteDate', TEXT(), table=<orders>), Column('AccID', BIGINT(), table=<orders>), Column('AccCode', TEXT(), table=<orders>), Column('BuySell', TEXT(), table=<orders>), Column('Side', BIGINT(), table=<orders>), Column('OrderSide', TEXT(), table=<orders>), Column('SecID', BIGINT(), table=<orders>), Column('SecCode', TEXT(), table=<orders>), Column('Exchange', TEXT(), table=<orders>), Column('Destination', TEXT(), table=<orders>), Column('Quantity', BIGINT(), table=<orders>), Column('PriceMultiplier', FLOAT(), table=<orders>), Column('Price', FLOAT(), table=<orders>), Column('Value', FLOAT(), table=<orders>), Column('ValueMultiplier', FLOAT(), table=<orders>), Column('DoneVolum

In [18]:
## Initialize LLM and SQL toolkit
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit

groq_llm = ChatGroq(model="llama3-8b-8192")
openai_llm = ChatOpenAI(model="gpt-4o-mini", api_key=os.environ['OPENAI_API_KEY'])
db = SQLDatabase.from_uri(f"sqlite:///{sqlite_db_path}")
print(db.get_usable_table_names())

toolkit = SQLDatabaseToolkit(db=db, llm=openai_llm)
tools = toolkit.get_tools()
print(tools)

['orders']
[QuerySQLDataBaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x0000010803C57090>), InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x0000010803C57090>), ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x0000010803C57090>), QuerySQLCheckerTool(description='Use this 

In [19]:
## Initialize agent
from langgraph.prebuilt import create_react_agent
from langchain import hub

## Prompt template for SQL agent, which contains rules for generating required SQL queries
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
system_message = prompt_template.format(dialect="SQLite", top_k=5)

agent_executor = create_react_agent(
    model=openai_llm, tools=tools, state_modifier=system_message
)

In [20]:
## Issue query
example_query = "Which order had the highest quantity?"

events = agent_executor.stream(
    {"messages": [("user", example_query)]},
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()


Which order had the highest quantity?
Tool Calls:
  sql_db_list_tables (call_xAbB8XDMFjvrlxwNFje1Kyyh)
 Call ID: call_xAbB8XDMFjvrlxwNFje1Kyyh
  Args:
Name: sql_db_list_tables

orders
Tool Calls:
  sql_db_schema (call_0jrnpjgXSBBJ4stqtM1dFSWe)
 Call ID: call_0jrnpjgXSBBJ4stqtM1dFSWe
  Args:
    table_names: orders
Name: sql_db_schema


CREATE TABLE orders (
	"Instance" TEXT, 
	"OrderNo" BIGINT, 
	"ParentOrderNo" BIGINT, 
	"RootParentOrderNo" BIGINT, 
	"CreateDate" TEXT, 
	"DeleteDate" TEXT, 
	"AccID" BIGINT, 
	"AccCode" TEXT, 
	"BuySell" TEXT, 
	"Side" BIGINT, 
	"OrderSide" TEXT, 
	"SecID" BIGINT, 
	"SecCode" TEXT, 
	"Exchange" TEXT, 
	"Destination" TEXT, 
	"Quantity" BIGINT, 
	"PriceMultiplier" FLOAT, 
	"Price" FLOAT, 
	"Value" FLOAT, 
	"ValueMultiplier" FLOAT, 
	"DoneVolume" BIGINT, 
	"DoneValue" BIGINT, 
	"Currency" TEXT, 
	"OrderType" BIGINT, 
	"PriceInstruction" TEXT, 
	"TimeInForce" BIGINT, 
	"Lifetime" TEXT, 
	"ClientOrderID" FLOAT, 
	"SecondaryClientOrderID" FLOAT, 
	"DestOrde

In [None]:
## Parse output