# Train and Deploy Merlin models with Vertex AI

In [1]:
# !pip install kfp
# !pip install google-cloud-pipeline-components --user

In [2]:
# ! python3 -c "import kfp; print('KFP SDK version: {}'.format(kfp.__version__))"
# ! python3 -c "import google_cloud_pipeline_components; print('google_cloud_pipeline_components version: {}'.format(google_cloud_pipeline_components.__version__))"
# ! python3 -c "import google.cloud.aiplatform; print('aiplatform SDK version: {}'.format(google.cloud.aiplatform.__version__))"

In [3]:
GCP_PROJECTS = !gcloud config get-value project
PROJECT_ID = GCP_PROJECTS[0]
PROJECT_NUM = !gcloud projects list --filter="$PROJECT_ID" --format="value(PROJECT_NUMBER)"
PROJECT_NUM = PROJECT_NUM[0]
LOCATION = 'us-central1'

print(f"PROJECT_ID: {PROJECT_ID}")
print(f"PROJECT_NUM: {PROJECT_NUM}")
print(f"LOCATION: {LOCATION}")

PROJECT_ID: hybrid-vertex
PROJECT_NUM: 934903580331
LOCATION: us-central1


In [4]:
VERTEX_SA = '934903580331-compute@developer.gserviceaccount.com'

In [5]:
import os
import json
from datetime import datetime
from time import time
import pandas as pd
# disable INFO and DEBUG logging everywhere
import logging
import time
from pprint import pprint

logging.disable(logging.WARNING)

from google.cloud import aiplatform as vertex_ai
from google.cloud import storage

# Pipelines
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from google_cloud_pipeline_components import aiplatform as gcc_aip
from google_cloud_pipeline_components.types import artifact_types

# Kubeflow SDK
# TODO: fix these
from kfp.v2 import dsl
import kfp
import kfp.v2.dsl
from kfp.v2.google import client as pipelines_client
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component)

storage_client = storage.Client(project=PROJECT_ID)

vertex_ai.init(project=PROJECT_ID,location=LOCATION)

In [6]:
REPO_DOCKER_PATH_PREFIX = 'src'
PIPELINES_SUB_DIR = 'train_pipes'

In [7]:
! rm -rf {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}
! mkdir {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}

# Pipeline Components

In [8]:
os.getcwd()

'/home/jupyter/merlin-on-vertex'

## Build Custom Image

In [9]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/build_custom_image.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)
@kfp.v2.dsl.component(
    base_image="gcr.io/google.com/cloudsdktool/cloud-sdk:latest",
    packages_to_install=[
        "google-cloud-build"
    ],
)
def build_custom_image(
    project: str,
    artifact_gcs_path: str,
    docker_name: str,
    app_dir_name: str,
    custom_image_uri: str,
) -> NamedTuple('Outputs', [
    ('custom_image_uri', str),
]):
    # TODO: make output Artifact for image_uri
    """
    custom pipeline component to build custom image using
    Cloud Build, the training/serving application code, and dependencies
    defined in the Dockerfile
    """
    
    import logging
    import os

    from google.cloud.devtools import cloudbuild_v1 as cloudbuild
    from google.protobuf.duration_pb2 import Duration

    # initialize client for cloud build
    logging.getLogger().setLevel(logging.INFO)
    build_client = cloudbuild.services.cloud_build.CloudBuildClient()
    
    # parse step inputs to get path to Dockerfile and training application code
    _gcs_dockerfile_path = os.path.join(artifact_gcs_path, f"{docker_name}") # Dockerfile.XXXXX
    _gcs_script_dir_path = os.path.join(artifact_gcs_path, f"{app_dir_name}/") # "trainer/"
    
    logging.info(f"_gcs_dockerfile_path: {_gcs_dockerfile_path}")
    logging.info(f"_gcs_script_dir_path: {_gcs_script_dir_path}")
    
    # define build steps to pull the training code and Dockerfile
    # and build/push the custom training container image
    build = cloudbuild.Build()
    build.steps = [
        {
            "name": "gcr.io/cloud-builders/gsutil",
            "args": ["cp", "-r", _gcs_script_dir_path, "."],
        },
        {
            "name": "gcr.io/cloud-builders/gsutil",
            "args": ["cp", _gcs_dockerfile_path, "Dockerfile"],
        },
        # enabling Kaniko cache in a Docker build that caches intermediate
        # layers and pushes image automatically to Container Registry
        # https://cloud.google.com/build/docs/kaniko-cache
        # {
        #     "name": "gcr.io/kaniko-project/executor:latest",
        #     # "name": "gcr.io/kaniko-project/executor:v1.8.0",        # TODO; downgraded to avoid error in build
        #     # "args": [f"--destination={training_image_uri}", "--cache=true"],
        #     "args": [f"--destination={training_image_uri}", "--cache=false"],
        # },
        {
            "name": "gcr.io/cloud-builders/docker",
            "args": ['build','-t', f'{custom_image_uri}', '.'],
        },
        {
            "name": "gcr.io/cloud-builders/docker",
            "args": ['push', f'{custom_image_uri}'], 
        },
    ]
    # override default timeout of 10min
    timeout = Duration()
    timeout.seconds = 7200
    build.timeout = timeout

    # create build
    operation = build_client.create_build(project_id=project, build=build)
    logging.info("IN PROGRESS:")
    logging.info(operation.metadata)

    # get build status
    result = operation.result()
    logging.info("RESULT:", result.status)

    # return step outputs
    return (
        custom_image_uri,
    )

