Query from fake bank database stored in [data/db/fake_bank.db](data/db/fake_bank.db)

In [1]:
import os
from dotenv import load_dotenv

# ibm genai
from genai.credentials import Credentials
from genai.model import Model
from genai.schemas import GenerateParams, ModelType
from genai.extensions.langchain import LangChainInterface  # langchain support

# langchain
from langchain.chains import SQLDatabaseChain, SQLDatabaseSequentialChain
from langchain.sql_database import SQLDatabase

load_dotenv()
api_key = os.getenv("GENAI_KEY", None)
api_url = os.getenv("GENAI_API", None)
creds = Credentials(api_key, api_endpoint=api_url)

In [2]:
# load model

llm = LangChainInterface(
        model=ModelType.FLAN_T5_11B,  # choice of model is not necessarily optimal
        credentials=Credentials(os.environ.get("GENAI_KEY")),
        params=GenerateParams(
            decoding_method="greedy",
            max_new_tokens=300,
            min_new_tokens=1,
            repetition_penalty=2,
        ).dict()
    )

In [3]:
# database connection and chain
db = SQLDatabase.from_uri("sqlite:///./data/db/fake_bank.db")
db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True)

In [4]:
db_chain.run('how many unique client names are in the bank')



[1m> Entering new  chain...[0m




Table names to use:
[33;1m[1;3m['accounts', 'clients'][0m

[1m> Entering new  chain...[0m
how many unique client names are in the bank
SQLQuery:[32;1m[1;3mSELECT count(DISTINCT Client_Name) FROM clients[0m
SQLResult: [33;1m[1;3m[(6,)][0m
Answer:[32;1m[1;3mThere are 6 unique client names in the bank.[0m
[1m> Finished chain.[0m

[1m> Finished chain.[0m


'There are 6 unique client names in the bank.'

In [9]:
db_chain.run('how many accounts are in the deposit product family in the bank')



[1m> Entering new  chain...[0m
Table names to use:
[33;1m[1;3m['accounts', 'clients', 'transactions'][0m

[1m> Entering new  chain...[0m
how many accounts are in the deposit product family  in the bank
SQLQuery:[32;1m[1;3mSELECT count(*) FROM accounts WHERE Product_Family = "deposit"[0m
SQLResult: [33;1m[1;3m[(11,)][0m
Answer:[32;1m[1;3m11[0m
[1m> Finished chain.[0m

[1m> Finished chain.[0m


'11'

In [11]:
db_chain.run('sum the dollars in the deposit product family accounts')



[1m> Entering new  chain...[0m
Table names to use:
[33;1m[1;3m['accounts', 'clients', 'transactions'][0m

[1m> Entering new  chain...[0m
sum the dollars in the deposit product family accounts
SQLQuery:[32;1m[1;3mSELECT sum(Balance) FROM accounts WHERE Product_Family = "deposit"[0m
SQLResult: [33;1m[1;3m[(20197586.46,)][0m
Answer:[32;1m[1;3m[20197586.46][0m
[1m> Finished chain.[0m

[1m> Finished chain.[0m


'[20197586.46]'

In [17]:
db_chain.run('sum the balance and group by client where product family is deposit')



[1m> Entering new  chain...[0m
Table names to use:
[33;1m[1;3m['transactions'][0m

[1m> Entering new  chain...[0m
sum the balance and group by client where product family is deposit
SQLQuery:[32;1m[1;3mSELECT sum(Amount) FROM transactions WHERE Product_family = "deposit" GROUP BY Client_key[0m

OperationalError: (sqlite3.OperationalError) no such column: Product_family
[SQL: SELECT sum(Amount) FROM transactions WHERE Product_family = "deposit" GROUP BY Client_key]
(Background on this error at: https://sqlalche.me/e/20/e3q8)