# Setup and AlloyDB Parameterized Secure Views

Parameterized Secure Views (PSV) in AlloyDB allow you to define tables and columns that natural-language queries can pull data from, and add additional restrictions on the range of rows available to an individual application user. These restrictions let you tightly control the data that your application users can view through natural-language queries, no matter how your users phrase these queries. PSVs are designed to protect against prompt injection attacks and help ensure that end users can view only the data that they are supposed to access.

## Basic Setup

### Define Notebook Variables

Update the variables below to match your environment. You will be prompted for the AlloyDB password you chose then you provisioned the environment with Terraform.

In [None]:
# Project variables
project_id = "your-project"
region = "your-region"
vpc = "demo-vpc"
gcs_bucket_name = f"project-files-{project_id}"

# AlloyDB variables
alloydb_cluster = "my-alloydb-cluster"
alloydb_instance = "my-alloydb-instance"
alloydb_database = "finance"
alloydb_password = input("Please enter the password for the AlloyDB 'postgres' database user: ")

In [None]:
# Set env variable to suppress annoying system warnings when running shell commands
%env GRPC_ENABLE_FORK_SUPPORT=1

### Connect to your Google Cloud Project

In [None]:
# Configure gcloud.
!gcloud config set project {project_id}

### Configure Logging

In [None]:
import logging
import sys

# Configure the root logger to output messages with INFO level or above
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s[%(levelname)5s][%(name)14s] - %(message)s',  datefmt='%H:%M:%S', force=True)

### Install Dependencies

In [None]:
! pip install --quiet google-cloud-storage==2.19.0 \
                      asyncpg==0.30.0 \
                      google.cloud.alloydb.connector==1.9.0 

### Define Helper Functions

#### REST API Helper Function

In [None]:
import requests
import google.auth
import json

# Get an access token based upon the current user
creds, _ = google.auth.default()
authed_session = google.auth.transport.requests.AuthorizedSession(creds)
access_token=creds.token

if project_id:
  authed_session.headers.update({"x-goog-user-project": project_id}) # Required to workaround a project quota bug

def rest_api_helper(
    url: str,
    http_verb: str,
    request_body: dict = None,
    params: dict = None,
    session: requests.Session = authed_session,
  ) -> dict:
  """Calls a REST API using a pre-authenticated requests Session."""

  headers = {"Content-Type": "application/json"}

  try:

    if http_verb == "GET":
      response = session.get(url, headers=headers, params=params)
    elif http_verb == "POST":
      response = session.post(url, json=request_body, headers=headers, params=params)
    elif http_verb == "PUT":
      response = session.put(url, json=request_body, headers=headers, params=params)
    elif http_verb == "PATCH":
      response = session.patch(url, json=request_body, headers=headers, params=params)
    elif http_verb == "DELETE":
      response = session.delete(url, headers=headers, params=params)
    else:
      raise ValueError(f"Unknown HTTP verb: {http_verb}")

    # Raise an exception for bad status codes (4xx or 5xx)
    response.raise_for_status()

    # Check if response has content before trying to parse JSON
    if response.content:
        return response.json()
    else:
        return {} # Return empty dict for empty responses (like 204 No Content)

  except requests.exceptions.RequestException as e:
      # Catch potential requests library errors (network, timeout, etc.)
      # Log detailed error information
      print(f"Request failed: {e}")
      if e.response is not None:
          print(f"Request URL: {e.request.url}")
          print(f"Request Headers: {e.request.headers}")
          print(f"Request Body: {e.request.body}")
          print(f"Response Status: {e.response.status_code}")
          print(f"Response Text: {e.response.text}")
          # Re-raise a more specific error or a custom one
          raise RuntimeError(f"API call failed with status {e.response.status_code}: {e.response.text}") from e
      else:
          raise RuntimeError(f"API call failed: {e}") from e
  except json.JSONDecodeError as e:
      print(f"Failed to decode JSON response: {e}")
      print(f"Response Text: {response.text}")
      raise RuntimeError(f"Invalid JSON received from API: {response.text}") from e


#### AlloyDB Helper Function

In [None]:
# Create AlloyDB Query Helper Function
import sqlalchemy
from sqlalchemy import text, exc
import pandas as pd

