# Multimodal Hybrid Product Search on AlloyDB - With Data Prep

This notebook provides a step-by-step example of implementing Hybrid Search in [AlloyDB for PostgreSQL](https://cloud.google.com/products/alloydb?e=48754805&hl=en) for Cymbal Shops, a fictional retailer with a large eCommerce presence. It combines multimodal vector embeddings ([`multimodalembedding@001`](https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-multimodal-embeddings)), fulltext search ([Generalized Inverted Index](https://www.postgresql.org/docs/current/gin.html)), and [BM25](https://en.wikipedia.org/wiki/Okapi_BM25) sparse embeddings ([pgvector 0.7.0+](https://github.com/pgvector/pgvector?tab=readme-ov-file#sparse-vectors)) with [Reciprocal Rank Fusion](https://medium.com/@devalshah1619/mathematical-intuition-behind-reciprocal-rank-fusion-rrf-explained-in-2-mins-002df0cc5e2a) re-ranking for enhanced product search.

> **IMPORTANT:** This notebook leverages Preview features in AlloyDB AI. Create your AlloyDB cluster (see below), then **[Sign up for the preview](https://docs.google.com/forms/d/e/1FAIpQLSfJ9vHIJ79nI7JWBDELPFL75pDQa4XVZQ2fxShfYddW0RwmLw/viewform)** to take full advantage of the features in this notebook. Use `ecom` as the database name in your request (you will create this database later).


The high-level flow is as follows:
- Imports a sample retail dataset (based on [theLook eCommerce dataset](https://console.cloud.google.com/marketplace/product/bigquery-public-data/thelook-ecommerce)) into an AlloyDB cluster.
- Asynchronously generates product descriptions for 29,120 products using the Gemini 2.0 Flash model.
- Asynchronously generates product images for 29,120 products using the Imagen 3 model.
- Asynchronously generates multimodal embeddings for the product images and product descriptions using the `multimodalembedding@001` model.
- Generates sparse embeddings for products using BM25.
- Creates [ScaNN vector indexes](https://cloud.google.com/blog/products/databases/understanding-the-scann-index-in-alloydb) for fast dense vector embedding queries.
- Creates an [HNSW](https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw) index for efficient sparse vector embedding queries.
- Creates a [GIN index](https://www.postgresql.org/docs/current/gin.html) for fast full-text search.
- Demonstrates a variety of search techniques (Traditional SQL, Dense Vector Search, Full-text Search, BM25 Sparse Vector Search, and Multimodal Search).
- Demonstrates combining techniques with blended hybrid search (Vector + FTS + BM25 + Traditional SQL)

To Do:
- Fix multimodal text embeddings (image embeddings work, but multimodal text embeddings are inconsistent).
- Add Search by Video.
- Add Semantic Ranking

## Basic Setup

You will need an AlloyDB for PostgreSQL instance to use this notebook. Create one now if you have not already created it.

### Define Variables

In [None]:
# Update these variables to match your environment
project_id = "your-project"  # @param {type:"string"}
region = "your-region"  # @param {type:"string"}
vpc = "your-vpc"  # @param {type:"string"}
image_bucket = "your-bucket"  # @param {type:"string"}
index_bucket = "your-bucket"  # @param {type:"string"}
export_bucket = "your-bucket"  # @param {type:"string"}
alloydb_cluster = "your-alloydb-cluster"  # @param {type:"string"}
alloydb_instance = "your-alloydb-instance"  # @param {type:"string"}
alloydb_password = input("Please provide a password to be used for 'postgres' database user: ")

# Don't change values below this line.
alloydb_database = "ecom" 
database_backup_uri = "gs://pr-public-demo-data/alloydb-retail-demo/data/ecom.sql"


### Install Dependencies

In [None]:
! pip install --quiet google-cloud-storage==2.19.0 \
                      google-cloud-aiplatform==1.74.0 \
                      pymilvus.model==0.3.2 \
                      asyncpg==0.30.0 \
                      google.cloud.alloydb.connector==1.9.0 \
                      jupyter-server==1.24.0 \
                      google-genai==1.4.0


### Connect Your Google Cloud Project

In [None]:
# Configure gcloud.
!gcloud config set project {project_id}

### Configure Logging

In [None]:
import logging
import sys

# Configure the root logger to output messages with INFO level or above
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s[%(levelname)5s][%(name)14s] - %(message)s',  datefmt='%H:%M:%S', force=True)

### Enable APIs for AlloyDB, Vertex AI, and Discovery Engine

You will need to enable these APIs in order to create an AlloyDB database and utilize Vertex AI as an embeddings service!

In [None]:
!gcloud services enable alloydb.googleapis.com aiplatform.googleapis.com discoveryengine.googleapis.com

### Initialize GenAI Client

In [None]:
from google import genai
from google.genai import types

genai_client = genai.Client(
    vertexai=True, project=project_id, location=region
)

## Define Helper Functions

#### rest_api_helper()

In [None]:
from google.auth.transport import requests
import google.auth.transport.requests
import requests
import google.auth
import json

# Get an access token based upon the current user
creds, _ = google.auth.default()
authed_session = google.auth.transport.requests.AuthorizedSession(creds)
access_token=creds.token

if project_id:
  authed_session.headers.update({"x-goog-user-project": project_id}) # Required to workaround a project quota bug

def rest_api_helper(
    session: requests.Session,
    url: str,
    http_verb: str,
    request_body: dict = None,
    params: dict = None
  ) -> dict:
  """Calls a REST API using a pre-authenticated requests Session."""

  headers = {"Content-Type": "application/json"}

  try:

    if http_verb == "GET":
      response = session.get(url, headers=headers, params=params)
    elif http_verb == "POST":
      response = session.post(url, json=request_body, headers=headers, params=params)
    elif http_verb == "PUT":
      response = session.put(url, json=request_body, headers=headers, params=params)
    elif http_verb == "PATCH":
      response = session.patch(url, json=request_body, headers=headers, params=params)
    elif http_verb == "DELETE":
      response = session.delete(url, headers=headers, params=params)
    else:
      raise ValueError(f"Unknown HTTP verb: {http_verb}")

    # Raise an exception for bad status codes (4xx or 5xx)
    response.raise_for_status()

    # Check if response has content before trying to parse JSON
    if response.content:
        return response.json()
    else:
        return {} # Return empty dict for empty responses (like 204 No Content)

  except requests.exceptions.RequestException as e:
      # Catch potential requests library errors (network, timeout, etc.)
      # Log detailed error information
      print(f"Request failed: {e}")
      if e.response is not None:
          print(f"Request URL: {e.request.url}")
          print(f"Request Headers: {e.request.headers}")
          print(f"Request Body: {e.request.body}")
          print(f"Response Status: {e.response.status_code}")
          print(f"Response Text: {e.response.text}")
          # Re-raise a more specific error or a custom one
          raise RuntimeError(f"API call failed with status {e.response.status_code}: {e.response.text}") from e
      else:
          raise RuntimeError(f"API call failed: {e}") from e
  except json.JSONDecodeError as e:
      print(f"Failed to decode JSON response: {e}")
      print(f"Response Text: {response.text}")
      raise RuntimeError(f"Invalid JSON received from API: {response.text}") from e



#### run_query()

In [None]:
# Create AlloyDB Query Helper Function
import sqlalchemy
from sqlalchemy import text, exc
import pandas as pd

async def run_query(pool, sql: str, params = None, output_as_df: bool = True):
    """Executes a SQL query or statement against the database pool.

    Handles various SQL statements:
    - SELECT/WITH: Returns results as a DataFrame (if output_as_df=True)
      or ResultProxy. Supports parameters. Does not commit.
    - EXPLAIN/EXPLAIN ANALYZE: Executes the explain, returns the query plan
      as a formatted multi-line string. Ignores output_as_df.
      Supports parameters. Does not commit.
    - INSERT/UPDATE/DELETE/CREATE/ALTER etc.: Executes the statement,
      commits the transaction, logs info, and returns the ResultProxy.
      Supports single or bulk parameters (executemany).

    Args:
      pool: An asynchronous SQLAlchemy connection pool.
      sql: A string containing the SQL query or statement template.
      params: Optional.
        - None: Execute raw SQL (Use with caution for non-SELECT/EXPLAIN).
        - dict or tuple: Parameters for a single execution.
        - list of dicts/tuples: Parameters for bulk execution (executemany).
      output_as_df (bool): If True and query is SELECT/WITH, return pandas DataFrame.
                           Ignored for EXPLAIN and non-data-returning statements.

    Returns:
      pandas.DataFrame | str | sqlalchemy.engine.Result | None:
        - DataFrame: For SELECT/WITH if output_as_df=True.
        - str: For EXPLAIN/EXPLAIN ANALYZE, containing the formatted query plan.
        - ResultProxy: For non-SELECT/WITH/EXPLAIN statements, or SELECT/WITH
                       if output_as_df=False.
        - None: If a SQLAlchemy ProgrammingError or other specific error occurs.

    Raises:
        Exception: Catches and logs `sqlalchemy.exc.ProgrammingError`, returning None.
                   May re-raise other database exceptions.

    Example Execution:
      Single SELECT:
        sql_select = "SELECT ticker, company_name from investments LIMIT 5"
        df_result = await run_query(pool, sql_select)

      Single non-SELECT - Parameterized (Safe!):
        Parameterized INSERT:
          sql_insert = "INSERT INTO investments (ticker, company_name) VALUES (:ticker, :name)"
          params_insert = {"ticker": "NEW", "name": "New Company"}
          insert_result = await run_query(pool, sql_insert, params_insert)

        Parameterized UPDATE:
          sql_update = "UPDATE products SET price = :price WHERE id = :product_id"
          params_update = {"price": 99.99, "product_id": 123}
          update_result = await run_query(pool, sql_update, params_update)

      Bulk Update:
        docs = pd.DataFrame([
            {'id': 101, 'sparse_embedding': '[0.1, 0.2]'},
            {'id': 102, 'sparse_embedding': '[0.3, 0.4]'},
            # ... more rows
        ])

        update_sql_template = '''
            UPDATE products
            SET sparse_embedding = :embedding,
                sparse_embedding_model = 'BM25'
            WHERE id = :product_id
        ''' # Using named parameters :param_name

        # Prepare list of dictionaries for params
        data_to_update = [
            {"embedding": row.sparse_embedding, "product_id": row.id}
            for row in docs.itertuples(index=False)
        ]

        if data_to_update:
          bulk_result = await run_query(pool, update_sql_template, data_to_update)
          # bulk_result is the SQLAlchemy ResultProxy

    """
    sql_lower_stripped = sql.strip().lower()
    is_select_with = sql_lower_stripped.startswith(('select', 'with'))
    is_explain = sql_lower_stripped.startswith('explain')

    # Determine if the statement is expected to return data rows or a plan
    is_data_returning = is_select_with or is_explain

    # Determine actual DataFrame output eligibility (only for SELECT/WITH)
    effective_output_as_df = output_as_df and is_select_with

    # Check if params suggest a bulk operation (for logging purposes)
    is_bulk_operation = isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], (dict, tuple, list))

    async with pool.connect() as conn:
        try:
          # Execute with or without params
          if params:
              result = await conn.execute(text(sql), params)
          else:
              # Add warning for raw SQL only if it's NOT data-returning
              #if not is_data_returning:
                  #logging.warning("Executing non-SELECT/EXPLAIN raw SQL without parameters. Ensure SQL is safe.")
              result = await conn.execute(text(sql))

          # --- Handle statements that return data or plan ---
          if is_data_returning:
              if is_explain:
                  # Fetch and format EXPLAIN output as a string
                    try:
                        plan_rows = result.fetchall()
                        # EXPLAIN output is usually text in the first column
                        query_plan = "\n".join([str(row[0]) for row in plan_rows])
                        #logging.info(f"EXPLAIN executed successfully for: {sql[:100]}...")
                        return query_plan
                    except Exception as e:
                        logging.error(f"Error fetching/formatting EXPLAIN result: {e}")
                        return None
              else: # Handle SELECT / WITH
                  if effective_output_as_df:
                      try:
                          rows = result.fetchall()
                          column_names = result.keys()
                          df = pd.DataFrame(rows, columns=column_names)
                          #logging.info(f"SELECT/WITH executed successfully, returning DataFrame for: {sql[:100]}...")
                          return df
                      except Exception as e:
                          logging.error(f"Error converting SELECT result to DataFrame: {e}")
                          logging.info(f"Returning raw ResultProxy for SELECT/WITH due to DataFrame conversion error for: {sql[:100]}...")
                          return result # Fallback to raw result
                  else:
                      # Return raw result proxy for SELECT/WITH if df output not requested
                      #logging.info(f"SELECT/WITH executed successfully, returning ResultProxy for: {sql[:100]}...")
                      return result

          # --- Handle Non-Data Returning Statements (INSERT, UPDATE, DELETE, CREATE, etc.) ---
          else:
              await conn.commit() # Commit changes ONLY for these statements
              operation_type = sql.strip().split()[0].upper()
              row_count = result.rowcount # Note: rowcount behavior varies

              if is_bulk_operation:
                  print(f"Bulk {operation_type} executed for {len(params)} items. Result rowcount: {row_count}")
              elif operation_type in ['INSERT', 'UPDATE', 'DELETE']:
                  print(f"{operation_type} statement executed successfully. {row_count} row(s) affected.")
              else: # CREATE, ALTER, etc.
                  print(f"{operation_type} statement executed successfully. Result rowcount: {row_count}")
              return result # Return the result proxy

        except exc.ProgrammingError as e:
            # Log the error with context
            logging.error(f"SQL Programming Error executing query:\nSQL: {sql[:500]}...\nParams (sample): {str(params)[:500]}...\nError: {e}")
            # Rollback might happen automatically on context exit with error, but explicit can be clearer
            # await conn.rollback() # Consider if needed based on pool/transaction settings
            return None # Return None on handled programming errors
        except Exception as e:
            # Log other unexpected errors
            logging.error(f"An unexpected error occurred during query execution:\nSQL: {sql[:500]}...\nError: {e}")
            # await conn.rollback() # Consider if needed
            raise # Re-raise unexpected errors



### retry_condition()

In [None]:
from tenacity import retry, wait_exponential, stop_after_attempt, before_sleep_log, retry_if_exception, wait_fixed

def retry_condition(error):
  error_string = str(error)
  print(error_string)

  retry_errors = [
      "429 Quota exceeded",
      #"The prompt could not be submitted",
  ]

  for retry_error in retry_errors:
    if retry_error in error_string:
      print("Retrying...")
      return True

  return False

### async_generate_text()

In [None]:
async def async_generate_text(prompt):
  result = await genai_client.aio.models.generate_content(
                model='gemini-2.0-flash',
                contents=[prompt]
            )
  return result.candidates[0].content.parts[0].text

### sync_generate_text()  

In [None]:
def sync_generate_text(prompt):
    result = genai_client.models.generate_content(
                  model='gemini-2.0-flash',
                  contents=[prompt]
              )
    return result.candidates[0].content.parts[0].text


### async_generate_embedding()

In [None]:
async def async_generate_embedding(input):
  result = await genai_client.aio.models.embed_content(
    model='gemini-embedding-001',
    contents=input
  )
  return result.embeddings[0].values

### generate_image()

In [None]:
# Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images
#            https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api

import os
import vertexai
from google.cloud import storage
from vertexai.preview.vision_models import ImageGenerationModel

vertexai.init(project=project_id, location=region)
storage_client = storage.Client()
bucket = storage_client.bucket(image_bucket)

imagen_model = ImageGenerationModel.from_pretrained("imagen-3.0-fast-generate-001")

@retry(wait=wait_exponential(multiplier=1, min=1, max=10), stop=stop_after_attempt(2), retry=retry_if_exception(retry_condition), before_sleep=before_sleep_log(logging.getLogger(), logging.INFO))
def generate_image(prompt, product_sku, id):

    image_name = f"{product_sku}.png"
    destination_blob_name = f"product-images/{image_name}"
    logging.info(f"Generating image for: {image_name})")

    images = imagen_model.generate_images(
        prompt=prompt,
        number_of_images=1,
        language="en",
        add_watermark=True,
        aspect_ratio="1:1",
        safety_filter_level="block_some",
        person_generation="dont_allow",
    )

    if not images.images:
      logging.info(f"RETRY 1: Retrying with a different prompt for {image_name}.")
      rewritten_prompt = sync_generate_text(f"Responding in 1 sentence, simplify this prompt for Imagen3: {prompt}")
      logging.info(f"Modified prompt for id {id} {image_name}: {rewritten_prompt}")

      images = imagen_model.generate_images(
          prompt=rewritten_prompt,
          number_of_images=1,
          language="en",
          add_watermark=True,
          aspect_ratio="1:1",
          safety_filter_level="block_only_high",
          person_generation="dont_allow",
      )

      if not images.images:
        logging.warning(f"FAILED: Image generation failed for {image_name}. Prompt: {prompt}")
        return None

    #logging.info(f"Done generating image for: {image_name})")

    # Write the image locally
    #logging.info(f"Writing image locally: {image_name})")
    local_filename = f"{image_name}"
    images[0].save(location=local_filename, include_generation_parameters=False)

    # Upload the image to GCS
    #logging.info(f"Uploading to GCS: {image_name})")
    blob = bucket.blob(destination_blob_name)
    blob.upload_from_filename(local_filename, content_type='image/png')


    # Clean up the local file
    #logging.info(f"Removing the local file: {image_name})")
    os.remove(local_filename)

    #logging.info(f"Returning URI: {image_name})")
    return f"gs://{image_bucket}/{destination_blob_name}" # Return GCS URI

### generate_multimodal_embeddings()

In [None]:
import vertexai
from vertexai.vision_models import Image as vai_image, MultiModalEmbeddingModel # Import Image as vai_image to avoid collions with PIL Image

vertexai.init(project=project_id, location=region)
mme_model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001")

@retry(wait=wait_fixed(10), stop=stop_after_attempt(7), retry=retry_if_exception(retry_condition))
def generate_multimodal_embeddings(uri, text):

  # If no image is provided, just generate text embeddings.
  if uri is None:
      embeddings = mme_model.get_embeddings(
          contextual_text=text,
          dimension=1408,
      )
      return embeddings

  # Load image
  image = vai_image.load_from_file(f"{uri}")

  # If no text is provided, just generate image embeddings
  if text is None:
      embeddings = mme_model.get_embeddings(
          image=image,
          dimension=1408,
      )
      return embeddings

  # If image and text are provided, generate both text and image embeddings
  embeddings = mme_model.get_embeddings(
      image=image,
      contextual_text=text,
      dimension=1408,
  )
  return embeddings

### df_image_formatter()

In [None]:
from IPython.display import HTML
from PIL import Image

def df_image_formatter(path, width = 200):
    """
    Formats an image path into an HTML image tag.
    """
    path = path.replace(f"gs://{image_bucket}",f"https://storage.cloud.google.com/{image_bucket}")
    return f"<img src='{path}' width={width}>"


## OPTIONAL: Create an AlloyDB Cluster

You will need an AlloyDB for PostgreSQL cluster to use this notebook. If you already have an AlloyDB cluster, you can skip to the `Connect to the AlloyDB Cluster` section. Otherwise, use the cells below to create one.

> ⏳ - Creating an AlloyDB cluster may take a few minutes.

In [None]:
# create the AlloyDB Cluster
!gcloud beta alloydb clusters create {alloydb_cluster} --password={alloydb_password} --region={region}

# Create the AlloyDB Instance
!gcloud beta alloydb instances create {alloydb_instance} --instance-type=PRIMARY --cpu-count=2 --region={region} --cluster={alloydb_cluster}

To connect to your AlloyDB instance from this notebook, you will need to enable public IP on your instance. Alternatively, you can follow [these instructions](https://cloud.google.com/alloydb/docs/connect-external) to connect to an AlloyDB for PostgreSQL instance with Private IP from outside your VPC. You can also use the `--authorized-external-networks` flag to limit communication over public IP to specific IP address ranges if desired.

In [None]:
# Enable Public IP on AlloyDB
!gcloud beta alloydb instances update {instance_name} --region={region} --cluster={cluster_name} --assign-inbound-public-ip=ASSIGN_IPV4 --database-flags="password.enforce_complexity=on"

## Connect to AlloyDB Cluster

This function will create a connection pool to your AlloyDB instance using the AlloyDB Python connector. The AlloyDB Python connector will automatically create secure connections to your AlloyDB instance using mTLS.

In [None]:
import asyncpg

import sqlalchemy
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

from google.cloud.alloydb.connector import AsyncConnector, IPTypes

async def init_connection_pool(connector: AsyncConnector, db_name: str = alloydb_database, pool_size: int = 5) -> AsyncEngine:
    # initialize Connector object for connections to AlloyDB
    connection_string = f"projects/{project_id}/locations/{region}/clusters/{alloydb_cluster}/instances/{alloydb_instance}"

    async def getconn() -> asyncpg.Connection:
        conn: asyncpg.Connection = await connector.connect(
            connection_string,
            "asyncpg",
            user="postgres",
            password=alloydb_password,
            db=db_name,
            ip_type=IPTypes.PRIVATE, # Optionally use IPTypes.PUBLIC
        )
        return conn

    pool = create_async_engine(
        "postgresql+asyncpg://",
        async_creator=getconn,
        pool_size=pool_size,
        max_overflow=0,
        isolation_level='AUTOCOMMIT'
    )
    return pool

connector = AsyncConnector()

postgres_db_pool = await init_connection_pool(connector, "postgres")
ecom_db_pool = await init_connection_pool(connector, f"{alloydb_database}")

## Import Sample Data to AlloyDB

### Add Required Permissions

In [None]:
project_number = ! gcloud projects describe {project_id} --format='value(projectNumber)'
project_number = project_number[0]

# These permissions are required to read from GCS for the data import and to integrate with Vertex AI for on-the-fly embedding generation.
roles_array = [
    "roles/storage.admin",
    "roles/aiplatform.user",
    "roles/discoveryengine.admin",
    "roles/secretmanager.secretAccessor",
]

for r in roles_array:
  ! gcloud projects add-iam-policy-binding {project_id} \
      --member="serviceAccount:service-{project_number}@gcp-sa-alloydb.iam.gserviceaccount.com" \
      --role="{r}"

# These permissions are required to generate a token later for semantic ranking models
roles_array = [
    "roles/iam.serviceAccountTokenCreator"
]

current_user = ! gcloud auth list --format='value(account)'
current_user = current_user[0]

for r in roles_array:
  ! gcloud projects add-iam-policy-binding {project_id} \
      --member="user:{current_user}" \
      --role="{r}"


### OPTIONAL: Drop Existing Database

In [None]:
# Close existing connections to the database
sql = f"""SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{alloydb_database}'
  AND pid <> pg_backend_pid();"""
await run_query(postgres_db_pool, sql)

# Uncomment the lines below to drop an existing database before re-creating it
#sql = f"DROP DATABASE {alloydb_database};"
#await run_query(postgres_db_pool, sql)

# Reinitiate the connection pool
ecom_db_pool = await init_connection_pool(connector, f"{alloydb_database}")

### Create Database

In [None]:
# Create the database
sql = f"CREATE DATABASE {alloydb_database};"
await run_query(postgres_db_pool, sql)

### Install Pre-requisite Extensions

In [None]:
sql_array = []

sql_array.append("CREATE EXTENSION IF NOT EXISTS vector;")

sql_array.append("CREATE EXTENSION IF NOT EXISTS google_ml_integration;")

for sql in sql_array:
  await run_query(ecom_db_pool, sql)

### Run the Import

In [None]:
# Reference: https://cloud.google.com/alloydb/docs/reference/rest/v1/projects.locations.clusters/import
#            https://cloud.google.com/alloydb/docs/import-sql-file

import time

url = f"https://alloydb.googleapis.com/v1/projects/{project_id}/locations/{region}/clusters/{alloydb_cluster}:import"
request_body = {
   "gcsUri": f"{database_backup_uri}",
   "database": f"{alloydb_database}",
   "user": "postgres",
   "sqlImportOptions": {}
}

result = rest_api_helper(authed_session, url, 'POST', request_body, {})
print(f"Kicked off import: {result}")

operation_id = result['name']

operation_complete = False
while operation_complete == False:
  print(f"Import still running: {operation_id}")
  url = f"https://alloydb.googleapis.com/v1/{operation_id}"
  response = rest_api_helper(authed_session, url, 'GET', request_body, {})
  operation_complete = response['done']
  if operation_complete:
    print(f"Operation complete. Check result payload for potential errors. \nResult: {response}")
    continue
  time.sleep(5)

### Check Row Counts

In [None]:
sql = """
SELECT 'distribution_centers' AS table_name, (SELECT COUNT(*) FROM distribution_centers) AS actual_row_count, 10 AS target_row_count
UNION ALL
SELECT 'events', (SELECT COUNT(*) FROM events), 2438862
UNION ALL
SELECT 'inventory_items', (SELECT COUNT(*) FROM inventory_items), 494254
UNION ALL
SELECT 'orders', (SELECT COUNT(*) FROM orders), 125905
UNION ALL
SELECT 'order_items', (SELECT COUNT(*) FROM order_items), 182905
UNION ALL
SELECT 'products', (SELECT COUNT(*) FROM products), 29120
UNION ALL
SELECT 'users', (SELECT COUNT(*) FROM users), 100000;
"""

await run_query(ecom_db_pool, sql)

### Test the Vertex AI Integration

In [None]:
sql = "SELECT embedding('gemini-embedding-001', 'This string will be transformed into an embedding.');"
await run_query(ecom_db_pool, sql)

## OPTIONAL: Remove Branding

### Remove Brands from Product Names

In [None]:
sql = "SELECT COUNT(*) FROM products WHERE name ILIKE CONCAT('%', brand, '%') "
await run_query(ecom_db_pool, sql)

In [None]:
sql = """UPDATE products SET name = regexp_replace(name, brand, '!!BRAND!!', 'gi')
WHERE name ILIKE '%' || brand || '%';"""

await run_query(ecom_db_pool, sql)

### Create New Brand Names

#### Define Prompt

In [None]:
def build_brand_prompt(brand):
  prompt = f"""**Persona & Context:**
You are a skilled Copywriter for Cymbal Shops. Cymbal Shops is a specialty big box retailer offering a diverse, curated mix of trendy clothing, unique household knick-knacks, stylish furnishings, and interesting personal items. Our customers appreciate value, style, and finding items with personality.

**Task:**
Create a new, completely original Brand Name analogous to the old Brand Name.

**Output Requirements:**
* **Length:** 1-3 words (approx. 3-5 sentences).
* **Goal:** The new Brand Name should be analogous to the old Brand Name, but it should be entirely original.
* **Format:** Text with NO formatting. Output ONLY the new Brand Name.
* **Exclusions:** The new Brand Name MUST NOT be a well-known Brand Name that is already in use.

**Product Information to Use:**
Old Brand Name: {brand}

**Now, come up with a new, completely original Brand Name.**
"""
  return prompt

#### Create Mapping Table

In [None]:
sql_array = []

sql_array.append("DROP TABLE IF EXISTS brand_mapping;")

sql_array.append("""
CREATE TABLE brand_mapping
AS
SELECT DISTINCT brand FROM products;
""")

for sql in sql_array:
  await run_query(ecom_db_pool, sql)

In [None]:
sql = """
ALTER TABLE brand_mapping
ADD COLUMN new_brand TEXT
"""
await run_query(ecom_db_pool, sql)

In [None]:
sql = """SELECT brand FROM brand_mapping WHERE new_brand IS NULL;"""
brand_df = await run_query(ecom_db_pool, sql)
brand_df

#### Build Dataframe

In [None]:
brand_df['prompt'] = brand_df.apply(
    lambda row: build_brand_prompt(
        row['brand'],
    ),
    axis=1
)
brand_df

#### Generate New Brands

In [None]:
import asyncio

async def load_queue_from_dataframe(df: pd.DataFrame, queue: asyncio.Queue, num_consumers: int):
    """
    Iterates through DataFrame rows and puts them into the asyncio queue.

    Args:
        df: The Pandas DataFrame to process.
        queue: The asyncio.Queue to put items into.
    """
    logging.info(f"Producer: Starting to load {len(df)} items into the queue...")
    # Use itertuples for efficiency. index=False avoids adding the DataFrame index.
    # name=None uses default namedtuple name 'Pandas'
    for row_tuple in df.itertuples(index=False, name='Product'):
        # Convert the named tuple to a dictionary - often easier for consumers
        item = row_tuple._asdict()
        await queue.put(item)
        logging.info(f"Producer: Put item {item.get('sku', 'N/A')} into queue. Queue size: {queue.qsize()}")
    logging.info("Producer: Finished loading all items.")

    # (Optional but Recommended) Add sentinel values (e.g., None) to signal completion
    # If you have multiple consumers, add one sentinel per consumer.
    # Add one sentinel per consumer
    for _ in range(num_consumers):
        await queue.put(None)
    print(f"Producer: Added {num_consumers} sentinel(s) to queue.")


async def process_items_from_queue(queue: asyncio.Queue, worker_id: int):
    """
    Continuously gets items from the queue and processes them until a sentinel is received.
    """
    logging.info(f"Consumer {worker_id}: Started...")
    while True:
        item = await queue.get()
        if item is None:
            # Sentinel received, signal task completion
            logging.info(f"Consumer {worker_id}: Sentinel received. Exiting.")
            queue.task_done()
            # Put the sentinel back if there are other consumers (not needed here as we only add one)
            # await queue.put(None)
            break # Exit the loop

        # --- Process the item ---
        new_brand = await async_generate_text(item.get('prompt'))
        new_brand = new_brand.replace("'","''")
        old_brand = item.get('brand')
        old_brand = old_brand.replace("'","''")
        sql = f"""UPDATE brand_mapping SET new_brand = '{new_brand}' WHERE brand = '{old_brand}';"""
        print(sql)
        await run_query(ecom_db_pool, sql)
        logging.info(f"Consumer {worker_id}: Finished processing brand: {item.get('brand')}. New brand: {new_brand}")
        # --- End processing ---

        queue.task_done() # Signal that this item processing is complete


async def generate_product_descriptions_concurrently(products_df):
    # Create the queue. You can optionally set a maxsize.
    # If maxsize is reached, the producer's `await queue.put(item)` will pause
    # until a consumer calls `queue.get()`, providing backpressure.
    work_queue = asyncio.Queue(maxsize=100)
    num_consumers = 10

    # Create tasks for the producer and consumer(s)
    producer_task = asyncio.create_task(load_queue_from_dataframe(products_df, work_queue, num_consumers))

    # Create one or more consumer tasks
    consumer_tasks = []
    for i in range(num_consumers):
        consumer_tasks.append(
            asyncio.create_task(process_items_from_queue(work_queue, i + 1))
        )

    # Wait for the producer to finish loading (optional, but good to ensure all items are queued)
    await producer_task

    # Wait for all consumers to finish processing all items
    # This relies on consumers calling queue.task_done() for each item + the sentinel
    await work_queue.join() # Wait until the queue is fully processed
    logging.info("Main: Queue has been fully processed.")


await generate_product_descriptions_concurrently(brand_df)


#### Remove New Lines

In [None]:
sql = """
UPDATE brand_mapping
SET new_brand = REPLACE(new_brand, '\n', '')
"""
await run_query(ecom_db_pool, sql)

#### Update Products Table with New Brands

In [None]:
sql = """
UPDATE products p
SET brand = SUBSTRING(b.new_brand, 0, 254),
    name = SUBSTRING(REPLACE(name, '!!BRAND!!', b.new_brand), 0, 254)
FROM brand_mapping b
WHERE p.brand = b.brand;
"""
await run_query(ecom_db_pool, sql)

#### Update Inventory Items Table with New Brands

In [None]:

sql = """
UPDATE inventory_items i
SET product_name = p.name,
    product_brand = p.brand
FROM products p
WHERE p.id = i.product_id
"""

await run_query(ecom_db_pool, sql)

## Generate Product Descriptions with Gemini 2.0 Flash

### Add product_description Column

In [None]:
sql = """ALTER TABLE products ADD COLUMN product_description TEXT;"""
await run_query(ecom_db_pool, sql)

### build_product_description_prompt()

In [None]:
def build_product_description_prompt(name, brand, category, department, retail_price, sku):
  prompt = f"""**Persona & Context:**
You are a skilled Copywriter for Cymbal Shops. Cymbal Shops is a specialty big box retailer offering a diverse, curated mix of trendy clothing, unique household knick-knacks, stylish furnishings, and interesting personal items. Our customers appreciate value, style, and finding items with personality. Our brand voice is:
* Approachable & Friendly
* Professional & Trustworthy
* Slightly Quirky & Stylish
* Focused on Value & Benefits

**Task:**
Write a compelling product description for the Cymbal Shops online product catalog page.

**Output Requirements:**
* **Length:** 50-75 words (approx. 3-5 sentences).
* **Tone:** Match the Cymbal Shops brand voice described above.
* **Goal:** Engage the customer and highlight the key benefits and appeal of the product. Translate features into benefits.
* **Format:** A single paragraph of prose. Optionally, include 2-3 key feature bullet points after the main paragraph if features are distinct and numerous.
* **Exclusions:** Do NOT mention the SKU or Retail Price within the written description.

**Product Information to Use:**
Product Name: {name}
Brand: {brand}
Category: {category}
Department: {department}
Retail Price: {retail_price}  (For context only, do not include in description)
SKU: {sku} (For context only, do not include in description)

**Now, write the description for the product detailed above.**
"""
  return prompt

### Get Products Without Descriptions

In [None]:
sql = """SELECT id, name, brand, category, department, retail_price, sku FROM products WHERE product_description IS NULL;"""
products_df = await run_query(ecom_db_pool, sql)
products_df

### Build Product Description Prompts

In [None]:
products_df['prompt'] = products_df.apply(
    lambda row: build_product_description_prompt(
        row['name'],
        row['brand'],
        row['category'],
        row['department'],
        row['retail_price'],
        row['sku']
    ),
    axis=1
)
products_df

### Generate Product Descriptions Asynchronously

> NOTE: You may want to grab some coffee or tea. This step will take about 50 minutes to complete. You can adjust `work_queue` and `num_consumers` to balance processing speed vs throttling/quota limits.

In [None]:
import asyncio

async def load_queue_from_dataframe(df: pd.DataFrame, queue: asyncio.Queue, num_consumers: int):
    """
    Iterates through DataFrame rows and puts them into the asyncio queue.

    Args:
        df: The Pandas DataFrame to process.
        queue: The asyncio.Queue to put items into.
    """
    logging.info(f"Producer: Starting to load {len(df)} items into the queue...")
    # Use itertuples for efficiency. index=False avoids adding the DataFrame index.
    # name=None uses default namedtuple name 'Pandas'
    for row_tuple in df.itertuples(index=False, name='Product'):
        # Convert the named tuple to a dictionary - often easier for consumers
        item = row_tuple._asdict()
        await queue.put(item)
        logging.info(f"Producer: Put item {item.get('sku', 'N/A')} into queue. Queue size: {queue.qsize()}")
    logging.info("Producer: Finished loading all items.")

    # (Optional but Recommended) Add sentinel values (e.g., None) to signal completion
    # If you have multiple consumers, add one sentinel per consumer.
    # Add one sentinel per consumer
    for _ in range(num_consumers):
        await queue.put(None)
    print(f"Producer: Added {num_consumers} sentinel(s) to queue.")


async def process_items_from_queue(queue: asyncio.Queue, worker_id: int):
    """
    Continuously gets items from the queue and processes them until a sentinel is received.
    """
    logging.info(f"Consumer {worker_id}: Started...")
    while True:
        item = await queue.get()
        if item is None:
            # Sentinel received, signal task completion
            logging.info(f"Consumer {worker_id}: Sentinel received. Exiting.")
            queue.task_done()
            # Put the sentinel back if there are other consumers (not needed here as we only add one)
            # await queue.put(None)
            break # Exit the loop

        # --- Process the item ---
        product_description = await async_generate_text(item.get('prompt'))
        product_description = product_description.replace("'","''")
        sql = f"UPDATE products SET product_description = '{product_description}' WHERE id = {item.get('id')};"
        await run_query(ecom_db_pool, sql)
        logging.info(f"Consumer {worker_id}: Finished processing ID: {item.get('id')}, SKU: {item.get('sku')}")
        # --- End processing ---

        queue.task_done() # Signal that this item processing is complete


async def generate_product_descriptions_concurrently(products_df):
    # Create the queue. You can optionally set a maxsize.
    # If maxsize is reached, the producer's `await queue.put(item)` will pause
    # until a consumer calls `queue.get()`, providing backpressure.
    work_queue = asyncio.Queue(maxsize=100)
    num_consumers = 10

    # Create tasks for the producer and consumer(s)
    producer_task = asyncio.create_task(load_queue_from_dataframe(products_df, work_queue, num_consumers))

    # Create one or more consumer tasks
    consumer_tasks = []
    for i in range(num_consumers):
        consumer_tasks.append(
            asyncio.create_task(process_items_from_queue(work_queue, i + 1))
        )

    # Wait for the producer to finish loading (optional, but good to ensure all items are queued)
    await producer_task

    # Wait for all consumers to finish processing all items
    # This relies on consumers calling queue.task_done() for each item + the sentinel
    await work_queue.join() # Wait until the queue is fully processed
    logging.info("Main: Queue has been fully processed.")


await generate_product_descriptions_concurrently(products_df)


### View Product Descriptions

In [None]:
sql = "SELECT id, name, brand, category, department, sku, product_description FROM products LIMIT 5;"
await run_query(ecom_db_pool, sql)

## Generate Product Images with Imagen 3

> IMPORTANT: This section uses the `imagen-3.0-fast-generate-001` model to generate product images for a dataset containing 29,120 products. As of the time of publishing this notebook, the model costs $0.02 per image. There is a default limit of 1000 images set below to prevent inadvertently running a cost job, but you can adjust that limit up or down as desired.

> NOTE: If you would like to generate pictures of people, ensure your project is allow-listed first. You can request to be allow-listed using [this form](https://docs.google.com/forms/d/e/1FAIpQLSduBp9w84qgim6vLriQ9p7sdz62bMJaL-nNmIVoyiOwd84SMw/viewform).

### Create Coming Soon Image

You can run this cell repeatedly until you're happy with the placeholder image.

In [None]:
from IPython.display import Image, display
from IPython.core.display import HTML

coming_soon_uri = generate_image('Professional image with gray background and bold white lettering that says "Coming Soon"', 'coming_soon', 1)
print(f"gsutil uri: {coming_soon_uri}")

display(Image(url=f"https://storage.cloud.google.com/{image_bucket}/product-images/coming_soon.png"))

### Add product_image_uri Column

In [None]:
sql = """ALTER TABLE products ADD COLUMN product_image_uri TEXT;"""
await run_query(ecom_db_pool, sql)

### Use Placeholder Image for Sentitive Categories

In [None]:
#coming_soon_uri = 'gs://stylesearch-masked/product-images/coming_soon.png'

sql = f"""
UPDATE products
SET product_image_uri = '{coming_soon_uri}'
WHERE category IN ('Intimates','Underwear')
"""

await run_query(ecom_db_pool, sql)

In [None]:
sql = f"""
SELECT COUNT(*) FROM products
WHERE product_image_uri IS NULL
"""

await run_query(ecom_db_pool, sql)

### Define Product Image Prompt Builder Function

In [None]:
def build_product_image_prompt(name, brand, category, department):
  # Remove breaking quotes
  name = name.replace("'", "")

  # Remove terms that trigger image generation failures
  remove = [
      brand,
      '-',
      'Boys',
      'Boy',
      'Girls',
      'Girl',
      'Juniors',
      'Junior',
      '.',
      '&',
      '\n',
  ]
  for r in remove:
      name = name.lower().replace(r.lower(), '')
  name = name.strip()
  prompt = f"Product image (with NO logos): {department} {name} {category} {brand}"
  prompt = prompt[0:1200] # Shorten long prompts to fit within token limit
  if not prompt:
    prompt = 'Coming Soon'
  return prompt

### Get Products Without Product Images

> IMPORTANT: This is the cell that builds the dataset that will be used to generate photos. You can adjust the limit up or down as desired to control the cost of the image generation job.

In [None]:
# Set the number of products to generate images for here
image_limit = 30000

# Get products without images
sql = f"""SELECT id,
    name,
    brand,
    category,
    department,
    retail_price,
    sku,
    product_description
  FROM products
  WHERE product_image_uri IS NULL
  AND name IS NOT NULL
  AND brand IS NOT NULL
  AND name != 'Discontinued'
  LIMIT {image_limit};"""
products_df = await run_query(ecom_db_pool, sql)
products_df

### Build Image Prompts

In [None]:
products_df['image_prompt'] = products_df.apply(
    lambda row: build_product_image_prompt(
        row['name'],
        row['brand'],
        row['category'],
        row['department'],
    ),
    axis=1
)
products_df

### Generate Images Asynchronously

This step will take 4-5 hours to generate images for all 29,120 products in the Cymbal Shops product catalog, running 150 async requests at a time. You can adjust the number of images to generate in the cells above (see comments). You can also request a quota increase to allow more concurrent invocations, in which case you can increase the `num_consumers` variable below.

> NOTE: Errors are expected in this step due to ambiguous product names and content filter false positives. Failures will be marked in the `product_image_uri` column in the database. Early testing resulted in an ~80% success rate. You can tweak the prompt and retry for failed items if desired.

In [None]:
import asyncio


async def load_queue_from_dataframe(df: pd.DataFrame, queue: asyncio.Queue, num_consumers: int):
    """
    Iterates through DataFrame rows and puts them into the asyncio queue.

    Args:
        df: The Pandas DataFrame to process.
        queue: The asyncio.Queue to put items into.
    """
    print(f"Producer: Starting to load {len(df)} items into the queue...")
    # Use itertuples for efficiency. index=False avoids adding the DataFrame index.
    # name=None uses default namedtuple name 'Pandas'
    for row_tuple in df.itertuples(index=False, name='Product'):
        # Convert the named tuple to a dictionary - often easier for consumers
        item = row_tuple._asdict()
        await queue.put(item)
        logging.info(f"Producer: Put item {item.get('sku', 'N/A')} into queue. Queue size: {queue.qsize()}")
    logging.info("Producer: Finished loading all items.")

    # Add one sentinel value (e.g., None) per consumer to signal completion
    for _ in range(num_consumers):
        await queue.put(None)
    print(f"Producer: Added {num_consumers} sentinel(s) to queue.")


async def process_items_from_queue(queue: asyncio.Queue, worker_id: int):
    """
    Continuously gets items from the queue and processes them until a sentinel is received.
    Propagates exceptions if processing fails.
    """
    logging.info(f"Consumer {worker_id}: Started...")
    while True:
        item = await queue.get()

        # --- Check for Sentinel ---
        if item is None:
            print(f"Consumer {worker_id}: Sentinel received. Exiting.")
            queue.task_done() # Mark sentinel processing as done
            break # Exit the loop

        # --- Process the item ---

        try:
            log_prefix = f"Consumer {worker_id}: Item ID {item.get('id', 'N/A')}:"
            logging.info(f"{log_prefix} Starting processing.")

            # Get current running loop
            loop = asyncio.get_running_loop()

            # Run the blocking image generation function
            # Ensure generate_image raises exceptions on failure or returns None clearly
            image_uri = await loop.run_in_executor(
                None, # Use default executor (ThreadPoolExecutor)

                # --- Define the blocking function to run ---
                generate_image,

                # --- Add function parameters here ---
                item.get('image_prompt'), # First argument for generate_image
                item.get('sku'),          # Second argument for generate_image, etc
                item.get('id'),
            )

            # Handle image generation failure (if it returns None instead of raising)
            if image_uri is None:
                # Log warning and continue (skip DB update for this item)
                logging.warning(f"{log_prefix} Image generation failed or returned None.")
                sql = f"UPDATE products SET product_image_uri = 'FAILED - Prompt: {item.get('image_prompt')}' WHERE id = {item.get('id')};"
                await run_query(ecom_db_pool, sql)
            else:
                logging.info(f"{log_prefix} Image generated: '{image_uri}'")
                # Run the database update
                sql = f"UPDATE products SET product_image_uri = '{image_uri}' WHERE id = {item.get('id')};"
                await run_query(ecom_db_pool, sql)
                #logging.info(f"{log_prefix} Database updated successfully.")

            # --- Processing successful for this item ---
            queue.task_done() # Signal completion ONLY on success

        except Exception as e:
            # Log the exception WITH traceback
            logging.error(f"Consumer {worker_id}: Unhandled exception processing item ID {item.get('id', 'N/A')}: {e}", exc_info=True)
            # Log error without raising so that remaining items can be processed.
            logging.warning(e)


async def process_dataframe_concurrently(products_df):
    # Create the queue with maxsize. If maxsize is reached, the producer's
    # `await queue.put(item)` will pause until a consumer calls `queue.get()`, providing backpressure.

    # --- Set queue max size and number of consumers/workers here ---
    work_queue = asyncio.Queue(maxsize=100)
    num_consumers = 10
    all_tasks = []

    # --- Create Tasks ---
    print("Main: Creating producer task...")
    producer_task = asyncio.create_task(
        load_queue_from_dataframe(products_df, work_queue, num_consumers),
        name="Producer"
    )
    all_tasks.append(producer_task)

    print(f"Main: Creating {num_consumers} consumer tasks...")
    for i in range(num_consumers):
        consumer_task = asyncio.create_task(
            process_items_from_queue(work_queue, i + 1),
            name=f"Consumer-{i+1}"
        )
        all_tasks.append(consumer_task)

    # Wait for the producer to finish loading (optional, but good to ensure all items are queued)
    await producer_task

    # --- Run Tasks and Handle Completion/Failure ---
    done, pending = [], []
    try:
        # Wait for all tasks to complete. gather will raise the *first* exception
        # encountered in any of the tasks.
        print("Main: Waiting for tasks to complete...")
        # Use asyncio.wait instead of gather to have more control over pending tasks on error
        done, pending = await asyncio.wait(all_tasks, return_when=asyncio.FIRST_COMPLETED)

        # Check if any completed tasks raised an exception
        for task in done:
            if task.exception():
                raise task.exception() # Raise the exception from the failed task

    except Exception as e:
        print(f"Main: An error occurred in a task: {e}", exc_info=True)
        print("Main: Attempting to cancel pending tasks...")
        for task in pending: # Cancel tasks found in the pending set from asyncio.wait
             task.cancel()

        # Give cancelled tasks a moment to process the cancellation
        # and gather any CancelledError exceptions (optional but cleaner)
        if pending:
             await asyncio.wait(pending, timeout=1.0) # Wait briefly

        # Important: Re-raise the original exception to stop the program execution
        # Or handle it appropriately (e.g., sys.exit(1))
        raise e # Propagate the error out

    finally:
        # Ensure all tasks are truly finished one way or another (optional cleanup)
        remaining_tasks = [t for t in all_tasks if not t.done()]
        if remaining_tasks:
             print("Main: Waiting for final cleanup of any remaining tasks...")
             await asyncio.wait(remaining_tasks, timeout=1.0) # Brief wait

    # If execution reaches here, it means all tasks finished without unhandled exceptions propagating
    print("Main: Process finished.")

# This is a very verbose process. Changing the logging level to WARNING
logging.basicConfig(level=logging.WARNING, stream=sys.stdout, format='%(asctime)s[%(levelname)5s][%(name)14s] - %(message)s',  datefmt='%H:%M:%S', force=True)

# Kick off parallel image creation
await process_dataframe_concurrently(products_df)

# Switch logging back to INFO
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s[%(levelname)5s][%(name)14s] - %(message)s',  datefmt='%H:%M:%S', force=True)


### View Image Generation Success Rate

In [None]:
sql = """WITH a AS (
  SELECT
    (SELECT COUNT(*)  FROM products WHERE product_image_uri IS NOT NULL) AS processed,
    (SELECT COUNT(*) FROM products WHERE product_image_uri LIKE 'FAILED%') AS failed
) SELECT processed,
  processed - failed AS successful,
  failed,
  1-(failed/processed::FLOAT) AS success_rate
  FROM a;"""
await run_query(ecom_db_pool, sql)

## Clean Up Image Data

### Set Image to "Coming Soon" for NULLs and Incorrectly Generated Images

In [None]:
coming_soon_uri = f"gs://{image_bucket}/product-images/coming_soon.png"

sql = "SELECT COUNT(*) FROM products WHERE product_image_uri IS NULL OR product_image_uri LIKE 'FAILED%';"
result = await run_query(ecom_db_pool, sql)
print(result)

sql = f"""
  UPDATE products
  SET product_image_uri = '{coming_soon_uri}'
  WHERE product_image_uri IS NULL
  OR product_image_uri LIKE 'FAILED%';"""

result = await run_query(ecom_db_pool, sql)
result


## Generate Multimodal Embeddings

### Add Embedding Columns

In [None]:
sql_array = []
sql_array.append("ALTER TABLE products ADD COLUMN product_description_embedding VECTOR(1408);")
sql_array.append("ALTER TABLE products ADD COLUMN product_description_embedding_model TEXT;")
sql_array.append("ALTER TABLE products ADD COLUMN product_image_embedding VECTOR(1408);")
sql_array.append("ALTER TABLE products ADD COLUMN product_image_embedding_model TEXT;")
for sql in sql_array:
  await run_query(ecom_db_pool, sql)

### Get Products to Embed

In [None]:
sql = f"""SELECT id,
            name,
            brand,
            category,
            department,
            retail_price,
            sku,
            product_description,
            product_image_uri
         FROM products
         WHERE product_description_embedding IS NULL
         AND product_image_uri != 'gs://{image_bucket}/product-images/coming_soon.png'
         LIMIT 30000;"""
products_df = await run_query(ecom_db_pool, sql)
products_df

### Generate Multimodal Embeddings Asynchronously

In [None]:
import asyncio
import vertexai
from vertexai.vision_models import Image, MultiModalEmbeddingModel


async def load_queue_from_dataframe(df: pd.DataFrame, queue: asyncio.Queue, num_consumers: int):
    """
    Iterates through DataFrame rows and puts them into the asyncio queue.

    Args:
        df: The Pandas DataFrame to process.
        queue: The asyncio.Queue to put items into.
    """
    print(f"Producer: Starting to load {len(df)} items into the queue...")
    # Use itertuples for efficiency. index=False avoids adding the DataFrame index.
    # name=None uses default namedtuple name 'Pandas'
    for row_tuple in df.itertuples(index=False, name='Product'):
        # Convert the named tuple to a dictionary - often easier for consumers
        item = row_tuple._asdict()
        await queue.put(item)
        logging.info(f"Producer: Put item {item.get('sku', 'N/A')} into queue. Queue size: {queue.qsize()}")
    logging.info("Producer: Finished loading all items.")

    # Add one sentinel value (e.g., None) per consumer to signal completion
    for _ in range(num_consumers):
        await queue.put(None)
    print(f"Producer: Added {num_consumers} sentinel(s) to queue.")


async def process_items_from_queue(queue: asyncio.Queue, worker_id: int):
    """
    Continuously gets items from the queue and processes them until a sentinel is received.
    Propagates exceptions if processing fails.
    """
    logging.info(f"Consumer {worker_id}: Started...")
    while True:
        item = await queue.get()

        # --- Check for Sentinel ---
        if item is None:
            print(f"Consumer {worker_id}: Sentinel received. Exiting.")
            queue.task_done() # Mark sentinel processing as done
            break # Exit the loop

        # --- Process the item ---

        try:
            log_prefix = f"Consumer {worker_id}: Item ID {item.get('id', 'N/A')}:"
            logging.info(f"{log_prefix} Starting processing.")

            # Get current running loop
            loop = asyncio.get_running_loop()

            # Define variables
            prompt = f"{item.get('name')} {item.get('product_description')}"
            prompt = prompt.replace("'","")
            uri = item.get('product_image_uri')
            if uri:
                if uri.startswith('FAILED'):
                   uri = None

            # Run the blocking embedding generation function
            result = await loop.run_in_executor(
                None, # Use default executor (ThreadPoolExecutor)

                # --- Define the io blocking function to run ---
                generate_multimodal_embeddings,

                # --- Add function parameters here ---
                uri,      # First argument for generate_multimodal_embeddings
                prompt,   # Second argument for generate_multimodal_embeddings, etc
            )

            # Handle embedding generation failure (if it returns None instead of raising)
            if result is None:
                # Log warning and continue (skip DB update for this item)
                logging.warning(f"{log_prefix} Embedding generation failed or returned None.")
            else:
                # Build the UPDATE statement
                sql = f"""UPDATE products
                    SET product_description_embedding = '{result.text_embedding}',
                        product_description_embedding_model = 'multimodalembedding@001'"""

                if result.image_embedding:
                    sql = sql + f""",
                        product_image_embedding = '{result.image_embedding}',
                        product_image_embedding_model = 'multimodalembedding@001'"""
                else:
                    logging.info("No image embedding received in payload.")

                # Run the database UPDATE
                sql = sql + f" WHERE id = {item.get('id')};"
                await run_query(ecom_db_pool, sql)

            # --- Processing successful for this item ---
            queue.task_done() # Signal completion ONLY on success

        except Exception as e:
            # Log the exception WITH traceback
            logging.error(f"Consumer {worker_id}: Unhandled exception processing item ID {item.get('id', 'N/A')}: {e}", exc_info=True)
            # Log error without raising so that remaining items can be processed.
            logging.warning(e)


async def process_dataframe_concurrently(products_df):
    # Create the queue with maxsize. If maxsize is reached, the producer's
    # `await queue.put(item)` will pause until a consumer calls `queue.get()`, providing backpressure.

    # --- Set queue max size and number of consumers/workers here ---
    work_queue = asyncio.Queue(maxsize=800)
    num_consumers = 100
    all_tasks = []

    # --- Create Tasks ---
    print("Main: Creating producer task...")
    producer_task = asyncio.create_task(
        load_queue_from_dataframe(products_df, work_queue, num_consumers),
        name="Producer"
    )
    all_tasks.append(producer_task)

    print(f"Main: Creating {num_consumers} consumer tasks...")
    for i in range(num_consumers):
        consumer_task = asyncio.create_task(
            process_items_from_queue(work_queue, i + 1),
            name=f"Consumer-{i+1}"
        )
        all_tasks.append(consumer_task)

    # Wait for the producer to finish loading (optional, but good to ensure all items are queued)
    await producer_task

    # --- Run Tasks and Handle Completion/Failure ---
    done, pending = [], []
    try:
        # Wait for all tasks to complete. gather will raise the *first* exception
        # encountered in any of the tasks.
        print("Main: Waiting for tasks to complete...")
        # Use asyncio.wait instead of gather to have more control over pending tasks on error
        done, pending = await asyncio.wait(all_tasks, return_when=asyncio.FIRST_COMPLETED)

        # Check if any completed tasks raised an exception
        for task in done:
            if task.exception():
                raise task.exception() # Raise the exception from the failed task

    except Exception as e:
        print(f"Main: An error occurred in a task: {e}", exc_info=True)
        print("Main: Attempting to cancel pending tasks...")
        for task in pending: # Cancel tasks found in the pending set from asyncio.wait
             task.cancel()

        # Give cancelled tasks a moment to process the cancellation
        # and gather any CancelledError exceptions (optional but cleaner)
        if pending:
             await asyncio.wait(pending, timeout=1.0) # Wait briefly

        # Important: Re-raise the original exception to stop the program execution
        # Or handle it appropriately (e.g., sys.exit(1))
        raise e # Propagate the error out

    finally:
        # Ensure all tasks are truly finished one way or another (optional cleanup)
        remaining_tasks = [t for t in all_tasks if not t.done()]
        if remaining_tasks:
             print("Main: Waiting for final cleanup of any remaining tasks...")
             await asyncio.wait(remaining_tasks, timeout=1.0) # Brief wait

    # If execution reaches here, it means all tasks finished without unhandled exceptions propagating
    print("Main: Process finished.")

# This is a very verbose process. Changing the logging level to WARNING
logging.basicConfig(level=logging.WARNING, stream=sys.stdout, format='%(asctime)s[%(levelname)5s][%(name)14s] - %(message)s',  datefmt='%H:%M:%S', force=True)

# Kick off parallel image creation
await process_dataframe_concurrently(products_df)

# Switch logging back to INFO
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s[%(levelname)5s][%(name)14s] - %(message)s',  datefmt='%H:%M:%S', force=True)



### Verify Embedding Counts

In [None]:
sql = f"""SELECT 'product_description' AS column_name, (SELECT COUNT(*) FROM products WHERE product_description IS NULL) AS null_count
UNION ALL
SELECT 'product_description_embedding', (SELECT COUNT(*) FROM products WHERE product_description_embedding IS NULL AND product_image_uri != 'gs://{image_bucket}/product-images/coming_soon.png')
UNION ALL
SELECT 'product_image_uri', (SELECT COUNT(*) FROM products WHERE product_image_uri IS NULL OR product_image_uri LIKE 'FAILED%')
UNION ALL
SELECT 'product_image_embedding', (SELECT COUNT(*) FROM products WHERE product_image_embedding IS NULL AND product_image_uri != 'gs://{image_bucket}/product-images/coming_soon.png');"""

await run_query(ecom_db_pool, sql)

## Generate Semantic Embeddings

### Replace Embedding Columns

In [None]:
sql_array = []
sql_array.append("ALTER TABLE products DROP COLUMN embedding;")
sql_array.append("ALTER TABLE products DROP COLUMN embedding_model_version;")
sql_array.append("ALTER TABLE products ADD COLUMN product_embedding VECTOR(3072);")
sql_array.append("ALTER TABLE products ADD COLUMN product_embedding_model TEXT;")
for sql in sql_array:
  await run_query(ecom_db_pool, sql)

### Embed Products Asynchronously

In [None]:
import asyncio
import math

# Update logging to level=WARN due to verbose output
logging.basicConfig(level=logging.WARN, stream=sys.stdout, format='%(asctime)s[%(levelname)5s][%(name)14s] - %(message)s',  datefmt='%H:%M:%S', force=True)

async def apply_async_func_to_dataframe(df, func, column_to_embed, new_column_name):
    """
    Applies an asynchronous function to a specified column of a DataFrame
    and stores the results in a new column.
    """
    # Create a list of coroutine objects
    tasks = [func(row[column_to_embed]) for index, row in df.iterrows()]

    # Run all coroutines concurrently
    results = await asyncio.gather(*tasks)

    # Assign the results back to the DataFrame
    df[new_column_name] = results
    return df

async def main():
    # Get count of rows to embed
    sql = """
      SELECT COUNT(*)
      FROM products
      WHERE product_embedding IS NULL;
      """
    row_count = await run_query(ecom_db_pool, sql, output_as_df=False)
    rows_to_embed = row_count.fetchall()[0][0]
    print(f"Rows to embed: {rows_to_embed}")

    # Generate embeddings in batches
    batch_size = 100
    batch_count = math.ceil(rows_to_embed / batch_size)
    current_batch = 1

    while current_batch <= batch_count:
      # Define the dataframe
      sql = f"""
        SELECT id, CONCAT('Name: ', name, ' \\nCategory: ', category, ' \\nBrand: ', brand, ' \\nDepartment: ', department) AS embed_string
        FROM products
        WHERE product_embedding IS NULL
        LIMIT {batch_size};
        """
      semantic_embedding_df = await run_query(ecom_db_pool, sql)

      print(f"Batch {current_batch} of {batch_count}: Running async embedding job...")

      # Apply the async function
      semantic_embedding_df = await apply_async_func_to_dataframe(
          semantic_embedding_df,
          async_generate_embedding,
          'embed_string',
          'product_embedding'
      )

      # Update the database for the current batch
      # Use bulk executemany for more efficient update
      print(f"Batch {current_batch} of {batch_count}: Updating database with embeddings...")
      data_to_update = [
          {"product_embedding": str(row.product_embedding), "id": row.id}
          for row in semantic_embedding_df.itertuples(index=False) # Use index=False since id is a column, not a df index
      ]

      sql = f"""
      UPDATE products
      SET product_embedding = :product_embedding,
          product_embedding_model = 'gemini-embedding-001'
      WHERE id = :id;
      """
      await run_query(ecom_db_pool, sql, data_to_update)

      current_batch = current_batch + 1

# To run this, you need to execute the main async function
if __name__ == "__main__":
    # For Jupyter notebooks or environments where an event loop might already be running:
    try:
        loop = asyncio.get_running_loop()
    except RuntimeError: # No running event loop
        loop = None

    if loop and loop.is_running():
        # If a loop is already running, schedule the task
        task = loop.create_task(main())
        print("Scheduled main() function to run in existing event loop.")
        await task
    else:
        # If no loop is running, create and run a new one
        task = asyncio.run(main())
        print("Scheduled main() function to run in new event loop.")
        await task


# Update logging back to level=INFO
logging.basicConfig(level=logging.WARN, stream=sys.stdout, format='%(asctime)s[%(levelname)5s][%(name)14s] - %(message)s',  datefmt='%H:%M:%S', force=True)


### Ensure All Rows Have Embeddings

In [None]:
sql = "SELECT COUNT(*) FROM products WHERE product_embedding IS NULL"
await run_query(ecom_db_pool, sql)

## Generate BM25 Sparse Embeddings

### Create or Instantiate Bucket for BM25 Index Data

In [None]:
# Create index_bucket if none is provided
import uuid
from google.cloud import storage

storage_client = storage.Client()
random_suffix = uuid.uuid4().hex[:6]  # Get a 6-character hexadecimal suffix
index_bucket_name = f"bm25-index-{random_suffix}"

if not index_bucket:
    index_bucket = storage_client.create_bucket(index_bucket_name)
    print(f"Bucket {index_bucket.name} created.")
else:
    print(f"Using provided index_bucket: {index_bucket}")
    index_bucket = storage_client.bucket(index_bucket)

### Define Column Weights for BM25 Index

This is a fairly brute-force method of adding more weight to one column over another by simply repeating it multiple times in the BM25 embedding input. You may prefer to research other more sophisticated methods to achieve similar results, but this gets the job done.

In [None]:
# Higher weights give more priority to the column
name_weight = 5
brand_weight = 3
category_weight = 2
department_weight = 2
product_description_weight = 1
sku_weight = 1
bm25_index_sql = f"""SELECT id,
    REPEAT(COALESCE(name,'') || ' ', {name_weight}) ||
    REPEAT(COALESCE(brand,'') || ' ', {brand_weight}) ||
    REPEAT(COALESCE(category,'') || ' ', {category_weight}) ||
    REPEAT(COALESCE(department,'') || ' ', {department_weight}) ||
    REPEAT(COALESCE(product_description,''), {product_description_weight})||
    REPEAT(COALESCE(sku,'') || ' ', {sku_weight}) AS content
    FROM products;"""

### Build BM25 Index from AlloyDB Content

> NOTE: In this example, we run `bm25_ef.fit()` in-memory and serve the model locally since we have a relatively small product catalog (29,120 items). For very large product catalogs, this operation may exceed local memory. In such case, you might need to re-write this step in Spark (or similar) and potentially host the model on a dedicated endpoint.

In [None]:
# Reference: https://milvus.io/api-reference/pymilvus/v2.4.x/EmbeddingModels/BM25EmbeddingFunction/BM25EmbeddingFunction.md

from pymilvus.model.sparse import BM25EmbeddingFunction
from pymilvus.model.sparse.bm25.tokenizers import build_default_analyzer

# Use the default English analyzer
analyzer = build_default_analyzer(language="en")

# Instantiate BM25 model
bm25_ef = BM25EmbeddingFunction(
    analyzer = analyzer,
    k1 = 1.5, # This controls document term normalization
    b = 0.75, # This controls field length normalization
    epsilon = 0.25 # This is used to smooth idf values
)

# Get text content from the vector store
docs = await run_query(ecom_db_pool, bm25_index_sql)
docs = docs.replace("'","")

# Fit the model to the AlloyDB content
bm25_ef.fit(docs['content'].to_list())


### Upload Model Parameters to GCS

In [None]:
import os

# Store the fitted parameters to expedite future processing.
bm25_params_file_name = "bm25_params.json"
bm25_ef.save(bm25_params_file_name)

# Upload saved model parameters to GCS
current_directory = os.getcwd()
blob_path = f"bm25_index/{bm25_params_file_name}"
file_name = f"{current_directory}/{bm25_params_file_name}"
blob = index_bucket.blob(blob_path)
blob.upload_from_filename(file_name)
print(f"Uploaded model parameters from: {file_name} to gs://{index_bucket.name}/{blob_path}")

### Download Model Parameters from GCS

In [None]:
# Download model parameters (update blob to use your own file)
blob = index_bucket.blob(f"bm25_index/{bm25_params_file_name}")
blob.download_to_filename(bm25_params_file_name)

# Load the saved params (optionally provide your own index.json)
bm25_ef = BM25EmbeddingFunction()
bm25_ef.load(bm25_params_file_name)

# Print out the max dims:
max_1_based_dims = bm25_ef.dim + 1
print(f"Max dimensionality: {max_1_based_dims}")

### Define Helper Function for `sparsevec` Encoding

In [None]:
def encode_sparsevec(query: str, dimensions: int = max_1_based_dims) -> str:

    # Generate the sparse embeddings
    sparse_embeddings = bm25_ef.encode_queries([query])
    lil = sparse_embeddings.tolil(copy=False)
    sparse_scores, sparse_indices = lil.data.tolist()[0], lil.rows.tolist()[0]

    # Ensure sparse_scores and sparse_indices lists are the same length
    assert len(sparse_scores) == len(sparse_indices)

    # sparsevec data type is 1-based. sparse_indices are zero-based.
    sparse_indices = [x + 1 for x in sparse_indices]

    # Zip results and transform to expected format for pgvector sparsevec type
    result = [f"{key}:{value:.7g}" for key, value in zip(sparse_indices, sparse_scores)]
    return f"{{{','.join(result)}}}/{dimensions}"

### Create Sparse Embeddings for Content

In [None]:
docs['sparse_embedding'] = docs['content'].apply(encode_sparsevec)
docs

### Add Sparse Embedding Columns to AlloyDB

In [None]:
sql_array = []
sql_array.append(f"ALTER TABLE products DROP COLUMN IF EXISTS sparse_embedding;")
sql_array.append(f"ALTER TABLE products ADD COLUMN sparse_embedding sparsevec({bm25_ef.dim + 1});")
sql_array.append(f"ALTER TABLE products DROP COLUMN IF EXISTS sparse_embedding_model;")
sql_array.append(f"ALTER TABLE products ADD COLUMN sparse_embedding_model TEXT;")
for sql in sql_array:
  await run_query(ecom_db_pool, sql)

### Update AlloyDB with Sparse Embeddings

In [None]:
# Use bulk executemany for more efficient update
data_to_update = [
    {"sparse_embedding": row.sparse_embedding, "id": row.id}
    for row in docs.itertuples(index=False) # Use index=False since id is a column, not a df index
]

sql = f"""
UPDATE products
SET sparse_embedding = :sparse_embedding,
    sparse_embedding_model = 'BM25'
WHERE id = :id;
"""

await run_query(ecom_db_pool, sql, data_to_update)

## Create Weighted Full-text Search Column

Columns weighted `A` have more weight than columns weighted `B`, which have more weight than columns weighted `C`, etc.

In [None]:
sql_array = []
sql_array.append(f"ALTER TABLE products DROP COLUMN IF EXISTS fts_document;")
sql_array.append(f"""ALTER TABLE products ADD COLUMN fts_document tsvector GENERATED ALWAYS AS (
      setweight(to_tsvector('english', coalesce(name, '')), 'A') ||
      setweight(to_tsvector('english', coalesce(brand, '')), 'B') ||
      setweight(to_tsvector('english', coalesce(category, '')), 'C') ||
      setweight(to_tsvector('english', coalesce(department, '')), 'C') ||
      setweight(to_tsvector('english', coalesce(product_description, '')), 'D') ||
      setweight(to_tsvector('english', coalesce(sku, '')), 'D')
    ) STORED;""")
for sql in sql_array:
  await run_query(ecom_db_pool, sql)

## Create Indexes

### Create Standard PostgreSQL Indexes for Efficient Facet Searches

In [None]:
sql_array = []

sql_array.append("DROP INDEX IF EXISTS idx_products_brand;")
sql_array.append("CREATE INDEX idx_products_brand ON products (brand);")
sql_array.append("DROP INDEX IF EXISTS idx_products_category;")
sql_array.append("CREATE INDEX idx_products_category ON products (category);")
sql_array.append("DROP INDEX IF EXISTS idx_products_retail_price;")
sql_array.append("CREATE INDEX idx_products_retail_price ON products (retail_price);")
sql_array.append("DROP INDEX IF EXISTS idx_products_sku;")
sql_array.append("CREATE INDEX idx_products_sku ON products (sku);")

for sql in sql_array:
  await run_query(ecom_db_pool, sql)

### Create ScaNN Indexes for Efficient ANN Vector Search

In [None]:
sql_array = []

sql_array.append("CREATE EXTENSION IF NOT EXISTS alloydb_scann")
sql_array.append("SET SESSION scann.num_leaves_to_search = 1")
sql_array.append("SET SESSION scann.pre_reordering_num_neighbors=50")
sql_array.append("DROP INDEX IF EXISTS embedding_scann")
sql_array.append("""
CREATE INDEX embedding_scann ON products
  USING scann (product_embedding cosine)
  WITH (num_leaves=2);
""")
sql_array.append("DROP INDEX IF EXISTS product_description_embedding_scann")
sql_array.append("""
CREATE INDEX product_description_embedding_scann ON products
  USING scann (product_description_embedding cosine)
  WITH (num_leaves=2);
""")
sql_array.append("DROP INDEX IF EXISTS product_image_embedding_scann")
sql_array.append("""
CREATE INDEX product_image_embedding_scann ON products
  USING scann (product_image_embedding cosine)
  WITH (num_leaves=2);
""")

for sql in sql_array:
  await run_query(ecom_db_pool, sql)


### Create HNSW Index for Efficient BM25 ANN Sparsevec Search

In [None]:
sql_array = []
sql_array.append("DROP INDEX IF EXISTS sparse_embedding_hnsw")
sql_array.append("""
CREATE INDEX sparse_embedding_hnsw ON products
  USING hnsw (sparse_embedding sparsevec_ip_ops)
  WITH (m = 16, ef_construction = 64);
""")
sql_array.append("SET hnsw.ef_search = 100;") # This is necessary for better recall with sparsevec

for sql in sql_array:
  await run_query(ecom_db_pool, sql)

### Create GIN Index for Efficient Full-text Search

In [None]:
# Ref: https://www.postgresql.org/docs/current/gin.html
sql_array = []
sql_array.append("DROP INDEX IF EXISTS products_fts_document_gin;")
sql_array.append("CREATE INDEX products_fts_document_gin ON products USING GIN (fts_document);")

for sql in sql_array:
  await run_query(ecom_db_pool, sql)

### Analyze `products` Table

Collect statistics about the contents of table now that we've changed it substantially.

In [None]:
sql = "ANALYZE products;"
await run_query(ecom_db_pool, sql)

## Run Advanced Product Search

### View the Raw Data

In [None]:
sql = "SELECT * FROM products LIMIT 5;"
await run_query(ecom_db_pool, sql)

### Search with Traditional SQL

Traditional SQL remains one of the most popular and prevalent query languages among Business Analysts, Developers, DBAs, and Data Scientists due to its simplicity, expressiveness, and broad applicability. However, it is not usually the best choice for Product Search for many reasons, including:
- Inefficient query patterns (e.g. searching multiple fields with ILIKE '%keyword%') will result in full table scans.
- Intolerant to typos, synonyms, spelling variants, etc.
- No relevance ranking (keyword proximity, field weighting, popularity, sales data, recency, user reviews, personalization signals) - result order is arbitrary.



In [None]:
# --- Uncomment one query at a time ---
query = "sunglasses" # Simple keyword search returns relevant but arbitrarily ranked and inefficient results
#query = "silver sunglasses" # Adding color returns NO results
#query = "sunglases" # Typo returns NO results
#query = "81A6F51D90AF2C00DFC715C5DC5FE88D" # Search by SKU returns the specific item, but inefficiently

# --- Set explain=True to view query performance ---
explain = False

sql = f"""
WITH trad_sql AS (
    SELECT
      ROW_NUMBER () OVER (ORDER BY name) AS trad_sql_rank,
      name,
      product_image_uri,
      brand,
      product_description,
      category,
      department,
      cost,
      retail_price,
      sku
    FROM products
    WHERE name ILIKE '%{query}%'
      OR sku ILIKE '%{query}%'
      OR category ILIKE '%{query}%'
      OR brand ILIKE '%{query}%'
      OR department ILIKE '%{query}%'
      OR product_description ILIKE '%{query}%'
    ORDER BY name
    LIMIT 10
) SELECT * FROM trad_sql;
"""

if explain:
  sql = 'EXPLAIN ANALYZE ' + sql

result = await run_query(ecom_db_pool, sql)
try:
  html_output = HTML(result.to_html(escape=False, formatters={'product_image_uri': df_image_formatter}))
  display(html_output)
except:
  print(result)

### Search with Full-text Search

Full-text Search in PostgreSQL provides the capability to identify natural-language documents that satisfy a query, and optionally to sort them by relevance to the query. The simplest search considers query as a set of words and similarity as the frequency of query words in the document.

Full-text search provides several advantages over traditional SQL, including:
- Significant performance improvement with the ts_vector data type (storing lexemes) and GIN indexes compared to the `ILIKE '%keyword%'` syntax.
- Linguistic processing during indexing and querying, including stemming, stop word removal, and case normalization.
- Relevance ranking with `ts_rank()` (based on term frequency) and `ts_rank_cd()` (based on term frequency and proximity of matching lexemes).
- Weighted search, allowing you to prioritize matches found in one column (e.g. `name` or `sku`) over another (eg. `product_description`).

Weaknesses:
- Matches on lexical similarity, but does not understand semantic similarity.

References:
- https://www.postgresql.org/docs/current/textsearch-intro.html
- https://www.postgresql.org/docs/current/functions-textsearch.html
- https://www.postgresql.org/docs/current/textsearch-controls.html

In [None]:
# --- Choose appropriate query function ---
# plainto_tsquery: treats input as space-separated keywords, ANDs them. Good start.
# websearch_to_tsquery: More flexible, handles quotes for phrases, OR, '-' for negation. Often better for user input.
# phraseto_tsquery: Treats entire input as a phrase.
# We'll use plainto_tsquery here, adjust as needed.

fts_query_function = 'plainto_tsquery'

# --- Uncomment one query at a time ---
query = "sunglasses" # Simple keyword search returns relevant, ranked results ~200x faster than Traditional SQL
#query = "silver sunglasses" # Adding color returns relevant, ranked, efficient results
#query = "sunglases" # Typo returns NO results
#query = "sunglass" # Stemmed lexeme returns results
#query = "81A6F51D90AF2C00DFC715C5DC5FE88D" # Search by SKU efficiently returns the specific item
#query = "Shades" # FTS doesn't understand that 'shades' are semantically similar to 'sunglasses'

# --- Set explain=True to view query performance ---
explain = False

sql = f"""
WITH fts_search AS (
  SELECT
      ts_rank(fts_document, {fts_query_function}('english', '{query}')) AS fts_rank_score,
      name,
      product_image_uri,
      brand,
      product_description,
      category,
      department,
      cost,
      retail_price,
      sku
  FROM products
  WHERE fts_document @@ {fts_query_function}('english', '{query}')
  ORDER BY fts_rank_score DESC
  LIMIT 10
) SELECT
  ROW_NUMBER () OVER (ORDER BY fts_rank_score DESC) AS fts_rank, *
  FROM fts_search
"""


if explain:
  sql = 'EXPLAIN ANALYZE ' + sql

result = await run_query(ecom_db_pool, sql)
try:
  html_output = HTML(result.to_html(escape=False, formatters={'product_image_uri': df_image_formatter}))
  display(html_output)
except:
  print(result)

### Search with BM25 Keyword Search

BM25 (Okapi BM25) is the default ranking algorithm in many dedicated search engines and is generally considered a more advanced and often more effective  solution than full-text search, due in part to the following characteristics:
- Term Frequency (TF) Saturation: PostgreSQL's ts_rank relevance score increases linearly with the number of times a term appears in a document. This means a document repeating a keyword many times can get a disproportionately high score. BM 25 implements TF saturation using a parameter (`k1`). This means that after a term appears a certain number of times, each additional occurrence contributes less and less to the score. This better reflects human intuition (a document mentioning "sunglasses" 5 times is relevant, but one mentioning it 50 times isn't necessarily 10 times more relevant) and prevents keyword stuffing from dominating results.
- Document Length Normalization: PostgreSQL's ts_rank includes normalization options to penalize longer documents (controlled by flags), but the method is relatively simple. BM25 uses a more sophisticated normalization method controlled by a parameter (`b`). It considers the document's length relative to the average document length across the entire collection. This allows it to more effectively balance relevance across documents of varying lengths, often preventing very long or very short documents from being unfairly penalized or promoted.
- Tunability: PostgreSQL's ts_rank allows tuning by  adjusting normalization flags or using `setweight` for different fields. BM25 provides explicit parameters (`k1` for TF saturation, `b` for length normalization) that directly control the core behavior of the ranking algorithm, allowing for fine-tuning based on the characteristics of the specific dataset and desired search behavior.
- Empirical Performance: BM25 is derived from probabilistic models and has consistently demonstrated strong performance in information retrieval benchmarks, often outperforming standard TF-IDF implementations in terms of relevance ranking quality.

The BM25 index input for this example was defined as:

```sql
SELECT COALESCE(name,'') || ' ' ||
       COALESCE(sku,'') || ' ' ||
       COALESCE(category,'') || ' ' ||
       COALESCE(brand,'') || ' ' ||
       COALESCE(department,'') || ' ' ||
       COALESCE(product_description,'')
FROM products;
```

Weaknesses
- New vocabulary requires rebuilding the index and re-embedding every record.


In [None]:
# --- Set explain=True to view query performance ---
explain = False

# --- Uncomment one query at a time ---
query = "sunglasses" # Simple keyword search gives good results
#query = "silver sunglasses" # Lots of irrelevant results with 'silver' in the name
#query = "81A6F51D90AF2C00DFC715C5DC5FE88D" # Search by SKU does not return relevant results
#query = "silver sunglasses" # Adding color returns non-relevant results
#query = "Shades" # BM25 also doesn't understand that 'shades' are semantically similar to 'sunglasses'

sparse_query_embedding = encode_sparsevec(query)
print(f"Encoded query as BM25 sparse embedding: {sparse_query_embedding}")

sql = f"""
WITH bm25_search AS (
  SELECT sparse_embedding <#> '{sparse_query_embedding}' AS bm25_distance,
      ROW_NUMBER () OVER (ORDER BY sparse_embedding <#> '{sparse_query_embedding}') AS bm25_rank,
      name,
      product_image_uri,
      brand,
      product_description,
      category,
      department,
      cost,
      retail_price,
      sku
    FROM products
    ORDER BY bm25_distance
    LIMIT 10
) SELECT * FROM bm25_search WHERE bm25_distance < 1;
"""
if explain:
  sql = 'EXPLAIN ANALYZE ' + sql

result = await run_query(ecom_db_pool, sql)
try:
  html_output = HTML(result.to_html(escape=False, formatters={'product_image_uri': df_image_formatter}))
  display(html_output)
except:
  print(result)

### Search with Text Embeddings

#### `gemini-embedding-001`

In [None]:
# --- Set explain=True to view query performance ---
explain = False

# --- Uncomment one query at a time ---
query = "sunglasses" # Simple keyword search gives good results
#query = "silver sunglasses" # All results are for sunglasses
#query = "81A6F51D90AF2C00DFC715C5DC5FE88D" # Search by SKU does not return relevant results
#query = "silver sunglasses" # Adding color returns all sunglasses, but misses two results that FTS found
#query = "Shades" # gemini-embedding-001 model understands that 'shades' is semanitcally similar to 'sunglasses'

sql = f"""
WITH vector_search AS (
  SELECT product_embedding <=> embedding('gemini-embedding-001', '{query}')::vector AS distance,
    name,
    product_image_uri,
    brand,
    product_description,
    category,
    department,
    cost,
    retail_price,
    sku
  FROM products
  ORDER BY distance
  LIMIT 10
) SELECT ROW_NUMBER () OVER (ORDER BY distance) AS vector_rank, *
FROM vector_search
"""

if explain:
  sql = 'EXPLAIN ANALYZE ' + sql

result = await run_query(ecom_db_pool, sql)
try:
  html_output = HTML(result.to_html(escape=False, formatters={'product_image_uri': df_image_formatter}))
  display(html_output)
except:
  print(result)

#### `multimodal-embedding@001`

> NOTE: This is very inaccurate right now. Need to re-embed with different input.

In [None]:
# --- Set explain=True to view query performance ---
explain = False

# --- Uncomment one query at a time ---
#query = "sunglasses" # TBD
query = "silver sunglasses" # TBD
#query = "81A6F51D90AF2C00DFC715C5DC5FE88D" # TBD
#query = "silver sunglasses" # TBD
#query = "Shades" # TBD


# Get embedding
multimodal_embedding = generate_multimodal_embeddings(None, query)

# Search
sql = f"""
WITH multimodal_vector_search AS (
  SELECT product_description_embedding <=> '{multimodal_embedding.text_embedding}' AS distance,
    name,
    product_image_uri,
    brand,
    product_description,
    category,
    department,
    cost,
    retail_price,
    sku
  FROM products
  ORDER BY distance
  LIMIT 10
) SELECT RANK () OVER (ORDER BY distance) AS vector_rank, *
FROM multimodal_vector_search
"""

if explain:
  sql = 'EXPLAIN ANALYZE ' + sql

result = await run_query(ecom_db_pool, sql)
try:
  html_output = HTML(result.to_html(escape=False, formatters={'product_image_uri': df_image_formatter}))
  display(html_output)
except:
  print(result)

### Search by Image

#### Show sample images

In [None]:
test_images = {
  'object': ['Blue Jacket', 'Brown Jacket', 'Black Coat', 'Gray Jacket', 'Noogler Hat',],
  'gs_uri': ['gs://pr-public-demo-data/alloydb-retail-demo/user_photos/1.png',
             'gs://pr-public-demo-data/alloydb-retail-demo/user_photos/2.png',
             'gs://pr-public-demo-data/alloydb-retail-demo/user_photos/3.png',
             'gs://pr-public-demo-data/alloydb-retail-demo/user_photos/4.png',
             'gs://pr-public-demo-data/alloydb-retail-demo/user_photos/5.png'],
  'image': ['https://storage.cloud.google.com/pr-public-demo-data/alloydb-retail-demo/user_photos/1.png',
            'https://storage.cloud.google.com/pr-public-demo-data/alloydb-retail-demo/user_photos/2.png',
            'https://storage.cloud.google.com/pr-public-demo-data/alloydb-retail-demo/user_photos/3.png',
            'https://storage.cloud.google.com/pr-public-demo-data/alloydb-retail-demo/user_photos/4.png',
            'https://storage.cloud.google.com/pr-public-demo-data/alloydb-retail-demo/user_photos/5.png']
}

test_images_df = pd.DataFrame(test_images)
test_images_df
HTML(test_images_df.to_html(escape=False, formatters={'image': lambda x: df_image_formatter(x, 400)}))

#### Search by Image

In [None]:
# Uncomment 1 image at a time.
image_uri = 'gs://pr-public-demo-data/alloydb-retail-demo/user_photos/1.png'
#image_uri = 'gs://pr-public-demo-data/alloydb-retail-demo/user_photos/2.png'
#image_uri = 'gs://pr-public-demo-data/alloydb-retail-demo/user_photos/3.png'
#image_uri = 'gs://pr-public-demo-data/alloydb-retail-demo/user_photos/4.png'
#image_uri = 'gs://pr-public-demo-data/alloydb-retail-demo/user_photos/5.png'

multimodal_embedding = generate_multimodal_embeddings(image_uri, None)

sql = f"""
WITH multimodal_vector_search AS (
  SELECT product_image_embedding <=> '{multimodal_embedding.image_embedding}' AS distance,
    id,
    name,
    product_image_uri,
    brand,
    product_description,
    category,
    department,
    cost,
    retail_price,
    sku
  FROM products
  ORDER BY distance
  LIMIT 24
) SELECT RANK () OVER (ORDER BY distance) AS vector_rank, *
FROM multimodal_vector_search
"""

result = await run_query(ecom_db_pool, sql)

HTML(result.to_html(escape=False, formatters={'product_image_uri': df_image_formatter}))

#### Use AlloyDB Multimodal Embedding Function

Instead of generating the embedding first and then executing the SQL query, you can generate the image embedding for the query on the fly with SQL using AlloyDB's new multimodal embedding feature.

> NOTE: This is a preview feature. See [this doc](https://cloud.google.com/alloydb/docs/ai/generate-multimodal-embeddings) for more info about getting access to the feature.

In [None]:
# Reference: https://cloud.google.com/alloydb/docs/ai/generate-multimodal-embeddings

image_uri = 'gs://pr-public-demo-data/alloydb-retail-demo/user_photos/1.png'
#image_uri = 'gs://pr-public-demo-data/alloydb-retail-demo/user_photos/2.png'
#image_uri = 'gs://pr-public-demo-data/alloydb-retail-demo/user_photos/3.png'
#image_uri = 'gs://pr-public-demo-data/alloydb-retail-demo/user_photos/4.png'
#image_uri = 'gs://pr-public-demo-data/alloydb-retail-demo/user_photos/5.png'


sql = f"""
WITH image_embedding AS (
  SELECT ai.image_embedding(
      model_id => 'multimodalembedding@001',
      image => '{image_uri}',
      mimetype => 'image/png')::vector AS embedding
), multimodal_vector_search AS (
  SELECT product_image_embedding <=> image_embedding.embedding AS distance,
    name,
    product_image_uri,
    brand,
    product_description,
    category,
    department,
    cost,
    retail_price,
    sku
  FROM products, image_embedding
  ORDER BY distance
  LIMIT 5
) SELECT RANK () OVER (ORDER BY distance) AS vector_rank, *
FROM multimodal_vector_search
"""

result = await run_query(ecom_db_pool, sql)
#result

HTML(result.to_html(escape=False, formatters={'product_image_uri': df_image_formatter}))

### Combine Search Techniques with Hybrid Search

In testing, this efficient hybrid search query ran in less than 20ms.

> NOTE: Traditional SQL was modified to include just SKU matching for efficiency. Vector, B25, and FTS techniques are implemented as shown above.

In [None]:
# TO DO:  Fix search on misspellings. Currently returning all BM25 irrelevant results.

# --- Set explain=True to view query performance ---
explain = False

# --- Uncomment one query at a time ---
query = "sunglasses" # Simple keyword search gives good results
#query = "silver sunglasses" # We get more sunglasses than non-sunglasses in the result
#query = "81A6F51D90AF2C00DFC715C5DC5FE88D" # First result is our exact SKU
#query = "Shades" # Blended result
query = "sunglases" # Misspelling

sparse_query_embedding = encode_sparsevec(query)

print(f"BM25 sparse embedding: {sparse_query_embedding}")

# Set top_k
top_k = 20

# Define RRF smoothing constant k
rrf_k = 60

sql = f"""
WITH trad_sql AS (
    SELECT
      RANK () OVER (ORDER BY name) AS trad_sql_rank,
      id,
      name,
      product_image_uri,
      brand,
      product_description,
      category,
      department,
      cost,
      retail_price,
      sku
    FROM products
    WHERE sku = '{query}'
    ORDER BY name
    LIMIT {top_k * 2}
), fts_search AS (
  SELECT
      ts_rank(fts_document, {fts_query_function}('english', '{query}')) AS fts_rank_score,
      RANK () OVER (ORDER BY ts_rank(fts_document, {fts_query_function}('english', '{query}')) DESC) as fts_rank,
      id,
      name,
      product_image_uri,
      brand,
      product_description,
      category,
      department,
      cost,
      retail_price,
      sku
  FROM products
  WHERE fts_document @@ {fts_query_function}('english', '{query}')
  ORDER BY fts_rank_score DESC
  LIMIT {top_k * 2}
), bm25_search AS (
  SELECT sparse_embedding <#> '{sparse_query_embedding}' AS bm25_distance,
      RANK () OVER (ORDER BY sparse_embedding <#> '{sparse_query_embedding}') AS bm25_rank,
      id,
      name,
      product_image_uri,
      brand,
      product_description,
      category,
      department,
      cost,
      retail_price,
      sku
    FROM products
    WHERE sparse_embedding <#> '{sparse_query_embedding}' < 1
    ORDER BY sparse_embedding <#> '{sparse_query_embedding}'
    LIMIT {top_k * 2}
), vector_search AS (
  SELECT embedding <=> embedding('gemini-embedding-001', '{query}')::vector AS distance,
    RANK () OVER (ORDER BY embedding <=> embedding('gemini-embedding-001', '{query}')::vector) AS vector_rank,
    id,
    name,
    product_image_uri,
    brand,
    product_description,
    category,
    department,
    cost,
    retail_price,
    sku
  FROM products
  ORDER BY distance
  LIMIT {top_k * 2}
) SELECT
    COALESCE(vector_search.id, fts_search.id, bm25_search.id, trad_sql.id) AS id,
    (
      COALESCE( (1.0 / ({rrf_k} + MAX(vector_search.vector_rank))), 0.0 ) +
      COALESCE( (1.0 / ({rrf_k} + MAX(fts_search.fts_rank))), 0.0 ) +
      COALESCE( (1.0 / ({rrf_k} + MAX(bm25_search.bm25_rank))), 0.0 ) +
      COALESCE( (1.0 / ({rrf_k} + MAX(trad_sql.trad_sql_rank))), 0.0 )
    ) AS rrf_score,
    CONCAT_WS(
        '+',
        CASE WHEN MAX(vector_search.vector_rank) IS NOT NULL THEN 'vector' ELSE NULL END,
        CASE WHEN MAX(fts_search.fts_rank) IS NOT NULL THEN 'fts' ELSE NULL END,
        CASE WHEN MAX(bm25_search.bm25_rank) IS NOT NULL THEN 'bm25' ELSE NULL END,
        CASE WHEN MAX(trad_sql.trad_sql_rank) IS NOT NULL THEN 'trad' ELSE NULL END
    ) AS result_type,
    COALESCE(vector_search.id, fts_search.id, bm25_search.id, trad_sql.id) AS id,
    COALESCE(vector_search.name, fts_search.name, bm25_search.name, trad_sql.name) AS name,
    COALESCE(vector_search.product_image_uri, fts_search.product_image_uri, bm25_search.product_image_uri, trad_sql.product_image_uri) AS product_image_uri,
    COALESCE(vector_search.brand, fts_search.brand, bm25_search.brand, trad_sql.brand) AS brand,
    COALESCE(vector_search.product_description, fts_search.product_description, bm25_search.product_description, trad_sql.product_description) AS product_description,
    COALESCE(vector_search.category, fts_search.category, bm25_search.category, trad_sql.category) AS category,
    COALESCE(vector_search.department, fts_search.department, bm25_search.department, trad_sql.department) AS department,
    COALESCE(vector_search.cost, fts_search.cost, bm25_search.cost, trad_sql.cost) AS cost,
    COALESCE(vector_search.retail_price, fts_search.retail_price, bm25_search.retail_price, trad_sql.retail_price) AS retail_price,
    COALESCE(vector_search.sku, fts_search.sku, bm25_search.sku, trad_sql.sku) AS sku
FROM vector_search
FULL OUTER JOIN fts_search ON vector_search.id = fts_search.id
FULL OUTER JOIN bm25_search ON COALESCE(vector_search.id, fts_search.id) = bm25_search.id
FULL OUTER JOIN trad_sql ON COALESCE(vector_search.id, fts_search.id, bm25_search.id) = trad_sql.id
GROUP BY COALESCE(vector_search.id, fts_search.id, bm25_search.id, trad_sql.id),
    COALESCE(vector_search.name, fts_search.name, bm25_search.name, trad_sql.name),
    COALESCE(vector_search.product_image_uri, fts_search.product_image_uri, bm25_search.product_image_uri, trad_sql.product_image_uri),
    COALESCE(vector_search.brand, fts_search.brand, bm25_search.brand, trad_sql.brand),
    COALESCE(vector_search.product_description, fts_search.product_description, bm25_search.product_description, trad_sql.product_description),
    COALESCE(vector_search.category, fts_search.category, bm25_search.category, trad_sql.category),
    COALESCE(vector_search.department, fts_search.department, bm25_search.department, trad_sql.department),
    COALESCE(vector_search.cost, fts_search.cost, bm25_search.cost, trad_sql.cost),
    COALESCE(vector_search.retail_price, fts_search.retail_price, bm25_search.retail_price, trad_sql.retail_price),
    COALESCE(vector_search.sku, fts_search.sku, bm25_search.sku, trad_sql.sku)
ORDER BY rrf_score DESC
LIMIT {top_k};
"""

if explain:
  sql = 'EXPLAIN ANALYZE ' + sql

result = await run_query(ecom_db_pool, sql)
try:
  html_output = HTML(result.to_html(escape=False, formatters={'product_image_uri': df_image_formatter}))
  display(html_output)
except:
  print(result)


### Hybrid Search with Built-in Support

In [None]:
# --- Choose appropriate query function ---
# plainto_tsquery: treats input as space-separated keywords, ANDs them. Good start.
# websearch_to_tsquery: More flexible, handles quotes for phrases, OR, '-' for negation. Often better for user input.
# phraseto_tsquery: Treats entire input as a phrase.
# We'll use plainto_tsquery here, adjust as needed.

fts_query_function = 'websearch_to_tsquery'

# --- Set explain=True to view query performance ---
explain = False

# --- Uncomment one query at a time ---
query = "sunglasses" # Simple keyword search gives good results
#query = "silver sunglasses" # We get more sunglasses than non-sunglasses in the result
#query = "81A6F51D90AF2C00DFC715C5DC5FE88D" # First result is our exact SKU
#query = "Shades" # Blended result
#query = "sunglases" # Misspelling

# Set top_k
top_k = 20

# Define RRF smoothing constant k
rrf_k = 60

sql = f"""
WITH trad_sql AS (
    SELECT
      RANK () OVER (ORDER BY name) AS trad_sql_rank,
      id,
      name,
      product_image_uri,
      brand,
      product_description,
      category,
      department,
      cost,
      retail_price,
      sku,
      'SQL' AS retrieval_method
    FROM products
    WHERE sku = '{query}'
    ORDER BY name
    LIMIT {top_k * 2}
), fts_search AS (
  SELECT
      ts_rank(fts_document, {fts_query_function}('english', '{query}')) AS fts_rank_score,
      RANK () OVER (ORDER BY ts_rank(fts_document, {fts_query_function}('english', '{query}')) DESC) as fts_rank,
      id,
      name,
      product_image_uri,
      brand,
      product_description,
      category,
      department,
      cost,
      retail_price,
      sku,
      'FTS' AS retrieval_method
  FROM products
  WHERE fts_document @@ {fts_query_function}('english', '{query}')
  ORDER BY fts_rank_score DESC
  LIMIT {top_k * 2}
), vector_search AS (
  SELECT embedding <=> embedding('gemini-embedding-001', '{query}')::vector AS distance,
      RANK () OVER (ORDER BY embedding <=> embedding('gemini-embedding-001', '{query}')::vector) AS vector_rank,
      id,
      name,
      product_image_uri,
      brand,
      product_description,
      category,
      department,
      cost,
      retail_price,
      sku,
      'VECTOR' AS retrieval_method
  FROM products
  ORDER BY distance
  LIMIT {top_k * 2}
) SELECT
    COALESCE(vector_search.id, fts_search.id, trad_sql.id) AS id,
    (
      COALESCE( (1.0 / ({rrf_k} + vector_search.vector_rank)), 0.0 ) +
      COALESCE( (1.0 / ({rrf_k} + fts_search.fts_rank)), 0.0 ) +
      COALESCE( (1.0 / ({rrf_k} + trad_sql.trad_sql_rank)), 0.0 )
    ) AS rrf_score,
    CONCAT_WS(
        '+',
        CASE WHEN vector_search.vector_rank IS NOT NULL THEN 'VECTOR' ELSE NULL END,
        CASE WHEN fts_search.fts_rank IS NOT NULL THEN 'FTS' ELSE NULL END,
        CASE WHEN trad_sql.trad_sql_rank IS NOT NULL THEN 'SQL' ELSE NULL END
    ) AS result_type,
    COALESCE(vector_search.id, fts_search.id, trad_sql.id) AS id,
    COALESCE(vector_search.name, fts_search.name, trad_sql.name) AS name,
    COALESCE(vector_search.product_image_uri, fts_search.product_image_uri, trad_sql.product_image_uri) AS product_image_uri,
    COALESCE(vector_search.brand, fts_search.brand, trad_sql.brand) AS brand,
    COALESCE(vector_search.product_description, fts_search.product_description, trad_sql.product_description) AS product_description,
    COALESCE(vector_search.category, fts_search.category, trad_sql.category) AS category,
    COALESCE(vector_search.department, fts_search.department, trad_sql.department) AS department,
    COALESCE(vector_search.cost, fts_search.cost, trad_sql.cost) AS cost,
    COALESCE(vector_search.retail_price, fts_search.retail_price, trad_sql.retail_price) AS retail_price,
    COALESCE(vector_search.sku, fts_search.sku, trad_sql.sku) AS sku
FROM vector_search
FULL OUTER JOIN fts_search ON vector_search.id = fts_search.id
FULL OUTER JOIN trad_sql ON COALESCE(vector_search.id, fts_search.id) = trad_sql.id
ORDER BY rrf_score DESC
LIMIT {top_k};
"""

if explain:
  sql = 'EXPLAIN ANALYZE ' + sql

print(sql)
result = await run_query(ecom_db_pool, sql)
try:
  html_output = HTML(result.to_html(escape=False, formatters={'product_image_uri': df_image_formatter}))
  display(html_output)
except:
  print(result)


### Enable The AlloyDB Columnar Engine

References:
* https://cloud.google.com/alloydb/docs/columnar-engine/configure
* https://cloud.google.com/alloydb/docs/instance-configure-database-flags#gcloud

In [None]:
result = ! gcloud beta alloydb instances update {alloydb_instance} \
   --database-flags google_columnar_engine.enabled=on,google_columnar_engine.enable_vectorized_join=on,password.enforce_complexity=on,google_ml_integration.enable_model_support=on \
   --region={region} \
   --cluster={alloydb_cluster} \
   --project={project_id} \
   --update-mode=FORCE_APPLY

### Add Facet Columns to Column Store

Reference: https://cloud.google.com/alloydb/docs/columnar-engine/manage-content-manually

In [None]:
sql = """SELECT google_columnar_engine_add(
    relation => 'products',
    columns => 'id,brand,category,retail_price'
);"""

await run_query(ecom_db_pool, sql)

### Validate Columns are Added to Columnar Engine

Reference: https://cloud.google.com/alloydb/docs/columnar-engine/monitor-tune

In [None]:
sql = "SELECT * FROM g_columnar_columns;"
await run_query(ecom_db_pool, sql)

### Get Facets with Hybrid Search

This example shows how you can efficiently get facets that show the count of brand, category, and price bin foryour hybrid search queries.

> NOTE: The Columnar Engine may not be chosen for aggregations when existing indexes are more efficient. The query engine will decide whether using the Columnar Engine would be more efficient than a standard query or not. You can get details about the decisions the query engine is making in relation to choose the Columnar Engine by running `EXPLAIN (COLUMNAR_ENGINE)` on the query. For most queries on this small dataset (29,120 items), the Columnar Engine will not be chosen. For large product catalogs with millions of items, it's much more likely that the Columnar Engine will be chosen for efficient query execution.

In [None]:
# --- Set explain=True to view query performance ---
explain = False

# --- Turn columnar engine "on" or "off" to test performance impact of the columnar engine. ---
columnar_engine_state = "on"

# --- Choose a query term ---
query = "sunglasses" # Simple keyword search gives good results
#query = "silver sunglasses" # We get more sunglasses than non-sunglasses in the result
#query = "81A6F51D90AF2C00DFC715C5DC5FE88D" # First result is our exact SKU
#query = "Shades" # Blended result
#query = "sunglases" # Misspelling

await run_query(ecom_db_pool, f"SET google_columnar_engine.enable_columnar_scan={columnar_engine_state};")

sql = f"""WITH
  -- 1. Define the pool of candidate IDs exactly once
  candidate_ids AS (
    WITH vector_candidates AS (
      SELECT id, embedding <=> embedding('gemini-embedding-001', '{query}')::vector AS distance FROM products ORDER BY distance LIMIT 500
    )
    SELECT id FROM vector_candidates WHERE distance < 0.4
    UNION
    SELECT id FROM products WHERE sku = '{query}'
    UNION
    SELECT id FROM products WHERE fts_document @@ websearch_to_tsquery('english', '{query}')
  ),
  -- 2. Join products with candidates and prepare facet columns (including calculated ones)
  products_for_faceting AS (
    SELECT
      p.brand,
      p.category,
      CASE
        WHEN p.retail_price < 50 THEN '$0 - $49.99'
        WHEN p.retail_price >= 50 AND p.retail_price < 100 THEN '$50 - $99.99'
        WHEN p.retail_price >= 100 AND p.retail_price < 250 THEN '$100 - $249.99'
        WHEN p.retail_price >= 250 AND p.retail_price < 500 THEN '$250 - $499.99'
        WHEN p.retail_price >= 500 THEN '$500+'
        ELSE NULL
      END AS price_range,
      p.retail_price
    FROM
      products AS p
      JOIN candidate_ids AS c ON p.id = c.id
  ),
  -- 3. Calculate Aggregations using GROUPING SETS
  facet_aggregations AS (
    SELECT
      COALESCE(brand, category, price_range) AS facet_value,
      CASE
        WHEN GROUPING(brand) = 0 THEN 'brand'
        WHEN GROUPING(category) = 0 THEN 'category'
        WHEN GROUPING(price_range) = 0 THEN 'price_range'
      END AS facet_type,
      COUNT(*) AS count,
      MIN(retail_price) as min_price_for_ordering
    FROM
      products_for_faceting
    WHERE
        brand IS NOT NULL OR category IS NOT NULL OR price_range IS NOT NULL
    GROUP BY
      GROUPING SETS (
        (brand),
        (category),
        (price_range)
      )
  )
-- 4. Final SELECT and ORDER BY from the aggregated results
SELECT
  facet_value,
  facet_type,
  count
FROM
  facet_aggregations
ORDER BY
  facet_type ASC,
  CASE WHEN facet_type = 'price_range' THEN min_price_for_ordering ELSE NULL END ASC NULLS LAST, -- Handle potential NULL min_price
  count DESC,
  facet_value ASC;"""

if explain:
  sql = 'EXPLAIN (COLUMNAR_ENGINE) ' + sql

result = await run_query(ecom_db_pool, sql)
try:
  html_output = HTML(result.to_html(escape=False, formatters={'product_image_uri': df_image_formatter}))
  display(html_output)
except:
  print(result)

## Get Facets for Full Result Set

### Enable The AlloyDB Columnar Engine

References:
* https://cloud.google.com/alloydb/docs/columnar-engine/configure
* https://cloud.google.com/alloydb/docs/instance-configure-database-flags#gcloud

In [None]:
result = ! gcloud beta alloydb instances update {alloydb_instance} \
   --database-flags google_columnar_engine.enabled=on,google_columnar_engine.enable_vectorized_join=on,password.enforce_complexity=on,google_ml_integration.enable_model_support=on \
   --region={region} \
   --cluster={alloydb_cluster} \
   --project={project_id} \
   --update-mode=FORCE_APPLY

### Add Columns to Column Store Automatically

Reference: https://cloud.google.com/alloydb/docs/columnar-engine/manage-content-recommendations

In [None]:
sql = "SELECT google_columnar_engine_recommend();"

await run_query(ecom_db_pool, sql)

#### View Recommended Columns

In [None]:
sql = "SELECT database_name, schema_name, relation_name, column_name FROM g_columnar_recommended_columns;"

await run_query(ecom_db_pool, sql)

### (OPTIONAL) Add Facet Columns to Column Store Manually

Reference: https://cloud.google.com/alloydb/docs/columnar-engine/manage-content-manually

In [None]:
sql = """SELECT google_columnar_engine_add(
    relation => 'products',
    columns => 'id,brand,category,retail_price'
);"""

await run_query(ecom_db_pool, sql)

### Validate Columns are Added to Columnar Engine

Reference: https://cloud.google.com/alloydb/docs/columnar-engine/monitor-tune

In [None]:
sql = "SELECT * FROM g_columnar_columns;"
await run_query(ecom_db_pool, sql)

### Facets with Hybrid Search

This example shows how you can efficiently get facets that show the count of brand, category, and price bin foryour hybrid search queries.

In testing, execution time ranged from 80ms-200ms with Columnar Engine turned off.

In [None]:
# --- Set explain=True to view query performance ---
explain = False

# --- Turn columnar engine "on" or "off" to test performance impact of the columnar engine. ---
columnar_engine_state = "on"

# --- Choose a query term ---
query = "sunglasses" # Simple keyword search gives good results
#query = "silver sunglasses" # We get more sunglasses than non-sunglasses in the result
#query = "81A6F51D90AF2C00DFC715C5DC5FE88D" # First result is our exact SKU
#query = "Shades" # Blended result
#query = "sunglases" # Misspelling

await run_query(ecom_db_pool, f"SET google_columnar_engine.enable_columnar_scan={columnar_engine_state};")

sql = f"""WITH
  -- 1. Define the pool of candidate IDs exactly once
  candidate_ids AS (
    WITH vector_candidates AS (
      SELECT id, embedding <=> embedding('gemini-embedding-001', '{query}')::vector AS distance FROM products ORDER BY distance LIMIT 500
    )
    SELECT id FROM vector_candidates WHERE distance < 0.4
    UNION
    SELECT id FROM products WHERE sku = '{query}'
    UNION
    SELECT id FROM products WHERE fts_document @@ websearch_to_tsquery('english', '{query}')
  ),
  -- 2. Join products with candidates and prepare facet columns (including calculated ones)
  products_for_faceting AS (
    SELECT
      p.brand,
      p.category,
      CASE
        WHEN p.retail_price < 50 THEN '$0 - $49.99'
        WHEN p.retail_price >= 50 AND p.retail_price < 100 THEN '$50 - $99.99'
        WHEN p.retail_price >= 100 AND p.retail_price < 250 THEN '$100 - $249.99'
        WHEN p.retail_price >= 250 AND p.retail_price < 500 THEN '$250 - $499.99'
        WHEN p.retail_price >= 500 THEN '$500+'
        ELSE NULL
      END AS price_range,
      p.retail_price
    FROM
      products AS p
      JOIN candidate_ids AS c ON p.id = c.id
  ),
  -- 3. Calculate Aggregations using GROUPING SETS
  facet_aggregations AS (
    SELECT
      COALESCE(brand, category, price_range) AS facet_value,
      CASE
        WHEN GROUPING(brand) = 0 THEN 'brand'
        WHEN GROUPING(category) = 0 THEN 'category'
        WHEN GROUPING(price_range) = 0 THEN 'price_range'
      END AS facet_type,
      COUNT(*) AS count,
      MIN(retail_price) as min_price_for_ordering
    FROM
      products_for_faceting
    WHERE
        brand IS NOT NULL OR category IS NOT NULL OR price_range IS NOT NULL
    GROUP BY
      GROUPING SETS (
        (brand),
        (category),
        (price_range)
      )
  )
-- 4. Final SELECT and ORDER BY from the aggregated results
SELECT
  facet_value,
  facet_type,
  count
FROM
  facet_aggregations
ORDER BY
  facet_type ASC,
  CASE WHEN facet_type = 'price_range' THEN min_price_for_ordering ELSE NULL END ASC NULLS LAST, -- Handle potential NULL min_price
  count DESC,
  facet_value ASC;"""

if explain:
  sql = 'EXPLAIN (COLUMNAR_ENGINE) ' + sql

result = await run_query(ecom_db_pool, sql)
try:
  html_output = HTML(result.to_html(escape=False, formatters={'product_image_uri': df_image_formatter}))
  display(html_output)
except:
  print(result)

## Use ai.if() SQL Operator for Advanced Filtering

Reference: https://cloud.google.com/alloydb/docs/ai/evaluate-semantic-queries-ai-operators

In [None]:
# --- Set explain=True to view query performance ---
explain = False

# --- Define the primary query and the AI filter ---
query = "winter coat" # Simple keyword search gives good results
ai_filter = "made by a luxury brand"

sql = f"""
WITH vector_search AS (
  SELECT embedding <=> embedding('gemini-embedding-001', '{query}')::vector AS distance,
    name,
    product_image_uri,
    brand,
    product_description,
    category,
    department,
    cost,
    retail_price,
    sku
  FROM products
  ORDER BY distance
  LIMIT 20
) SELECT ROW_NUMBER () OVER (ORDER BY distance) AS vector_rank, *
  FROM vector_search
  WHERE ai.if(prompt => 'The following product {ai_filter}: ' ||
                        ' Product name: ' || COALESCE(name, '') ||
                        ' Brand: ' || COALESCE(brand, '') ||
                        ' Category: ' || COALESCE(category, '') ||
                        ' Department: ' || COALESCE(department, '') ||
                        ' Price: ' || COALESCE(retail_price, '') ||
                        ' Description: ' || COALESCE(product_description, ''))
"""

if explain:
  sql = 'EXPLAIN ANALYZE ' + sql

result = await run_query(ecom_db_pool, sql)
try:
  html_output = HTML(result.to_html(escape=False, formatters={'product_image_uri': df_image_formatter}))
  display(html_output)
except:
  print(result)

## Export Prepped Data

Export the data to GCS to allow for easier reproducibility and demo deployment.

In [None]:
# Reference: https://cloud.google.com/alloydb/docs/reference/rest/v1/projects.locations.clusters/export
#            https://cloud.google.com/alloydb/docs/export-sql-file

import time

url = f"https://alloydb.googleapis.com/v1/projects/{project_id}/locations/{region}/clusters/{alloydb_cluster}:export"
request_body = {
   "gcsDestination": {"uri": f"gs://{export_bucket}/alloydb_export/{alloydb_database}.sql"},
   "database": f"{alloydb_database}",
   "sqlExportOptions": {
      "tables": [
        "distribution_centers",
        "events",
        "inventory_items",
        "order_items",
        "orders",
        "products",
        "users",
      ]
    }
}

result = rest_api_helper(authed_session, url, 'POST', request_body, {})
print(f"Kicked off export: {result}")

operation_id = result['name']

operation_complete = False
while operation_complete == False:
  print(f"Export still running: {operation_id}")
  url = f"https://alloydb.googleapis.com/v1/{operation_id}"
  response = rest_api_helper(authed_session, url, 'GET', request_body, {})
  operation_complete = response['done']
  if operation_complete:
    print(f"Operation complete. Check result payload for potential errors. \nResult: {response}")
    continue
  time.sleep(5)