Writing src/train_pipes/build_custom_image.py


## Train Job

In [10]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/train_merlin.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.18.1',
        # 'google-cloud-storage',
    ],
)
def train_merlin(
    project: str,
    location: str,
    version: str,
    train_image_uri: str,     # TODO: Artifact
    tb_resource: str,
    batch_size: int, 
    train_epochs: int,
    train_dir: str,
    valid_dir: str,
    workflow_dir: str,
    experiment_name: str,
    experiment_run: str,
    service_account: str,
) -> NamedTuple('Outputs', [
    ('merlin_model_gcs_dir', str),
    ('query_tower_gcs_dir', str),
    ('candidate_tower_gcs_uri', str),
    ('candidate_embeddings_gcs_uri', str),
]):
    
    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
    # ====================================================
    # Helper function for workerpool specs
    # ====================================================
    def prepare_worker_pool_specs(
        image_uri,
        # args,
        cmd,
        replica_count=1,
        machine_type="n1-standard-16",
        accelerator_count=1,
        accelerator_type="ACCELERATOR_TYPE_UNSPECIFIED",
        reduction_server_count=0,
        reduction_server_machine_type="n1-highcpu-16",
        reduction_server_image_uri="us-docker.pkg.dev/vertex-ai-restricted/training/reductionserver:latest",
    ):

        if accelerator_count > 0:
            machine_spec = {
                "machine_type": machine_type,
                "accelerator_type": accelerator_type,
                "accelerator_count": accelerator_count,
            }
        else:
            machine_spec = {"machine_type": machine_type}

        container_spec = {
            "image_uri": image_uri,
            # "args": args,
            "command": cmd,
        }

        chief_spec = {
            "replica_count": 1,
            "machine_spec": machine_spec,
            "container_spec": container_spec,
        }

        worker_pool_specs = [chief_spec]
        if replica_count > 1:
            workers_spec = {
                "replica_count": replica_count - 1,
                "machine_spec": machine_spec,
                "container_spec": container_spec,
            }
            worker_pool_specs.append(workers_spec)
        if reduction_server_count > 1:
            workers_spec = {
                "replica_count": reduction_server_count,
                "machine_spec": {
                    "machine_type": reduction_server_machine_type,
                },
                "container_spec": {"image_uri": reduction_server_image_uri},
            }
            worker_pool_specs.append(workers_spec)

        return worker_pool_specs
    
    # ====================================================
    # Define device strategy
    # ====================================================
    # TODO: parameterize
    
    WORKER_MACHINE_TYPE = 'a2-highgpu-1g'
    REPLICA_COUNT = 1
    ACCELERATOR_TYPE = 'NVIDIA_TESLA_A100'
    PER_MACHINE_ACCELERATOR_COUNT = 1
    REDUCTION_SERVER_COUNT = 0                                                      
    REDUCTION_SERVER_MACHINE_TYPE = "n1-highcpu-16"
    DISTRIBUTE_STRATEGY = 'single'
    
    # ====================================================
    # # DEFINE ARGS
    # ====================================================
    # TODO: parameterize
    
    BATCH_SIZE = 4096*4      # TODO: `batch_size * 4 ? jw
    LEARNING_RATE = 0.001
    LAYERS = "[512, 256, 128]"
    
    OUTPUT_BUCKET = 'jt-merlin-scaling'
    
    EXPERIMENT_RUN = f'{experiment_run}-{TIMESTAMP}'
    
    WORKER_CMD = [
        'sh',
        '-euc',
        f'''pip freeze && python -m trainer.train_task --tb_name={tb_resource} --per_gpu_batch_size={batch_size} \
        --train_output_bucket={OUTPUT_BUCKET} --train_dir={train_dir} --valid_dir={valid_dir} --workflow_dir={workflow_dir} \
        --num_epochs={train_epochs} --learning_rate={LEARNING_RATE} --distribute={DISTRIBUTE_STRATEGY} \
        --experiment_name={experiment_name} --experiment_run={EXPERIMENT_RUN} --project={project} --location={location}'''
    ]
    
    WORKER_POOL_SPECS = prepare_worker_pool_specs(
        image_uri=train_image_uri,
        # args=WORKER_ARGS,
        cmd=WORKER_CMD,
        replica_count=REPLICA_COUNT,
        machine_type=WORKER_MACHINE_TYPE,
        accelerator_count=PER_MACHINE_ACCELERATOR_COUNT,
        accelerator_type=ACCELERATOR_TYPE,
        reduction_server_count=REDUCTION_SERVER_COUNT,
        reduction_server_machine_type=REDUCTION_SERVER_MACHINE_TYPE,
    )
    # ==============================================================================
    # Submit Train Job 
    # ==============================================================================
    STAGING_BUCKET = f'gs://{OUTPUT_BUCKET}/{experiment_name}'
    JOB_NAME = f'train-merlin-retrieval-{version}'
    gpu_type = ACCELERATOR_TYPE.lower() # lowercase for labels

    job = vertex_ai.CustomJob(
        display_name=JOB_NAME,
        worker_pool_specs=WORKER_POOL_SPECS,
        staging_bucket=STAGING_BUCKET,
        labels={
            'gpu': f'{gpu_type}',
            'gpu_per_replica' : f'{PER_MACHINE_ACCELERATOR_COUNT}',
            'replica_cnt' : f'{REPLICA_COUNT}',
        }
    )
    
    job.run(
        sync=True, 
        service_account=service_account,
        # tensorboard=EXPERIMENT_TB,
        restart_job_on_worker_restart=False,
        enable_web_access=True,
    )
    
    # uris set during train script
    WORKING_DIR_GCS_URI = f'gs://{OUTPUT_BUCKET}/{experiment_name}/{EXPERIMENT_RUN}'
    MODEL_DIR = f"{WORKING_DIR_GCS_URI}/model-dir"
    QUERY_TOWER_PATH = f"{MODEL_DIR}/query-tower"
    CANDIDATE_TOWER_PATH = f"{MODEL_DIR}/candidate-tower"
    EMBEDDINGS_PATH = f"{MODEL_DIR}/candidate-embeddings"
    
    logging.info(f'WORKING_DIR_GCS_URI: {WORKING_DIR_GCS_URI}')
    logging.info(f'MODEL_DIR: {MODEL_DIR}')
    logging.info(f'QUERY_TOWER_PATH: {QUERY_TOWER_PATH}')
    logging.info(f'CANDIDATE_TOWER_PATH: {CANDIDATE_TOWER_PATH}')
    logging.info(f'EMBEDDINGS_PATH: {EMBEDDINGS_PATH}')
    
    return (
        f'{MODEL_DIR}',
        f'{QUERY_TOWER_PATH}',
        f'{CANDIDATE_TOWER_PATH}',
        f'{EMBEDDINGS_PATH}',
    )

