In [1]:
from pyprojroot import here
db_path = str(here("data")) + "/sqldb.db"

In [2]:
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Employee LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'Adams', 'Andrew', 'General Manager', None, '1962-02-18 00:00:00', '2002-08-14 00:00:00', '11120 Jasper Ave NW', 'Edmonton', 'AB', 'Canada', 'T5K 2N1', '+1 (780) 428-9482', '+1 (780) 428-3457', 'andrew@chinookcorp.com'), (2, 'Edwards', 'Nancy', 'Sales Manager', 1, '1958-12-08 00:00:00', '2002-05-01 00:00:00', '825 8 Ave SW', 'Calgary', 'AB', 'Canada', 'T2P 2T3', '+1 (403) 262-3443', '+1 (403) 262-3322', 'nancy@chinookcorp.com'), (3, 'Peacock', 'Jane', 'Sales Support Agent', 2, '1973-08-29 00:00:00', '2002-04-01 00:00:00', '1111 6 Ave SW', 'Calgary', 'AB', 'Canada', 'T2P 5M5', '+1 (403) 262-3443', '+1 (403) 262-6712', 'jane@chinookcorp.com'), (4, 'Park', 'Margaret', 'Sales Support Agent', 2, '1947-09-19 00:00:00', '2003-05-03 00:00:00', '683 10 Street SW', 'Calgary', 'AB', 'Canada', 'T2P 5G3', '+1 (403) 263-4423', '+1 (403) 263-4289', 'margaret@chinookcorp.com'), (5, 'Johnson', 'Steve', 'Sales Support Agent', 2, '1965-03-03 00:00:00', '2003-10-17 00:00:00', '7727B 41 Ave', 'Calgar

In [16]:
from sqlalchemy import create_engine, inspect

# Create an engine that connects to the test.db SQLite database
engine = create_engine(f"sqlite:///{db_path}")

# Connect to the database
connection = engine.connect()

# Create an inspector object
inspector = inspect(engine)

# Retrieve the names of all the tables in the database
table_names = inspector.get_table_names()
print("Tables:", table_names)

# Loop over each table to get detailed information like schema, columns, etc.
for table_name in table_names:
    print(f"Information for table: {table_name}")
    
    # Get the schema of the table (for SQLite, schema is often None)
    print(f"Schema: {inspector.get_schema_names()}")
    
    # Get the columns and their attributes for each table
    columns = inspector.get_columns(table_name)
    for column in columns:
        print(f"Column: {column['name']} Type: {column['type']}")
    
    # Additionally, you can use get_pk_constraint and get_foreign_keys 
    # methods to retrieve information about primary and foreign keys respectively
    pk_constraint = inspector.get_pk_constraint(table_name)
    print(f"Primary Key Constraint: {pk_constraint}")

    foreign_keys = inspector.get_foreign_keys(table_name)
    print(f"Foreign Keys: {foreign_keys}")

# Do not forget to close the connection when done
connection.close()

Tables: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
Information for table: Album
Schema: ['main']
Column: AlbumId Type: INTEGER
Column: Title Type: NVARCHAR(160)
Column: ArtistId Type: INTEGER
Primary Key Constraint: {'constrained_columns': ['AlbumId'], 'name': None}
Foreign Keys: [{'name': None, 'constrained_columns': ['ArtistId'], 'referred_schema': None, 'referred_table': 'Artist', 'referred_columns': ['ArtistId'], 'options': {}}]
Information for table: Artist
Schema: ['main']
Column: ArtistId Type: INTEGER
Column: Name Type: NVARCHAR(120)
Primary Key Constraint: {'constrained_columns': ['ArtistId'], 'name': None}
Foreign Keys: []
Information for table: Customer
Schema: ['main']
Column: CustomerId Type: INTEGER
Column: FirstName Type: NVARCHAR(40)
Column: LastName Type: NVARCHAR(20)
Column: Company Type: NVARCHAR(80)
Column: Address Type: NVARCHAR(70)
Column: City Type: NVARCHAR(40)
Column: State T

### Query and QA SQL DB

In [7]:
from langchain_community.utilities import SQLDatabase
from pyprojroot import here
import warnings
warnings.filterwarnings("ignore")

In [8]:
db_path = str(here("data")) + "/sqldb.db"
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")

In [6]:
from dotenv import load_dotenv
import os
print("Environment variables are loaded:", load_dotenv())
print("test by reading a variable:", os.getenv("OPENAI_API_TYPE"))

Environment variables are loaded: True
test by reading a variable: None


In [4]:
import os
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
messages = [
    SystemMessage(content="You're a helpful assistant"),
    HumanMessage(content="hello"),
]

# Initialize the OpenAI client
client = ChatOpenAI(
    openai_api_key=os.getenv("OPENAI_API_KEY"),
    model_name=os.getenv("GPT_MODEL_NAME", "gpt-3.5-turbo"),  # Defaulting to gpt-3.5-turbo if not set
    temperature=0.7  # You can adjust the temperature if needed
)

# Send the chat completion request
response = client.invoke(messages)

print(response)


content='Hello! How can I assist you today?' additional_kwargs={'refusal': None} response_metadata={'token_usage': {'completion_tokens': 9, 'prompt_tokens': 17, 'total_tokens': 26}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None} id='run-a67429fd-06be-4cfc-b4c8-a5c735206ca9-0' usage_metadata={'input_tokens': 17, 'output_tokens': 9, 'total_tokens': 26}


In [9]:
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(client, db)
response = chain.invoke({"question": "How many employees are there"})
print(response)

SELECT COUNT(EmployeeId) as NumEmployees
FROM Employee


In [11]:
db.run(response)

'[(8,)]'

In [12]:
chain.get_prompts()[0].pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

In [10]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

write_query = create_sql_query_chain(client, db)
execute_query = QuerySQLDataBaseTool(db=db)

chain = write_query | execute_query

chain.invoke({"question": "How many employees are there"})

'[(8,)]'

In [18]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

answer = answer_prompt | client | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

chain.invoke({"question": "How many employees are there"})

'There are a total of 8 employees.'

## Agents

In [22]:
from langchain_community.agent_toolkits import create_sql_agent

agent_executor = create_sql_agent(client, db=db, agent_type="openai-tools", verbose=True)

In [23]:
agent_executor.invoke(
    {
        "input": "List the total sales per country. Which country's customers spent the most?"
    }
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'Invoice,Customer'}`


[0m[33;1m[1;3m
CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
1	Luís	Gonçalves	Embraer - Empr

{'input': "List the total sales per country. Which country's customers spent the most?",
 'output': 'The total sales per country are as follows:\n\n1. USA: $523.06\n2. Canada: $303.96\n3. France: $195.10\n4. Brazil: $190.10\n5. Germany: $156.48\n6. United Kingdom: $112.86\n7. Czech Republic: $90.24\n8. Portugal: $77.24\n9. India: $75.26\n10. Chile: $46.62\n\nCustomers from the USA spent the most with a total of $523.06.'}

In [22]:
agent_executor.invoke({"input": "Describe the playlisttrack table"})
# agent_executor.invoke("Describe the playlisttrack table")



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'PlaylistTrack'}`


[0m[33;1m[1;3m
CREATE TABLE "PlaylistTrack" (
	"PlaylistId" INTEGER NOT NULL, 
	"TrackId" INTEGER NOT NULL, 
	PRIMARY KEY ("PlaylistId", "TrackId"), 
	FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
	FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)

/*
3 rows from PlaylistTrack table:
PlaylistId	TrackId
1	3402
1	3389
1	3390
*/[0m[32;1m[1;3mThe `PlaylistTrack` table has the following columns:
- PlaylistId (INTEGER, NOT NULL)
- TrackId (INTEGER, NOT NULL)

It has a composite primary key on the columns PlaylistId and TrackId. Additionally, there are foreign key constraints on the TrackId column referencing the Track table and on the PlaylistId

{'input': 'Describe the playlisttrack table',
 'output': 'The `PlaylistTrack` table has the following columns:\n- PlaylistId (INTEGER, NOT NULL)\n- TrackId (INTEGER, NOT NULL)\n\nIt has a composite primary key on the columns PlaylistId and TrackId. Additionally, there are foreign key constraints on the TrackId column referencing the Track table and on the PlaylistId column referencing the Playlist table.\n\nHere are 3 sample rows from the `PlaylistTrack` table:\n- PlaylistId: 1, TrackId: 3402\n- PlaylistId: 1, TrackId: 3389\n- PlaylistId: 1, TrackId: 3390'}

## Spotify Dataset

In [13]:
import pandas as pd
from pyprojroot import here

In [14]:
df = pd.read_excel(here("data/csv_xlsx/spotify_dataset.xlsx"))
print(df.shape)
print(df.columns.tolist())
display(df.head(3))

(953, 25)
['track_name', 'artist(s)_name', 'artist_count', 'released_year', 'released_month', 'released_day', 'in_spotify_playlists', 'in_spotify_charts', 'streams', 'in_apple_playlists', 'in_apple_charts', 'in_deezer_playlists', 'in_deezer_charts', 'in_shazam_charts', 'bpm', 'key', 'mode', 'danceability_%', 'valence_%', 'energy_%', 'acousticness_%', 'instrumentalness_%', 'liveness_%', 'speechiness_%', 'cover_url']


Unnamed: 0,track_name,artist(s)_name,artist_count,released_year,released_month,released_day,in_spotify_playlists,in_spotify_charts,streams,in_apple_playlists,...,key,mode,danceability_%,valence_%,energy_%,acousticness_%,instrumentalness_%,liveness_%,speechiness_%,cover_url
0,Seven (feat. Latto) (Explicit Ver.),"Latto, Jung Kook",2,2023,7,14,553,147,141381703,43,...,B,Major,80,89,83,31,0,8,4,Not Found
1,LALA,Myke Towers,1,2023,3,23,1474,48,133716286,48,...,C#,Major,71,61,74,7,0,10,4,https://i.scdn.co/image/ab67616d0000b2730656d5...
2,vampire,Olivia Rodrigo,1,2023,6,30,1397,113,140003974,94,...,F,Major,51,32,53,17,0,31,6,https://i.scdn.co/image/ab67616d0000b273e85259...


In [25]:
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine
db_path = str(here("data")) + "/csv_xlsx_sqldb.db"
db_path = f"sqlite:///{db_path}"

In [26]:
db = SQLDatabase(engine=engine)
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM spotify_dataset;")

sqlite
['spotify', 'spotify_dataset']


'[(\'Seven (feat. Latto) (Explicit Ver.)\', \'Latto, Jung Kook\', 2, 2023, 7, 14, 553, 147, \'141381703\', 43, 263, 45, 10, 826.0, 125, \'B\', \'Major\', 80, 89, 83, 31, 0, 8, 4, \'Not Found\'), (\'LALA\', \'Myke Towers\', 1, 2023, 3, 23, 1474, 48, \'133716286\', 48, 126, 58, 14, 382.0, 92, \'C#\', \'Major\', 71, 61, 74, 7, 0, 10, 4, \'https://i.scdn.co/image/ab67616d0000b2730656d5ce813ca3cc4b677e05\'), (\'vampire\', \'Olivia Rodrigo\', 1, 2023, 6, 30, 1397, 113, \'140003974\', 94, 207, 91, 14, 949.0, 138, \'F\', \'Major\', 51, 32, 53, 17, 0, 31, 6, \'https://i.scdn.co/image/ab67616d0000b273e85259a1cae29a8d91f2093d\'), (\'Cruel Summer\', \'Taylor Swift\', 1, 2019, 8, 23, 7858, 100, \'800840817\', 116, 207, 125, 12, 548.0, 170, \'A\', \'Major\', 55, 58, 72, 11, 0, 11, 15, \'https://i.scdn.co/image/ab67616d0000b273e787cffec20aa2a396a61647\'), (\'WHERE SHE GOES\', \'Bad Bunny\', 1, 2023, 5, 18, 3133, 50, \'303236322\', 84, 133, 87, 15, 425.0, 144, \'A\', \'Minor\', 65, 23, 80, 14, 63, 11,

In [27]:
from langchain_community.agent_toolkits import create_sql_agent
agent_executor = create_sql_agent(client, db=db, agent_type="openai-tools", verbose=True)

In [28]:
agent_executor.invoke({"input": "Hi, can you provide the top 5 tracks of an artist 'Taylor swift?"})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mspotify, spotify_dataset[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'spotify'}`


[0m[33;1m[1;3m
CREATE TABLE spotify (
	track_name TEXT, 
	"artist(s)_name" TEXT, 
	artist_count BIGINT, 
	released_year BIGINT, 
	released_month BIGINT, 
	released_day BIGINT, 
	in_spotify_playlists BIGINT, 
	in_spotify_charts BIGINT, 
	streams TEXT, 
	in_apple_playlists BIGINT, 
	in_apple_charts BIGINT, 
	in_deezer_playlists BIGINT, 
	in_deezer_charts BIGINT, 
	in_shazam_charts FLOAT, 
	bpm BIGINT, 
	"key" TEXT, 
	mode TEXT, 
	"danceability_%" BIGINT, 
	"valence_%" BIGINT, 
	"energy_%" BIGINT, 
	"acousticness_%" BIGINT, 
	"instrumentalness_%" BIGINT, 
	"liveness_%" BIGINT, 
	"speechiness_%" BIGINT, 
	cover_url TEXT
)

/*
3 rows from spotify table:
track_name	artist(s)_name	artist_count	released_year	released_month	released_day	in_spotify_playlists	in_

{'input': "Hi, can you provide the top 5 tracks of an artist 'Taylor swift?",
 'output': "The top 5 tracks of the artist 'Taylor Swift' are as follows:\n1. Track: Anti-Hero, Streams: 999,748,277\n2. Track: Lover, Streams: 882,831,184\n3. Track: cardigan, Streams: 812,019,557\n4. Track: Cruel Summer, Streams: 800,840,817\n5. Track: Style, Streams: 786,181,836"}