In [1]:
import google.generativeai as genai

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# get api key from https://aistudio.google.com/app/apikey and place in .env
# for more details https://ai.google.dev/gemini-api/docs/api-key
import os
from dotenv import load_dotenv, dotenv_values 
load_dotenv()
if(os.getenv("GOOGLE_API_KEY")==""):
    raise RuntimeError("GOOGLE_API_KEY in .env file not provided.")

In [3]:
import psycopg2

DB_HOST = "db"
DB_PORT = "5432"
DB_NAME = "employees"
DB_USER = "user"
DB_PASSWORD = "password"
db_conn = psycopg2.connect(
            dbname=DB_NAME,
            user=DB_USER,
            password=DB_PASSWORD,
            host=DB_HOST,
            port=DB_PORT
        )

In [4]:
def list_tables() -> list[str]:
    """Retrieve the names of all tables in the database."""
    # Include print logging statements so you can see when functions are being called.
    print(' - DB CALL: list_tables')

    cursor = db_conn.cursor()

    # Fetch the table names.
    cursor.execute("SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname NOT IN ('pg_catalog', 'information_schema') ORDER BY schemaname, tablename;")

    tables = cursor.fetchall()
    return [t[0] for t in tables]


list_tables()

 - DB CALL: list_tables


['department',
 'department_employee',
 'department_manager',
 'employee',
 'salary',
 'title']

In [5]:
def describe_table(table_name: str) -> list[tuple[str, str]]:
    """Look up the table schema.

    Returns:
      List of columns, where each entry is a tuple of (column, type).
    """
    print(' - DB CALL: describe_table')

    cursor = db_conn.cursor()

    cursor.execute(f'SELECT column_name, data_type FROM information_schema.columns WHERE table_name = \'{table_name}\';')

    schema = cursor.fetchall()
    # [column index, column name, column type, ...]
    return [(col[0], col[1]) for col in schema]


describe_table("title")

 - DB CALL: describe_table


[('employee_id', 'bigint'),
 ('title', 'character varying'),
 ('from_date', 'date'),
 ('to_date', 'date')]

In [6]:
def execute_query(sql: str) -> list[list[str]]:
    """Execute a SELECT statement, returning the results."""
    print(' - DB CALL: execute_query')

    cursor = db_conn.cursor()

    cursor.execute(sql)
    return cursor.fetchall()


execute_query("select * from title;")

 - DB CALL: execute_query


[(10001,
  'Senior Engineer',
  datetime.date(1986, 6, 26),
  datetime.date(9999, 1, 1)),
 (10002, 'Staff', datetime.date(1996, 8, 3), datetime.date(9999, 1, 1)),
 (10003,
  'Senior Engineer',
  datetime.date(1995, 12, 3),
  datetime.date(9999, 1, 1)),
 (10004, 'Engineer', datetime.date(1986, 12, 1), datetime.date(1995, 12, 1)),
 (10004,
  'Senior Engineer',
  datetime.date(1995, 12, 1),
  datetime.date(9999, 1, 1)),
 (10005,
  'Senior Staff',
  datetime.date(1996, 9, 12),
  datetime.date(9999, 1, 1)),
 (10005, 'Staff', datetime.date(1989, 9, 12), datetime.date(1996, 9, 12)),
 (10006,
  'Senior Engineer',
  datetime.date(1990, 8, 5),
  datetime.date(9999, 1, 1)),
 (10007,
  'Senior Staff',
  datetime.date(1996, 2, 11),
  datetime.date(9999, 1, 1)),
 (10007, 'Staff', datetime.date(1989, 2, 10), datetime.date(1996, 2, 11)),
 (10008,
  'Assistant Engineer',
  datetime.date(1998, 3, 11),
  datetime.date(2000, 7, 31)),
 (10009,
  'Assistant Engineer',
  datetime.date(1985, 2, 18),
  datetim

In [7]:
# These are the Python functions defined above.
db_tools = [list_tables, describe_table, execute_query]

instruction = """You are a helpful chatbot that can interact with an PostgreSQL database.
You will take the users questions and turn them into postgresql fomrat SQL queries using the tools
available. Once you have the information you need, you will answer the user's question using
the data returned. Use list_tables to see what tables are present, describe_table to understand
the schema, and execute_query to issue an SQL SELECT query.
Note: No need of backslash in LIKE Eg: use 'A%' and do not use \'A%\' """

model = genai.GenerativeModel(
    "models/gemini-1.5-flash-latest", tools=db_tools, system_instruction=instruction
)

# Define a retry policy. The model might make multiple consecutive calls automatically
# for a complex query, this ensures the client retries if it hits quota limits.
from google.api_core import retry

retry_policy = {"retry": retry.Retry(predicate=retry.if_transient_error)}

# Start a chat with automatic function calling enabled.
chat = model.start_chat(enable_automatic_function_calling=True)
resp = chat.send_message("What does this database have?", request_options=retry_policy)
print(resp.text)

 - DB CALL: list_tables
This database contains the following tables: department, department_employee, department_manager, employee, salary, and title.


In [8]:
response = chat.send_message("What emplyees details does db have?", request_options=retry_policy)
print(response.text)

 - DB CALL: describe_table
The `employee` table has the following columns:

*   id (bigint)
*   birth\_date (date)
*   first\_name (character varying)
*   last\_name (character varying)
*   gender (USER-DEFINED)
*   hire\_date (date)




In [9]:
response = chat.send_message("Do we have emplyee address?", request_options=retry_policy)
print(response.text)

Based on the schema of the `employee` table, there is no address information available for employees in this database.



In [10]:
chat = model.start_chat(enable_automatic_function_calling=True)
response = chat.send_message("workers count in each department", request_options=retry_policy)
print(response.text)

 - DB CALL: list_tables
 - DB CALL: describe_table
 - DB CALL: describe_table
 - DB CALL: describe_table
 - DB CALL: execute_query
The number of workers in each department is as follows:

* Customer Service: 23580
* Development: 85707
* Finance: 17346
* Human Resources: 17786
* Marketing: 20211
* Production: 73485
* Quality Management: 20117
* Research: 21126
* Sales: 52245


In [11]:
import textwrap


def print_chat_turns(chat):
    """Prints out each turn in the chat history, including function calls and responses."""
    for event in chat.history:
        print(f"{event.role.capitalize()}:")

        for part in event.parts:
            if txt := part.text:
                print(f'  "{txt}"')
            elif fn := part.function_call:
                args = ", ".join(f"{key}={val}" for key, val in fn.args.items())
                print(f"  Function call: {fn.name}({args})")
            elif resp := part.function_response:
                print("  Function response:")
                print(textwrap.indent(str(resp), "    "))

        print()


print_chat_turns(chat)

User:
  "workers count in each department"

Model:
  Function call: list_tables()

User:
  Function response:
    name: "list_tables"
    response {
      fields {
        key: "result"
        value {
          list_value {
            values {
              string_value: "department"
            }
            values {
              string_value: "department_employee"
            }
            values {
              string_value: "department_manager"
            }
            values {
              string_value: "employee"
            }
            values {
              string_value: "salary"
            }
            values {
              string_value: "title"
            }
          }
        }
      }
    }


Model:
  Function call: describe_table(table_name=department)

User:
  Function response:
    name: "describe_table"
    response {
      fields {
        key: "result"
        value {
          list_value {
            values {
              list_value {
                val