In [None]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI模型花园 - JAX Vision Transformer

<table align="left">
  <td>
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_jax_vision_transformer.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Colab logo"> 在Colab中运行
    </a>
  </td>

  <td>
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_jax_vision_transformer.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      在GitHub上查看
    </a>
  </td>
  <td>
    <a href="https://console.cloud.google.com/vertex-ai/notebooks/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/main/notebooks/community/model_garden/model_garden_jax_vision_transformer.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo">
在Vertex AI工作台中打开
    </a>
  </td>
</table>

注意：此笔记本已在以下环境中进行测试：

- Python 版本 = 3.9

##概述

本笔记本演示了如何在GPU上微调[JAX ViT-B16模型](https://github.com/google-research/vision_transformer#available-vit-models)以进行图像分类任务，并在Vertex AI上部署它们进行在线预测。

了解更多关于[Vertex AI中生成式人工智能支持](https://cloud.google.com/blog/products/ai-machine-learning/vertex-ai-model-garden-and-generative-ai-studio)。

### 目标

在本教程中，您将学习如何使用 Vertex AI 预训练的 JAX Vision Transformer 模型进行微调、部署和预测。

本教程使用以下谷歌云 ML 服务和资源:

- Vertex AI Model Garden
- Vertex AI 训练
- Vertex AI 模型注册表
- Vertex AI 在线预测

执行的步骤包括:

- 对基于 JAX Vision Transformer 的模型进行微调。
- 将模型上传至 [模型注册表](https://cloud.google.com/vertex-ai/docs/model-registry/introduction)。
- 在 [端点](https://cloud.google.com/vertex-ai/docs/predictions/using-private-endpoints) 上部署模型。
- 运行用于图像分类的在线预测。

数据集

本笔记本使用 [tf_flowers 数据集](https://www.tensorflow.org/datasets/catalog/tf_flowers)，并有一个部分展示了如何下载和准备它。您也可以遵循类似的过程来使用您自己的自定义数据集。

费用

本教程使用 Google Cloud 中的付费组件：

* Vertex AI
* Cloud Storage

了解[Vertex AI 价格](https://cloud.google.com/vertex-ai/pricing)和[Cloud Storage 价格](https://cloud.google.com/storage/pricing)，并使用[价格计算器](https://cloud.google.com/products/calculator/)根据您预期的使用量生成成本估算。

## 安装

安装以下必要的软件包以执行此笔记本。

In [None]:
# Install the packages.
! pip3 install --upgrade google-cloud-aiplatform

只有协作

In [None]:
# Automatically restart kernel after installs so that your environment can access the new packages.
# import IPython

# app = IPython.Application.instance()
# app.kernel.do_shutdown(True)

## 在您开始之前

### 设置您的 Google Cloud 项目

**无论您使用什么笔记本环境，下面的步骤都是必需的。**

1. [选择或创建一个 Google Cloud 项目](https://console.cloud.google.com/cloud-resource-manager)。当您第一次创建一个帐户时，您将获得 $300 的免费信用用于计算/存储成本。

1. [确保为您的项目启用了计费](https://cloud.google.com/billing/docs/how-to/modify-project)。

1. [启用 Vertex AI API 和 Compute Engine API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com,compute_component)。

1. 如果您在本地运行这个笔记本，您需要安装[Cloud SDK](https://cloud.google.com/sdk)。

设置您的项目ID

**如果您不知道您的项目ID**，请尝试以下操作：
* 运行`gcloud config list`。
* 运行`gcloud projects list`。
* 查看支持页面：[查找项目ID](https://support.google.com/googleapi/answer/7014113)

In [None]:
PROJECT_ID = "[your-project-id]"  # @param {type:"string"}

# Set the project id
! gcloud config set project {PROJECT_ID}

#### 区域

您还可以更改 Vertex AI 使用的 `REGION` 变量。了解有关 [Vertex AI 区域](https://cloud.google.com/vertex-ai/docs/general/locations) 的更多信息。

In [None]:
REGION = "us-central1"  # @param {type: "string"}

### 验证您的Google Cloud账户

根据您的Jupyter环境，您可能需要手动进行身份验证。请按下面的相关说明进行操作。

1. 顶点 AI 工作台
* 无需做任何操作，因为您已经通过验证。

本地JupyterLab实例，请取消注释并运行:

In [None]:
# ! gcloud auth login

3. 合作、取消注释并运行:

In [None]:
# from google.colab import auth
# auth.authenticate_user()

4. 服务账号或其他
* 查看如何为您的服务账号授予Cloud Storage权限，请访问https://cloud.google.com/storage/docs/gsutil/commands/iam#ch-examples。

创建一个云存储桶

创建一个存储桶，用于存储中间产物，例如数据集。

In [None]:
BUCKET_URI = f"gs://your-bucket-name-{PROJECT_ID}-unique"  # @param {type:"string"}

只有在您的存储桶尚不存在时才能运行以下单元格来创建您的云存储桶。

In [None]:
! gsutil mb -l {REGION} -p {PROJECT_ID} {BUCKET_URI}

### 导入库

In [None]:
import base64
import glob
import os
import random
import shutil
from datetime import datetime
from io import BytesIO

import numpy as np
from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
from PIL import Image

### 初始化用于 Python 的 Vertex AI SDK

为您的项目初始化用于 Python 的 Vertex AI SDK。

In [None]:
staging_bucket = os.path.join(BUCKET_URI, "jax_vit_staging")
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=staging_bucket)

定义常量

In [None]:
# The pre-built training docker image.
TRAIN_DOCKER_URI = "us-docker.pkg.dev/vertex-ai-restricted/vertex-vision-model-garden-dockers/jax-vit-train-gpu"
# The pre-built TF SavedModel conversion docker image.
MODEL_CONVERSION_DOCKER_URI = "us-docker.pkg.dev/vertex-ai-restricted/vertex-vision-model-garden-dockers/jax-vit-model-conversion"
# The pre-built prediction docker image.
OPTIMIZED_TF_RUNTIME_IMAGE_URI = (
    "us-docker.pkg.dev/vertex-ai-restricted/prediction/tf_opt-gpu.nightly:latest"
)

### 定义常用函数

本节定义以下功能：

- 将 [tf_flowers 数据集](https://www.tensorflow.org/datasets/catalog/tf_flowers) 的图像分割为 `train` 和 `test` 文件夹。
- 将类似 `gs://bucket-name` 的 Cloud Storage 路径转换为 GCSFuse 路径格式，例如 `/gcsfuse/bucket-name`。
- 将本地图像文件编码为字符串，用于预测输入。

In [None]:
def split(base_dir, test_ratio=0.1):
    """Splits images and moves them to train and test folders."""
    paths = glob.glob(f"{base_dir}/*/*.jpg")
    random.shuffle(paths)
    counts = dict(test=0, train=0)
    for i, path in enumerate(paths):
        split = "test" if i < test_ratio * len(paths) else "train"
        *_, class_name, basename = path.split("/")
        dst = f"{base_dir}/{split}/{class_name}/{basename}"
        if not os.path.isdir(os.path.dirname(dst)):
            os.makedirs(os.path.dirname(dst))
        shutil.move(path, dst)
        counts[split] += 1
    print(f'Moved {counts["train"]:,} train and {counts["test"]:,} test images.')


def gcs_fuse_path(path: str) -> str:
    """Try to convert path to gcsfuse path if it starts with gs:// else do not modify it."""
    path = path.strip()
    if path.startswith("gs://"):
        return "/gcs/" + path[5:]
    return path


def load_bytes_from_local_image(local_image_path, new_width=-1):
    """Returns encoded image string for prediction input."""
    image = Image.open(local_image_path)
    if new_width <= 0:
        new_image = image
    else:
        width, height = image.size
        print("original input image size: ", width, " , ", height)
        new_height = int(height * new_width / width)
        print("new input image size: ", new_width, " , ", new_height)
        new_image = image.resize((new_width, new_height))
    buffered = BytesIO()
    new_image.save(buffered, format="JPEG")
    encoded_string = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return encoded_string

准备数据集

如果您没有使用[TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/overview#all_datasets)，那么您需要准备您的数据集并将其存储在云存储中。以下示例展示了如何为[tf_flowers数据集](https://www.tensorflow.org/datasets/catalog/tf_flowers)执行此操作。如果使用了TensorFlow Datasets，则将数据集名称（如`tf_flowers`）传递给`--config.dataset`标志，并跳过此部分。

In [None]:
local_flower_data_directory = "./flower_photos"  # @param {type:"string"}
FLOWER_DATA_GCS_PATH = os.path.join(BUCKET_URI, "flower_dataset")
# The flower dataset has 5 classes.
NUM_CLASSES = 5
# NOTE: For custom dataset, the training code picks the class names
# from the folder structure and then sorts them to create a mapping
# from class-index to class-name. This is why the mapping below
# looks different from default `tf_flowers` documentation.
LABEL_IDX_TO_STR = {
    0: "daisy",
    1: "dandelion",
    2: "roses",
    3: "sunflowers",
    4: "tulips",
}

In [None]:
# Download flower data to a local directory.
! rm -rf $local_flower_data_directory;
! (cd "./" && curl https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz | tar xz)

In [None]:
# Since the default file format of above "tf_flowers" dataset is
# flower_photos/{class_name}/{filename}.jpg
# we first need to split it into a "train" (90%) and a "test" (10%) set:
# flower_photos/train/{class_name}/{filename}.jpg
# flower_photos/test/{class_name}/{filename}.jpg

split(local_flower_data_directory)

In [None]:
# Move Flower data from local directory to Cloud Storage.
# This step takes around 2 mins to finish.
! gsutil -m cp -R $local_flower_data_directory/train/* $FLOWER_DATA_GCS_PATH/train/
! gsutil -m cp -R $local_flower_data_directory/test/* $FLOWER_DATA_GCS_PATH/test/

使用Vertex AI SDK创建和运行使用model-garden JAX视觉转换器训练docker的训练作业。 训练使用一个V100 GPU，在训练作业开始后大约运行10分钟。

In [None]:
# Set up training docker arguments.

TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
JOB_NAME = "jax_vision_transformer" + TIMESTAMP

finetuning_workdir = os.path.join(BUCKET_URI, JOB_NAME)
pre_trained_dir = "gs://vit_models/imagenet21k"
docker_args_list = [
    "--config",
    "vit_jax/configs/vit.py:b16",
    "--config.dataset",
    f"{gcs_fuse_path(FLOWER_DATA_GCS_PATH)}",
    "--config.pp.train",
    "train",
    "--config.pp.test",
    "test",
    "--config.pretrained_dir",
    f"{gcs_fuse_path(pre_trained_dir)}",
    "--config.batch",
    "128",
    "--config.batch_eval",
    "128",
    "--config.base_lr",
    "0.01",
    "--config.shuffle_buffer",
    "1000",
    "--config.total_steps",
    "100",
    "--config.warmup_steps",
    "10",
    "--config.pp.crop",
    "224",
    "--workdir",
    f"{gcs_fuse_path(finetuning_workdir)}",
]
print(docker_args_list)

In [None]:
# Create and run the training job.
# Click on the generated link in the output under "View backing custom job:" to see your run in the Cloud Console.
NUM_GPU = 1
container_uri = TRAIN_DOCKER_URI
job = aiplatform.CustomContainerTrainingJob(
    display_name=JOB_NAME,
    container_uri=container_uri,
)
model = job.run(
    args=docker_args_list,
    base_output_dir=f"{finetuning_workdir}",
    replica_count=1,
    machine_type="n1-standard-4",
    accelerator_type="NVIDIA_TESLA_V100",
    accelerator_count=NUM_GPU,
)

将之前微调的JAX模型转换为TF SavedModel，用于在线预测。

In [None]:
# Set up model conversion docker arguments.
# Note: Many of the arguments below are similar to the training job
# such as the model name and train and test data related parameters.

jax_checkpoint_dir = finetuning_workdir

TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
JOB_NAME = "jax_model_conversion" + TIMESTAMP
saved_model_dir = os.path.join(BUCKET_URI, "jax2tf_" + TIMESTAMP)

docker_args_list = [
    "--config",
    "vit_jax/configs/vit.py:b16",
    "--num_classes",
    f"{NUM_CLASSES}",
    "--saved_model_dir",
    f"{saved_model_dir}",
    "--jax_checkpoint_dir",
    f"{jax_checkpoint_dir}",
    "--config.pretrained_dir",
    f"{pre_trained_dir}",
    "--config.dataset",
    f"{gcs_fuse_path(FLOWER_DATA_GCS_PATH)}",
    "--config.pp.train",
    "train",
    "--config.pp.test",
    "test",
    "--config.pp.crop",
    "224",
]
print(docker_args_list)

In [None]:
# Create and run the model conversion job.
# Click on the generated link in the output under "View backing custom job:" to see your run in the Cloud Console.
container_uri = MODEL_CONVERSION_DOCKER_URI
job = aiplatform.CustomContainerTrainingJob(
    display_name=JOB_NAME,
    container_uri=container_uri,
)
model_conversion_workdir = os.path.join(BUCKET_URI, JOB_NAME)
model = job.run(
    args=docker_args_list,
    base_output_dir=f"{model_conversion_workdir}",
    replica_count=1,
    machine_type="n1-standard-4",
)

## 运行在线预测

使用转换后的TF SavedModel 运行在线预测。

上传 TF SavedModel 并部署到一个终端点进行预测。这一步大约需要15分钟完成。

In [None]:
serving_env = {
    "MODEL_ID": "ViT-JAX-",
    "DEPLOY_SOURCE": "notebook",
}

jax_vit_model = aiplatform.Model.upload(
    display_name="jax_vit",
    artifact_uri=saved_model_dir,
    serving_container_image_uri=OPTIMIZED_TF_RUNTIME_IMAGE_URI,
    serving_container_args=[],
    location=REGION,
    serving_container_environment_variables=serving_env,
)

jax_vit_endpoint = jax_vit_model.deploy(
    deployed_model_display_name="jax_vit_deployed",
    traffic_split={"0": 100},
    machine_type="n1-standard-4",
    accelerator_type="NVIDIA_TESLA_V100",
    accelerator_count=1,
    min_replica_count=1,
    max_replica_count=1,
)

加载本地测试图像文件，将其编码为字符串，发送到端点进行预测，然后从预测的类别概率生成最终的类别标签。

In [None]:
test_directory = os.path.join(local_flower_data_directory, "test/tulips")
local_test_image_path = os.path.join(test_directory, os.listdir(test_directory)[0])
print(local_test_image_path)
instances_list = [
    {
        "bytes_inputs": {
            "b64": load_bytes_from_local_image(local_test_image_path, new_width=240)
        }
    }
]
instances = [json_format.ParseDict(s, Value()) for s in instances_list]
results = jax_vit_endpoint.predict(instances=instances)
logits = results.predictions[0]
predicted_label = LABEL_IDX_TO_STR[int(np.argmax(logits))]
print("predicted_label: ", predicted_label)

清理

要清理此项目中使用的所有Google Cloud资源，您可以删除用于该教程的[Google Cloud 项目](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects)。

否则，您可以删除在此教程中创建的各个资源。

In [None]:
# Delete endpoint resource.
jax_vit_endpoint.delete(force=True)

# Delete model resource.
jax_vit_model.delete()

# Delete Cloud Storage objects that were created.
delete_bucket = True
if delete_bucket or os.getenv("IS_TESTING"):
    ! gsutil -m rm -r $BUCKET_URI