Writing src/train_pipes/train_merlin.py


## Create ANN Index

In [11]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/create_ann_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.18.1',
        # 'google-cloud-storage',
    ],
)
def create_ann_index(
    project: str,
    location: str,
    version: str, 
    vpc_network_name: str,
    emb_index_gcs_uri: str,
    dimensions: int,
    ann_index_display_name: str,
    approximate_neighbors_count: int,
    distance_measure_type: str,
    leaf_node_embedding_count: int,
    leaf_nodes_to_search_percent: int, 
    ann_index_description: str,
    # ann_index_labels: Dict, 
) -> NamedTuple('Outputs', [
    ('ann_index_resource_uri', str),
    ('ann_index', Artifact),
]):
    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
    
    ENDPOINT = "{}-aiplatform.googleapis.com".format(location)
    NETWORK_NAME = vpc_network_name
    INDEX_DIR_GCS = emb_index_gcs_uri
    PARENT = "projects/{}/locations/{}".format(project, location)

    logging.info(f"ENDPOINT: {ENDPOINT}")
    logging.info(f"project: {project}")
    logging.info(f"location: {location}")
    logging.info(f"INDEX_DIR_GCS: {INDEX_DIR_GCS}")
    
    # ==============================================================================
    # Create Index 
    # ==============================================================================

    start = time.time()
        
    tree_ah_index = vertex_ai.MatchingEngineIndex.create_tree_ah_index(
        display_name=f'{ann_index_display_name}-{TIMESTAMP}',
        contents_delta_uri=f'{emb_index_gcs_uri}/', # emb_index_gcs_uri,
        dimensions=dimensions,
        approximate_neighbors_count=approximate_neighbors_count,
        distance_measure_type=distance_measure_type,
        leaf_node_embedding_count=leaf_node_embedding_count,
        leaf_nodes_to_search_percent=leaf_nodes_to_search_percent,
        description=ann_index_description,
        # labels=ann_index_labels,
        # sync=True,
    )

    end = time.time()
    elapsed_time = round((end - start), 2)
    logging.info(f'Elapsed time creating index: {elapsed_time} seconds\n')
    
    ann_index_resource_uri = tree_ah_index.resource_name
    logging.info("ann_index_resource_uri:", ann_index_resource_uri) 

    return (
      f'{ann_index_resource_uri}',
      tree_ah_index,
    )

Writing src/train_pipes/create_ann_index.py


## Create brute force index

In [12]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/create_brute_force_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.18.1',
        # 'google-cloud-storage',
    ],
)
def create_brute_force_index(
    project: str,
    location: str,
    version: str,
    vpc_network_name: str,
    emb_index_gcs_uri: str,
    dimensions: int,
    brute_force_index_display_name: str,
    approximate_neighbors_count: int,
    distance_measure_type: str,
    brute_force_index_description: str,
    # brute_force_index_labels: Dict,
) -> NamedTuple('Outputs', [
    ('brute_force_index_resource_uri', str),
    ('brute_force_index', Artifact),
]):

    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
    
    ENDPOINT = "{}-aiplatform.googleapis.com".format(location)
    NETWORK_NAME = vpc_network_name
    INDEX_DIR_GCS = emb_index_gcs_uri
    PARENT = "projects/{}/locations/{}".format(project, location)

    logging.info("ENDPOINT: {}".format(ENDPOINT))
    logging.info("PROJECT_ID: {}".format(project))
    logging.info("REGION: {}".format(location))
    
    # ==============================================================================
    # Create Index 
    # ==============================================================================

    start = time.time()
    
    brute_force_index = vertex_ai.MatchingEngineIndex.create_brute_force_index(
        display_name=f'{brute_force_index_display_name}-{TIMESTAMP}',
        contents_delta_uri=f'{emb_index_gcs_uri}/', # emb_index_gcs_uri,
        dimensions=dimensions,
        # approximate_neighbors_count=approximate_neighbors_count,
        distance_measure_type=distance_measure_type,
        description=brute_force_index_description,
        # labels=brute_force_index_labels,
        # sync=True,
    )
    brute_force_index_resource_uri = brute_force_index.resource_name
    print("brute_force_index_resource_uri:",brute_force_index_resource_uri) 

    return (
      f'{brute_force_index_resource_uri}',
      brute_force_index,
    )

