In [1]:
import os
import ast
import sqlite3
import openai
import pandas as pd
from pathlib import Path, PosixPath
from typing import NamedTuple
from tabulate import tabulate
from dotenv import dotenv_values
from langchain import OpenAI, SQLDatabase, PromptTemplate, LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough  # Try without .passthrough
from langchain_experimental.sql import SQLDatabaseChain

# from langchain.utilities import SQLDatabase

In [2]:
def get_env_values() -> dict[str, str | None]:
    return dotenv_values()

In [3]:
env_values: dict[str, str | None] = get_env_values()


# NamedTuple type hint
class ParametersType(NamedTuple):
    data_dir: PosixPath  # Platform neutral pathlib PosixPath to data directory
    acs_path: PosixPath  # Platform neutral pathlib PosixPath to ACS data
    db_path: PosixPath  # Platform neutral pathlib PosixPath to SQLite3 database
    db_connection: sqlite3.Connection  # SQLite3 database connection
    openai_api_key: str  # OpenAI API key
    huggingfacehub_api_token: str  # HuggingFace API token


Parameters: ParametersType = ParametersType(
    data_dir=Path.cwd() / "Data",
    acs_path=Path.cwd() / "Data/ACS_2012_21.csv",
    db_path=Path.cwd() / "Data/data.sqlite3",
    db_connection=sqlite3.connect(
        Path.cwd() / "Data/data.sqlite3"
    ),  # ":memory:", "Data/data.sqlite3", "Data/acs.sqlite3"
    openai_api_key=env_values["OPENAI_API_KEY"],
    huggingfacehub_api_token=env_values["HUGGINGFACEHUB_API_TOKEN"],
)

In [4]:
class ColumnDetail(NamedTuple):
    name: str
    type: str
    values: list[str | int | float]

In [5]:
def get_table_details(
    table_name: str, db_info: dict[str, dict[str, list]]
) -> dict[int, ColumnDetail]:
    values: list[str | int | float] = [
        [t[i] for t in db_info[table_name]["records"]]
        for i in range(len(db_info[table_name]["columns"]))
    ]

    col_details: list[tuple[str, str, list[str | int | float]]] = list(
        zip(
            db_info[table_name]["columns"],
            db_info[table_name]["column_types"],
            values,
        )
    )

    table_detail: dict[int, ColumnDetail] = {
        i: ColumnDetail(*col_detail) for i, col_detail in enumerate(col_details)
    }

    return table_detail

In [6]:
class DatabaseInfoExtractor:
    """
    A class for extracting information about tables and columns from a SQLite database.

    Attributes:
    - db_path (str): The path to the SQLite database file.
    - conn (sqlite3.Connection): The connection object to the database.
    - cursor (sqlite3.Cursor): The cursor object for executing SQL queries.

    Methods:
    - extract_info() -> dict[str, dict[str, list[str]]]: Returns a dictionary containing information about each table in the database.
    """

    def __init__(self, db_path: PosixPath) -> None:
        self.db_path = db_path
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()

    def extract_info(self) -> dict[str, dict[str, list[str]]]:
        tables: dict[str, str] = {}
        self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        table_names: list[str] = [row[0] for row in self.cursor.fetchall()]
        for table_name in table_names:
            columns: list[str] = []
            column_types: list[str] = []
            self.cursor.execute(f"PRAGMA table_info({table_name})")
            for row in self.cursor.fetchall():
                columns.append(row[1])
                column_types.append(row[2])
            tables[table_name] = {"columns": columns, "column_types": column_types}

            # Fetch the first five records from the table
            self.cursor.execute(f"SELECT * FROM {table_name} LIMIT 5;")
            records: list[tuple] = self.cursor.fetchall()
            tables[table_name]["records"] = records

        return tables

In [7]:
# Instantiate DatabaseInfoExtractor
db_info_extractor = DatabaseInfoExtractor(Parameters.db_path)

# Extract information about the database
db_info = db_info_extractor.extract_info()
print(db_info)

