## Step 1: Environment Setup and Installation

This step installs dependencies from `requirements.txt` and checks for `OPENAI_API_KEY`.  

If installation fails, it retries up to 3 times before exiting.  

Once complete, it clears the output and prints a success message.  


In [None]:
# Boilerplate: This block goes into every notebook.
# It sets up the environment, installs the requirements, and checks for the required environment variables.

from IPython.display import clear_output
import os

requirements_installed = False
max_retries = 3
retries = 0


def install_requirements():
    """Installs the requirements from requirements.txt file"""
    global requirements_installed, retries
    if requirements_installed:
        print("Requirements already installed.")
        return

    print("Installing requirements...")
    install_status = os.system("pip install -r requirements.txt")
    if install_status == 0:
        print("Requirements installed successfully.")
        requirements_installed = True
    else:
        print("Failed to install requirements.")
        if retries < max_retries:
            print("Retrying...")
            retries += 1
            return install_requirements()
        exit(1)
    return


install_requirements()
clear_output()
print("🚀 Setup complete. Continue to the next cell.")

## Step 2: Environment Variable Setup

This step loads environment variables from `.env` using `dotenv`.  

It checks if `OPENAI_API_KEY` is set; if missing, it exits.  

After validation, it confirms successful setup.  


In [None]:
from dotenv import load_dotenv

REQUIRED_ENV_VARS = ["ANTHROPIC_API_KEY"]


def setup_env():
    """Sets up the environment variables"""
    global REQUIRED_ENV

    def check_env(env_var):
        value = os.getenv(env_var)
        if value is None:
            print(f"Please set the {env_var} environment variable.")
            exit(1)
        else:
            print(f"{env_var} is set.")

    load_dotenv(override=True, dotenv_path=".env")

    variables_to_check = REQUIRED_ENV_VARS

    for var in variables_to_check:
        check_env(var)

    print("Environment variables are set.")


setup_env()

The script begins by importing the necessary modules: `sqlite3` for handling SQLite databases, `traceback` for detailed error reporting, `Union` from the `typing` module to specify multiple return types, and `os` for file operations.  

The `unlock_database` function is designed to unlock a database file by creating a backup copy. It first prints a message indicating that the database is being unlocked. It then generates a new filename by appending `.new` to the original database file name. After that, it establishes connections to both the original database and the new backup file. Using the `backup` method, it safely copies the original database into the backup file. Once the backup is complete, both database connections are closed, and the original database file is replaced with the new backup copy. If any errors occur during this process, they are caught in the `except` block, where an error message is printed along with the full traceback for debugging.  

The `create_connection` function is responsible for establishing a connection to the specified SQLite database. Before doing so, it first calls the `unlock_database` function to ensure the database is accessible. It then attempts to create a connection using `sqlite3.connect`, setting a timeout of 10 seconds. If the connection is successful, it returns the connection object. If any errors occur, they are caught in the `except` block, where an error message is printed along with the traceback for debugging. If the connection fails, the function returns `None`.  


In [None]:
import sqlite3
import traceback
from typing import Union
import os


def unlock_database(db_file: str) -> None:
    """
    Unlock the database file by creating a new copy.

    Args:
        db_file (str): Path to the database file.
    """
    try:
        print("Unlocking database:", db_file)

        new_db = f"{db_file}.new"

        # Open the original database
        conn = sqlite3.connect(db_file)
        backup_conn = sqlite3.connect(new_db)

        with backup_conn:
            conn.backup(backup_conn)  # Safely backup the database

        conn.close()
        backup_conn.close()

        # Replace the old database with the new copy
        os.replace(new_db, db_file)

        print("Database unlocked successfully.")

    except Exception as e:
        print(f"Error unlocking database: {e}")
        traceback.print_exc()


def create_connection(db_file: str) -> Union[sqlite3.Connection, None]:
    """
    Create a database connection to the SQLite database specified by DB file.

    Args:
        db_file (str): database file

    Returns:
        Connection object or None
    """
    try:
        unlock_database(db_file)
        con = sqlite3.connect(db_file, timeout=10)
        return con
    except Exception as e:
        print(f"Error creating connection: {e}")
        traceback.print_exc()
        return None

