In [None]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table align="left">

  <td>
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/matching_engine/matching_engine_for_indexing.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Colab logo">
      在Colab中运行
    </a>
  </td>
  <td>
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/matching_engine/matching_engine_for_indexing.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      在GitHub上查看
    </a>
  </td>
  <td>
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/main/notebooks/community/matching_engine/matching_engine_for_indexing.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo">
      在Vertex AI Workbench中打开
    </a>
  </td>                                                                                               
</table>

## 概述

这个示例演示了如何使用Vertex AI Vector Search。这是一个高规模、低延迟的解决方案，用于在大型语料库中查找类似的向量（或更具体地说是“嵌入”）。此外，它是一种全面管理的服务，进一步降低了运营开销。它是建立在谷歌研究开发的[近似最近邻（ANN）技术](https://ai.googleblog.com/2020/07/announcing-scann-efficient-vector.html)之上的。

### 目标

在这个笔记本中，您将学习如何创建近似最近邻居（ANN）索引，针对索引进行查询，并验证索引的性能。

执行的步骤包括：

* 创建一个Vertex AI矢量搜索索引和暴力索引
* 创建一个带VPC网络的IndexEndpoint
* 部署一个Vertex AI矢量搜索索引和暴力索引
* 执行在线查询
* 提交批量查询
* 计算召回率指标

数据集

该教程使用的数据集是[GloVe数据集](https://nlp.stanford.edu/projects/glove/)。

成本

本教程使用 Google Cloud 的计费组件：

* Vertex AI
* 云存储

了解 [Vertex AI
价格](https://cloud.google.com/vertex-ai/pricing) 和 [云存储
价格](https://cloud.google.com/storage/pricing)，并使用 [定价
计算器](https://cloud.google.com/products/calculator/)
根据您的预期使用量生成成本估算。

## 在开始之前

### 设置您的谷歌云项目

**以下步骤是必需的，无论您使用哪种笔记本环境。**

1. [选择或创建一个谷歌云项目](https://console.cloud.google.com/cloud-resource-manager)。当您首次创建账户时，您将获得$300的免费信用额度，可用于支付计算/存储成本。

2. [确保为您的项目启用了计费](https://cloud.google.com/billing/docs/how-to/modify-project)。

3. [启用 Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com)。

4. 如果您在本地运行此笔记本，您需要安装 [Cloud SDK](https://cloud.google.com/sdk)。

安装

下载并安装最新版本的Python的Vertex AI SDK。

In [None]:
! pip install -U git+https://github.com/googleapis/python-aiplatform.git@main --user

安装`h5py`来准备样本数据集，安装`grpcio-tools`用于查询索引。

In [None]:
! pip install protobuf==3.20.*
! pip install -U google-api-python-client==1.8.0 --user
! pip install -U grpcio-tools==1.47.0 --user
! pip install -U grpcio==1.47.0 --user
! pip install -U grpcio-status==1.47.0 --user
! pip install -U h5py --user

重新启动内核

安装额外的软件包后，您需要重新启动笔记本内核，这样它才能找到这些软件包。

In [None]:
# Automatically restart kernel after installs
import os

if not os.getenv("IS_TESTING"):
    # Automatically restart kernel after installs
    import IPython

    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)

设置您的项目ID

**如果您不知道您的项目ID**，请尝试以下操作：
* 运行 `gcloud config list`。
* 运行 `gcloud projects list`。
* 查看支持页面：[定位项目ID](https://support.google.com/googleapi/answer/7014113)

In [None]:
PROJECT_ID = "[your-project-id]"  # @param {type:"string"}

# Set the project id
! gcloud config set project {PROJECT_ID}

### 设置地区

您还可以更改 Vertex AI 使用的 `REGION` 变量。了解有关 [Vertex AI 地区](https://cloud.google.com/vertex-ai/docs/general/locations) 的更多信息。

- **警告：** 
    - **确保 [选择 Vertex AI 服务可用的地区](https://cloud.google.com/vertex-ai/docs/general/locations#available_regions)。**
    - **如果您使用 Vertex Workbench，则笔记本实例需位于部署您的 Vertex AI Vector Search 的相同地区。** （例如，如果您将 `REGION = "us-central1"` 设置为与教程相同的地区，则笔记本实例必须位于 `us-central1`）。

In [None]:
REGION = "us-central1"  # @param {type: "string"}

# Set the regions
! gcloud config set ai_platform/region {REGION}

### 认证您的谷歌云账户

根据您的Jupyter环境，您可能需要手动认证。请按照以下相关说明进行操作。

**1. 顶点 AI 工作台**
* 无需操作，因为您已经通过身份验证。

本地JupyterLab实例，请取消注释并运行：

In [None]:
# ! gcloud auth login

3. 合作，取消注释并运行:

In [None]:
# from google.colab import auth
# auth.authenticate_user()

查看如何在https://cloud.google.com/storage/docs/gsutil/commands/iam#ch-examples为您的服务账户授予云存储权限。

### 准备一个VPC网络

为了减少可能导致不必要的增加延迟的网络开销，最好通过直接的[VPC Peering](https://cloud.google.com/vertex-ai/docs/general/vpc-peering)连接从您的VPC中调用Vertex AI Vector Search端点。以下部分描述了如何设置VPC Peering连接，如果您还没有一个。这是一个一次性的初始设置任务。您也可以重用现有的VPC网络并跳过此部分。

* **警告：** 匹配服务的gRPC API（用于在线查询已部署的索引）必须在具有以下要求的Google Cloud Notebook实例中执行：
    * **确保您选择为Vertex AI Vector Search服务创建的VPC网络**（而不是使用“默认”网络）。也就是说，您将需要创建下面的VPC网络，然后创建一个使用该VPC的新笔记本实例。
    * 如果您在colab或不同VPC网络或区域的Google Cloud Notebook实例中运行，gRPC API将无法连接网络（InactiveRPCError）。

In [None]:
NETWORK_NAME = "ucaip-haystack-vpc-network"  # @param {type:"string"}
PEERING_RANGE_NAME = "ucaip-haystack-range"

In [None]:
# Create a VPC network
! gcloud compute networks create {NETWORK_NAME} --bgp-routing-mode=regional --subnet-mode=auto --project={PROJECT_ID}

# Add necessary firewall rules
! gcloud compute firewall-rules create {NETWORK_NAME}-allow-icmp --network {NETWORK_NAME} --priority 65534 --project {PROJECT_ID} --allow icmp

! gcloud compute firewall-rules create {NETWORK_NAME}-allow-internal --network {NETWORK_NAME} --priority 65534 --project {PROJECT_ID} --allow all --source-ranges 10.128.0.0/9

! gcloud compute firewall-rules create {NETWORK_NAME}-allow-rdp --network {NETWORK_NAME} --priority 65534 --project {PROJECT_ID} --allow tcp:3389

! gcloud compute firewall-rules create {NETWORK_NAME}-allow-ssh --network {NETWORK_NAME} --priority 65534 --project {PROJECT_ID} --allow tcp:22

# Reserve IP range
! gcloud compute addresses create {PEERING_RANGE_NAME} --global --prefix-length=16 --network={NETWORK_NAME} --purpose=VPC_PEERING --project={PROJECT_ID} --description="peering range for uCAIP Haystack."

创建VPC Peering。如果您正在从Vertex AI Workbench运行此操作，可能需要确保您的笔记本实例服务或用户帐户具有服务网络管理员角色。

In [None]:
# Set up peering with service networking
! gcloud services vpc-peerings connect --service=servicenetworking.googleapis.com --network={NETWORK_NAME} --ranges={PEERING_RANGE_NAME} --project={PROJECT_ID}

### 创建云存储桶

**无论您使用什么笔记本环境，都需要执行以下步骤。**

创建一个存储桶来存储中间产物，例如数据集。在下方设置您的云存储桶的名称。它必须在所有云存储桶中是唯一的。

* **警告：**
    * **您不能在 Vertex AI 训练中使用多区域存储桶。**

In [None]:
BUCKET_NAME = "gs://[your-bucket-name-unique]"  # @param {type:"string"}

In [None]:
from datetime import datetime

UUID = datetime.now().strftime("%Y%m%d%H%M%S")

if (
    BUCKET_NAME == ""
    or BUCKET_NAME is None
    or BUCKET_NAME == "gs://[your-bucket-name-unique]"
):
    BUCKET_NAME = "gs://" + PROJECT_ID + "aip-" + UUID

只有当您的存储桶不存在时：运行以下单元格以创建您的云存储桶。

In [None]:
! gsutil mb -l $REGION -p $PROJECT_ID $BUCKET_NAME

最后，通过检查云存储桶的内容来验证对其的访问权限。

In [None]:
! gsutil ls -al $BUCKET_NAME

导入库并定义常量

将Vertex AI（统一）客户端库导入到您的Python环境中。

In [None]:
import time

import grpc
import h5py
from google.cloud import aiplatform_v1beta1
from google.protobuf import struct_pb2

In [None]:
ENDPOINT = "{}-aiplatform.googleapis.com".format(REGION)

AUTH_TOKEN = !gcloud auth print-access-token
PROJECT_NUMBER = !gcloud projects list --filter="PROJECT_ID:'{PROJECT_ID}'" --format='value(PROJECT_NUMBER)'
PROJECT_NUMBER = PROJECT_NUMBER[0]

PARENT = "projects/{}/locations/{}".format(PROJECT_ID, REGION)

print("ENDPOINT: {}".format(ENDPOINT))
print("PROJECT_ID: {}".format(PROJECT_ID))
print("REGION: {}".format(REGION))

准备数据

GloVe数据集包含一组预训练的嵌入。这些嵌入被分成"train"和"test"两部分。
我们将从“train”部分创建一个向量搜索索引，并使用“test”部分中的嵌入向量作为查询向量来测试向量搜索索引。

注意：尽管数据分成了“train”部分，但这些都是预训练的嵌入，因此可以被索引用于搜索。术语“train”和“test”分割仅用于与通常的机器学习术语保持一致。

下载GloVe数据集。

In [None]:
! gsutil cp gs://cloud-samples-data/vertex-ai/matching_engine/glove-100-angular.hdf5 .

将数据读入内存。

In [None]:
# The number of nearest neighbors to be retrieved from database for each query.
k = 10

h5 = h5py.File("glove-100-angular.hdf5", "r")
train = h5["train"]
test = h5["test"]

In [None]:
train[0]

将火车行程保存为JSONL格式。

In [None]:
with open("glove100.json", "w") as f:
    for i in range(len(train)):
        f.write('{"id":"' + str(i) + '",')
        f.write('"embedding":[' + ",".join(str(x) for x in train[i]) + "]}")
        f.write("\n")

将训练数据上传至Google云存储

In [None]:
# NOTE: Everything in this Google Cloud Storage directory will be DELETED before uploading the data

! gsutil rm -raf {BUCKET_NAME}/** 2> /dev/null || true

In [None]:
! gsutil cp glove100.json {BUCKET_NAME}/glove100.json

In [None]:
! gsutil ls {BUCKET_NAME}

创建索引

### 创建用于生产的Vertex AI矢量搜索索引

In [None]:
index_client = aiplatform_v1beta1.IndexServiceClient(
    client_options=dict(api_endpoint=ENDPOINT)
)

设定常数

In [None]:
DIMENSIONS = 100
DISPLAY_NAME = "glove_100_1"
DISPLAY_NAME_BRUTE_FORCE = DISPLAY_NAME + "_brute_force"

#### 创建顶点 AI 矢量搜索索引配置

请阅读[文档](https://cloud.google.com/vertex-ai/docs/matching-engine/configuring-indexes)，了解可用于调整索引的各种配置参数。

In [None]:
treeAhConfig = struct_pb2.Struct(
    fields={
        "leafNodeEmbeddingCount": struct_pb2.Value(number_value=500),
        "leafNodesToSearchPercent": struct_pb2.Value(number_value=7),
    }
)

algorithmConfig = struct_pb2.Struct(
    fields={"treeAhConfig": struct_pb2.Value(struct_value=treeAhConfig)}
)

config = struct_pb2.Struct(
    fields={
        "dimensions": struct_pb2.Value(number_value=DIMENSIONS),
        "approximateNeighborsCount": struct_pb2.Value(number_value=150),
        "distanceMeasureType": struct_pb2.Value(string_value="DOT_PRODUCT_DISTANCE"),
        "algorithmConfig": struct_pb2.Value(struct_value=algorithmConfig),
    }
)

metadata = struct_pb2.Struct(
    fields={
        "config": struct_pb2.Value(struct_value=config),
        "contentsDeltaUri": struct_pb2.Value(string_value=BUCKET_NAME),
    }
)

matching_engine_index = {
    "display_name": DISPLAY_NAME,
    "description": "Glove 100 Vertex AI Vector Search Index",
    "metadata": struct_pb2.Value(struct_value=metadata),
}

In [None]:
matching_engine_index = index_client.create_index(
    parent=PARENT, index=matching_engine_index
)

In [None]:
# Poll the operation until it's done successfullly.
# This will take ~45 min.

while True:
    if matching_engine_index.done():
        break
    print("Poll the operation to create index...")
    time.sleep(60)

In [None]:
INDEX_RESOURCE_NAME = matching_engine_index.result().name
INDEX_RESOURCE_NAME

创建暴力索引（用于地面真相）

暴力索引使用天真的暴力方法来查找最近的邻居。这种方法不快，也不高效。因此，不建议在生产中使用暴力索引。它们用于查找“地面真相”邻居集，以便可以使用该“地面真相”集来衡量为生产使用调整的索引的召回率。为了确保苹果和苹果的比较，暴力索引的`distanceMeasureType`，`featureNormType`，`dimensions`应该与用于调整生产索引的索引相匹配。

创建暴力索引配置：

In [None]:
algorithmConfig = struct_pb2.Struct(
    fields={"bruteForceConfig": struct_pb2.Value(struct_value=struct_pb2.Struct())}
)

config = struct_pb2.Struct(
    fields={
        "dimensions": struct_pb2.Value(number_value=DIMENSIONS),
        "approximateNeighborsCount": struct_pb2.Value(number_value=150),
        "distanceMeasureType": struct_pb2.Value(string_value="DOT_PRODUCT_DISTANCE"),
        "algorithmConfig": struct_pb2.Value(struct_value=algorithmConfig),
    }
)

metadata = struct_pb2.Struct(
    fields={
        "config": struct_pb2.Value(struct_value=config),
        "contentsDeltaUri": struct_pb2.Value(string_value=BUCKET_NAME),
    }
)

brute_force_index = {
    "display_name": DISPLAY_NAME_BRUTE_FORCE,
    "description": "Glove 100 index (brute force)",
    "metadata": struct_pb2.Value(struct_value=metadata),
}

In [None]:
brute_force_index = index_client.create_index(parent=PARENT, index=brute_force_index)

In [None]:
# Poll the operation until it's done successfullly.
# This will take ~45 min.

while True:
    if brute_force_index.done():
        break
    print("Poll the operation to create index...")
    time.sleep(60)

In [None]:
INDEX_BRUTE_FORCE_RESOURCE_NAME = brute_force_index.result().name
INDEX_BRUTE_FORCE_RESOURCE_NAME

## 更新索引

创建增量数据文件。

In [None]:
with open("glove100_incremental.json", "w") as f:
    f.write(
        '{"id":"0","embedding":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]}\n'
    )

将增量数据文件复制到一个新的子目录中。

In [None]:
! gsutil cp glove100_incremental.json {BUCKET_NAME}/incremental/glove100.json

创建更新索引请求

In [None]:
metadata = struct_pb2.Struct(
    fields={
        "contentsDeltaUri": struct_pb2.Value(string_value=BUCKET_NAME + "/incremental"),
    }
)

matching_engine_index = {
    "name": INDEX_RESOURCE_NAME,
    "display_name": DISPLAY_NAME,
    "description": "Glove 100 Vertex AI Vector Search Index",
    "metadata": struct_pb2.Value(struct_value=metadata),
}

In [None]:
matching_engine_index = index_client.update_index(index=matching_engine_index)

In [None]:
# Poll the operation until it's done successfullly.
# This will take ~45 min.

while True:
    if matching_engine_index.done():
        break
    print("Poll the operation to update index...")
    time.sleep(60)

In [None]:
INDEX_RESOURCE_NAME = matching_engine_index.result().name
INDEX_RESOURCE_NAME

创建一个带有VPC网络的索引终端点

In [None]:
index_endpoint_client = aiplatform_v1beta1.IndexEndpointServiceClient(
    client_options=dict(api_endpoint=ENDPOINT)
)

In [None]:
VPC_NETWORK_NAME = "projects/{}/global/networks/{}".format(PROJECT_NUMBER, NETWORK_NAME)
VPC_NETWORK_NAME

In [None]:
index_endpoint = {
    "display_name": "index_endpoint_for_demo",
    "network": VPC_NETWORK_NAME,
}

In [None]:
r = index_endpoint_client.create_index_endpoint(
    parent=PARENT, index_endpoint=index_endpoint
)

In [None]:
r.result()

In [None]:
INDEX_ENDPOINT_NAME = r.result().name
INDEX_ENDPOINT_NAME

部署索引

部署一个 Vertex AI 向量搜索索引

In [None]:
DEPLOYED_INDEX_ID = "matching_engine_glove_deployed"

In [None]:
deploy_matching_engine_index = {
    "id": DEPLOYED_INDEX_ID,
    "display_name": DEPLOYED_INDEX_ID,
    "index": INDEX_RESOURCE_NAME,
}

如果下一个命令出现错误，请等待一些分钟以便索引端点被创建，并重试。

In [None]:
r = index_endpoint_client.deploy_index(
    index_endpoint=INDEX_ENDPOINT_NAME, deployed_index=deploy_matching_engine_index
)

In [None]:
# Poll the operation until it's done successfullly.

while True:
    if r.done():
        break
    print("Poll the operation to deploy index...")
    time.sleep(60)

In [None]:
r.result()

### 部署強制索引

In [None]:
DEPLOYED_BRUTE_FORCE_INDEX_ID = "glove_brute_force_deployed"

In [None]:
deploy_brute_force_index = {
    "id": DEPLOYED_BRUTE_FORCE_INDEX_ID,
    "display_name": DEPLOYED_BRUTE_FORCE_INDEX_ID,
    "index": INDEX_BRUTE_FORCE_RESOURCE_NAME,
}

In [None]:
r = index_endpoint_client.deploy_index(
    index_endpoint=INDEX_ENDPOINT_NAME, deployed_index=deploy_brute_force_index
)

In [None]:
# Poll the operation until it's done successfullly.

while True:
    if r.done():
        break
    print("Poll the operation to deploy index...")
    time.sleep(60)

In [None]:
r.result()

## 创建在线查询

在构建了索引之后，您可以通过在线查询gRPC API（Match服务）对部署的索引进行查询，在来自同一地区的虚拟机实例中（例如本教程中的 'us-central1'）。客户端使用这个gRPC API 的步骤如下：

* 在本地编写 `match_service.proto` 
* 在终端中克隆包含 match_service.proto 依赖项的存储库：

`$ mkdir third_party && cd third_party`

`$ git clone https://github.com/googleapis/googleapis.git`

* 编译协议缓冲区（见下文）
* 获取索引端点
* 使用生成的存根进行调用，传递参数值

### 解决连接问题

如果出现连接问题，请执行以下步骤：

* 验证索引端点、索引和VPC是否都在同一个Google Cloud项目中
* 验证索引端点、索引和VPC是否都在同一地区，并且是有效的（例如 us-central1）
* 验证网络中是否有防火墙规则拒绝所有外发连接。如果有，请禁用此规则或用允许连接到索引端点IP的另一个规则覆盖它。

In [None]:
%%writefile match_service.proto

syntax = "proto3";

package google.cloud.aiplatform.container.v1beta1;

import "google/rpc/status.proto";

// MatchService is a Google managed service for efficient vector similarity
// search at scale.
service MatchService {
  // Returns the nearest neighbors for the query. If it is a sharded
  // deployment, calls the other shards and aggregates the responses.
  rpc Match(MatchRequest) returns (MatchResponse) {}

  // Returns the nearest neighbors for batch queries. If it is a sharded
  // deployment, calls the other shards and aggregates the responses.
  rpc BatchMatch(BatchMatchRequest) returns (BatchMatchResponse) {}
}

// Parameters for a match query.
message MatchRequest {
  // The ID of the DeploydIndex that will serve the request.
  // This MatchRequest is sent to a specific IndexEndpoint of the Control API,
  // as per the IndexEndpoint.network. That IndexEndpoint also has
  // IndexEndpoint.deployed_indexes, and each such index has an
  // DeployedIndex.id field.
  // The value of the field below must equal one of the DeployedIndex.id
  // fields of the IndexEndpoint that is being called for this request.
  string deployed_index_id = 1;

  // The embedding values.
  repeated float float_val = 2;

  // The number of nearest neighbors to be retrieved from database for
  // each query. If not set, will use the default from
  // the service configuration.
  int32 num_neighbors = 3;

  // The list of restricts.
  repeated Namespace restricts = 4;

  // Crowding is a constraint on a neighbor list produced by nearest neighbor
  // search requiring that no more than some value k' of the k neighbors
  // returned have the same value of crowding_attribute.
  // It's used for improving result diversity.
  // This field is the maximum number of matches with the same crowding tag.
  int32 per_crowding_attribute_num_neighbors = 5;

  // The number of neighbors to find via approximate search before
  // exact reordering is performed. If not set, the default value from scam
  // config is used; if set, this value must be > 0.
  int32 approx_num_neighbors = 6;

  // The fraction of the number of leaves to search, set at query time allows
  // user to tune search performance. This value increase result in both search
  // accuracy and latency increase. The value should be between 0.0 and 1.0. If
  // not set or set to 0.0, query uses the default value specified in
  // NearestNeighborSearchConfig.TreeAHConfig.leaf_nodes_to_search_percent.
  int32 leaf_nodes_to_search_percent_override = 7;
}

// Response of a match query.
message MatchResponse {
  message Neighbor {
    // The ids of the matches.
    string id = 1;

    // The distances of the matches.
    double distance = 2;
  }
  // All its neighbors.
  repeated Neighbor neighbor = 1;
}

// Parameters for a batch match query.
message BatchMatchRequest {
  // Batched requests against one index.
  message BatchMatchRequestPerIndex {
    // The ID of the DeploydIndex that will serve the request.
    string deployed_index_id = 1;

    // The requests against the index identified by the above deployed_index_id.
    repeated MatchRequest requests = 2;

    // Selects the optimal batch size to use for low-level batching. Queries
    // within each low level batch are executed sequentially while low level
    // batches are executed in parallel.
    // This field is optional, defaults to 0 if not set. A non-positive number
    // disables low level batching, i.e. all queries are executed sequentially.
    int32 low_level_batch_size = 3;
  }

  // The batch requests grouped by indexes.
  repeated BatchMatchRequestPerIndex requests = 1;
}

// Response of a batch match query.
message BatchMatchResponse {
  // Batched responses for one index.
  message BatchMatchResponsePerIndex {
    // The ID of the DeployedIndex that produced the responses.
    string deployed_index_id = 1;

    // The match responses produced by the index identified by the above
    // deployed_index_id. This field is set only when the query against that
    // index succeed.
    repeated MatchResponse responses = 2;

    // The status of response for the batch query identified by the above
    // deployed_index_id.
    google.rpc.Status status = 3;
  }

  // The batched responses grouped by indexes.
  repeated BatchMatchResponsePerIndex responses = 1;
}

// Namespace specifies the rules for determining the datapoints that are
// eligible for each matching query, overall query is an AND across namespaces.
message Namespace {
  // The string name of the namespace that this proto is specifying,
  // such as "color", "shape", "geo", or "tags".
  string name = 1;

  // The allowed tokens in the namespace.
  repeated string allow_tokens = 2;

  // The denied tokens in the namespace.
  // The denied tokens have exactly the same format as the token fields, but
  // represents a negation. When a token is denied, then matches will be
  // excluded whenever the other datapoint has that token.
  //
  // For example, if a query specifies {color: red, blue, !purple}, then that
  // query will match datapoints that are red or blue, but if those points are
  // also purple, then they will be excluded even if they are red/blue.
  repeated string deny_tokens = 3;
}

编译协议缓冲区，然后生成`match_service_pb2.py`和`match_service_pb2_grpc.py`。

In [None]:
! python -m grpc_tools.protoc -I=. --proto_path=third_party/googleapis --python_out=. --grpc_python_out=. match_service.proto

获得私有端点:

In [None]:
DEPLOYED_INDEX_SERVER_IP = (
    list(index_endpoint_client.list_index_endpoints(parent=PARENT))[0]
    .deployed_indexes[0]
    .private_endpoints.match_grpc_address
)
DEPLOYED_INDEX_SERVER_IP

测试您的查询: 

测验你的查询:

In [None]:
import match_service_pb2
import match_service_pb2_grpc

channel = grpc.insecure_channel("{}:10000".format(DEPLOYED_INDEX_SERVER_IP))
stub = match_service_pb2_grpc.MatchServiceStub(channel)

In [None]:
# Test query
query = [
    -0.11333,
    0.48402,
    0.090771,
    -0.22439,
    0.034206,
    -0.55831,
    0.041849,
    -0.53573,
    0.18809,
    -0.58722,
    0.015313,
    -0.014555,
    0.80842,
    -0.038519,
    0.75348,
    0.70502,
    -0.17863,
    0.3222,
    0.67575,
    0.67198,
    0.26044,
    0.4187,
    -0.34122,
    0.2286,
    -0.53529,
    1.2582,
    -0.091543,
    0.19716,
    -0.037454,
    -0.3336,
    0.31399,
    0.36488,
    0.71263,
    0.1307,
    -0.24654,
    -0.52445,
    -0.036091,
    0.55068,
    0.10017,
    0.48095,
    0.71104,
    -0.053462,
    0.22325,
    0.30917,
    -0.39926,
    0.036634,
    -0.35431,
    -0.42795,
    0.46444,
    0.25586,
    0.68257,
    -0.20821,
    0.38433,
    0.055773,
    -0.2539,
    -0.20804,
    0.52522,
    -0.11399,
    -0.3253,
    -0.44104,
    0.17528,
    0.62255,
    0.50237,
    -0.7607,
    -0.071786,
    0.0080131,
    -0.13286,
    0.50097,
    0.18824,
    -0.54722,
    -0.42664,
    0.4292,
    0.14877,
    -0.0072514,
    -0.16484,
    -0.059798,
    0.9895,
    -0.61738,
    0.054169,
    0.48424,
    -0.35084,
    -0.27053,
    0.37829,
    0.11503,
    -0.39613,
    0.24266,
    0.39147,
    -0.075256,
    0.65093,
    -0.20822,
    -0.17456,
    0.53571,
    -0.16537,
    0.13582,
    -0.56016,
    0.016964,
    0.1277,
    0.94071,
    -0.22608,
    -0.021106,
]

request = match_service_pb2.MatchRequest()
request.deployed_index_id = DEPLOYED_INDEX_ID
for val in query:
    request.float_val.append(val)

response = stub.Match(request)
response

## 提交批量查询

您可以使用BatchMatch API在单个RPC调用中运行多个查询。

In [None]:
def get_request(embedding, deployed_index_id):
    request = match_service_pb2.MatchRequest(num_neighbors=k)
    request.deployed_index_id = deployed_index_id
    for val in embedding:
        request.float_val.append(val)
    return request

In [None]:
# Test query
queries = [
    [
        -0.11333,
        0.48402,
        0.090771,
        -0.22439,
        0.034206,
        -0.55831,
        0.041849,
        -0.53573,
        0.18809,
        -0.58722,
        0.015313,
        -0.014555,
        0.80842,
        -0.038519,
        0.75348,
        0.70502,
        -0.17863,
        0.3222,
        0.67575,
        0.67198,
        0.26044,
        0.4187,
        -0.34122,
        0.2286,
        -0.53529,
        1.2582,
        -0.091543,
        0.19716,
        -0.037454,
        -0.3336,
        0.31399,
        0.36488,
        0.71263,
        0.1307,
        -0.24654,
        -0.52445,
        -0.036091,
        0.55068,
        0.10017,
        0.48095,
        0.71104,
        -0.053462,
        0.22325,
        0.30917,
        -0.39926,
        0.036634,
        -0.35431,
        -0.42795,
        0.46444,
        0.25586,
        0.68257,
        -0.20821,
        0.38433,
        0.055773,
        -0.2539,
        -0.20804,
        0.52522,
        -0.11399,
        -0.3253,
        -0.44104,
        0.17528,
        0.62255,
        0.50237,
        -0.7607,
        -0.071786,
        0.0080131,
        -0.13286,
        0.50097,
        0.18824,
        -0.54722,
        -0.42664,
        0.4292,
        0.14877,
        -0.0072514,
        -0.16484,
        -0.059798,
        0.9895,
        -0.61738,
        0.054169,
        0.48424,
        -0.35084,
        -0.27053,
        0.37829,
        0.11503,
        -0.39613,
        0.24266,
        0.39147,
        -0.075256,
        0.65093,
        -0.20822,
        -0.17456,
        0.53571,
        -0.16537,
        0.13582,
        -0.56016,
        0.016964,
        0.1277,
        0.94071,
        -0.22608,
        -0.021106,
    ],
    [
        -0.99544,
        -2.3651,
        -0.24332,
        -1.0321,
        0.42052,
        -1.1817,
        -0.16451,
        -1.683,
        0.49673,
        -0.27258,
        -0.025397,
        0.34188,
        1.5523,
        1.3532,
        0.33297,
        -0.0056677,
        -0.76525,
        0.49587,
        1.2211,
        0.83394,
        -0.20031,
        -0.59657,
        0.38485,
        -0.23487,
        -1.0725,
        0.95856,
        0.16161,
        -1.2496,
        1.6751,
        0.73899,
        0.051347,
        -0.42702,
        0.16257,
        -0.16772,
        0.40146,
        0.29837,
        0.96204,
        -0.36232,
        -0.47848,
        0.78278,
        0.14834,
        1.3407,
        0.47834,
        -0.39083,
        -1.037,
        -0.24643,
        -0.75841,
        0.7669,
        -0.37363,
        0.52741,
        0.018563,
        -0.51301,
        0.97674,
        0.55232,
        1.1584,
        0.73715,
        1.3055,
        -0.44743,
        -0.15961,
        0.85006,
        -0.34092,
        -0.67667,
        0.2317,
        1.5582,
        1.2308,
        -0.62213,
        -0.032801,
        0.1206,
        -0.25899,
        -0.02756,
        -0.52814,
        -0.93523,
        0.58434,
        -0.24799,
        0.37692,
        0.86527,
        0.069626,
        1.3096,
        0.29975,
        -1.3651,
        -0.32048,
        -0.13741,
        0.33329,
        -1.9113,
        -0.60222,
        -0.23921,
        0.12664,
        -0.47961,
        -0.89531,
        0.62054,
        0.40869,
        -0.08503,
        0.6413,
        -0.84044,
        -0.74325,
        -0.19426,
        0.098722,
        0.32648,
        -0.67621,
        -0.62692,
    ],
]

batch_request = match_service_pb2.BatchMatchRequest()
batch_request_matching_engine = (
    match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex()
)
batch_request_brute_force = (
    match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex()
)
batch_request_matching_engine.deployed_index_id = DEPLOYED_INDEX_ID
batch_request_brute_force.deployed_index_id = DEPLOYED_BRUTE_FORCE_INDEX_ID
for query in queries:
    batch_request_matching_engine.requests.append(get_request(query, DEPLOYED_INDEX_ID))
    batch_request_brute_force.requests.append(
        get_request(query, DEPLOYED_BRUTE_FORCE_INDEX_ID)
    )
batch_request.requests.append(batch_request_matching_engine)
batch_request.requests.append(batch_request_brute_force)

response = stub.BatchMatch(batch_request)
response

### 计算召回率指标

使用部署的蛮力索引作为基准真相，计算Vertex AI Vector Search索引的召回率。

In [None]:
def get_neighbors(embedding, deployed_index_id):
    request = match_service_pb2.MatchRequest(num_neighbors=k)
    request.deployed_index_id = deployed_index_id
    for val in embedding:
        request.float_val.append(val)
    response = stub.Match(request)
    return [int(n.id) for n in response.neighbor]

In [None]:
# This will take 5-10 min

recall = sum(
    
        len(
            set(get_neighbors(test[i], DEPLOYED_BRUTE_FORCE_INDEX_ID)).intersection(
                set(get_neighbors(test[i], DEPLOYED_INDEX_ID))
            )
        )
        for i in range(len(test))
    
) / (1.0 * len(test) * k)

print("Recall: {}".format(recall))

清理

要清理此项目中使用的所有谷歌云资源，您可以[删除用于本教程的谷歌云项目](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects)。

否则，您可以删除您在本教程中创建的各个资源。

删除Vertex AI矢量搜索资源###

In [None]:
index_client.delete_index(name=INDEX_RESOURCE_NAME)
index_client.delete_index(name=INDEX_BRUTE_FORCE_RESOURCE_NAME)

In [None]:
index_endpoint_client.delete_index_endpoint(name=INDEX_ENDPOINT_NAME)

删除 Google Cloud Storage 存储通列。

In [None]:
import os

delete_bucket = False
if delete_bucket or os.getenv("IS_TESTING"):
    ! gsutil -m rm -r $BUCKET_NAME