Writing src/train_pipes/create_brute_force_index.py


## Create ANN index endpoint

In [13]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/create_ann_index_endpoint_vpc.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.18.1',
        # 'google-cloud-storage',
    ],
)
def create_ann_index_endpoint_vpc(
    ann_index_artifact: Input[Artifact],
    project: str,
    project_number: str,
    location: str,
    version: str,
    vpc_network_name: str,
    ann_index_endpoint_display_name: str,
    ann_index_endpoint_description: str,
) -> NamedTuple('Outputs', [
    ('vpc_network_resource_uri', str),
    ('ann_index_endpoint_resource_uri', str),
    ('ann_index_endpoint', Artifact),
    ('ann_index_endpoint_display_name', str),
]):

    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

    vpc_network_resource_uri = f'projects/{project_number}/global/networks/{vpc_network_name}'
    logging.info(f"vpc_network_resource_uri: {vpc_network_resource_uri}")

    ann_index_endpoint = vertex_ai.MatchingEngineIndexEndpoint.create(
        display_name=f'{ann_index_endpoint_display_name}',
        description=ann_index_endpoint_description,
        network=vpc_network_resource_uri,
    )
    ann_index_endpoint_resource_uri = ann_index_endpoint.resource_name
    logging.info(f"ann_index_endpoint_resource_uri: {ann_index_endpoint_resource_uri}")

    return (
        f'{vpc_network_resource_uri}',
        f'{ann_index_endpoint_resource_uri}',
        ann_index_endpoint,
        f'{ann_index_endpoint_display_name}'
    )

Writing src/train_pipes/create_ann_index_endpoint_vpc.py


## Create brute force index endpoint

In [14]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/create_brute_index_endpoint_vpc.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.18.1',
        # 'google-cloud-storage',
    ],
)
def create_brute_index_endpoint_vpc(
    bf_index_artifact: Input[Artifact],
    project: str,
    project_number: str,
    location: str,
    version: str,
    vpc_network_name: str,
    brute_index_endpoint_display_name: str,
    brute_index_endpoint_description: str,
) -> NamedTuple('Outputs', [
    ('vpc_network_resource_uri', str),
    ('brute_index_endpoint_resource_uri', str),
    ('brute_index_endpoint', Artifact),
    ('brute_index_endpoint_display_name', str),
]):

    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

    vpc_network_resource_uri = f'projects/{project_number}/global/networks/{vpc_network_name}'
    logging.info(f"vpc_network_resource_uri: {vpc_network_resource_uri}")

    brute_index_endpoint = vertex_ai.MatchingEngineIndexEndpoint.create(
        display_name=f'{brute_index_endpoint_display_name}',
        description=brute_index_endpoint_description,
        network=vpc_network_resource_uri,
    )
    brute_index_endpoint_resource_uri = brute_index_endpoint.resource_name
    logging.info(f"brute_index_endpoint_resource_uri: {brute_index_endpoint_resource_uri}")

    return (
      f'{vpc_network_resource_uri}',
      f'{brute_index_endpoint_resource_uri}',
      brute_index_endpoint,
      f'{brute_index_endpoint_display_name}'
    )

Writing src/train_pipes/create_brute_index_endpoint_vpc.py


## Deploy ANN Index

In [15]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/deploy_ann_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.18.1'
    ]
)
def deploy_ann_index(
    project: str,
    location: str,
    version: str,
    deployed_ann_index_name: str,
    ann_index_resource_uri: str,
    index_endpoint_resource_uri: str,
) -> NamedTuple('Outputs', [
    ('index_endpoint_resource_uri', str),
    ('ann_index_resource_uri', str),
    ('deployed_ann_index_name', str),
    ('deployed_ann_index', Artifact),
]):
  
    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
    VERSION = version
    
    ann_index = vertex_ai.MatchingEngineIndex(
      index_name=ann_index_resource_uri
    )
    ann_index_resource_uri = ann_index.resource_name

    index_endpoint = vertex_ai.MatchingEngineIndexEndpoint(
      index_endpoint_resource_uri
    )

    index_endpoint = index_endpoint.deploy_index(
      index=ann_index, 
      deployed_index_id=f'{deployed_ann_index_name}' #-{TIMESTAMP}'
    )

    logging.info(f"index_endpoint.deployed_indexes: {index_endpoint.deployed_indexes}")

    return (
      f'{index_endpoint_resource_uri}',
      f'{ann_index_resource_uri}',
      f'{deployed_ann_index_name}',
      ann_index,
    )

Writing src/train_pipes/deploy_ann_index.py


## Deploy brute force Index

