In [None]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

使用Vertex AI矢量搜索进行文本到图像嵌入
![ ](https://www.google-analytics.com/collect?v=2&tid=G-L6X3ECH596&cid=1&en=page_view&sid=1&dt=sdk_matching_engine_create_text_to_image_embeddings.ipynb&dl=notebooks%2Fofficial%2Fmatching_engine%2Fsdk_matching_engine_create_text_to_image_embeddings.ipynb)

<table align="left">
  <td>
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/official/matching_engine/sdk_matching_engine_create_text_to_image_embeddings.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Colab logo"> 在Colab中运行
    </a>
  </td>
  <td>
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/official/matching_engine/sdk_matching_engine_create_text_to_image_embeddings.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      查看GitHub上的代码
    </a>
  </td>
      <td>
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/main/notebooks/official/matching_engine/sdk_matching_engine_create_text_to_image_embeddings.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo">
      在Vertex AI Workbench中打开
    </a>
  </td>
</table>

## 概述

此示例演示了如何使用DiffusionDB数据集和CLIP模型创建文本到图像嵌入，并将其上传到Vertex AI向量搜索服务。这是一个高规模、低延迟的解决方案，用于在大型语料库中查找相似的向量。此外，它是一个完全托管的提供，进一步减少了运营开销。它是建立在谷歌研究开发的[近似最近邻（ANN）技术](https://ai.googleblog.com/2020/07/announcing-scann-efficient-vector.html)之上。

**先决条件**：此笔记本要求您已经设置了一个VPC网络。请参阅[创建Vertex AI向量搜索索引笔记本](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/official/matching_engine/sdk_matching_engine_for_indexing.ipynb)中的“准备VPC网络”部分。

了解更多关于[Vertex AI向量搜索](https://cloud.google.com/vertex-ai/docs/matching-engine/overview)。

### 目标

在这本笔记本中，您将学习如何对自定义文本嵌入进行编码，创建一个近似最近邻（ANN）索引，并针对索引进行查询。

本教程使用以下 Google Cloud ML 服务：

- `Vertex AI Vector Search`

执行的步骤包括：

* 创建ANN索引
* 使用VPC网络创建索引端点
* 部署ANN索引
* 执行在线查询

数据集

本教程使用的数据集是[DiffusionDB数据集](https://github.com/poloclub/diffusiondb)。

DiffusionDB是第一个大规模的文本到图像提示数据集。它包含由真实用户指定的提示和超参数生成的1400万张图像。这个前所未有的规模和多样性的人为驱动数据集为理解提示和生成模型之间的相互作用、检测深度伪造以及设计人工智能交互工具以帮助用户更容易地使用这些模型提供了令人兴奋的研究机会。

## 安装

安装最新版本的Cloud Storage、BigQuery和Python的Vertex AI SDK。

In [None]:
# Install the packages
! pip3 install --upgrade google-cloud-aiplatform \
                         google-cloud-storage --upgrade

安装最新版本的transformers和torch库以编码文本和图像嵌入。

In [None]:
# Install the packages
! pip3 install --upgrade transformers torch --upgrade

安装最新版本的google-cloud-vision，用于过滤安全图像，同时安装pyrate_limiter，以限制对Google Cloud Vision API的调用。

In [None]:
# Install the packages
! pip install google-cloud-vision pyrate_limiter==2.10

只有合作：取消以下单元格的注释以重新启动内核。

In [None]:
# Automatically restart kernel after installs so that your environment can access the new packages
# import IPython

# app = IPython.Application.instance()
# app.kernel.do_shutdown(True)

## 在开始之前
#### 设置您的项目ID

如果您不知道您的项目ID，可以尝试以下操作：
* 运行 `gcloud config list`。
* 运行 `gcloud projects list`。
* 参考支持页面：[查找项目ID](https://support.google.com/googleapi/answer/7014113)

In [None]:
PROJECT_ID = "[your-project-id]"  # @param {type:"string"}

# Set the project id
! gcloud config set project {PROJECT_ID}

#### 区域

您还可以更改由Vertex AI使用的`REGION`变量。了解有关[Vertex AI区域](https://cloud.google.com/vertex-ai/docs/general/locations)的更多信息。

In [None]:
REGION = "us-central1"  # @param {type: "string"}

### 验证您的 Google Cloud 帐户

根据您的 Jupyter 环境，您可能需要手动进行身份验证。请按照以下相关说明操作。

1. 顶点人工智能工作台
* 无需操作，因为您已经验证通过。

2. 本地 JupyterLab 实例，取消注释并运行：

In [None]:
# ! gcloud auth login

3. Colab，取消注释并运行：

In [None]:
# from google.colab import auth
# auth.authenticate_user()

请参考https://cloud.google.com/storage/docs/gsutil/commands/iam#ch-examples，了解如何向您的服务帐号授予云存储权限。

身份验证：在从 Vertex AI Workbench 笔记本终端注销并且需要凭据时，请重新运行 `gcloud auth login` 命令。

只有协作：取消对以下单元格的注释以重新启动内核。

In [None]:
# Automatically restart kernel after installs so that your environment can access the new packages
# import IPython

# app = IPython.Application.instance()
# app.kernel.do_shutdown(True)

创建一个云存储桶

创建一个存储桶来存储中间产物，例如数据集。

In [None]:
BUCKET_URI = f"gs://your-bucket-name-{PROJECT_ID}-unique"  # @param {type:"string"}

只有当您还没有这个存储桶时：运行以下单元格来创建您的云存储存储桶。

In [None]:
! gsutil mb -l {REGION} -p {PROJECT_ID} {BUCKET_URI}

##准备数据

您将使用[DiffusionDB数据集](https://github.com/poloclub/diffusiondb)中的图像提示和图像对。

克隆DiffusionDB存储库###

In [None]:
! git clone https://github.com/poloclub/diffusiondb

### 为了下载数据集安装依赖项

In [None]:
! pip install -r diffusiondb/requirements.txt

下载图像文件###

In [None]:
# Download image files from 1 to 5. Each file is 1000 images.
! python diffusiondb/scripts/download.py -i 1 -r 5

提取图像存档

In [None]:
# Unzip all image files
image_directory = "extracted"

! unzip -n 'images/*.zip' -d '{image_directory}'

### 加载图像元数据

In [None]:
import json
import os

metadatas = {}
for file_name in os.listdir(image_directory):
    if file_name.endswith(".json"):
        with open(os.path.join(image_directory, file_name)) as f:
            metadata = json.load(f)
            metadatas.update(metadata)

image_names = list(metadatas.keys())
image_paths = [os.path.join(image_directory, image_name) for image_name in image_names]

len(metadatas)

### 定义检测显性图像的函数

定义一个函数来查询Cloud Vision API以检测潜在的显性图像。

了解更多关于[检测显性内容](https://cloud.google.com/vision/docs/detecting-safe-search)。

In [None]:
from typing import Optional

from google.cloud import vision

client = vision.ImageAnnotatorClient()


def detect_safe_search(path: str) -> Optional[bool]:
    """Detects unsafe features in the file."""
    import io

    image_file = io.open(path, "rb")
    content = image_file.read()
    image_file.close()

    image = vision.Image(content=content)

    response = client.safe_search_detection(image=image)

    if response.error.message:
        print(response.error.message)
        return None

    return response.safe_search_annotation

定义安全搜索标注转换为布尔值
定义一个将安全搜索标注结果转换为布尔值的函数。

In [None]:
from google.cloud.vision_v1.types.image_annotator import (Likelihood,
                                                          SafeSearchAnnotation)


# Returns true if some annotations have a potential safety issues
def convert_annotation_to_safety(safe_search_annotation: SafeSearchAnnotation) -> bool:
    return all(
        [
            (safe_level == Likelihood.VERY_UNLIKELY)
            or (safe_level == Likelihood.UNLIKELY)
            for safe_level in [
                safe_search_annotation.adult,
                safe_search_annotation.medical,
                safe_search_annotation.violence,
                safe_search_annotation.racy,
            ]
        ]
    )

### 执行限速显示图像检测

Google Cloud Vision 对API请求设置了速率限制。

使用速率限制器确保请求在此限制范围内。
为了获得更好的性能，请使用线程池进行并行请求。这超出了本笔记本的范围。

了解更多关于[配额和限制](https://cloud.google.com/vision/quotas?hl=en)。

In [None]:
import numpy as np
from pyrate_limiter import Duration, Limiter, RequestRate
from tqdm import tqdm

# Create a rate limiter with a limit of 1800 requests per second
limiter = Limiter(RequestRate(1800, Duration.MINUTE))

safe_search_annotations = []
for image_path in tqdm(image_paths, total=len(image_paths)):
    limiter.try_acquire()
    safe_search_annotations.append(detect_safe_search(image_path))

# Convert annotations to boolean values
is_safe_values_cloud_vision = list(
    map(convert_annotation_to_safety, safe_search_annotations)
)

# Print number of safe images found
print(
    f"Safe images = {np.array(is_safe_values_cloud_vision).sum()} out of {len(is_safe_values_cloud_vision)} images"
)

In [None]:
# Filter images by safety
metadatas = [
    metadata
    for metadata, is_safe in zip(metadatas, is_safe_values_cloud_vision)
    if is_safe
]
image_names = [
    image_name
    for image_name, is_safe in zip(image_names, is_safe_values_cloud_vision)
    if is_safe
]
image_paths = [
    image_path
    for image_path, is_safe in zip(image_paths, is_safe_values_cloud_vision)
    if is_safe
]

#### 实例化文本到图像编码模型

使用由OpenAI开发的[clip-vit-base-patch32编码器](https://huggingface.co/openai/clip-vit-base-patch32)将文本和图像转换为嵌入向量。

> CLIP模型是由OpenAI研究人员开发的，用于了解什么对计算机视觉任务的鲁棒性有贡献。该模型还被开发用于测试模型在零样本方式下对任意图像分类任务进行泛化的能力。

In [None]:
import torch
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast

MODEL_ID = "openai/clip-vit-base-patch32"

device = (
    "cuda"
    if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)

tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
model = CLIPModel.from_pretrained(MODEL_ID).to(device)

#### 定义编码函数

定义后续将接受文本和图像并将它们转换为嵌入的函数。

在此处查看更多信息: https://huggingface.co/openai/clip-vit-base-patch32#use-with-transformers

In [None]:
import copy
from typing import List

import numpy as np
from PIL import Image
from tqdm.auto import tqdm


def encode_text_to_embedding(model, tokenizer, text: List[str]) -> np.ndarray:
    inputs = tokenizer(text, return_tensors="pt")
    embeddings = model.get_text_features(**inputs)
    return embeddings.cpu().detach().numpy()


def encode_images_to_embedding(model, device, image_paths: List[str]) -> np.ndarray:
    images = [copy.deepcopy(Image.open(path)) for path in image_paths]
    image_pixel_values = processor(
        text=None, images=images, return_tensors="pt", padding=True
    )["pixel_values"].to(device)
    embeddings = model.get_image_features(pixel_values=image_pixel_values)
    return embeddings.squeeze(0).cpu().detach().numpy()


def encode_images_to_embedding_chunked(
    model, device, image_paths: List[str], batch_size: int = 16
) -> np.ndarray:
    embeddings_list = []

    # Process images in batches to prevent out-of-memory errors.
    for i in tqdm(range(0, len(image_paths), batch_size)):
        embeddings_list.append(
            encode_images_to_embedding(
                model=model, device=device, image_paths=image_paths[i : i + batch_size]
            )
        )

    return np.vstack(embeddings_list)

测试编码功能

对一部分数据进行编码，看看嵌入和距离度量是否合理。

根据[CLIP研究论文](https://arxiv.org/pdf/2103.00020.pdf)，嵌入的相似性是使用余弦相似度进行计算的。

In [None]:
import random

# Encode 1000 images
image_paths_filtered = random.sample(image_paths, 1000)
image_embeddings = encode_images_to_embedding_chunked(
    model=model, device=device, image_paths=image_paths_filtered
)

In [None]:
import numpy as np


def cosine_similarity(
    text_embedding: np.ndarray, image_embeddings: np.ndarray
) -> np.ndarray:
    # compute cosine similarity between text and image embeddings by taking the dot product normalized by the product of the L2 norms
    return np.divide(
        np.dot(text_embedding, image_embeddings.T),
        (
            np.linalg.norm(text_embedding)
            * np.linalg.norm(image_embeddings, axis=1, keepdims=True)
        ).squeeze(),
    )

In [None]:
import math

import matplotlib.pyplot as plt

text_query = "Birds in flight"

# Calculate text embedding of query
text_embedding = encode_text_to_embedding(
    model=model, tokenizer=tokenizer, text=[text_query]
)[0]

# Calculate cosine similarity
scores = cosine_similarity(
    text_embedding=text_embedding, image_embeddings=image_embeddings
)

# Set the maximum number of images to display
MAX_IMAGES = 20

# Sort images and scores by descending order of scores and select the top max_images
sorted_data = sorted(
    zip(image_paths_filtered, scores), key=lambda x: x[1], reverse=True
)[:MAX_IMAGES]

# Calculate the number of rows and columns needed to display the images
num_cols = 4
num_rows = math.ceil(len(sorted_data) / num_cols)


# Create a grid of subplots to display the images
fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(10, 12))

# Loop through the top max_images images and display them in the subplots
for i, (image_path, score) in enumerate(sorted_data):
    # Calculate the row and column index for the current image
    row_idx = i // num_cols
    col_idx = i % num_cols

    # Display the image in the current subplot
    image = copy.deepcopy(Image.open(image_path))
    axs[row_idx, col_idx].imshow(image, cmap="gray")

    # Set the title of the subplot to the image index and score
    axs[row_idx, col_idx].set_title(f"Rank {i+1}, Score = {score:.2f}")

    # Remove ticks from the subplot
    axs[row_idx, col_idx].set_xticks([])
    axs[row_idx, col_idx].set_yticks([])

# Adjust the spacing between subplots and display the plot
plt.subplots_adjust(hspace=0.3, wspace=0.1)
plt.show()

创建索引时保存维度大小以便将来使用。

In [None]:
DIMENSIONS = len(text_embedding)

DIMENSIONS

保存火车拆分为JSONL格式。

数据必须以JSONL格式格式化，这意味着每个嵌入字典都以自己的一行JSON字符串形式写入。

In [None]:
import tempfile

# Create temporary file to write embeddings to
embeddings_file = tempfile.NamedTemporaryFile(suffix=".json", delete=False)

embeddings_file.name

In [None]:
import json

BATCH_SIZE = 1000

with open(embeddings_file.name, "a") as f:
    for i in tqdm(range(0, len(image_names), BATCH_SIZE)):
        image_names_chunk = image_names[i : i + BATCH_SIZE]
        image_paths_chunk = image_paths[i : i + BATCH_SIZE]

        embeddings = encode_images_to_embedding_chunked(
            model=model, device=device, image_paths=image_paths_chunk
        )

        # Append to file
        embeddings_formatted = [
            json.dumps(
                {
                    "id": str(id),
                    "embedding": [str(value) for value in embedding],
                }
            )
            + "\n"
            for id, embedding in zip(image_names_chunk, embeddings)
        ]
        f.writelines(embeddings_formatted)

将训练数据上传到GCS。

In [None]:
UNIQUE_FOLDER_NAME = "embeddings_folder_unique"
EMBEDDINGS_INITIAL_URI = f"{BUCKET_URI}/{UNIQUE_FOLDER_NAME}/"
! gsutil cp {embeddings_file.name} {EMBEDDINGS_INITIAL_URI}

创建索引

创建 ANN 索引（用于生产环境）

In [None]:
DISPLAY_NAME = "text_to_image"
DESCRIPTION = "CLIP text_to_image embeddings"

创建ANN索引配置：

要了解更多关于配置索引的信息，请参阅[输入数据格式和结构](https://cloud.google.com/vertex-ai/docs/matching-engine/match-eng-setup#input-data-format)。

In [None]:
from google.cloud import aiplatform

aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)

In [None]:
tree_ah_index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
    display_name=DISPLAY_NAME,
    contents_delta_uri=EMBEDDINGS_INITIAL_URI,
    dimensions=DIMENSIONS,
    approximate_neighbors_count=150,
    distance_measure_type="COSINE_DISTANCE",
    leaf_node_embedding_count=500,
    leaf_nodes_to_search_percent=7,
    description=DESCRIPTION,
)

In [None]:
INDEX_RESOURCE_NAME = tree_ah_index.resource_name
INDEX_RESOURCE_NAME

使用资源名称，您可以检索现有的MatchingEngineIndex。

In [None]:
tree_ah_index = aiplatform.MatchingEngineIndex(index_name=INDEX_RESOURCE_NAME)

## 在VPC网络中创建一个索引终端点

In [None]:
# Retrieve the project number
PROJECT_NUMBER = !gcloud projects list --filter="PROJECT_ID:'{PROJECT_ID}'" --format='value(PROJECT_NUMBER)'
PROJECT_NUMBER = PROJECT_NUMBER[0]

VPC_NETWORK = "[YOUR-VPC-NETWORK]"
VPC_NETWORK_FULL = "projects/{}/global/networks/{}".format(PROJECT_NUMBER, VPC_NETWORK)
VPC_NETWORK_FULL

In [None]:
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
    display_name=DISPLAY_NAME,
    description=DISPLAY_NAME,
    network=VPC_NETWORK_FULL,
)

部署索引

### 部署 ANN 指数

In [None]:
DEPLOYED_INDEX_ID = "deployed_index_id_unique"

In [None]:
my_index_endpoint = my_index_endpoint.deploy_index(
    index=tree_ah_index, deployed_index_id=DEPLOYED_INDEX_ID
)

my_index_endpoint.deployed_indexes

创建在线查询

在建立了索引之后，您可以查询部署的索引以找到最近的邻居。

In [None]:
# Encode query
text_embeddings = encode_text_to_embedding(
    model=model, tokenizer=tokenizer, text=["New York skyline"]
)

In [None]:
# Define number of neighbors to return
NUM_NEIGHBORS = 20

response = my_index_endpoint.match(
    deployed_index_id=DEPLOYED_INDEX_ID,
    queries=text_embeddings,
    num_neighbors=NUM_NEIGHBORS,
)

response

绘制响应并验证图像是否与文本查询匹配。

In [None]:
# Sort images and scores by descending order of scores and select the top max_images
sorted_data = sorted(response[0], key=lambda x: x.distance, reverse=True)

# Create a grid of subplots to display the images
fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(10, 12))

# Loop through the top max_images images and display them in the subplots
for i, response in enumerate(sorted_data):
    image_path = f"{image_directory}/{response.id}"
    score = response.distance

    # Calculate the row and column index for the current image
    row_idx = i // num_cols
    col_idx = i % num_cols

    # Display the image in the current subplot
    if os.path.exists(image_path):
        image = copy.deepcopy(Image.open(image_path))
        axs[row_idx, col_idx].imshow(image, cmap="gray")

        # Set the title of the subplot to the image index and score
        axs[row_idx, col_idx].set_title(f"Rank {i+1}, Score = {score:.2f}")

        # Remove ticks from the subplot
        axs[row_idx, col_idx].set_xticks([])
        axs[row_idx, col_idx].set_yticks([])

# Adjust the spacing between subplots and display the plot
plt.subplots_adjust(hspace=0.3, wspace=0.1)
plt.show()

清理

要清理此项目中使用的所有Google Cloud资源，您可以删除用于本教程的[Google Cloud项目](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects)。
您也可以通过运行以下代码手动删除您创建的资源。

In [None]:
import os

delete_bucket = False

# Force undeployment of indexes and delete endpoint
my_index_endpoint.delete(force=True)

# Delete indexes
tree_ah_index.delete()

if delete_bucket or os.getenv("IS_TESTING"):
    ! gsutil -m rm -r $BUCKET_URI