In [None]:
import google.generativeai as genai

In [2]:
# get api key from https://aistudio.google.com/app/apikey and place in .env
# for more details https://ai.google.dev/gemini-api/docs/api-key
import os
from dotenv import load_dotenv, dotenv_values 
load_dotenv()
if(os.getenv("GOOGLE_API_KEY")==""):
    raise RuntimeError("GOOGLE_API_KEY in .env file not provided.")

In [3]:
import psycopg2

DB_HOST = "db"
DB_PORT = "5432"
DB_NAME = "sampledb"
DB_USER = "user"
DB_PASSWORD = "password"
db_conn = psycopg2.connect(
            dbname=DB_NAME,
            user=DB_USER,
            password=DB_PASSWORD,
            host=DB_HOST,
            port=DB_PORT
        )

In [None]:
def list_tables() -> list[str]:
    """Retrieve the names of all tables in the database."""
    # Include print logging statements so you can see when functions are being called.
    print(' - DB CALL: list_tables')

    cursor = db_conn.cursor()

    # Fetch the table names.
    cursor.execute("SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname NOT IN ('pg_catalog', 'information_schema') ORDER BY schemaname, tablename;")

    tables = cursor.fetchall()
    return [t[0] for t in tables]


list_tables()

In [None]:
def describe_table(table_name: str) -> list[tuple[str, str]]:
    """Look up the table schema.

    Returns:
      List of columns, where each entry is a tuple of (column, type).
    """
    print(' - DB CALL: describe_table')

    cursor = db_conn.cursor()

    cursor.execute(f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';")

    schema = cursor.fetchall()
    # [column index, column name, column type, ...]
    return [(col[0], col[1]) for col in schema]


describe_table("Customer")

In [None]:
def execute_query(sql: str) -> list[list[str]]:
    """Execute a SELECT statement, returning the results."""
    print(' - DB CALL: execute_query')

    cursor = db_conn.cursor()

    cursor.execute(sql)
    return cursor.fetchall()


execute_query("select * from \"Customer\";")

In [None]:
# These are the Python functions defined above.
db_tools = [list_tables, describe_table, execute_query]

instruction = """You are a helpful chatbot that can interact with an PostgreSQL 17 database.
You will take the users questions and turn them into postgresql fomrat SQL queries using the tools
available. Once you have the information you need, you will answer the user's question using
the data returned. Use list_tables to see what tables are present, describe_table to understand
the schema, and execute_query to issue an SQL SELECT query. Note: query should follow Postgresql syntax."""

model = genai.GenerativeModel(
    "models/gemini-1.5-flash-latest", tools=db_tools, system_instruction=instruction
)

# Define a retry policy. The model might make multiple consecutive calls automatically
# for a complex query, this ensures the client retries if it hits quota limits.
from google.api_core import retry

retry_policy = {"retry": retry.Retry(predicate=retry.if_transient_error)}

# Start a chat with automatic function calling enabled.
chat = model.start_chat(enable_automatic_function_calling=True)
resp = chat.send_message("list of customer name starting with V?", request_options=retry_policy)
print(resp.text)

In [None]:
model = genai.GenerativeModel(
    "models/gemini-1.5-pro-latest", tools=db_tools, system_instruction=instruction
)

chat = model.start_chat(enable_automatic_function_calling=True)
response = chat.send_message('Which salesperson sold the cheapest product?', request_options=retry_policy)
print(response.text)

In [None]:
import textwrap


def print_chat_turns(chat):
    """Prints out each turn in the chat history, including function calls and responses."""
    for event in chat.history:
        print(f"{event.role.capitalize()}:")

        for part in event.parts:
            if txt := part.text:
                print(f'  "{txt}"')
            elif fn := part.function_call:
                args = ", ".join(f"{key}={val}" for key, val in fn.args.items())
                print(f"  Function call: {fn.name}({args})")
            elif resp := part.function_response:
                print("  Function response:")
                print(textwrap.indent(str(resp), "    "))

        print()


print_chat_turns(chat)