{'saacempratio': {'columns': ['GeoFIPS', 'GeoName', 'Region', 'TableName', 'LineCode', 'IndustryClassification', 'Description', 'Unit', 'Year', 'Value'], 'column_types': ['TEXT', 'TEXT', 'REAL', 'TEXT', 'REAL', 'TEXT', 'TEXT', 'TEXT', 'TEXT', 'TEXT'], 'records': [(' 01000', 'Alabama', 5.0, 'SAACEmpRatio', 10.0, '...', 'Total arts employment ratio ', 'Ratio', '2001', '0.033'), (' 01000', 'Alabama', 5.0, 'SAACEmpRatio', 100.0, '...', ' Core arts and cultural production ', 'Ratio', '2001', '0.245'), (' 01000', 'Alabama', 5.0, 'SAACEmpRatio', 111.0, '...', '  Performing arts companies ', 'Ratio', '2001', '0.911'), (' 01000', 'Alabama', 5.0, 'SAACEmpRatio', 112.0, '...', '  Promoters of performing arts and similar events ', 'Ratio', '2001', '0.826'), (' 01000', 'Alabama', 5.0, 'SAACEmpRatio', 113.0, '...', '  Agents/managers for artists ', 'Ratio', '2001', '0.512')]}, 'saacartsva': {'columns': ['GeoFIPS', 'GeoName', 'Region', 'TableName', 'LineCode', 'IndustryClassification', 'Description',

In [8]:
table_summaries = []

for table_name, table_info in db_info.items():
    table_detail: dict[int, ColumnDetail] = get_table_details(table_name, db_info)
    table_summary = []

    for col_detail in table_detail.values():
        table_summary.append(
            [table_name, col_detail.name, col_detail.type, col_detail.values]
        )

    table_summaries.append(table_summary)

table_headers = ["Table Name", "Column Name", "Column Type", "First Five Column Values"]

for table_summary in table_summaries:
    table_text = tabulate(table_summary, headers=table_headers, tablefmt="grid")
    print(table_text)
    print("\n")

+--------------+------------------------+---------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Table Name   | Column Name            | Column Type   | First Five Column Values                                                                                                                                                                        |
| saacempratio | GeoFIPS                | TEXT          | [' 01000', ' 01000', ' 01000', ' 01000', ' 01000']                                                                                                                                              |
+--------------+------------------------+---------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [18]:
table_summaries: dict[str, list[str, str, str, list[str | int | float]]] = {}

for table_name, table_info in db_info.items():
    table_detail: dict[int, ColumnDetail] = get_table_details(table_name, db_info)
    table_summary = []

    for col_detail in table_detail.values():
        table_summary.append(
            [table_name, col_detail.name, col_detail.type, col_detail.values]
        )

    table_summaries[table_name] = table_summary

table_headers = ["Table Name", "Column Name", "Column Type", "First Five Column Values"]

for table_name, table_summary in table_summaries.items():
    table_text = tabulate(table_summary, headers=table_headers, tablefmt="grid") + "\n"
    print(table_text)

+--------------+------------------------+---------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Table Name   | Column Name            | Column Type   | First Five Column Values                                                                                                                                                                        |
| saacempratio | GeoFIPS                | TEXT          | [' 01000', ' 01000', ' 01000', ' 01000', ' 01000']                                                                                                                                              |
+--------------+------------------------+---------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [16]:
table_summaries["acs"]

[['acs', 'Unnamed: 0', 'INTEGER', [0, 1, 2, 3, 4]],
 ['acs',
  'Geography',
  'TEXT',
  ['0400000US01', '0400000US01', '0400000US01', '0400000US01', '0400000US01']],
 ['acs',
  'Geographic Area Name',
  'TEXT',
  ['Alabama', 'Alabama', 'Alabama', 'Alabama', 'Alabama']],
 ['acs',
  'Total population',
  'INTEGER',
  [4777326, 4817678, 4799277, 4841164, 4830620]],
 ['acs', 'Male', 'INTEGER', [2317520, 2336020, 2328592, 2346193, 2341093]],
 ['acs', 'Female', 'INTEGER', [2459806, 2481658, 2470685, 2494971, 2489527]],
 ['acs', 'Under 5 years', 'INTEGER', [305091, 299571, 301925, 292771, 295054]],
 ['acs', '5 to 9 years', 'INTEGER', [309360, 304412, 306456, 305707, 305714]],
 ['acs',
  '10 to 14 years',
  'INTEGER',
  [318484, 321104, 320031, 313980, 318437]],
 ['acs',
  '15 to 19 years',
  'INTEGER',
  [337159, 327579, 332287, 324809, 324020]],
 ['acs',
  '20 to 24 years',
  'INTEGER',
  [340808, 347110, 345240, 342489, 348044]],
 ['acs',
  '25 to 34 years',
  'INTEGER',
  [607797, 618482, 

In [22]:
table_text = (
    tabulate(table_summaries["acs"], headers=table_headers, tablefmt="grid") + "\n"
)
print(table_text)

+--------------+--------------------------------------------------------------+---------------+-----------------------------------------------------------------------------+
| Table Name   | Column Name                                                  | Column Type   | First Five Column Values                                                    |
| acs          | Unnamed: 0                                                   | INTEGER       | [0, 1, 2, 3, 4]                                                             |
+--------------+--------------------------------------------------------------+---------------+-----------------------------------------------------------------------------+
| acs          | Geography                                                    | TEXT          | ['0400000US01', '0400000US01', '0400000US01', '0400000US01', '0400000US01'] |
+--------------+--------------------------------------------------------------+---------------+-----------------------------------

In [78]:
question: str = "How many nine year olds were in Kentucky in 2020?"

table_text: str = (
    tabulate(table_summaries["acs"], headers=table_headers, tablefmt="plain") + "\n"
)  # grid

prompt: str = f"""
Please return a syntactically correct SQLite query answering the following question:

{question}

Using the table information below:

{table_text}

Please quote the columns in the query.

If you are not able to exactly answer the question, please return the following message:

I am not able to answer the question exactly.
"""

print(prompt)


Please return a syntactically correct SQLite query answering the following question:

How many nine year olds were in Kentucky in 2020?

Using the table information below:

Table Name    Column Name                                                   Column Type    First Five Column Values
acs           Unnamed: 0                                                    INTEGER        [0, 1, 2, 3, 4]
acs           Geography                                                     TEXT           ['0400000US01', '0400000US01', '0400000US01', '0400000US01', '0400000US01']
acs           Geographic Area Name                                          TEXT           ['Alabama', 'Alabama', 'Alabama', 'Alabama', 'Alabama']
acs           Total population                                              INTEGER        [4777326, 4817678, 4799277, 4841164, 4830620]
acs           Male                                                          INTEGER        [2317520, 2336020, 2328592, 2346193, 2341093]
acs           F

In [79]:
# Simple LLM call Using LangChain
llm = OpenAI(model_name="text-davinci-003", openai_api_key=Parameters.openai_api_key)

result: str = llm(prompt)

print(prompt, result)


Please return a syntactically correct SQLite query answering the following question:

How many nine year olds were in Kentucky in 2020?

Using the table information below:

Table Name    Column Name                                                   Column Type    First Five Column Values
acs           Unnamed: 0                                                    INTEGER        [0, 1, 2, 3, 4]
acs           Geography                                                     TEXT           ['0400000US01', '0400000US01', '0400000US01', '0400000US01', '0400000US01']
acs           Geographic Area Name                                          TEXT           ['Alabama', 'Alabama', 'Alabama', 'Alabama', 'Alabama']
acs           Total population                                              INTEGER        [4777326, 4817678, 4799277, 4841164, 4830620]
acs           Male                                                          INTEGER        [2317520, 2336020, 2328592, 2346193, 2341093]
acs           F

In [80]:
print(result)


SELECT SUM("5 to 9 years") AS nine_year_olds
FROM acs
WHERE Geography LIKE '0400000US21' AND YEAR = 2020;


In [124]:
question: str = "How many nine year olds were in Kentucky in 2020?"

table_feature: str = (
    tabulate(table_summaries["acs"], headers=table_headers, tablefmt="plain") + "\n"
)  # grid

id_data_prompt: str = f"""
What columns in the table below:

{table_feature}

Can be used to answer the following question:

{question}
"""

print(id_data_prompt)


What columns in the table below:

Table Name    Column Name                                                   Column Type    First Five Column Values
acs           Unnamed: 0                                                    INTEGER        [0, 1, 2, 3, 4]
acs           Geography                                                     TEXT           ['0400000US01', '0400000US01', '0400000US01', '0400000US01', '0400000US01']
acs           Geographic Area Name                                          TEXT           ['Alabama', 'Alabama', 'Alabama', 'Alabama', 'Alabama']
acs           Total population                                              INTEGER        [4777326, 4817678, 4799277, 4841164, 4830620]
acs           Male                                                          INTEGER        [2317520, 2336020, 2328592, 2346193, 2341093]
acs           Female                                                        INTEGER        [2459806, 2481658, 2470685, 2494971, 2489527]
acs           Und

In [125]:
result: str = llm(id_data_prompt)

print(
    id_data_prompt,
    result,
    table_summaries["acs"][0][0],
    table_summaries["acs"][0][1],
    sep=" - ",
)


What columns in the table below:

Table Name    Column Name                                                   Column Type    First Five Column Values
acs           Unnamed: 0                                                    INTEGER        [0, 1, 2, 3, 4]
acs           Geography                                                     TEXT           ['0400000US01', '0400000US01', '0400000US01', '0400000US01', '0400000US01']
acs           Geographic Area Name                                          TEXT           ['Alabama', 'Alabama', 'Alabama', 'Alabama', 'Alabama']
acs           Total population                                              INTEGER        [4777326, 4817678, 4799277, 4841164, 4830620]
acs           Male                                                          INTEGER        [2317520, 2336020, 2328592, 2346193, 2341093]
acs           Female                                                        INTEGER        [2459806, 2481658, 2470685, 2494971, 2489527]
acs           Und

In [104]:
len(table_summaries["acs"])

96

In [131]:
question: str = "How many nine year olds were in Kentucky in 2020?"

for i in range(len(table_summaries["acs"])):
    table_feature: str = (
        tabulate([table_summaries["acs"][i]], headers=table_headers, tablefmt="plain")
        + "\n"
    )  # grid

    id_data_prompt: str = f"""
    Is the following column:

    {table_feature}

    A candidate to help answer the following question:

    {question}

    Your answer should only be exactly one of the following:
    1) yes,
    2) no, or
    3) maybe
    Ommit punctuation from your response and return it in lower case.
    """

    result: str = llm(id_data_prompt)

    print(result, table_summaries["acs"][i][0], table_summaries["acs"][i][1], sep=" - ")


no - acs - Unnamed: 0

no - acs - Geography

no - acs - Geographic Area Name

no - acs - Total population

no - acs - Male

no - acs - Female

no - acs - Under 5 years

yes - acs - 5 to 9 years

no - acs - 10 to 14 years

no - acs - 15 to 19 years

no - acs - 20 to 24 years

no - acs - 25 to 34 years

no - acs - 35 to 44 years

no - acs - 45 to 54 years

no - acs - 55 to 59 years

no - acs - 60 to 64 years

no - acs - 65 to 74 years

no - acs - 75 to 84 years

no - acs - 85 years and over

no - acs - Median age (years)

no - acs - 18 years and over

no - acs - 21 years and over

no - acs - 65 years and over

no - acs - One race

no - acs - Two or more races

no - acs - White

no - acs - Black or African American

no - acs - American Indian and Alaska Native

no - acs - Asian

no - acs - Native Hawaiian and Other Pacific Islander

no - acs - Some other race

no - acs - Hispanic or Latino (of any race)

no - acs - Not Hispanic or Latino

no - acs - Total housing units_x

no - acs - YEAR

# ACS Data

In [32]:
acs_df: pd.DataFrame = pd.read_csv(Parameters.acs_path)
acs_df.drop(columns=["Unnamed: 0"], inplace=True)
display(acs_df.info())
display(acs_df.head())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 520 entries, 0 to 519
Data columns (total 95 columns):
 #   Column                                                        Non-Null Count  Dtype  
---  ------                                                        --------------  -----  
 0   Geography                                                     520 non-null    object 
 1   Geographic Area Name                                          520 non-null    object 
 2   Total population                                              520 non-null    int64  
 3   Male                                                          520 non-null    int64  
 4   Female                                                        520 non-null    int64  
 5   Under 5 years                                                 520 non-null    int64  
 6   5 to 9 years                                                  520 non-null    int64  
 7   10 to 14 years                                                520 non-n

None

Unnamed: 0,Geography,Geographic Area Name,Total population,Male,Female,Under 5 years,5 to 9 years,10 to 14 years,15 to 19 years,20 to 24 years,...,Vacant housing units,Homeowner vacancy rate,Rental vacancy rate,Median rooms,"Median (dollars), Value",Owner-occupied units,Housing units with a mortgage,Housing units without a mortgage,"Median (dollars), Rent",No rent paid
0,0400000US01,Alabama,4777326,2317520,2459806,305091,309360,318484,337159,340808,...,335071,2.5,9.0,5.7,122300,1289324,776946,512378,691,63064
1,0400000US01,Alabama,4817678,2336020,2481658,299571,304412,321104,327579,347110,...,339433,2.5,8.9,5.7,122500,1281604,762450,519154,705,62178
2,0400000US01,Alabama,4799277,2328592,2470685,301925,306456,320031,332287,345240,...,348464,2.6,9.0,5.7,123800,1274196,751234,522962,715,62842
3,0400000US01,Alabama,4841164,2346193,2494971,292771,305707,313980,324809,342489,...,351004,2.5,9.0,5.7,125500,1269145,738618,530527,717,63906
4,0400000US01,Alabama,4830620,2341093,2489527,295054,305714,318437,324020,348044,...,358274,2.4,9.4,5.7,128500,1267824,730637,537187,728,65161


In [33]:
"""
SELECT SUM(Under 5 years + 5 to 9 years) AS nine_year_olds
FROM acs
WHERE Geography LIKE '0400000US21'
AND YEAR=2020;
"""

(
    acs_df.query("Geography == '0400000US21' and YEAR == 2020")["Under 5 years"].sum()
    + acs_df.query("Geography == '0400000US21' and YEAR == 2020")["5 to 9 years"].sum()
)

548588

In [34]:
acs_df[acs_df["Geography"] == "0400000US21"]["Geographic Area Name"].unique()

array(['Kentucky'], dtype=object)

# List all tables in SQLite database

In [None]:
cursor = Parameters.db_connection.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables: list[tuple] = cursor.fetchall()
tables

# Test OpenAI API Key

In [None]:
"""
Please provide python code to ask openai a question and retrieve the answer

To ask OpenAI a question and retrieve the answer, you can use the OpenAI API.
Here's an example Python code that demonstrates how to do this using the openai package:

In this example, you need to replace "YOUR_API_KEY" with your actual API key,
which you can obtain by signing up for OpenAI's API at https://beta.openai.com/signup/.

The ask_openai function takes three parameters: the question you want to ask,
the name of the OpenAI model you want to use (e.g., "davinci" or "curie"),
and the max_length of the generated answer (in number of tokens).

The function sends a request to the OpenAI API using the openai.Completion.create method,
which takes the engine, prompt, and max_tokens as parameters. The response is a list of
completions (i.e., possible answers), and we take the first one (which is usually the most likely answer)
and return it as a string.
"""

openai.api_key = Parameters.openai_api_key  # "YOUR_API_KEY" # replace with your API key


def ask_openai(question, model, max_length):
    prompt = f"Q: {question}\nA:"
    completions = openai.Completion.create(
        engine=model, prompt=prompt, max_tokens=max_length
    )
    answer = completions.choices[0].text.strip()
    return answer


# Example usage
question = "What is the capital of France?"
model = "davinci"
max_length = 100
answer = ask_openai(question, model, max_length)
print(answer)

# [LangChain](https://pypi.org/project/langchain/)
https://coinsbench.com/chat-with-your-databases-using-langchain-bb7d31ed2e76  
https://medium.com/@hannanmentor/python-custom-chatgpt-with-your-own-data-f307635dd5bd  

## Check that LangChain works

In [26]:
# Simple LLM call Using LangChain
llm = OpenAI(model_name="text-davinci-003", openai_api_key=Parameters.openai_api_key)
question = "Which language is used to create chatgpt ?"
print(question, llm(question))

Which language is used to create chatgpt ? 

ChatGPT is written in Python.


## Prompt template

In [None]:
# Creating a prompt template and running the LLM chain
template = "What are the top {n} resources to learn {language} programming?"
prompt = PromptTemplate(template=template, input_variables=["n", "language"])
chain = LLMChain(llm=llm, prompt=prompt)
input_ = {"n": 3, "language": "Python"}
print(chain.run(input_))

In [None]:
type(prompt)

In [None]:
# SQLDatabase.create_table_from_df(df=acs_df, table_name="acs", db_connection=Parameters.db_connection)  # Create table from dataframe

In [None]:
# Use LangChain to answers quetions using a SQLite3 database
llm = OpenAI(model_name="text-davinci-003", openai_api_key=Parameters.openai_api_key)
dburi = "sqlite:///Data/acs.sqlite3"
db = SQLDatabase.from_uri(dburi)
question = "What is the population of Kentucky?"
print(question, llm(question, db))

In [None]:
template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)

In [None]:
type(prompt)

In [None]:
print(prompt)

In [None]:
# db = SQLDatabase.from_uri("sqlite:///Data/data.sqlite3")
db = SQLDatabase.from_uri("sqlite:///Data/acs.sqlite3")

In [None]:
result = db.run("SELECT * FROM acs LIMIT 5")
result

In [None]:
result: list[tuple] = ast.literal_eval(db.run("SELECT * FROM acs LIMIT 5"))
print(result)

In [None]:
def get_schema(_):
    return db.get_table_info()

In [None]:
def run_query(query):
    return db.run(query)

In [None]:
model = ChatOpenAI()

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | model.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [None]:
sql_response.invoke({"question": "How many geographic areas are in the acs table?"})

In [None]:
print(get_schema(_))

In [None]:
type(sql_response)

In [None]:
sql_response

In [None]:
openai.api_key = Parameters.openai_api_key


def count_tokens(prompt):
    response = openai.Completion.create(
        engine="davinci",
        prompt=prompt,
        max_tokens=0,
        n=1,
        stop=None,
        temperature=0.0,
        logprobs=0,
        echo=True,
    )
    return len(response.choices[0].text.split())


prompt = "This is a prompt to count tokens."
print(count_tokens(prompt))  # Output: 6

In [None]:
prompt = get_schema(_)
print(count_tokens(prompt))

In [None]:
# setup llm
llm = ChatOpenAI()  # OpenAI(temperature=0, openai_api_key=Parameters.openai_api_key)

dburi = "sqlite:///Data/acs.sqlite3"
db = SQLDatabase.from_uri(dburi)

# Create db chain
QUERY = """
Given an input question: first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.

Use the following format:
Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

{question}
"""

# Setup the database chain
db_chain = SQLDatabaseChain.from_llm(llm=llm, db=db, verbose=True)


def get_prompt():
    print("Type 'exit' to quit")

    while True:
        prompt = input("Enter a prompt: ")

        if prompt.lower() == "exit":
            print("Exiting...")
            break
        else:
            try:
                question = QUERY.format(question=prompt)
                print(db_chain.run(question))
            except Exception as e:
                print(e)


get_prompt()

In [None]:
template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)

In [None]:
model = ChatOpenAI()

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | model.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [None]:
sql_response.invoke({"question": "How many geographic areas are in the acs table?"})

In [None]:
# setup llm
llm = ChatOpenAI()  # OpenAI(temperature=0, openai_api_key=Parameters.openai_api_key)

dburi = "sqlite:///Data/acs.sqlite3"
db = SQLDatabase.from_uri(dburi)

# Create db chain
QUERY = """
Given an input question: first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.

Use the following format:
Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

{question}
"""

# Setup the database chain
db_chain = SQLDatabaseChain.from_llm(llm=llm, db=db, verbose=True)

print("Type 'exit' to quit")

prompt = input("Enter a prompt: ")

if prompt.lower() == "exit":
    print("Exiting...")
else:
    try:
        sql_response.invoke({"question": prompt})
        question = sql_response  # prompt  # QUERY.format(question=prompt)
        print(db_chain.run(question))
    except Exception as e:
        print(e)

In [None]:
# How many nine year olds were in Kentucky in 2020

In [None]:
for k, v in db_chain.dict()["llm_chain"]["prompt"].items():
    print(k)
    print(v)
    print()