In [None]:
import sqlite3


def execute_script(
    connection: sqlite3.Connection, script: str, script_key: str
) -> None:
    try:
        print("Executing script: ", script_key)
        cursor = connection.cursor()
        cursor.executescript(script)
        connection.commit()
        print("Script executed successfully")
    except Exception as e:
        print(f"Error executing script: {e}")
        traceback.print_exc()

In [None]:
bank_db = "data/bank.db"

bank_db_connection = create_connection(bank_db)

seed_script = """
-- Connect to SQLite database
PRAGMA foreign_keys = ON;

BEGIN TRANSACTION;

-- Error handling setup
SAVEPOINT start;

-- Create Customers Table
SAVEPOINT customers;
CREATE TABLE IF NOT EXISTS Customers (
    customer_id INTEGER PRIMARY KEY AUTOINCREMENT,
    name TEXT NOT NULL,
    email TEXT UNIQUE NOT NULL,
    phone TEXT NOT NULL,
    address TEXT
);
RELEASE customers;

-- Create Accounts Table
SAVEPOINT accounts;
CREATE TABLE IF NOT EXISTS Accounts (
    account_id INTEGER PRIMARY KEY AUTOINCREMENT,
    customer_id INTEGER NOT NULL,
    account_type TEXT CHECK(account_type IN ('Checking', 'Savings')) NOT NULL,
    balance DECIMAL(10,2) NOT NULL DEFAULT 0.00,
    FOREIGN KEY (customer_id) REFERENCES Customers(customer_id) ON DELETE CASCADE
);
RELEASE accounts;

-- Create Transactions Table
SAVEPOINT transactions;
CREATE TABLE IF NOT EXISTS Transactions (
    transaction_id INTEGER PRIMARY KEY AUTOINCREMENT,
    account_id INTEGER NOT NULL,
    transaction_type TEXT CHECK(transaction_type IN ('Deposit', 'Withdrawal', 'Transfer')) NOT NULL,
    amount DECIMAL(10,2) NOT NULL,
    transaction_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    FOREIGN KEY (account_id) REFERENCES Accounts(account_id) ON DELETE CASCADE
);
RELEASE transactions;

-- Ensure customers exist before inserting accounts
SAVEPOINT check_customers;
INSERT OR IGNORE INTO Customers (customer_id, name, email, phone, address) VALUES
    (1, 'Alice Johnson', 'alice@example.com', '123-456-7890', '123 Elm Street'),
    (2, 'Bob Smith', 'bob@example.com', '234-567-8901', '456 Oak Street'),
    (3, 'Charlie Brown', 'charlie@example.com', '345-678-9012', '789 Maple Street');
RELEASE check_customers;

-- Insert sample accounts with valid customer IDs
SAVEPOINT insert_accounts;
INSERT OR IGNORE INTO Accounts (customer_id, account_type, balance) VALUES
    (1, 'Checking', 1500.00),
    (1, 'Savings', 3000.00),
    (2, 'Checking', 2000.00),
    (3, 'Savings', 500.00);
RELEASE insert_accounts;

-- Ensure accounts exist before inserting transactions
SAVEPOINT check_accounts;
INSERT OR IGNORE INTO Accounts (account_id, customer_id, account_type, balance) VALUES
    (1, 1, 'Checking', 1500.00),
    (2, 1, 'Savings', 3000.00),
    (3, 2, 'Checking', 2000.00),
    (4, 3, 'Savings', 500.00);
RELEASE check_accounts;

-- Insert sample transactions with valid account IDs
SAVEPOINT insert_transactions;
INSERT OR IGNORE INTO Transactions (account_id, transaction_type, amount) VALUES
    (1, 'Deposit', 500.00),
    (2, 'Withdrawal', 200.00),
    (3, 'Deposit', 1000.00),
    (4, 'Withdrawal', 50.00);
RELEASE insert_transactions;

COMMIT;

-- Error handling
PRAGMA foreign_keys = ON;
"""

execute_script(bank_db_connection, seed_script, "seed_bank_db")

In [None]:
from typing import List, Any
from litellm import completion
import json