In [16]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/deploy_brute_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.18.1',
        # 'google-cloud-storage',
    ],
)
def deploy_brute_index(
    project: str,
    location: str,
    version: str,
    deployed_brute_force_index_name: str,
    brute_force_index_resource_uri: str,
    index_endpoint_resource_uri: str,
) -> NamedTuple('Outputs', [
    ('index_endpoint_resource_uri', str),
    ('brute_force_index_resource_uri', str),
    ('deployed_brute_force_index_name', str),
    ('deployed_brute_force_index', Artifact),
]):
  
    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

    brute_index = vertex_ai.MatchingEngineIndex(
        index_name=brute_force_index_resource_uri
    )
    brute_force_index_resource_uri = brute_index.resource_name

    index_endpoint = vertex_ai.MatchingEngineIndexEndpoint(index_endpoint_resource_uri)

    index_endpoint = index_endpoint.deploy_index(
        index=brute_index, 
        deployed_index_id=f'{deployed_brute_force_index_name}', #-{TIMESTAMP}'
    )

    logging.info(f"index_endpoint.deployed_indexes: {index_endpoint.deployed_indexes}")

    return (
      f'{index_endpoint_resource_uri}',
      f'{brute_force_index_resource_uri}',
      f'{deployed_brute_force_index_name}', #-{TIMESTAMP}',
      brute_index,
    )

Writing src/train_pipes/deploy_brute_index.py


# Build & Compile Pipeline

### pipe configs

In [17]:
PIPELINE_VERSION = 'v9' # pipeline code
PIPELINE_TAG = f'merlin-train-deploy--{PIPELINE_VERSION}'
print("PIPELINE_TAG:", PIPELINE_TAG)

PIPELINE_TAG: merlin-train-deploy--v9


## Build pipeline

In [32]:
from src.train_pipes import train_merlin, build_custom_image, \
                            create_ann_index, create_brute_force_index, create_ann_index_endpoint_vpc, \
                            create_brute_index_endpoint_vpc, deploy_ann_index, deploy_brute_index

