This notebook showcases an agent designed to interact with a sql databases. The agent builds off of SQLDatabaseChain and is designed to answer more general questions about a database, as well as recover from errors.

In [None]:
# Importing required libraries
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.document_loaders import TextLoader, DirectoryLoader
from langchain.text_splitter import CharacterTextSplitter
import sqlite3
import xml.etree.ElementTree as ET
import requests
import pandas as pd
import os

In [None]:
# Configuring OpenAI API key
os.environ["OPENAI_API_KEY"] = "YOUR_OPENAI_API_KEY"

In [None]:
# Loading data from SQLite databases
conn = sqlite3.connect('advisors.db')
advisors_df = pd.read_sql_query("SELECT * FROM advisors", conn)

conn = sqlite3.connect('users.db')
users_df = pd.read_sql_query("SELECT * FROM users", conn)

conn = sqlite3.connect('user_1p_data.db')
user_1p_data_df = pd.read_sql_query("SELECT * FROM user_1p_data", conn)

In [None]:
# Loading data from XML file
tree = ET.parse('IA_FIRM_SEC_Feed_05_13_2024.xml')
root = tree.getroot()
firms_data = []
for firm in root.findall('Firm'):
    info = firm.find('Info')
    firms_data.append({
        'FirmCrdNb': info.attrib['FirmCrdNb'],
        'FirmName': info.attrib['FirmName'],
        'FirmCity': info.attrib['FirmCity'],
        'FirmState': info.attrib['FirmState'],
        'FirmCountry': info.attrib['FirmCountry']
    })
firms_df = pd.DataFrame(firms_data)

In [None]:
# Function to fetch news data from API
def fetch_news_data(query):
    # Add your news API integration code here
    # Example using NewsAPI.org
    url = f"https://newsapi.org/v2/everything?q={query}&apiKey=YOUR_NEWS_API_KEY"
    response = requests.get(url).json()
    return response['articles']

In [None]:
# Function to fetch stock data from API
def fetch_stock_data(ticker):
    # Add your stock API integration code here
    # Example using Polygon.io
    url = f"https://api.polygon.io/v2/aggs/ticker/{ticker}/prev?apiKey=YOUR_POLYGON_API_KEY"
    response = requests.get(url).json()
    return response['results']

In [None]:
# Creating vector store and retrieval QA chain
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = [f"{row['Name']} - {row['Areas of Expertise']}" for _, row in advisors_df.iterrows()]
texts += [f"{row['Name']} - {row['Occupation']} - {row['Investment Focus']}" for _, row in users_df.iterrows()]
texts += [f"{row['Headline/Title']} - {row['Keywords']}" for _, row in user_1p_data_df.iterrows()]
texts += [f"{row['FirmName']} - {row['FirmCity']} - {row['FirmState']} - {row['FirmCountry']}" for _, row in firms_df.iterrows()]

texts = [text_splitter.split_text(text) for text in texts]
texts = [chunk for text in texts for chunk in text]

embeddings = OpenAIEmbeddings()
vectorstore = Chroma.from_texts(texts, embeddings)

qa = RetrievalQA.from_chain_type(
    llm=OpenAI(),
    chain_type="stuff",
    retriever=vectorstore.as_retriever(),
    return_source_documents=True
)

In [None]:
# Interactive chatbot
print("Welcome to the Financial Assistant Chatbot!")
while True:
    query = input("Ask me anything about finance, investments, or stocks: ")
    if query.lower() == "exit":
        break

    if "news" in query.lower():
        articles = fetch_news_data(query)
        for article in articles:
            print(f"Title: {article['title']}")
            print(f"Summary: {article['description']}")
            print(f"Link: {article['url']}")
            print("-" * 50)

    elif any(ticker.upper() in query.upper() for ticker in ['AAPL', 'MSFT', 'AMZN', 'GOOGL', 'FB']):
        tickers = [ticker.upper() for ticker in ['AAPL', 'MSFT', 'AMZN', 'GOOGL', 'FB'] if ticker.upper() in query.upper()]
        for ticker in tickers:
            stock_data = fetch_stock_data(ticker)
            print(f"Stock: {ticker}")
            print(f"Open: {stock_data[0]['o']}")
            print(f"High: {stock_data[0]['h']}")
            print(f"Low: {stock_data[0]['l']}")
            print(f"Close: {stock_data[0]['c']}")
            print(f"Volume: {stock_data[0]['v']}")
            print("-" * 50)

    else:
        result = qa({"query": query})
        print(f"Answer: {result['result']}")
        print("-" * 50)
        print("Sources:")
        for doc in result["source_documents"]:
            print(f"- {doc.metadata['source']}")
        print("-" * 50)