class DBoss:
    """
    DBoss enables you to interact with your SQLite database in natural language.
    """

    DEFAULT_LLM_MODEL = "anthropic/claude-3-7-sonnet-latest"
    DEFAULT_TEMPERATURE = 0.5

    def __init__(self, db_file: str):
        """
        Initialize DBoss with the database file.

        Args:
            db_file (str): Database file
        """
        self.db_file = db_file
        self.connection = create_connection(db_file)

    def execute_query(self, query: str) -> List[Any]:
        """
        Execute a query on the database.

        Args:
            query (str): SQL query

        Returns:
            None
        """
        try:
            cursor = self.connection.cursor()
            cursor.execute(query)
            rows = cursor.fetchall()
            return rows
        except Exception as e:
            print(f"Error executing query: {e}")
            traceback.print_exc()
            return []

    def _get_full_schema(self) -> List[Any]:
        """
        Get the full schema of the database.

        Returns:
            List: List of tuples containing the schema details
        """
        query = "SELECT name FROM sqlite_master WHERE type='table';"
        return self.execute_query(query)

    def get_table_schema(self, table_name: str) -> List[Any]:
        """
        Get the schema of a specific table.

        Args:
            table_name (str): Name of the table

        Returns:
            List: List of tuples containing the schema details
        """
        query = f"PRAGMA table_info({table_name});"
        return self.execute_query(query)

    def get_all_table_schema(self) -> dict:
        """
        Get the schema of all tables in the database.

        Returns:
            dict: Dictionary containing the schema details of all tables
        """
        schema = {}
        tables = self._get_full_schema()
        for table in tables:
            table_name = table[0]
            schema[table_name] = self.get_table_schema(table_name)
        return schema

    def generate_sql_query(self, user_query: str) -> str:
        """
        Generate a SQL query from a user query.

        Args:
            user_query (str): User query

        Returns:
            str: SQL query
        """

        def strip_backticks_and_sql_block(query: str) -> str:
            query = query.replace("`", "").replace("```sql", "").replace("```", "")
            if query.startswith("sql"):
                query = query[3:]
            return query.strip()

        try:
            user_prompt = f"""
            Given the following user query and database schema, generate the corresponding SQL query:
            User Query: {user_query}
            Database Schema: {json.dumps(self.get_all_table_schema())}
            Simply respond with the SQL query that would generate the desired result and nothing else.
            """
            response = completion(
                messages=[{"role": "user", "content": user_prompt}],
                model=self.DEFAULT_LLM_MODEL,
                temperature=self.DEFAULT_TEMPERATURE,
            )
            return strip_backticks_and_sql_block(response.choices[0].message.content)
        except Exception as e:
            print(f"Error generating SQL query: {e}")
            traceback.print_exc()
            return ""

    def run_text_query(self, user_query: str) -> List[Any]:
        """
        Run a text query on the database.

        Args:
            user_query (str): User query

        Returns:
            List: List of tuples containing the query results
        """
        sql_query = self.generate_sql_query(user_query)
        if not sql_query:
            raise Exception("Failed to generate SQL query.")
        return self.execute_query(sql_query)

In [None]:
## Simple querying example

bank_db = "data/bank.db"
dboss = DBoss(bank_db)

In [None]:
# 1. Simple query execution

query = "SELECT * FROM Customers;"
result = dboss.execute_query(query=query)
print(query)

for row in result:
    print(row)

In [None]:
# 2. Get the schema of a specific table

table_name = "Customers"

schema = dboss.get_table_schema(table_name)

print(f"\nSchema of table '{table_name}':")

for row in schema:
    print(row)

In [None]:
# 3. Get the schema of all tables

full_schema = dboss.get_all_table_schema()

print("\nFull schema of the database:")

for table_name, schema in full_schema.items():
    print(f"\nSchema of table '{table_name}':")
    for row in schema:
        print(row)

In [None]:
query = "Get all accounts with savings above 100"

sql_query = dboss.generate_sql_query(query)
print(f"\n{query}")
print(f"\n{sql_query}")

In [None]:
bank_db = "data/bank.db"
query = "Get all customer names, amount and account details with Savings above 100"

dboss = DBoss(bank_db)
result = dboss.run_text_query(query)

print(f"\n{query}")
for row in result:
    print(row)