@kfp.v2.dsl.pipeline(
    name=f'{PIPELINE_VERSION}-{PIPELINE_TAG}'.replace('_', '-')
)
def pipeline(
    project: str,
    project_number: str,
    location: str,
    version:str,
    vpc_network_name: str,
    pipe_gcs_path: str,
    training_image_uri: str,
    serving_image_uri: str,
    train_docker_name: str,
    serving_docker_name: str,
    tb_resource: str,
    batch_size: int,
    train_epochs: int,
    train_dir: str,
    valid_dir: str,
    workflow_dir: str,
    experiment_name: str,
    experiment_run: str,
    service_account: str,
):
    
    from kfp.v2.components import importer_node
    from google_cloud_pipeline_components.types import artifact_types
    # ========================================================================
    # TODO: data processing steps
    # ========================================================================
    
    
    
    # ========================================================================
    # Build TRAIN Image
    # ========================================================================
    build_custom_train_image_op = (
        build_custom_image.build_custom_image(
            project=project,
            artifact_gcs_path=f'{pipe_gcs_path}',
            app_dir_name='trainer',
            docker_name=train_docker_name,
            custom_image_uri=training_image_uri,
        )
        .set_display_name("Build Train Image")
        .set_caching_options(True)
    )
    
    # ========================================================================
    # Train Merlin Towers
    # ========================================================================
    
    train_merlin_op = (
        train_merlin.train_merlin(
            project=project,
            location=location,
            version=version,
            train_image_uri=build_custom_train_image_op.outputs['custom_image_uri'],
            tb_resource=tb_resource,
            batch_size=batch_size,
            train_epochs=train_epochs,
            train_dir=train_dir,
            valid_dir=valid_dir,
            workflow_dir=workflow_dir,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            service_account=service_account,
        )
        .set_display_name("Train Merlin Towers")
        .set_caching_options(True)
    )
    
    # ========================================================================
    # Build SERVING Image
    # ========================================================================
    build_custom_serving_image_op = (
        build_custom_image.build_custom_image(
            project=project,
            artifact_gcs_path=f'{pipe_gcs_path}',
            app_dir_name='serving',
            docker_name=serving_docker_name,
            custom_image_uri=serving_image_uri,
        )
        .set_display_name("Build Serving Image")
        # .after(build_custom_train_image_op)
        .set_caching_options(True)
    )
       
    # ========================================================================
    # Import Trained Towers to Pipeline DAG
    # ========================================================================
    import_query_model_task = (
        importer_node.importer(
            artifact_uri=train_merlin_op.outputs['merlin_model_gcs_dir'],
            artifact_class=artifact_types.UnmanagedContainerModel, #https://cloud.google.com/vertex-ai/docs/reference/rest/v1/ModelContainerSpec
            metadata={'containerSpec':{'imageUri': build_custom_serving_image_op.outputs["custom_image_uri"],
                                      'command': ["sh", "-c", "uvicorn app.main:app --host 0.0.0.0 --port $AIP_HTTP_PORT"]},
                                      'healthRoute': '/health',
                                      'predictRoute': '/predict',
                                      'args': ['--gpus all'],
                                      'env': [{'name': 'WORKFLOW_URI', 'value': workflow_dir}]
                     }
        )
        .set_display_name("Import Query Tower")
        .set_caching_options(True)
    )


    
    import_candidate_model_task = (
        importer_node.importer(
            artifact_uri=train_merlin_op.outputs['candidate_tower_gcs_uri'],
            artifact_class=artifact_types.UnmanagedContainerModel,
        )
        .set_display_name("Import Candidate Tower")
        .set_caching_options(True)
    )

    # ========================================================================
    # Upload Models to Vertex AI Model Registry
    # ========================================================================
    
    query_model_upload_op = (gcc_aip.ModelUploadOp(project=project, 
                                                  location=location, 
                                                  unmanaged_container_model=import_query_model_task.outputs["artifact"],
                                                  display_name=f'merlin-query-tower-{version}',
                                                 )
                             .set_display_name("Register Query Tower")
                             .set_caching_options(True)
                            )
    
    # ========================================================================
    # Deploy Model to Endpoint
    # ========================================================================
    endpoint_create_op = (
        gcc_aip.EndpointCreateOp(
            project=project,
            display_name=f'query-tower-endpoint-{version}'
        )
        .after(query_model_upload_op)
        .set_display_name("Create Query Endpoint")
        .set_caching_options(True)
    )
    
    model_deploy_op = (
        gcc_aip.ModelDeployOp(
            endpoint=endpoint_create_op.outputs['endpoint'],
            # endpoint=vertex_ai.Endpoint(endpoint_create_op.outputs['resourceUri']),
            model=query_model_upload_op.outputs['model'],
            deployed_model_display_name=f'deployed-query-tower-{version}',
            dedicated_resources_machine_type="n1-standard-4",
            dedicated_resources_accelerator_type="NVIDIA_TESLA_T4",
            dedicated_resources_accelerator_count=1,
            dedicated_resources_max_replica_count=1,
            dedicated_resources_min_replica_count=1,
            traffic_split={"0": 100}
        )
        .set_display_name("Deploy Query Tower")
        .set_caching_options(True)
    )
    
    # ========================================================================
    # Create ME indexes
    # ========================================================================
    
    create_ann_index_op = (
        create_ann_index.create_ann_index(
            project=project,
            location=location,
            version=version,
            vpc_network_name=vpc_network_name,
            emb_index_gcs_uri=train_merlin_op.outputs['candidate_embeddings_gcs_uri'],
            dimensions=128,
            ann_index_display_name=f'ann_index_pipeline_test_{version}',
            approximate_neighbors_count=50,
            distance_measure_type="DOT_PRODUCT_DISTANCE",
            leaf_node_embedding_count=500,
            leaf_nodes_to_search_percent=7, 
            ann_index_description="testing ann index for Merlin deployment",
            # ann_index_labels=ann_index_labels,
        )
        .set_display_name("Create ANN Index")
        # .after(XXXX)
        .set_caching_options(True)
    )
    
    create_brute_force_index_op = (
        create_brute_force_index.create_brute_force_index(
            project=project,
            location=location,
            version=version,
            vpc_network_name=vpc_network_name,
            emb_index_gcs_uri=train_merlin_op.outputs['candidate_embeddings_gcs_uri'],
            dimensions=128,
            brute_force_index_display_name=f'bf_index_pipeline_test_{version}',
            approximate_neighbors_count=50,
            distance_measure_type="DOT_PRODUCT_DISTANCE",
            brute_force_index_description="testing bf index for Merlin deployment",
            # brute_force_index_labels=brute_force_index_labels,
        )
        .set_display_name("Create BF Index")
        # .after(XXX)
        .set_caching_options(True)
    )
    
    # ========================================================================
    # Create ME index endpoints
    # ========================================================================
    
    create_ann_index_endpoint_vpc_op = (
        create_ann_index_endpoint_vpc.create_ann_index_endpoint_vpc(
            ann_index_artifact=create_ann_index_op.outputs['ann_index'],
            project=project,
            project_number=project_number,
            version=version,
            location=location,
            vpc_network_name=vpc_network_name,
            ann_index_endpoint_display_name=f'ann-index=endpoint-{version}',
            ann_index_endpoint_description='endpoint for ann index',
        )
        .set_display_name("Create ANN Index Endpoint")
        # .after(XXX)
    )
        
    create_brute_index_endpoint_vpc_op = (
        create_brute_index_endpoint_vpc.create_brute_index_endpoint_vpc(
            bf_index_artifact=create_brute_force_index_op.outputs['brute_force_index'],
            project=project,
            project_number=project_number,
            version=version,
            location=location,
            vpc_network_name=vpc_network_name,
            brute_index_endpoint_display_name=f'bf-index-endpoint-{version}',
            brute_index_endpoint_description='endpoint for brute force index',
        )
        .set_display_name("Create BF Index Endpoint")
        # .after(XXX)
    )
    
    # ========================================================================
    # Deploy Indexes
    # ========================================================================

    deploy_ann_index_op = (
        deploy_ann_index.deploy_ann_index(
            project=project,
            location=location,
            version=version,
            deployed_ann_index_name=f'deployed_ann_index_{version}',
            ann_index_resource_uri=create_ann_index_op.outputs['ann_index_resource_uri'],
            index_endpoint_resource_uri=create_ann_index_endpoint_vpc_op.outputs['ann_index_endpoint_resource_uri'],
        )
        .set_display_name("Deploy ANN Index")
        .set_caching_options(True)
    )

    deploy_brute_index_op = (
        deploy_brute_index.deploy_brute_index(
            project=project,
            location=location,
            version=version,
            deployed_brute_force_index_name=f'deployed_bf_index_{version}',
            brute_force_index_resource_uri=create_brute_force_index_op.outputs['brute_force_index_resource_uri'],
            index_endpoint_resource_uri=create_brute_index_endpoint_vpc_op.outputs['brute_index_endpoint_resource_uri'],
        )
        .set_display_name("Deploy BF Index")
        .set_caching_options(True)
    )


## Compile Pipeline

