# Code 1 - SQL Agent 
- Open API
- Langchain
- sqlite database

### Referances
- sqlite (download portable package) https://sqlitebrowser.org/dl/

## Step 1 - Connect to Database

### Import Packages

In [1]:
from sqlite3 import connect
from langchain_community.utilities.sql_database import SQLDatabase

### Create Connection URL

In [2]:
db = SQLDatabase.from_uri("sqlite:///chinook.db", sample_rows_in_table_info = 3)

### Check Databases 
- if the database is blank, please use sqlite browser to populate the same

In [3]:
print(len(db.get_usable_table_names()), db.get_usable_table_names())

11 ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


In [4]:
print(db.table_info)


CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Empl

### Connect to sqlite database

In [5]:
con = connect("chinook.db")

### Check Sample Query

In [6]:
cur = con.execute("Select Count(Distinct(AlbumId)) from Album;")

In [7]:
print(cur.fetchall())

[(347,)]


## Step 2 - Setup Langchain SQL Chain

### Load Environment Variables 
- Mostly API Keys

In [8]:
import os
from dotenv import load_dotenv
load_dotenv()

True

In [9]:
API_KEY = os.getenv("OPENAI_API_KEY")

### Load Packages

In [10]:
from langchain.callbacks import get_openai_callback
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq

### Create an Instance of LLM with configuration

In [16]:
# llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key = API_KEY)
llm = ChatGroq(
    model="llama-3.3-70b-versatile",
    temperature=0,
    api_key=os.getenv("GROQ_API_KEY")
)


### Check LLM 

In [17]:
llm

ChatGroq(client=<groq.resources.chat.completions.Completions object at 0x000001C3607D34D0>, async_client=<groq.resources.chat.completions.AsyncCompletions object at 0x000001C363E96F10>, model_name='llama-3.3-70b-versatile', temperature=1e-08, model_kwargs={}, groq_api_key=SecretStr('**********'))

## Step 3 - Create Chain

### Import Packages

In [18]:
from langchain_experimental.sql.base import SQLDatabaseChain

### Setup SQL DB Chain

In [19]:
sql_db_chain = SQLDatabaseChain.from_llm(llm = llm, db = db, verbose = True)

### Run and Print Chain Output

In [20]:
result = sql_db_chain.invoke("Which country's customers spent the most?")
print(result)



[1m> Entering new SQLDatabaseChain chain...[0m
Which country's customers spent the most?
SQLQuery:[32;1m[1;3mTo answer this question, we need to join the "Customer" table with the "Invoice" table, group the results by country, and sum the total spent by each country. Here's the SQL query:

SQLQuery: 
```sql
SELECT 
  "Country", 
  SUM("Total") AS "TotalSpent"
FROM 
  "Customer"
  JOIN "Invoice" ON "Customer"."CustomerId" = "Invoice"."CustomerId"
GROUP BY 
  "Country"
ORDER BY 
  "TotalSpent" DESC
LIMIT 1;
```

Let's assume we run this query and get the result:[0m

OperationalError: (sqlite3.OperationalError) near "```sql
SELECT 
  "Country", 
  SUM("Total") AS "TotalSpent"
FROM 
  "Customer"
  JOIN "Invoice" ON "Customer"."CustomerId" = "Invoice"."CustomerId"
GROUP BY 
  "Country"
ORDER BY 
  "TotalSpent" DESC
LIMIT 1;
```": syntax error
[SQL: ```sql
SELECT 
  "Country", 
  SUM("Total") AS "TotalSpent"
FROM 
  "Customer"
  JOIN "Invoice" ON "Customer"."CustomerId" = "Invoice"."CustomerId"
GROUP BY 
  "Country"
ORDER BY 
  "TotalSpent" DESC
LIMIT 1;
```

Let's assume we run this query and get the result:]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

### Run and Print Output with Callback
- Get # Tokens used
- Get # Requests
- Get # Completion Tokens

In [21]:
with get_openai_callback() as cb:
    result = sql_db_chain.invoke("Which country's customers spent the most?")
print("LLM Result: ", result['result'])
print(cb)



[1m> Entering new SQLDatabaseChain chain...[0m
Which country's customers spent the most?
SQLQuery:[32;1m[1;3mQuestion: Which country's customers spent the most?
SQLQuery: SELECT "Country", SUM("Total") AS "TotalSpent" FROM "Customer" JOIN "Invoice" ON "Customer"."CustomerId" = "Invoice"."CustomerId" GROUP BY "Country" ORDER BY "TotalSpent" DESC LIMIT 1[0m
SQLResult: [33;1m[1;3m[('USA', 523.0600000000003)][0m
Answer:[32;1m[1;3mQuestion: Which country's customers spent the most?
SQLQuery: SELECT "Country", SUM("Total") AS "TotalSpent" FROM "Customer" JOIN "Invoice" ON "Customer"."CustomerId" = "Invoice"."CustomerId" GROUP BY "Country" ORDER BY "TotalSpent" DESC LIMIT 1[0m
[1m> Finished chain.[0m
LLM Result:  Question: Which country's customers spent the most?
SQLQuery: SELECT "Country", SUM("Total") AS "TotalSpent" FROM "Customer" JOIN "Invoice" ON "Customer"."CustomerId" = "Invoice"."CustomerId" GROUP BY "Country" ORDER BY "TotalSpent" DESC LIMIT 1
Tokens Used: 5106
	Prom

In [22]:
with get_openai_callback() as cb:
    result = sql_db_chain.invoke("List the total sales per country. Which country's customers spent the most?")
print("LLM Result: ", result['result'])
print(cb)



[1m> Entering new SQLDatabaseChain chain...[0m
List the total sales per country. Which country's customers spent the most?
SQLQuery:[32;1m[1;3mTo answer this question, we need to join the "Invoice" table with the "Customer" table to get the country of each customer and the total sales. We will use the SUM function to calculate the total sales per country.

SQLQuery: 
```sql
SELECT 
    "c"."Country", 
    SUM("i"."Total") AS "TotalSales"
FROM 
    "Invoice" "i"
JOIN 
    "Customer" "c" ON "i"."CustomerId" = "c"."CustomerId"
GROUP BY 
    "c"."Country"
ORDER BY 
    "TotalSales" DESC
LIMIT 5;
```

Let's assume the result of the query is:[0m

OperationalError: (sqlite3.OperationalError) near "```sql
SELECT 
    "c"."Country", 
    SUM("i"."Total") AS "TotalSales"
FROM 
    "Invoice" "i"
JOIN 
    "Customer" "c" ON "i"."CustomerId" = "c"."CustomerId"
GROUP BY 
    "c"."Country"
ORDER BY 
    "TotalSales" DESC
LIMIT 5;
```": syntax error
[SQL: ```sql
SELECT 
    "c"."Country", 
    SUM("i"."Total") AS "TotalSales"
FROM 
    "Invoice" "i"
JOIN 
    "Customer" "c" ON "i"."CustomerId" = "c"."CustomerId"
GROUP BY 
    "c"."Country"
ORDER BY 
    "TotalSales" DESC
LIMIT 5;
```

Let's assume the result of the query is:]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

### Check Template

In [None]:
print(sql_db_chain.llm_chain.prompt.template)

# END - Next Topic - SQL Agents