async def run_alloydb_query(pool, sql: str, params = None, output_as_df: bool = True):
    """Executes a SQL query or statement against the database pool.

    Handles various SQL statements:
    - SELECT/WITH: Returns results as a DataFrame (if output_as_df=True)
      or ResultProxy. Supports parameters. Does not commit.
    - EXPLAIN/EXPLAIN ANALYZE: Executes the explain, returns the query plan
      as a formatted multi-line string. Ignores output_as_df.
      Supports parameters. Does not commit.
    - INSERT/UPDATE/DELETE/CREATE/ALTER etc.: Executes the statement,
      commits the transaction, logs info, and returns the ResultProxy.
      Supports single or bulk parameters (executemany).

    Args:
      pool: An asynchronous SQLAlchemy connection pool.
      sql: A string containing the SQL query or statement template.
      params: Optional.
        - None: Execute raw SQL (Use with caution for non-SELECT/EXPLAIN).
        - dict or tuple: Parameters for a single execution.
        - list of dicts/tuples: Parameters for bulk execution (executemany).
      output_as_df (bool): If True and query is SELECT/WITH, return pandas DataFrame.
                           Ignored for EXPLAIN and non-data-returning statements.

    Returns:
      pandas.DataFrame | str | sqlalchemy.engine.Result | None:
        - DataFrame: For SELECT/WITH if output_as_df=True.
        - str: For EXPLAIN/EXPLAIN ANALYZE, containing the formatted query plan.
        - ResultProxy: For non-SELECT/WITH/EXPLAIN statements, or SELECT/WITH
                       if output_as_df=False.
        - None: If a SQLAlchemy ProgrammingError or other specific error occurs.

    Raises:
        Exception: Catches and logs `sqlalchemy.exc.ProgrammingError`, returning None.
                   May re-raise other database exceptions.

    Example Execution:
      Single SELECT:
        sql_select = "SELECT ticker, company_name from investments LIMIT 5"
        df_result = await run_alloydb_query(pool, sql_select)

      Single non-SELECT - Parameterized (Safe!):
        Parameterized INSERT:
          sql_insert = "INSERT INTO investments (ticker, company_name) VALUES (:ticker, :name)"
          params_insert = {"ticker": "NEW", "name": "New Company"}
          insert_result = await run_alloydb_query(pool, sql_insert, params_insert)

        Parameterized UPDATE:
          sql_update = "UPDATE products SET price = :price WHERE id = :product_id"
          params_update = {"price": 99.99, "product_id": 123}
          update_result = await run_alloydb_query(pool, sql_update, params_update)

      Bulk Update:
        docs = pd.DataFrame([
            {'id': 101, 'sparse_embedding': '[0.1, 0.2]'},
            {'id': 102, 'sparse_embedding': '[0.3, 0.4]'},
            # ... more rows
        ])

        update_sql_template = '''
            UPDATE products
            SET sparse_embedding = :embedding,
                sparse_embedding_model = 'BM25'
            WHERE id = :product_id
        ''' # Using named parameters :param_name

        # Prepare list of dictionaries for params
        data_to_update = [
            {"embedding": row.sparse_embedding, "product_id": row.id}
            for row in docs.itertuples(index=False)
        ]

        if data_to_update:
          bulk_result = await run_alloydb_query(pool, update_sql_template, data_to_update)
          # bulk_result is the SQLAlchemy ResultProxy

    """
    sql_lower_stripped = sql.strip().lower()
    is_select_with = sql_lower_stripped.startswith(('select', 'with'))
    is_explain = sql_lower_stripped.startswith('explain')

    # Determine if the statement is expected to return data rows or a plan
    is_data_returning = is_select_with or is_explain

    # Determine actual DataFrame output eligibility (only for SELECT/WITH)
    effective_output_as_df = output_as_df and is_select_with

    # Check if params suggest a bulk operation (for logging purposes)
    is_bulk_operation = isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], (dict, tuple, list))

    async with pool.connect() as conn:
        try:
          # Execute with or without params
          if params:
              result = await conn.execute(text(sql), params)
          else:
              # Add warning for raw SQL only if it's NOT data-returning
              #if not is_data_returning:
                  #logging.warning("Executing non-SELECT/EXPLAIN raw SQL without parameters. Ensure SQL is safe.")
              result = await conn.execute(text(sql))

          # --- Handle statements that return data or plan ---
          if is_data_returning:
              if is_explain:
                  # Fetch and format EXPLAIN output as a string
                    try:
                        plan_rows = result.fetchall()
                        # EXPLAIN output is usually text in the first column
                        query_plan = "\n".join([str(row[0]) for row in plan_rows])
                        #logging.info(f"EXPLAIN executed successfully for: {sql[:100]}...")
                        return query_plan
                    except Exception as e:
                        logging.error(f"Error fetching/formatting EXPLAIN result: {e}")
                        return None
              else: # Handle SELECT / WITH
                  if effective_output_as_df:
                      try:
                          rows = result.fetchall()
                          column_names = result.keys()
                          df = pd.DataFrame(rows, columns=column_names)
                          #logging.info(f"SELECT/WITH executed successfully, returning DataFrame for: {sql[:100]}...")
                          return df
                      except Exception as e:
                          logging.error(f"Error converting SELECT result to DataFrame: {e}")
                          logging.info(f"Returning raw ResultProxy for SELECT/WITH due to DataFrame conversion error for: {sql[:100]}...")
                          return result # Fallback to raw result
                  else:
                      # Return raw result proxy for SELECT/WITH if df output not requested
                      #logging.info(f"SELECT/WITH executed successfully, returning ResultProxy for: {sql[:100]}...")
                      return result

          # --- Handle Non-Data Returning Statements (INSERT, UPDATE, DELETE, CREATE, etc.) ---
          else:
              await conn.commit() # Commit changes ONLY for these statements
              operation_type = sql.strip().split()[0].upper()
              row_count = result.rowcount # Note: rowcount behavior varies

              if is_bulk_operation:
                  print(f"Bulk {operation_type} executed for {len(params)} items. Result rowcount: {row_count}")
              elif operation_type in ['INSERT', 'UPDATE', 'DELETE']:
                  print(f"{operation_type} statement executed successfully. {row_count} row(s) affected.")
              else: # CREATE, ALTER, etc.
                  print(f"{operation_type} statement executed successfully. Result rowcount: {row_count}")
              return result # Return the result proxy

        except exc.ProgrammingError as e:
            # Log the error with context
            logging.error(f"SQL Programming Error executing query:\nSQL: {sql[:500]}...\nParams (sample): {str(params)[:500]}...\nError: {e}")
            # Rollback might happen automatically on context exit with error, but explicit can be clearer
            # await conn.rollback() # Consider if needed based on pool/transaction settings
            return None # Return None on handled programming errors
        except Exception as e:
            # Log other unexpected errors
            logging.error(f"An unexpected error occurred during query execution:\nSQL: {sql[:500]}...\nError: {e}")
            # await conn.rollback() # Consider if needed
            raise # Re-raise unexpected errors