In [34]:
BUCKET = 'jt-merlin-scaling'
BUCKET_URI = f'gs://{BUCKET}'
PIPELINE_ROOT = 'pipelines_root'
PIPE_GCS_DIR = f'{BUCKET_URI}/{PIPELINE_ROOT}/{PIPELINE_VERSION}'
print(f"PIPE_GCS_DIR: {PIPE_GCS_DIR}")

SERVING_DOCKERNAME='merlin-retriever'
TRAIN_DOCKERNAME='merlintf-22_09_v2'
SERVING_SUB_DIR='serving'
TRAIN_SUB_DIR='trainer'

PIPE_GCS_DIR: gs://jt-merlin-scaling/pipelines_root/v9


In [35]:
kfp.v2.compiler.Compiler().compile(
  pipeline_func=pipeline, 
  package_path='custom_container_pipeline_spec.json',
)

In [36]:
!gsutil cp custom_container_pipeline_spec.json $PIPE_GCS_DIR/pipeline_spec.json

Copying file://custom_container_pipeline_spec.json [Content-Type=application/json]...
/ [1 files][ 87.9 KiB/ 87.9 KiB]                                                
Operation completed over 1 objects/87.9 KiB.                                     


In [37]:
!gsutil cp ./$REPO_DOCKER_PATH_PREFIX/Dockerfile.$SERVING_DOCKERNAME $PIPE_GCS_DIR/
!gsutil cp ./$REPO_DOCKER_PATH_PREFIX/Dockerfile.$TRAIN_DOCKERNAME $PIPE_GCS_DIR/

!gsutil -m cp -r ./$REPO_DOCKER_PATH_PREFIX/$SERVING_SUB_DIR $PIPE_GCS_DIR/
!gsutil -m cp -r ./$REPO_DOCKER_PATH_PREFIX/$TRAIN_SUB_DIR $PIPE_GCS_DIR/

Copying file://./src/Dockerfile.merlin-retriever [Content-Type=application/octet-stream]...
/ [1 files][  270.0 B/  270.0 B]                                                
Operation completed over 1 objects/270.0 B.                                      
Copying file://./src/Dockerfile.merlintf-22_09_v2 [Content-Type=application/octet-stream]...
/ [1 files][  353.0 B/  353.0 B]                                                
Operation completed over 1 objects/353.0 B.                                      
Copying file://./src/serving/requirements.txt [Content-Type=text/plain]...
Copying file://./src/serving/instances.json [Content-Type=application/json]...  
Copying file://./src/serving/app/predictor.py [Content-Type=text/x-python]...   
Copying file://./src/serving/app/__init__.py [Content-Type=text/x-python]...
Copying file://./src/serving/app/prestart.sh [Content-Type=text/x-sh]...        
Copying file://./src/serving/app/main.py [Content-Type=text/x-python]...        
Copying file:

In [38]:
# Optional: save and load pipeline definition
PIPELINES = {}
def save_pipelines():
    with open(PIPELINES_FILEPATH, 'w') as f:
        json.dump(PIPELINES, f)

## Submit Pipeline Job

## pipe args

In [39]:
EXPERIMENT_PREFIX = 'pipes'
MODEL_NAME = '2tower'
FRAMEWORK = 'merlin-tf'

EXPERIMENT_NAME = f'{EXPERIMENT_PREFIX}-{MODEL_NAME}-{FRAMEWORK}-{PIPELINE_VERSION}'
EXPERIMENT_RUN = f'run-v1'

print(f"EXPERIMENT_NAME: {EXPERIMENT_NAME}")

EXPERIMENT_NAME: pipes-2tower-merlin-tf-v9


In [40]:

VPC_NETWORK_NAME = 'ucaip-haystack-vpc-network'

# TRAIN JOB CONFIG
TENSORBOARD_RESOURCE = 'projects/934903580331/locations/us-central1/tensorboards/70659015247396864'
BATCH_SIZE = 4096*4
EPOCHS = 100

# IMAGES
TRAIN_IMAGE_URI = f'gcr.io/{PROJECT_ID}/merlin-tf-2tower-training-jtv1-22_09_v2'
# SERVING_IMAGE_URI = f'gcr.io/{PROJECT_ID}/merlin-triton-serving-v9'
SERVING_IMAGE_URI = f"gcr.io/hybrid-vertex/merlin-vertex-serv-v11"

# DOCKERFILES
TRAIN_DOCKERFILE_NAME = f'Dockerfile.{TRAIN_DOCKERNAME}'
SERVE_DOCKERFILE_NAME = f'Dockerfile.{SERVING_DOCKERNAME}'

# data and schema from nvtabular pipes
DATA_DIR = 'gs://jt-merlin-scaling/nvt-last5-v1full/nvt-processed'
TRAIN_DIR = f'{DATA_DIR}/train'
VALID_DIR = f'{DATA_DIR}/valid'
WORKFLOW_DIR = 'gs://jt-merlin-scaling/nvt-last5-v1full/nvt-analyzed'

print(f"TRAIN_DOCKERFILE_NAME: {TRAIN_DOCKERFILE_NAME}")
print(f"SERVE_DOCKERFILE_NAME: {SERVE_DOCKERFILE_NAME}\n")
print(f"EXPERIMENT_NAME: {EXPERIMENT_NAME}")
print(f"EXPERIMENT_RUN: {EXPERIMENT_RUN}")

