Query from fake bank database stored in [data/db/fake_bank.db](data/db/fake_bank.db)

In [1]:
import os
from dotenv import load_dotenv

# ibm genai
from genai.credentials import Credentials
from genai.model import Model
from genai.schemas import GenerateParams, ModelType
from genai.extensions.langchain import LangChainInterface  # langchain support

# langchain
from langchain.sql_database import SQLDatabase
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import AgentExecutor
from langchain.agents.agent_types import AgentType

load_dotenv()
api_key = os.getenv("GENAI_KEY", None)
api_url = os.getenv("GENAI_API", None)
creds = Credentials(api_key, api_endpoint=api_url)

In [8]:
# load model

llm = LangChainInterface(
        model=ModelType.FLAN_UL2,  # choice of model is not necessarily optimal
        #model=ModelType.CODEGEN_MONO_16B,  # didn't work well
        credentials=Credentials(os.environ.get("GENAI_KEY")),
        params=GenerateParams(
            decoding_method="greedy",
            max_new_tokens=1024,
            min_new_tokens=1,
            repetition_penalty=2,
        ).dict()
    )

In [11]:
# database connection and chain
db = SQLDatabase.from_uri("sqlite:///./data/db/fake_bank.db")
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

db_chain = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)
db_chain.handle_parsing_errors = True

In [12]:
db_chain.run('how many unique clients are in the bank')



[1m> Entering new  chain...[0m
[32;1m[1;3mAction: sql_db_schema , sql_db_list_tables; input().split(); output = []; for i in range(len(sql_db_schema)): list_tables_sql_db[i] = "Client"; query_sql_db=sql_db_query; #print(sql_db_query); print("n".join(map(str,sql_db_query)));[0m
Observation: Invalid Format: Missing 'Action Input:' after 'Action:'
Thought:[32;1m[1;3msql_db_query , sql_db_schema; input().split(); output = []; for i in range(len(sql_db_schema)): list_tables_sql_db[i] = "Client"; query_sql_db=sql_db_query; #print(sql_db_query); print("n".join(map(str,sql_db_query)));[0m
Observation: Invalid Format: Missing 'Action:' after 'Thought:'
Thought:[32;1m[1;3mAction: sql_db_query , sql_db_schema; input().split(); output = []; for i in range(len(sql_db_schema)): list_tables_sql_db[i] = "Client"; query_sql_db=sql_db_query; #print(sql_db_query); print("n".join(map(str,sql_db_query)));[0m
Observation: Invalid Format: Missing 'Action Input:' after 'Action:'
Thought:[32;1m[

KeyboardInterrupt: 

In [None]:
db_chain.run('how many accounts are in the deposit product family in the bank')

In [None]:
db_chain.run('sum the dollars in the deposit product family accounts')

In [None]:
db_chain.run('sum the balance and group by client where product family is deposit')