In [None]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table align="left">

  <td>
    <a href="https://console.cloud.google.com/vertex-ai/notebooks/deploy-notebook?download_url=https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/master/notebooks/community/cohere/cohere_embedding_with_matching_engine.ipynb">
      在Google Cloud Notebooks中运行
    </a>
  </td>
  <td>
    <a href="github.com/GoogleCloudPlatform/vertex-ai-samples/blob/master/notebooks/community/cohere/cohere_embedding_with_matching_engine.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      在GitHub上查看
    </a>
  </td>
</table>

## 概述

在这个笔记本中，您将学习如何使用co.embed快速捕捉关于一些输入数据的语义信息。然后，您将使用这些新特性与Vertex AI Matching Engine的近似最近邻（ANN）服务来查找相似的文本。

### 数据集

本教程使用的数据集是来自[TensorFlow数据集](https://www.tensorflow.org/datasets)的[ag_news_subset](https://www.tensorflow.org/datasets/catalog/ag_news_subset)。最终解决方案将能够找到与用户提供的主题相似的文章。

### 目标

在这个笔记本中，您将学习如何使用Cohere的API创建嵌入，然后利用Vertex AI Matching Engine创建一个索引，并对索引进行查询以找到相似的文本。

### 成本

本教程使用Google Cloud和Cohere的计费组件：

* Vertex AI
* Cloud Storage
* Cohere co.embed终端点使用

了解[Vertex AI价格](https://cloud.google.com/vertex-ai/pricing)和[Cloud Storage价格](https://cloud.google.com/storage/pricing)，并使用[Pricing计算器](https://cloud.google.com/products/calculator/)根据您的预期使用情况生成成本估算。

了解[Cohere价格](https://cohere.ai/pricing)

在您开始之前
让我们设置在整个演示中将要使用的变量

In [None]:
COHERE_API_KEY = "{API KEY}"
GOOGLE_PROJECT_ID = "{Project ID}"
NETWORK_NAME = "{Network Name}"
PEERING_RANGE_NAME = "{Range Name}"
BUCKET_NAME = "gs://{Bucket Name}"
REGION = "us-central1"

## 创建 VPC 网络

* **准备一个 VPC 网络**。为了减少可能导致不必要的延迟增加的网络开销，最好通过直接的 [VPC Peering](https://cloud.google.com/vertex-ai/docs/general/vpc-peering) 连接从您的 VPC 调用 ANN 端点。下面的部分描述了如何设置 VPC Peering 连接，如果您还没有的话。这是一个一次性的初始设置任务。您也可以重用现有的 VPC 网络，跳过这一部分。
* **警告：**匹配服务 gRPC API（用于针对您部署的索引创建在线查询）必须在满足以下要求的 Google Cloud 笔记本实例中执行：
  * **与您的 ANN 服务部署的地区相同**（例如，如果您将 `REGION = "us-central1"` 设置为与教程相同，笔记本实例必须位于 `us-central1`）。
  * **确保您选择了为 ANN 服务创建的 VPC 网络**（而不是使用“default”）。也就是说，您将需要创建以下 VPC 网络，然后创建一个使用该 VPC 的新笔记本实例。
  * 如果您在 colab 中运行它，或者在不同的 VPC 网络或地区中的 Google Cloud 笔记本实例中运行它，则 gRPC API 将无法对等连接网络（InactiveRPCError）。

In [None]:
PROJECT_ID = GOOGLE_PROJECT_ID  # @param {type:"string"}

# Create a VPC network
! gcloud compute networks create {NETWORK_NAME} --bgp-routing-mode=regional --subnet-mode=auto --project={PROJECT_ID}

# Add necessary firewall rules
! gcloud compute firewall-rules create {NETWORK_NAME}-allow-icmp --network {NETWORK_NAME} --priority 65534 --project {PROJECT_ID} --allow icmp

! gcloud compute firewall-rules create {NETWORK_NAME}-allow-internal --network {NETWORK_NAME} --priority 65534 --project {PROJECT_ID} --allow all --source-ranges 10.128.0.0/9

! gcloud compute firewall-rules create {NETWORK_NAME}-allow-rdp --network {NETWORK_NAME} --priority 65534 --project {PROJECT_ID} --allow tcp:3389

! gcloud compute firewall-rules create {NETWORK_NAME}-allow-ssh --network {NETWORK_NAME} --priority 65534 --project {PROJECT_ID} --allow tcp:22

# Reserve IP range
! gcloud compute addresses create {PEERING_RANGE_NAME} --global --prefix-length=16 --network={NETWORK_NAME} --purpose=VPC_PEERING --project={PROJECT_ID} --description="peering range for uCAIP Haystack."

# Set up peering with service networking
! gcloud services vpc-peerings connect --service=servicenetworking.googleapis.com --network={NETWORK_NAME} --ranges={PEERING_RANGE_NAME} --project={PROJECT_ID}

身份验证：当您注销并需要凭证时，请在Google Cloud笔记本终端中重新运行`$ gcloud auth login`。

## 安装

安装 `tensorflow_datasets` 来准备样本数据集，并安装 `grpcio-tools` 用于针对索引进行查询。

In [None]:
! pip install -U grpcio-tools --user
! pip install -U tensorflow==2.9.1 --user
! pip install -U tensorflow-datasets --user

### 下载并安装最新版本（预览版）的Python Vertex SDK。

In [None]:
! pip install -U git+https://github.com/googleapis/python-aiplatform.git@main-test --user

安装 `Cohere`

In [None]:
! pip install -U cohere --user

### 重新启动内核（Colab）

安装额外的软件包后，您需要重新启动笔记本内核，以便它可以找到这些软件包。

In [None]:
# Automatically restart kernel after installs
import os

if not os.getenv("IS_TESTING"):
    # Automatically restart kernel after installs
    import IPython

    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)

使用Cohere创建嵌入向量

In [None]:
# Lets start by loading the dataset with tensorflow-datasets
import tensorflow_datasets as tfds

dataset = tfds.load("ag_news_subset", split="train", shuffle_files=True)

In [None]:
# For speed and cost considerations, lets limit the dataset to 1000 examples
df = tfds.as_dataframe(dataset.take(1000), tfds.builder("ag_news_subset").info)
df["text"] = df["description"].apply(lambda x: x.decode())

In [None]:
# Finally, lets import cohere and use co.embed to create representations for these 1000 articles
import cohere

co = cohere.Client(COHERE_API_KEY)

In [None]:
# running each of the examples through the embedding endpoint
response = co.embed(model="small", texts=list(df["text"].values))

cohere_embeddings = response.embeddings

### 设置您的Google Cloud项目

**无论您使用哪种笔记本环境，下面的步骤都是必需的。**

1. [选择或创建一个Google Cloud项目](https://console.cloud.google.com/cloud-resource-manager).

1. [确保为您的项目启用了计费](https://cloud.google.com/billing/docs/how-to/modify-project).

1. [启用Vertex AI API、Compute Engine API和Service Networking API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com,compute_component,servicenetworking.googleapis.com).

1. 在下面的单元格中输入您的项目ID。然后运行该单元格，以确保Cloud SDK在此笔记本中的所有命令中使用正确的项目。

**注意**: Jupyter运行以`!`为前缀的行作为shell命令，并将以`$`为前缀的Python变量插入这些命令中。

设置您的项目ID

**如果您不知道您的项目ID**，您可以使用`gcloud`来获取您的项目ID。

In [None]:
import os

PROJECT_ID = GOOGLE_PROJECT_ID

# Get your Google Cloud project ID from gcloud
if not os.getenv("IS_TESTING"):
    shell_output = !gcloud config list --format 'value(core.project)' 2>/dev/null
    PROJECT_ID = shell_output[0]
    print("Project ID: ", PROJECT_ID)

否则，请在这里设置您的项目ID。

In [None]:
if PROJECT_ID == "" or PROJECT_ID is None:
    PROJECT_ID = GOOGLE_PROJECT_ID  # @param {type:"string"}

### 创建一个云存储存储桶

**无论您使用的是哪种笔记本环境，都需要执行以下步骤。**

在下面设置您的云存储存储桶的名称。它必须在所有
云存储存储桶中是唯一的。

您也可以更改 `REGION` 变量，该变量用于整个笔记本的操作。
确保在[选择 Vertex AI 服务可用的区域](https://cloud.google.com/vertex-ai/docs/general/locations#available_regions)。您
不可以使用多区域存储存储桶进行 Vertex AI 训练。

In [None]:
from datetime import datetime

TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

if BUCKET_NAME == "" or BUCKET_NAME is None or BUCKET_NAME == "gs://[your-bucket-name]":
    BUCKET_NAME = "gs://" + PROJECT_ID + "aip-" + TIMESTAMP

只有当您的存储桶尚不存在时，才运行以下单元格以创建您的Cloud Storage存储桶。

In [None]:
! gsutil mb -l $REGION $BUCKET_NAME

最后，通过检查其内容验证对您的云存储桶的访问。

In [None]:
# this will not return anything if the bucket is empty
! gsutil ls -al $BUCKET_NAME

导入库并定义常量

将Vertex AI（统一）客户端库导入到您的Python环境中。

In [None]:
import time

import grpc
from google.cloud import aiplatform_v1beta1
from google.protobuf import struct_pb2

In [None]:
ENDPOINT = "{}-aiplatform.googleapis.com".format(REGION)


AUTH_TOKEN = !gcloud auth print-access-token
PROJECT_NUMBER = !gcloud projects list --filter="PROJECT_ID:'{PROJECT_ID}'" --format='value(PROJECT_NUMBER)'
PROJECT_NUMBER = PROJECT_NUMBER[0]

PARENT = "projects/{}/locations/{}".format(PROJECT_ID, REGION)

print("ENDPOINT: {}".format(ENDPOINT))
print("PROJECT_ID: {}".format(PROJECT_ID))
print("REGION: {}".format(REGION))

!gcloud config set project {PROJECT_ID}
!gcloud config set ai_platform/region {REGION}

准备嵌入

这将获取通过Cohere生成的嵌入，并将其格式化为与Matching Engine配合使用。

将数据保存为JSONL格式。

In [None]:
# This converts the list of embeddings to the json format expected by Matching Engine

with open("cohere_embeddings.json", "w") as f:
    for i, e in enumerate(cohere_embeddings):
        f.write('{"id":"' + str(i) + '",')
        f.write('"embedding":' + str(e) + "}")
        f.write("\n")

将数据上传至GCS。

In [None]:
# NOTE: Everything in this GCS DIR will be DELETED before uploading the data.
# A CommandException is expected if no data is present

! gsutil rm -rf {BUCKET_NAME}/*

In [None]:
! gsutil cp cohere_embeddings.json {BUCKET_NAME}/cohere_embeddings.json

In [None]:
! gsutil ls {BUCKET_NAME}

创建索引##

创建ANN指数（供生产使用）

In [None]:
index_client = aiplatform_v1beta1.IndexServiceClient(
    client_options=dict(api_endpoint=ENDPOINT)
)

In [None]:
# Cohere small model is 1024 dimensions, update the dimension size if another model is being used
DIMENSIONS = 1024
DISPLAY_NAME = "cohere_embeddings"
DISPLAY_NAME_BRUTE_FORCE = DISPLAY_NAME + "_brute_force"

创建ANN索引配置：

请阅读文档以了解可以用于调整索引的各种配置参数。

In [None]:
treeAhConfig = struct_pb2.Struct(
    fields={
        "leafNodeEmbeddingCount": struct_pb2.Value(number_value=500),
        "leafNodesToSearchPercent": struct_pb2.Value(number_value=7),
    }
)

algorithmConfig = struct_pb2.Struct(
    fields={"treeAhConfig": struct_pb2.Value(struct_value=treeAhConfig)}
)

config = struct_pb2.Struct(
    fields={
        "dimensions": struct_pb2.Value(number_value=DIMENSIONS),
        "approximateNeighborsCount": struct_pb2.Value(number_value=150),
        "distanceMeasureType": struct_pb2.Value(string_value="DOT_PRODUCT_DISTANCE"),
        "algorithmConfig": struct_pb2.Value(struct_value=algorithmConfig),
    }
)

metadata = struct_pb2.Struct(
    fields={
        "config": struct_pb2.Value(struct_value=config),
        "contentsDeltaUri": struct_pb2.Value(string_value=BUCKET_NAME),
    }
)

ann_index = {
    "display_name": DISPLAY_NAME,
    "description": "Glove 100 ANN index",
    "metadata": struct_pb2.Value(struct_value=metadata),
}

In [None]:
ann_index = index_client.create_index(parent=PARENT, index=ann_index)

In [None]:
# Poll the operation until it's done successfullly.
# This will take some time (~30 minutes)

while True:
    if ann_index.done():
        break
    print("Poll the operation to create index...")
    time.sleep(60)

In [None]:
INDEX_RESOURCE_NAME = ann_index.result().name
INDEX_RESOURCE_NAME

创建暴力索引（用于基本真相）

暴力索引使用一种简单的暴力方法来查找最近的邻居。这种方法既不快速也不高效。因此，不建议在生产环境中使用暴力索引。它们用于查找“基本真相”邻居集，以便可以使用“基本真相”集来衡量为生产使用调整的索引的召回率。为了确保苹果和苹果的比较，暴力索引的 `distanceMeasureType`、`featureNormType` 和 `dimensions` 应与正在为生产使用调整的索引的匹配。创建暴力索引配置：

In [None]:
algorithmConfig = struct_pb2.Struct(
    fields={"bruteForceConfig": struct_pb2.Value(struct_value=struct_pb2.Struct())}
)

config = struct_pb2.Struct(
    fields={
        "dimensions": struct_pb2.Value(number_value=DIMENSIONS),
        "approximateNeighborsCount": struct_pb2.Value(number_value=150),
        "distanceMeasureType": struct_pb2.Value(string_value="DOT_PRODUCT_DISTANCE"),
        "algorithmConfig": struct_pb2.Value(struct_value=algorithmConfig),
    }
)

metadata = struct_pb2.Struct(
    fields={
        "config": struct_pb2.Value(struct_value=config),
        "contentsDeltaUri": struct_pb2.Value(string_value=BUCKET_NAME),
    }
)

brute_force_index = {
    "display_name": DISPLAY_NAME_BRUTE_FORCE,
    "description": "Glove 100 index (brute force)",
    "metadata": struct_pb2.Value(struct_value=metadata),
}

In [None]:
brute_force_index = index_client.create_index(parent=PARENT, index=brute_force_index)

In [None]:
# Poll the operation until it's done successfullly.
# This will take ~45 min.

while True:
    if brute_force_index.done():
        break
    print("Poll the operation to create index...")
    time.sleep(60)

In [None]:
INDEX_BRUTE_FORCE_RESOURCE_NAME = brute_force_index.result().name
INDEX_BRUTE_FORCE_RESOURCE_NAME

使用VPC网络创建一个IndexEndpoint

In [None]:
index_endpoint_client = aiplatform_v1beta1.IndexEndpointServiceClient(
    client_options=dict(api_endpoint=ENDPOINT)
)

In [None]:
VPC_NETWORK_NAME = "projects/{}/global/networks/{}".format(PROJECT_NUMBER, NETWORK_NAME)
VPC_NETWORK_NAME

In [None]:
index_endpoint = {
    "display_name": "index_endpoint_for_demo",
    "network": VPC_NETWORK_NAME,
}

In [None]:
r = index_endpoint_client.create_index_endpoint(
    parent=PARENT, index_endpoint=index_endpoint
)

In [None]:
r.result()

In [None]:
INDEX_ENDPOINT_NAME = r.result().name
INDEX_ENDPOINT_NAME

部署索引

### 部署ANN索引

In [None]:
DEPLOYED_INDEX_ID = "cohere_embedding_deployed"

In [None]:
deploy_ann_index = {
    "id": DEPLOYED_INDEX_ID,
    "display_name": DEPLOYED_INDEX_ID,
    "index": INDEX_RESOURCE_NAME,
}

In [None]:
r = index_endpoint_client.deploy_index(
    index_endpoint=INDEX_ENDPOINT_NAME, deployed_index=deploy_ann_index
)

In [None]:
# Poll the operation until it's done successfullly.
while True:
    if r.done():
        break
    print("Poll the operation to deploy index...")
    time.sleep(60)

In [None]:
r.result()

部署暴力指数

In [None]:
DEPLOYED_BRUTE_FORCE_INDEX_ID = "cohere_brute_force_deployed"

In [None]:
deploy_brute_force_index = {
    "id": DEPLOYED_BRUTE_FORCE_INDEX_ID,
    "display_name": DEPLOYED_BRUTE_FORCE_INDEX_ID,
    "index": INDEX_BRUTE_FORCE_RESOURCE_NAME,
}

In [None]:
r = index_endpoint_client.deploy_index(
    index_endpoint=INDEX_ENDPOINT_NAME, deployed_index=deploy_brute_force_index
)

In [None]:
# Poll the operation until it's done successfullly.

while True:
    if r.done():
        break
    print("Poll the operation to deploy index...")
    time.sleep(60)

In [None]:
r.result()

创建在线查询

在构建完索引之后，您可以通过虚拟机实例内的在线查询gRPC API（Match服务）对部署的索引进行查询，这些虚拟机实例位于相同的区域（例如，在本教程中是'us-central1'）。

客户端使用这个gRPC API的方式是按照以下步骤操作：

* 本地编写`match_service.proto`文件
* 编译协议缓冲区（请参见下文）
* 获取索引终端点
* 使用生成的存根代码进行调用，传递参数值

In [None]:
!git clone https://github.com/googleapis/googleapis.git

In [None]:
%%writefile match_service.proto

syntax = "proto3";

package google.cloud.aiplatform.container.v1beta1;

import "google/rpc/status.proto";

// MatchService is a Google managed service for efficient vector similarity
// search at scale.
service MatchService {
  // Returns the nearest neighbors for the query. If it is a sharded
  // deployment, calls the other shards and aggregates the responses.
  rpc Match(MatchRequest) returns (MatchResponse) {}

  // Returns the nearest neighbors for batch queries. If it is a sharded
  // deployment, calls the other shards and aggregates the responses.
  rpc BatchMatch(BatchMatchRequest) returns (BatchMatchResponse) {}
}

// Parameters for a match query.
message MatchRequest {
  // The ID of the DeploydIndex that will serve the request.
  // This MatchRequest is sent to a specific IndexEndpoint of the Control API,
  // as per the IndexEndpoint.network. That IndexEndpoint also has
  // IndexEndpoint.deployed_indexes, and each such index has an
  // DeployedIndex.id field.
  // The value of the field below must equal one of the DeployedIndex.id
  // fields of the IndexEndpoint that is being called for this request.
  string deployed_index_id = 1;

  // The embedding values.
  repeated float float_val = 2;

  // The number of nearest neighbors to be retrieved from database for
  // each query. If not set, will use the default from
  // the service configuration.
  int32 num_neighbors = 3;

  // The list of restricts.
  repeated Namespace restricts = 4;

  // Crowding is a constraint on a neighbor list produced by nearest neighbor
  // search requiring that no more than some value k' of the k neighbors
  // returned have the same value of crowding_attribute.
  // It's used for improving result diversity.
  // This field is the maximum number of matches with the same crowding tag.
  int32 per_crowding_attribute_num_neighbors = 5;

  // The number of neighbors to find via approximate search before
  // exact reordering is performed. If not set, the default value from scam
  // config is used; if set, this value must be > 0.
  int32 approx_num_neighbors = 6;

  // The fraction of the number of leaves to search, set at query time allows
  // user to tune search performance. This value increase result in both search
  // accuracy and latency increase. The value should be between 0.0 and 1.0. If
  // not set or set to 0.0, query uses the default value specified in
  // NearestNeighborSearchConfig.TreeAHConfig.leaf_nodes_to_search_percent.
  int32 leaf_nodes_to_search_percent_override = 7;
}

// Response of a match query.
message MatchResponse {
  message Neighbor {
    // The ids of the matches.
    string id = 1;

    // The distances of the matches.
    double distance = 2;
  }
  // All its neighbors.
  repeated Neighbor neighbor = 1;
}

// Parameters for a batch match query.
message BatchMatchRequest {
  // Batched requests against one index.
  message BatchMatchRequestPerIndex {
    // The ID of the DeploydIndex that will serve the request.
    string deployed_index_id = 1;

    // The requests against the index identified by the above deployed_index_id.
    repeated MatchRequest requests = 2;

    // Selects the optimal batch size to use for low-level batching. Queries
    // within each low level batch are executed sequentially while low level
    // batches are executed in parallel.
    // This field is optional, defaults to 0 if not set. A non-positive number
    // disables low level batching, i.e. all queries are executed sequentially.
    int32 low_level_batch_size = 3;
  }

  // The batch requests grouped by indexes.
  repeated BatchMatchRequestPerIndex requests = 1;
}

// Response of a batch match query.
message BatchMatchResponse {
  // Batched responses for one index.
  message BatchMatchResponsePerIndex {
    // The ID of the DeployedIndex that produced the responses.
    string deployed_index_id = 1;

    // The match responses produced by the index identified by the above
    // deployed_index_id. This field is set only when the query against that
    // index succeed.
    repeated MatchResponse responses = 2;

    // The status of response for the batch query identified by the above
    // deployed_index_id.
    google.rpc.Status status = 3;
  }

  // The batched responses grouped by indexes.
  repeated BatchMatchResponsePerIndex responses = 1;
}

// Namespace specifies the rules for determining the datapoints that are
// eligible for each matching query, overall query is an AND across namespaces.
message Namespace {
  // The string name of the namespace that this proto is specifying,
  // such as "color", "shape", "geo", or "tags".
  string name = 1;

  // The allowed tokens in the namespace.
  repeated string allow_tokens = 2;

  // The denied tokens in the namespace.
  // The denied tokens have exactly the same format as the token fields, but
  // represents a negation. When a token is denied, then matches will be
  // excluded whenever the other datapoint has that token.
  //
  // For example, if a query specifies {color: red, blue, !purple}, then that
  // query will match datapoints that are red or blue, but if those points are
  // also purple, then they will be excluded even if they are red/blue.
  repeated string deny_tokens = 3;
}

编译协议缓冲区，然后生成`match_service_pb2.py`和`match_service_pb2_grpc.py`。

In [None]:
! python -m grpc_tools.protoc -I=. --proto_path=googleapis --python_out=. --grpc_python_out=. match_service.proto

获取私有端点:

In [None]:
DEPLOYED_INDEX_SERVER_IP = (
    list(index_endpoint_client.list_index_endpoints(parent=PARENT))[0]
    .deployed_indexes[0]
    .private_endpoints.match_grpc_address
)
DEPLOYED_INDEX_SERVER_IP

测试您的查询：

In [None]:
import match_service_pb2
import match_service_pb2_grpc

channel = grpc.insecure_channel("{}:10000".format(DEPLOYED_INDEX_SERVER_IP))
stub = match_service_pb2_grpc.MatchServiceStub(channel)

### 用包含Cohere的查询测试搜索

In [None]:
raw_query = "Articles about the climate"

In [None]:
query = co.embed(model="small", texts=[raw_query]).embeddings[0]

In [None]:
# Test query
request = match_service_pb2.MatchRequest()
request.deployed_index_id = DEPLOYED_INDEX_ID
for val in query:
    request.float_val.append(val)

response = stub.Match(request)
response

清理

要清理该项目中使用的所有Google Cloud资源，您可以删除用于教程的[Google Cloud项目](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects)。您还可以通过运行以下代码手动删除创建的资源。

In [None]:
index_client.delete_index(name=INDEX_RESOURCE_NAME)

In [None]:
index_endpoint_client.delete_index_endpoint(name=INDEX_ENDPOINT_NAME)