TRAIN_DOCKERFILE_NAME: Dockerfile.merlintf-22_09_v2
SERVE_DOCKERFILE_NAME: Dockerfile.merlin-retriever

EXPERIMENT_NAME: pipes-2tower-merlin-tf-v9
EXPERIMENT_RUN: run-v1


In [41]:
overwrite = True
# overwrite = False

from kfp.v2.google.client import AIPlatformClient

pipeline_client = AIPlatformClient(
  project_id=PROJECT_ID,
  region=LOCATION,
)

if not PIPELINES.get('train') or overwrite:
    response = pipeline_client.create_run_from_job_spec(
        job_spec_path='custom_container_pipeline_spec.json',
        network=f'projects/{PROJECT_NUM}/global/networks/{VPC_NETWORK_NAME}', # set to same VPC as index
        service_account=VERTEX_SA,
        parameter_values={
            'project': PROJECT_ID,
            'project_number': PROJECT_NUM,
            'location': LOCATION,
            'vpc_network_name': VPC_NETWORK_NAME,
            'version': PIPELINE_VERSION,
            'pipe_gcs_path': PIPE_GCS_DIR,
            'training_image_uri': TRAIN_IMAGE_URI,
            'serving_image_uri': SERVING_IMAGE_URI,
            'train_docker_name': TRAIN_DOCKERFILE_NAME,
            'serving_docker_name': SERVE_DOCKERFILE_NAME,
            'tb_resource': TENSORBOARD_RESOURCE,
            'batch_size': BATCH_SIZE,
            'train_epochs': EPOCHS,
            'train_dir': TRAIN_DIR,
            'valid_dir': VALID_DIR,
            'workflow_dir': WORKFLOW_DIR,
            'experiment_name': EXPERIMENT_NAME,
            'experiment_run': EXPERIMENT_RUN,
            'service_account': VERTEX_SA,
        },
        pipeline_root=f'{PIPE_GCS_DIR}',
    )
    PIPELINES['train'] = response['name']

In [42]:
endpoint = vertex_ai.Endpoint('projects/934903580331/locations/us-central1/endpoints/8515115024753098752')

In [43]:
TEST_INSTANCE = {'collaborative': 'false',
                 'album_name_pl': ["There's Really A Wolf", 'Late Nights: The Album',
                       'American Teen', 'Crazy In Love', 'Pony'], 
                 'artist_genres_pl': ["'hawaiian hip hop', 'rap'",
                       "'chicago rap', 'dance pop', 'pop', 'pop rap', 'r&b', 'southern hip hop', 'trap', 'urban contemporary'",
                       "'pop', 'pop r&b'", "'dance pop', 'pop', 'r&b'",
                       "'chill r&b', 'pop', 'pop r&b', 'r&b', 'urban contemporary'"], 
                 'artist_name_pl': ['Russ', 'Jeremih', 'Khalid', 'Beyonc\xc3\xa9',
                       'William Singe'], 
                 'artist_pop_can': 82.0, 
                 'description_pl': '', 
                 'duration_ms_songs_pl': [237506.0, 217200.0, 219080.0, 226400.0, 121739.0], 
                 'n_songs_pl': 8.0, 
                 'name': 'Lit Tunes ', 
                 'num_albums_pl': 8.0, 
                 'num_artists_pl': 8.0, 
                 'track_name_pl': ['Losin Control', 'Paradise', 'Location',
                       'Crazy In Love - Remix', 'Pony'], 
                 'track_pop_pl': [79.0, 58.0, 83.0, 71.0, 57.0],
                 'duration_ms_seed_pl': 51023.1,
                 'pid': 1,
                 'track_uri_pl': ['spotify:track:4cxMGhkinTocPSVVKWIw0d',
                       'spotify:track:1wNEBPo3nsbGCZRryI832I',
                       'spotify:track:152lZdxL1OR0ZMW6KquMif',
                       'spotify:track:2f4IuijXLxYOeBncS60GUD',
                       'spotify:track:4Lj8paMFwyKTGfILLELVxt']
                     }

In [44]:
endpoint.predict([TEST_INSTANCE])

Prediction(predictions=[[0.4707373678684235, 0.8641103506088257, 0.0, 0.0, 1.018792748451233, 0.09464387595653534, 0.5575847029685974, 0.7692021131515503, 2.001033544540405, 3.422805786132812, 0.5328009128570557, 0.7689167261123657, 0.0, 1.141236186027527, 1.746711254119873, 0.6004542112350464, 0.4158501327037811, 0.0, 0.4438650012016296, 1.380507230758667, 0.1513499319553375, 0.183375671505928, 0.6494946479797363, 0.922292172908783, 0.0, 2.127224206924438, 0.0, 2.229432582855225, 0.3340595066547394, 0.4078029990196228, 0.1877781450748444, 1.94303035736084, 1.281924962997437, 0.0, 0.2793006598949432, 0.7752009034156799, 0.4282656610012054, 0.2450124770402908, 0.7832157015800476, 1.133267641067505, 0.9308550357818604, 0.0, 0.0, 0.0, 0.2122098505496979, 0.0, 0.953433632850647, 0.5627561807632446, 1.377547264099121, 0.0, 0.0, 0.0009004436433315277, 0.0, 0.4524123966693878, 2.027101755142212, 0.0, 0.0, 0.0, 4.159066200256348, 0.04837673902511597, 0.1159927695989609, 0.0, 0.0, 0.0, 0.0, 0.0