## Setup Parameterized Secure Views

### Connect to the AlloyDB Cluster

This function will create a connection pool to your AlloyDB instance using the AlloyDB Python connector. The AlloyDB Python connector will automatically create secure connections to your AlloyDB instance using mTLS.

In [None]:
import asyncpg

import sqlalchemy
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

from google.cloud.alloydb.connector import AsyncConnector, IPTypes

async def init_connection_pool(connector: AsyncConnector, db_name: str = alloydb_database, pool_size: int = 5) -> AsyncEngine:
    # initialize Connector object for connections to AlloyDB
    connection_string = f"projects/{project_id}/locations/{region}/clusters/{alloydb_cluster}/instances/{alloydb_instance}"

    async def getconn() -> asyncpg.Connection:
        conn: asyncpg.Connection = await connector.connect(
            connection_string,
            "asyncpg",
            user="postgres",
            password=alloydb_password,
            db=db_name,
            ip_type=IPTypes.PRIVATE, # Optionally use IPTypes.PUBLIC
        )
        return conn

    pool = create_async_engine(
        "postgresql+asyncpg://",
        async_creator=getconn,
        pool_size=pool_size,
        max_overflow=0,
        isolation_level='AUTOCOMMIT'
    )
    return pool

connector = AsyncConnector()

finance_db_pool = await init_connection_pool(connector, f"{alloydb_database}")

### Create `parameterized_views` Extension

In [None]:
sql = "CREATE EXTENSION IF NOT EXISTS parameterized_views;"
await run_alloydb_query(finance_db_pool, sql)

### Create a Parameterized Secure View

In [None]:
sql = "SELECT extversion FROM pg_extension WHERE extname = 'alloydb_ai_nl';"
await run_alloydb_query(postgres_db_pool, sql)