In [6]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.llms.openai import OpenAI
from langchain.agents import AgentExecutor

In [7]:
db = SQLDatabase.from_uri("sqlite:///./Chinook.db")
toolkit = SQLDatabaseToolkit(db=db)

agent_executor = create_sql_agent(
    llm=OpenAI(temperature=0),
    toolkit=toolkit,
    verbose=True
)

# Example: running queries

In [None]:
agent_executor.run("List the total sales per country. Which country's customers spent the most?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: ""[0m
Observation: [38;5;200m[1;3mArtist, Customer, Invoice, Genre, Album, InvoiceLine, MediaType, Track, Employee, Playlist, PlaylistTrack[0m
Thought:[32;1m[1;3m I should check the schema of the relevant tables.
Action: schema_sql_db
Action Input: "Invoice, Customer"[0m
Observation: [33;1m[1;3m
CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId	FirstName	LastName	Company	Address	City	State	Country	Postal

  sample_rows = connection.execute(command)


[32;1m[1;3m I should query the database to get the total sales per country.
Action: query_sql_db
Action Input: SELECT Customer.Country, SUM(Invoice.Total) AS TotalSales FROM Customer INNER JOIN Invoice ON Customer.CustomerId = Invoice.CustomerId GROUP BY Customer.Country ORDER BY TotalSales DESC LIMIT 10[0m
Observation: [36;1m[1;3m[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62)][0m
Thought:[32;1m[1;3m I now know the final answer
Final Answer: The customers from the USA spent the most, with a total of $523.06.[0m

[1m> Finished chain.[0m


'The customers from the USA spent the most, with a total of $523.06.'

In [None]:
agent_executor.run("Show the total number of tracks in each playlist. The Playlist name should be included in the result.")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: ""[0m
Observation: [38;5;200m[1;3mArtist, Customer, Invoice, Genre, Album, InvoiceLine, MediaType, Track, Employee, Playlist, PlaylistTrack[0m
Thought:[32;1m[1;3m I should look at the schema of the Playlist and PlaylistTrack tables to see what columns I can use.
Action: schema_sql_db
Action Input: "Playlist, PlaylistTrack"[0m
Observation: [33;1m[1;3m
CREATE TABLE "Playlist" (
	"PlaylistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("PlaylistId")
)

/*
3 rows from Playlist table:
PlaylistId	Name
1	Music
2	Movies
3	TV Shows
*/


CREATE TABLE "PlaylistTrack" (
	"PlaylistId" INTEGER NOT NULL, 
	"TrackId" INTEGER NOT NULL, 
	PRIMARY KEY ("PlaylistId", "TrackId"), 
	FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
	FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)

/*
3 rows from PlaylistTrack table:
PlaylistId	TrackId
1	3402
1	3389
1	3390
*/[0m


"The total number of tracks in each playlist is: '90’s Music' (1477), 'Brazilian Music' (39), 'Classical' (75), 'Classical 101 - Deep Cuts' (25), 'Classical 101 - Next Steps' (25), 'Classical 101 - The Basics' (25), 'Grunge' (15), 'Heavy Metal Classic' (26), 'Music' (6580), 'Music Videos' (1), 'On-The-Go 1' (1), 'TV Shows' (426)."

# Recovering from an error
In this example, the agent is able to recover from an error after initially trying to access an attribute (Track.ArtistId) which doesn’t exist.

In [None]:
agent_executor.run("Who are the top 3 best selling artists?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: ""[0m
Observation: [38;5;200m[1;3mArtist, Customer, Invoice, Genre, Album, InvoiceLine, MediaType, Track, Employee, Playlist, PlaylistTrack[0m
Thought:[32;1m[1;3m I should check the schema of the Artist and Invoice tables to see what columns I can use.
Action: schema_sql_db
Action Input: "Artist, Invoice"[0m
Observation: [33;1m[1;3m
CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Invoice" (
	"InvoiceId" INTEGER NOT NULL, 
	"CustomerId" INTEGER NOT NULL, 
	"InvoiceDate" DATETIME NOT NULL, 
	"BillingAddress" NVARCHAR(70), 
	"BillingCity" NVARCHAR(40), 
	"BillingState" NVARCHAR(40), 
	"BillingCountry" NVARCHAR(40), 
	"BillingPostalCode" NVARCHAR(10), 
	"Total" NUMERIC(10, 2) NOT NULL, 
	PRIMARY KEY ("InvoiceId"), 
	FOREIGN KEY("Cus

'The top 3 best selling artists are Iron Maiden, U2, and Metallica.'

In [None]:
agent_executor("Which Artist recorded the longest track?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: ""[0m
Observation: [38;5;200m[1;3mArtist, Customer, Invoice, Genre, Album, InvoiceLine, MediaType, Track, Employee, Playlist, PlaylistTrack[0m
Thought:[32;1m[1;3m I should check the schema of the Track table to see what columns I can query.
Action: schema_sql_db
Action Input: "Track"[0m
Observation: [33;1m[1;3m
CREATE TABLE "Track" (
	"TrackId" INTEGER NOT NULL, 
	"Name" NVARCHAR(200) NOT NULL, 
	"AlbumId" INTEGER, 
	"MediaTypeId" INTEGER NOT NULL, 
	"GenreId" INTEGER, 
	"Composer" NVARCHAR(220), 
	"Milliseconds" INTEGER NOT NULL, 
	"Bytes" INTEGER, 
	"UnitPrice" NUMERIC(10, 2) NOT NULL, 
	PRIMARY KEY ("TrackId"), 
	FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"), 
	FOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"), 
	FOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")
)

/*
3 rows from Track table:
TrackId	Name	AlbumId	MediaTypeId	GenreId	Composer	Mill

  sample_rows = connection.execute(command)


[32;1m[1;3m I should query the Track table for the longest track.
Action: query_sql_db
Action Input: SELECT Name, Milliseconds FROM Track ORDER BY Milliseconds DESC LIMIT 10[0m
Observation: [36;1m[1;3m[('Occupation / Precipice', 5286953), ('Through a Looking Glass', 5088838), ('Greetings from Earth, Pt. 1', 2960293), ('The Man With Nine Lives', 2956998), ('Battlestar Galactica, Pt. 2', 2956081), ('Battlestar Galactica, Pt. 1', 2952702), ('Murder On the Rising Star', 2935894), ('Battlestar Galactica, Pt. 3', 2927802), ('Take the Celestra', 2927677), ('Fire In Space', 2926593)][0m
Thought:[32;1m[1;3m I now know the final answer.
Final Answer: The longest track is Occupation / Precipice with a length of 5286953 milliseconds.[0m

[1m> Finished chain.[0m


{'input': 'Which Artist recorded the longest track?',
 'output': 'The longest track is Occupation / Precipice with a length of 5286953 milliseconds.'}

In [14]:
agent_executor("plot the trend in length of tracks and give me the python code of it")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: ""[0m
Observation: [38;5;200m[1;3mInvoice, MediaType, PlaylistTrack, Album, Genre, InvoiceLine, Artist, Playlist, Customer, Track, Employee[0m
Thought:[32;1m[1;3m I should look at the schema of the Track table to see what columns I can use.
Action: schema_sql_db
Action Input: "Track"[0m
Observation: [33;1m[1;3m
CREATE TABLE "Track" (
	"TrackId" INTEGER NOT NULL, 
	"Name" NVARCHAR(200) NOT NULL, 
	"AlbumId" INTEGER, 
	"MediaTypeId" INTEGER NOT NULL, 
	"GenreId" INTEGER, 
	"Composer" NVARCHAR(220), 
	"Milliseconds" INTEGER NOT NULL, 
	"Bytes" INTEGER, 
	"UnitPrice" NUMERIC(10, 2) NOT NULL, 
	PRIMARY KEY ("TrackId"), 
	FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"), 
	FOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"), 
	FOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")
)

/*
3 rows from Track table:
TrackId	Name	AlbumId	MediaTypeId	GenreId	Composer	Mill

{'input': 'plot the trend length of tracks and give me the python code of it',
 'output': "The trend length of tracks can be plotted using the following python code:\n\nimport matplotlib.pyplot as plt\n\nmilliseconds = [1071, 4884, 6373, 6635, 7941, 11650, 21211, 29048, 32287, 33149]\n\nplt.plot(milliseconds)\nplt.title('Trend Length of Tracks')\nplt.xlabel('Track Number')\nplt.ylabel('Milliseconds')\nplt.show()"}