# Train and Deploy Merlin models with Vertex AI

In [1]:
# ! pip3 install --upgrade --user -q google-cloud-aiplatform
# ! pip3 install --upgrade --user -q google-cloud-storage
# ! pip3 install --upgrade --user -q kfp
# ! pip3 install --upgrade --user -q google-cloud-pipeline-components

In [2]:
! python3 -c "import kfp; print('KFP SDK version: {}'.format(kfp.__version__))"
! python3 -c "import google_cloud_pipeline_components; print('google_cloud_pipeline_components version: {}'.format(google_cloud_pipeline_components.__version__))"
! python3 -c "import google.cloud.aiplatform; print('aiplatform SDK version: {}'.format(google.cloud.aiplatform.__version__))"

KFP SDK version: 1.8.19
google_cloud_pipeline_components version: 1.0.40
aiplatform SDK version: 1.23.0


In [3]:
GCP_PROJECTS = !gcloud config get-value project
PROJECT_ID = GCP_PROJECTS[0]
PROJECT_NUM = !gcloud projects list --filter="$PROJECT_ID" --format="value(PROJECT_NUMBER)"
PROJECT_NUM = PROJECT_NUM[0]
LOCATION = 'us-central1'

! gcloud config set project $PROJECT_ID

print(f"PROJECT_ID: {PROJECT_ID}")
print(f"PROJECT_NUM: {PROJECT_NUM}")
print(f"LOCATION: {LOCATION}")

Updated property [core/project].
PROJECT_ID: hybrid-vertex
PROJECT_NUM: 934903580331
LOCATION: us-central1


In [4]:
import os
import json
from datetime import datetime
from time import time
import pandas as pd
# disable INFO and DEBUG logging everywhere
import logging
import time
from pprint import pprint

logging.disable(logging.WARNING)

from google.cloud import aiplatform as vertex_ai
from google.cloud import storage

# Pipelines
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from google_cloud_pipeline_components import aiplatform as gcc_aip
from google_cloud_pipeline_components.types import artifact_types

# Kubeflow SDK
# TODO: fix these
from kfp.v2 import dsl
import kfp
import kfp.v2.dsl
from kfp.v2.google import client as pipelines_client
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component)

storage_client = storage.Client(project=PROJECT_ID)

vertex_ai.init(project=PROJECT_ID,location=LOCATION)

In [5]:
REPO_DOCKER_PATH_PREFIX = 'src'
PIPELINES_SUB_DIR = 'train_pipes'

In [6]:
! rm -rf {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}
! mkdir {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}

# Pipeline Components

In [8]:
# os.getcwd()

## Build Custom Image

In [9]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/build_custom_image.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)
@kfp.v2.dsl.component(
    base_image="gcr.io/google.com/cloudsdktool/cloud-sdk:latest",
    packages_to_install=[
        "google-cloud-build"
    ],
)
def build_custom_image(
    project: str,
    artifact_gcs_path: str,
    docker_name: str,
    app_dir_name: str,
    custom_image_uri: str,
    use_existing: str,
) -> NamedTuple('Outputs', [
    ('custom_image_uri', str),
]):
    # TODO: make output Artifact for image_uri
    """
    custom pipeline component to build custom image using
    Cloud Build, the training/serving application code, and dependencies
    defined in the Dockerfile
    """
    
    import logging
    import os

    from google.cloud.devtools import cloudbuild_v1 as cloudbuild
    from google.protobuf.duration_pb2 import Duration
    
    if use_existing=='False':
        # here

        # initialize client for cloud build
        logging.getLogger().setLevel(logging.INFO)
        build_client = cloudbuild.services.cloud_build.CloudBuildClient()

        # parse step inputs to get path to Dockerfile and training application code
        _gcs_dockerfile_path = os.path.join(artifact_gcs_path, f"{docker_name}") # Dockerfile.XXXXX
        _gcs_script_dir_path = os.path.join(artifact_gcs_path, f"{app_dir_name}/") # "trainer/"

        logging.info(f"_gcs_dockerfile_path: {_gcs_dockerfile_path}")
        logging.info(f"_gcs_script_dir_path: {_gcs_script_dir_path}")

        # define build steps to pull the training code and Dockerfile
        # and build/push the custom training container image
        build = cloudbuild.Build()
        build.steps = [
            {
                "name": "gcr.io/cloud-builders/gsutil",
                "args": ["cp", "-r", _gcs_script_dir_path, "."],
            },
            {
                "name": "gcr.io/cloud-builders/gsutil",
                "args": ["cp", _gcs_dockerfile_path, "Dockerfile"],
            },
            # enabling Kaniko cache in a Docker build that caches intermediate
            # layers and pushes image automatically to Container Registry
            # https://cloud.google.com/build/docs/kaniko-cache
            # {
            #     "name": "gcr.io/kaniko-project/executor:latest",
            #     # "name": "gcr.io/kaniko-project/executor:v1.8.0",        # TODO; downgraded to avoid error in build
            #     # "args": [f"--destination={training_image_uri}", "--cache=true"],
            #     "args": [f"--destination={training_image_uri}", "--cache=false"],
            # },
            {
                "name": "gcr.io/cloud-builders/docker",
                "args": ['build','-t', f'{custom_image_uri}', '.'],
            },
            {
                "name": "gcr.io/cloud-builders/docker",
                "args": ['push', f'{custom_image_uri}'], 
            },
        ]
        # override default timeout of 10min
        timeout = Duration()
        timeout.seconds = 7200
        build.timeout = timeout

        # create build
        operation = build_client.create_build(project_id=project, build=build)
        logging.info("IN PROGRESS:")
        logging.info(operation.metadata)

        # get build status
        result = operation.result()
        logging.info("RESULT:", result.status)
        
    else:
        logging.info(f"Using existing (prebuilt) image: {custom_image_uri}")

    # return step outputs
    return (
        custom_image_uri,
    )

Writing src/train_pipes/build_custom_image.py


## Create Managed Tensorboard

In [10]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/create_tensorboard.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)
@kfp.v2.dsl.component(
    base_image='python:3.9',
    packages_to_install=[
        'google-cloud-aiplatform==1.23.0',
        'numpy',
        'google-cloud-storage',
    ],
    # output_component_file="./pipelines/train_custom_model.yaml",
)
def create_tensorboard(
    project: str,
    location: str,
    model_version: str,
    pipeline_version: str,
    # model_name: str, 
    experiment_name: str,
    experiment_run: str,
) -> NamedTuple('Outputs', [
    ('tensorboard_resource_name', str),
    ('tensorboard_display_name', str),
]):
    
    import logging
    from google.cloud import aiplatform as vertex_ai
    from google.cloud import storage
    
    vertex_ai.init(
        project=project,
        location=location,
        # experiment=experiment_name,
    )
    
    logging.info(f'experiment_name: {experiment_name}')
    
    # # create new TB instance
    TENSORBOARD_DISPLAY_NAME=f"{experiment_name}-v1"
    tensorboard = vertex_ai.Tensorboard.create(display_name=TENSORBOARD_DISPLAY_NAME, project=project, location=location)
    TB_RESOURCE_NAME = tensorboard.resource_name
    
    logging.info(f'TENSORBOARD_DISPLAY_NAME: {TENSORBOARD_DISPLAY_NAME}')
    logging.info(f'TB_RESOURCE_NAME: {TB_RESOURCE_NAME}')
    
    return (
        f'{TB_RESOURCE_NAME}',
        f'{TENSORBOARD_DISPLAY_NAME}',
    )

Writing src/train_pipes/create_tensorboard.py


## Train Job

In [11]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/train_merlin.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.23.0',
        # 'google-cloud-storage',
    ],
)
def train_merlin(
    project: str,
    location: str,
    version: str,
    train_image_uri: str,     # TODO: Artifact
    train_output_gcs_bucket: str,
    tb_resource: str,
    batch_size: int, 
    train_epochs: int,
    train_dir: str,
    valid_dir: str,
    workflow_dir: str,
    experiment_name: str,
    experiment_run: str,
    service_account: str,
    worker_pool_specs: dict,
) -> NamedTuple('Outputs', [
    ('merlin_model_gcs_dir', str),
    ('query_tower_gcs_dir', str),
    ('candidate_tower_gcs_uri', str),
    ('candidate_embeddings_gcs_uri', str),
    ('working_dir_gcs_path', str),
]):
    
    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time
    
    TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
    
    BASE_OUTPUT_DIR = f'gs://{train_output_gcs_bucket}/{experiment_name}/{experiment_run}'
    STAGING_BUCKET = f'{BASE_OUTPUT_DIR}/staging'
    
    logging.info(f'BASE_OUTPUT_DIR: {BASE_OUTPUT_DIR}')
    logging.info(f'STAGING_BUCKET: {STAGING_BUCKET}')

    vertex_ai.init(
        project=project,
        location=location,
        experiment=experiment_name,
    )
    
    # ====================================================
    # # DEFINE ARGS
    # ====================================================
    # TODO: parameterize
    # worker_pool_specs[0]['container_spec']['command'].append(f'--tb_name={tb_resource}')
    JOB_NAME = f'mm-2t-pipe-train-{version}'
    
    logging.info(f'tensorboard_resource_name: {tb_resource}')
    logging.info(f'service_account: {service_account}')
    logging.info(f'worker_pool_specs: {worker_pool_specs}')
    logging.info(f'JOB_NAME: {JOB_NAME}')
    # logging.info(f'gpu_type: {gpu_type}')
    # ==============================================================================
    # Submit Train Job 
    # ==============================================================================

    job = vertex_ai.CustomJob(
        display_name=JOB_NAME,
        worker_pool_specs=worker_pool_specs,
        base_output_dir=BASE_OUTPUT_DIR,
        staging_bucket=STAGING_BUCKET,
    )
    
    job.run(
        sync=False, 
        service_account=service_account,
        tensorboard=tb_resource,
        restart_job_on_worker_restart=False,
        enable_web_access=True,
    )
    
    # uris set during train script
    WORKING_DIR_GCS_URI = f'gs://{train_output_gcs_bucket}/{experiment_name}/{experiment_run}'
    MODEL_DIR = f"{WORKING_DIR_GCS_URI}/model_dir"
    QUERY_TOWER_PATH = f"{MODEL_DIR}/query_tower"
    CANDIDATE_TOWER_PATH = f"{MODEL_DIR}/candidate_tower"
    EMBEDDINGS_PATH = f"{MODEL_DIR}/candidate_embeddings"
    
    logging.info(f'WORKING_DIR_GCS_URI: {WORKING_DIR_GCS_URI}')
    logging.info(f'MODEL_DIR: {MODEL_DIR}')
    logging.info(f'QUERY_TOWER_PATH: {QUERY_TOWER_PATH}')
    logging.info(f'CANDIDATE_TOWER_PATH: {CANDIDATE_TOWER_PATH}')
    logging.info(f'EMBEDDINGS_PATH: {EMBEDDINGS_PATH}')
    
    return (
        f'{MODEL_DIR}',
        f'{QUERY_TOWER_PATH}',
        f'{CANDIDATE_TOWER_PATH}',
        f'{EMBEDDINGS_PATH}',
        f'{WORKING_DIR_GCS_URI}',
    )

Writing src/train_pipes/train_merlin.py


## Custom Model Upload

In [12]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/upload_custom_model.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.23.0',
        # 'google-cloud-storage',
    ],
)
def upload_custom_model(
    project: str,
    location: str,
    version: str,
    display_name: str,
    artifact_uri: str,
    unmanaged_container_model: Input[Artifact],
    serving_container_image_uri: str,
) -> NamedTuple('Outputs', [
    ('model', Artifact),
    ('model_resource_name', str),
]):
    
    import logging
    from google.cloud import aiplatform as vertex_ai

    vertex_ai.init(
        project=project,
        location=location,
    )
    logging.info(f" display_name: {display_name}")
    logging.info(f" artifact_uri: {artifact_uri}")
    logging.info(f" unmanaged_container_model: {unmanaged_container_model}")
    logging.info(f" serving_container_image_uri: {serving_container_image_uri}")
    
    logging.info(f"Uploading model to Vertex...")
    model = vertex_ai.Model.upload(
        display_name=display_name,
        artifact_uri=artifact_uri,
        serving_container_image_uri=serving_container_image_uri,
        serving_container_predict_route='/predict',
        serving_container_health_route='/health',
        serving_container_command=["sh", "-c", "uvicorn main:app --host 0.0.0.0 --port $AIP_HTTP_PORT"],
        serving_container_args='--gpus all',
        # parent_model=PARENT_MODEL,
        sync=True,
    )
    
    MODEL_RESOURCE_NAME = model.resource_name
    logging.info(f" MODEL_RESOURCE_NAME: {MODEL_RESOURCE_NAME}")
    
    return (
        model,
        f'{MODEL_RESOURCE_NAME}',
    )

Writing src/train_pipes/upload_custom_model.py


## Create ANN Index

In [13]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/create_ann_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.23.0',
        # 'google-api-core==2.10.0',
        # 'google-cloud-storage',
    ],
)
def create_ann_index(
    project: str,
    location: str,
    version: str, 
    vpc_network_name: str,
    emb_index_gcs_uri: str,
    dimensions: int,
    ann_index_display_name: str,
    approximate_neighbors_count: int,
    distance_measure_type: str,
    leaf_node_embedding_count: int,
    leaf_nodes_to_search_percent: int, 
    ann_index_description: str,
    # ann_index_labels: Dict, 
) -> NamedTuple('Outputs', [
    ('ann_index_resource_uri', str),
    ('ann_index', Artifact),
]):
    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
    
    ENDPOINT = "{}-aiplatform.googleapis.com".format(location)
    NETWORK_NAME = vpc_network_name
    INDEX_DIR_GCS = emb_index_gcs_uri
    PARENT = "projects/{}/locations/{}".format(project, location)

    logging.info(f"ENDPOINT: {ENDPOINT}")
    logging.info(f"project: {project}")
    logging.info(f"location: {location}")
    logging.info(f"INDEX_DIR_GCS: {INDEX_DIR_GCS}")
    
    # ==============================================================================
    # Create Index 
    # ==============================================================================

    start = time.time()
        
    tree_ah_index = vertex_ai.MatchingEngineIndex.create_tree_ah_index(
        display_name=f'{ann_index_display_name}-{TIMESTAMP}',
        contents_delta_uri=f'{emb_index_gcs_uri}/', # emb_index_gcs_uri,
        dimensions=dimensions,
        approximate_neighbors_count=approximate_neighbors_count,
        distance_measure_type=distance_measure_type,
        leaf_node_embedding_count=leaf_node_embedding_count,
        leaf_nodes_to_search_percent=leaf_nodes_to_search_percent,
        description=ann_index_description,
        # labels=ann_index_labels,
        # sync=True,
    )

    end = time.time()
    elapsed_time = round((end - start), 2)
    logging.info(f'Elapsed time creating index: {elapsed_time} seconds\n')
    
    ann_index_resource_uri = tree_ah_index.resource_name
    logging.info("ann_index_resource_uri:", ann_index_resource_uri) 

    return (
      f'{ann_index_resource_uri}',
      tree_ah_index,
    )

Writing src/train_pipes/create_ann_index.py


## Create brute force index

In [14]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/create_brute_force_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.23.0',
        # 'google-api-core==2.10.0',
    ],
)
def create_brute_force_index(
    project: str,
    location: str,
    version: str,
    vpc_network_name: str,
    emb_index_gcs_uri: str,
    dimensions: int,
    brute_force_index_display_name: str,
    approximate_neighbors_count: int,
    distance_measure_type: str,
    brute_force_index_description: str,
    # brute_force_index_labels: Dict,
) -> NamedTuple('Outputs', [
    ('brute_force_index_resource_uri', str),
    ('brute_force_index', Artifact),
]):

    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
    
    ENDPOINT = "{}-aiplatform.googleapis.com".format(location)
    NETWORK_NAME = vpc_network_name
    INDEX_DIR_GCS = emb_index_gcs_uri
    PARENT = "projects/{}/locations/{}".format(project, location)

    logging.info("ENDPOINT: {}".format(ENDPOINT))
    logging.info("PROJECT_ID: {}".format(project))
    logging.info("REGION: {}".format(location))
    
    # ==============================================================================
    # Create Index 
    # ==============================================================================

    start = time.time()
    
    brute_force_index = vertex_ai.MatchingEngineIndex.create_brute_force_index(
        display_name=f'{brute_force_index_display_name}-{TIMESTAMP}',
        contents_delta_uri=f'{emb_index_gcs_uri}/', # emb_index_gcs_uri,
        dimensions=dimensions,
        # approximate_neighbors_count=approximate_neighbors_count,
        distance_measure_type=distance_measure_type,
        description=brute_force_index_description,
        # labels=brute_force_index_labels,
        # sync=True,
    )
    brute_force_index_resource_uri = brute_force_index.resource_name
    print("brute_force_index_resource_uri:",brute_force_index_resource_uri) 

    return (
      f'{brute_force_index_resource_uri}',
      brute_force_index,
    )

Writing src/train_pipes/create_brute_force_index.py


## Create ANN index endpoint

In [15]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/create_ann_index_endpoint_vpc.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.23.0',
        # 'google-api-core==2.10.0',
    ],
)
def create_ann_index_endpoint_vpc(
    ann_index_artifact: Input[Artifact],
    project: str,
    project_number: str,
    location: str,
    version: str,
    vpc_network_name: str,
    ann_index_endpoint_display_name: str,
    ann_index_endpoint_description: str,
) -> NamedTuple('Outputs', [
    ('vpc_network_resource_uri', str),
    ('ann_index_endpoint_resource_uri', str),
    ('ann_index_endpoint', Artifact),
    ('ann_index_endpoint_display_name', str),
]):

    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

    vpc_network_resource_uri = f'projects/{project_number}/global/networks/{vpc_network_name}'
    logging.info(f"vpc_network_resource_uri: {vpc_network_resource_uri}")

    ann_index_endpoint = vertex_ai.MatchingEngineIndexEndpoint.create(
        display_name=f'{ann_index_endpoint_display_name}',
        description=ann_index_endpoint_description,
        network=vpc_network_resource_uri,
    )
    ann_index_endpoint_resource_uri = ann_index_endpoint.resource_name
    logging.info(f"ann_index_endpoint_resource_uri: {ann_index_endpoint_resource_uri}")

    return (
        f'{vpc_network_resource_uri}',
        f'{ann_index_endpoint_resource_uri}',
        ann_index_endpoint,
        f'{ann_index_endpoint_display_name}'
    )

Writing src/train_pipes/create_ann_index_endpoint_vpc.py


## Create brute force index endpoint

In [16]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/create_brute_index_endpoint_vpc.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.23.0',
        # 'google-api-core==2.10.0',
    ],
)
def create_brute_index_endpoint_vpc(
    bf_index_artifact: Input[Artifact],
    project: str,
    project_number: str,
    location: str,
    version: str,
    vpc_network_name: str,
    brute_index_endpoint_display_name: str,
    brute_index_endpoint_description: str,
) -> NamedTuple('Outputs', [
    ('vpc_network_resource_uri', str),
    ('brute_index_endpoint_resource_uri', str),
    ('brute_index_endpoint', Artifact),
    ('brute_index_endpoint_display_name', str),
]):

    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

    vpc_network_resource_uri = f'projects/{project_number}/global/networks/{vpc_network_name}'
    logging.info(f"vpc_network_resource_uri: {vpc_network_resource_uri}")

    brute_index_endpoint = vertex_ai.MatchingEngineIndexEndpoint.create(
        display_name=f'{brute_index_endpoint_display_name}',
        description=brute_index_endpoint_description,
        network=vpc_network_resource_uri,
    )
    brute_index_endpoint_resource_uri = brute_index_endpoint.resource_name
    logging.info(f"brute_index_endpoint_resource_uri: {brute_index_endpoint_resource_uri}")

    return (
      f'{vpc_network_resource_uri}',
      f'{brute_index_endpoint_resource_uri}',
      brute_index_endpoint,
      f'{brute_index_endpoint_display_name}'
    )

Writing src/train_pipes/create_brute_index_endpoint_vpc.py


## Deploy ANN Index

In [17]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/deploy_ann_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.23.0',
        # 'google-api-core==2.10.0',
    ]
)
def deploy_ann_index(
    project: str,
    location: str,
    version: str,
    deployed_ann_index_name: str,
    ann_index_resource_uri: str,
    index_endpoint_resource_uri: str,
) -> NamedTuple('Outputs', [
    ('index_endpoint_resource_uri', str),
    ('ann_index_resource_uri', str),
    ('deployed_ann_index_name', str),
    ('deployed_ann_index', Artifact),
]):
  
    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
    VERSION = version
    
    ann_index = vertex_ai.MatchingEngineIndex(
      index_name=ann_index_resource_uri
    )
    ann_index_resource_uri = ann_index.resource_name

    index_endpoint = vertex_ai.MatchingEngineIndexEndpoint(
      index_endpoint_resource_uri
    )

    index_endpoint = index_endpoint.deploy_index(
      index=ann_index, 
      deployed_index_id=f'{deployed_ann_index_name}' #-{TIMESTAMP}'
    )

    logging.info(f"index_endpoint.deployed_indexes: {index_endpoint.deployed_indexes}")

    return (
      f'{index_endpoint_resource_uri}',
      f'{ann_index_resource_uri}',
      f'{deployed_ann_index_name}',
      ann_index,
    )

Writing src/train_pipes/deploy_ann_index.py


## Deploy brute force Index

In [18]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/deploy_brute_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.23.0',
        # 'google-api-core==2.10.0',
    ],
)
def deploy_brute_index(
    project: str,
    location: str,
    version: str,
    deployed_brute_force_index_name: str,
    brute_force_index_resource_uri: str,
    index_endpoint_resource_uri: str,
) -> NamedTuple('Outputs', [
    ('index_endpoint_resource_uri', str),
    ('brute_force_index_resource_uri', str),
    ('deployed_brute_force_index_name', str),
    ('deployed_brute_force_index', Artifact),
]):
  
    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

    brute_index = vertex_ai.MatchingEngineIndex(
        index_name=brute_force_index_resource_uri
    )
    brute_force_index_resource_uri = brute_index.resource_name

    index_endpoint = vertex_ai.MatchingEngineIndexEndpoint(index_endpoint_resource_uri)

    index_endpoint = index_endpoint.deploy_index(
        index=brute_index, 
        deployed_index_id=f'{deployed_brute_force_index_name}', #-{TIMESTAMP}'
    )

    logging.info(f"index_endpoint.deployed_indexes: {index_endpoint.deployed_indexes}")

    return (
      f'{index_endpoint_resource_uri}',
      f'{brute_force_index_resource_uri}',
      f'{deployed_brute_force_index_name}', #-{TIMESTAMP}',
      brute_index,
    )

Writing src/train_pipes/deploy_brute_index.py


## Test query deployment

In [19]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/test_deployed_query_model_v7.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.23.0',
        # 'google-api-core==2.10.0',
        'google-cloud-pipeline-components'
    ],
)
def test_deployed_query_model(
    project: str,
    location: str,
    version: str,
    deployed_endpoint: str,
    data_dir_bucket_name: str,
    test_instance_gcs_blob_name: str,
    # instances: list,
    metrics: Output[Metrics],
):
    # here
    import base64
    import logging

    from google.cloud import aiplatform
    from google.protobuf.json_format import Parse
    from google_cloud_pipeline_components.proto.gcp_resources_pb2 import \
        GcpResources
    
    from google.cloud import storage
    from google.cloud.storage.bucket import Bucket
    from google.cloud.storage.blob import Blob
    
    import pickle as pkl
    import time

    logging.getLogger().setLevel(logging.INFO)
    aiplatform.init(project=project)
    storage_client = storage.Client(project=project)

    # parse endpoint resource
    logging.info(f"Endpoint = {deployed_endpoint}")
    gcp_resources = Parse(deployed_endpoint, GcpResources())
    endpoint_uri = gcp_resources.resources[0].resource_uri
    endpoint_id = "/".join(endpoint_uri.split("/")[-8:-2])
    logging.info(f"Endpoint ID = {endpoint_id}")

    # define endpoint client
    _endpoint = aiplatform.Endpoint(endpoint_id)
    
    # ====================================================
    # Load test instance
    # ====================================================
    LOCAL_INSTANCE_FILE = 'merlin_last5_test_instance.pkl'
    logging.info(f"LOCAL_INSTANCE_FILE: {LOCAL_INSTANCE_FILE}")
    
    bucket = storage_client.bucket(data_dir_bucket_name)
    blob = bucket.blob(test_instance_gcs_blob_name)
    blob.download_to_filename(LOCAL_INSTANCE_FILE)

    filehandler = open(LOCAL_INSTANCE_FILE, 'rb')
    test_instances_dict = pkl.load(filehandler)
    filehandler.close()
    
    logging.info(f'test_instances_dict: {test_instances_dict}')
    
    # ====================================================
    # prediction request
    # ====================================================

    start = time.time()

    playlist_emb = _endpoint.predict(instances=[test_instances_dict])
    
    end = time.time()
    
    elapsed_time = end - start
    elapsed_time = round(elapsed_time, 4)
    logging.info(f'Deployed query model latency: {elapsed_time} seconds')
    logging.info(f'query embeddings: {playlist_emb.predictions}')
    
    metrics.log_metric("endpoint latency", elapsed_time)

Writing src/train_pipes/test_deployed_query_model_v7.py


## Test index deployment

In [38]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/test_model_index_endpoint_v5.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.23.0',
        'google-cloud-pipeline-components',
        'google-cloud-storage',
        'numpy'
    ],
)
def test_model_index_endpoint(
    project: str,
    location: str,
    version: str,
    data_dir_bucket_name: str,
    test_instance_gcs_blob_name: str,
    ann_index_endpoint_resource_uri: str,
    brute_index_endpoint_resource_uri: str,
    endpoint: str, # Input[Artifact],
    metrics: Output[Metrics],
    # metrics: Output[Metrics],
    # metrics: Output[Metrics],
    # metrics: Output[Metrics],
):
    import logging
    import time
    import numpy as np
    import pickle as pkl
    
    import base64

    from typing import Dict, List, Union

    from google.cloud import aiplatform as vertex_ai
    from google.protobuf.json_format import Parse
    from google_cloud_pipeline_components.proto.gcp_resources_pb2 import \
        GcpResources
    
    from google.cloud import storage
    from google.cloud.storage.bucket import Bucket
    from google.cloud.storage.blob import Blob

    # import tensorflow as tf

    logging.getLogger().setLevel(logging.INFO)

    vertex_ai.init(
        project=project,
        location=location,
    )
    storage_client = storage.Client(project=project)
    
    # ====================================================
    # get deployed model endpoint
    # ====================================================
    logging.info(f"Endpoint = {endpoint}")
    gcp_resources = Parse(endpoint, GcpResources())
    logging.info(f"gcp_resources = {gcp_resources}")
    
    _endpoint_resource = gcp_resources.resources[0].resource_uri
    logging.info(f"_endpoint_resource = {_endpoint_resource}")
    
    _endpoint_uri = "/".join(_endpoint_resource.split("/")[-8:-2])
    logging.info(f"_endpoint_uri = {_endpoint_uri}")
    
    # define endpoint resource in component
    _endpoint = vertex_ai.Endpoint(_endpoint_uri)
    logging.info(f"_endpoint defined")
    
    # ====================================================
    # Get indexes
    # ====================================================
    logging.info(f"ann_index_endpoint_resource_uri: {ann_index_endpoint_resource_uri}")
    logging.info(f"brute_index_endpoint_resource_uri: {brute_index_endpoint_resource_uri}")

    deployed_ann_index = vertex_ai.MatchingEngineIndexEndpoint(ann_index_endpoint_resource_uri)
    deployed_bf_index = vertex_ai.MatchingEngineIndexEndpoint(brute_index_endpoint_resource_uri)

    DEPLOYED_ANN_ID = deployed_ann_index.deployed_indexes[0].id
    DEPLOYED_BF_ID = deployed_bf_index.deployed_indexes[0].id
    logging.info(f"DEPLOYED_ANN_ID: {DEPLOYED_ANN_ID}")
    logging.info(f"DEPLOYED_BF_ID: {DEPLOYED_BF_ID}")
    
    # ====================================================
    # Load test instance
    # ====================================================
    LOCAL_INSTANCE_FILE = 'merlin_last5_test_instance.pkl'
    logging.info(f"LOCAL_INSTANCE_FILE: {LOCAL_INSTANCE_FILE}")
    
    bucket = storage_client.bucket(data_dir_bucket_name)
    blob = bucket.blob(test_instance_gcs_blob_name)
    blob.download_to_filename(LOCAL_INSTANCE_FILE)

    filehandler = open(LOCAL_INSTANCE_FILE, 'rb')
    test_instances_dict = pkl.load(filehandler)
    filehandler.close()
    
    logging.info(f'test_instances_dict: {test_instances_dict}')
    
    # ====================================================
    # get query response
    # ====================================================
    start = time.time()

    playlist_emb = _endpoint.predict(instances=[test_instances_dict])
    
    end = time.time()
    
    elapsed_query_time = end - start
    elapsed_query_time = round(elapsed_query_time, 4)
    logging.info(f'Query endpoint latency: {elapsed_query_time} seconds')
    
    # ====================================================
    # call matching engine with predicted emb vectors
    # ====================================================
    logging.info('Retreiving neighbors from ANN index...')
    start = time.time()
    ANN_response = deployed_ann_index.match(
        deployed_index_id=DEPLOYED_ANN_ID,
        queries=playlist_emb.predictions,
        num_neighbors=50
    )
    end = time.time()
    elapsed_ann_time = end - start
    elapsed_ann_time = round(elapsed_ann_time, 4)
    logging.info(f'ANN latency: {elapsed_ann_time} seconds')
    
    
    logging.info('Retreiving neighbors from BF index...')
    start = time.time()
    BF_response = deployed_bf_index.match(
        deployed_index_id=DEPLOYED_BF_ID,
        queries=playlist_emb.predictions,
        num_neighbors=50
    )
    end = time.time()
    elapsed_bf_time = end - start
    elapsed_bf_time = round(elapsed_bf_time, 4)
    logging.info(f'Bruteforce latency: {elapsed_bf_time} seconds')
    
    # TODO: write results to file -> GCS
    
    # ====================================================
    # Calculate recall by determining how many neighbors were correctly retrieved 
    # compare with brute-force search
    # ====================================================
    recalled_neighbors = 0
    for tree_ah_neighbors, brute_force_neighbors in zip(
        ANN_response, BF_response
    ):
        tree_ah_neighbor_ids = [neighbor.id for neighbor in tree_ah_neighbors]
        brute_force_neighbor_ids = [neighbor.id for neighbor in brute_force_neighbors]

        recalled_neighbors += len(
            set(tree_ah_neighbor_ids).intersection(brute_force_neighbor_ids)
        )

    recall = recalled_neighbors / len(
        [neighbor for neighbors in BF_response for neighbor in neighbors]
    )

    logging.info("Recall: {}".format(recall))
    logging.info(f'playlist_emb: {playlist_emb.predictions}')
    logging.info(f'ANN_response: {ANN_response}')
    logging.info(f'BF_response: {BF_response}')
    
    metrics.log_metric("elapsed_query_time", elapsed_query_time)
    metrics.log_metric("elapsed_ann_time", elapsed_ann_time)
    metrics.log_metric("elapsed_bf_time", elapsed_bf_time)
    metrics.log_metric("Recall", recall)

Overwriting src/train_pipes/test_model_index_endpoint_v5.py


## Compute config for pipeline steps

In [39]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/pipeline_config.py
CPU_LIMIT='96' # 24
MEMORY_LIMIT='170G' # 170G 624G
INSTANCE_TYPE="a2-highgpu-2g"
GPU_LIMIT=2
GPU_TYPE="NVIDIA_TESLA_A100"

Overwriting src/train_pipes/pipeline_config.py


# Prepare Job Specs

## Accelerators and Device Strategy

In [40]:
# ====================================================
# Single | Single machine, single GPU
# ====================================================
WORKER_MACHINE_TYPE = 'a2-highgpu-1g'
REPLICA_COUNT = 1
ACCELERATOR_TYPE = 'NVIDIA_TESLA_A100'
PER_MACHINE_ACCELERATOR_COUNT = 1
REDUCTION_SERVER_COUNT = 0                                                      
REDUCTION_SERVER_MACHINE_TYPE = "n1-highcpu-16"
DISTRIBUTE_STRATEGY = 'single'

## Vertex AI Experiments

In [41]:
VERSION='v11'

In [24]:
EXPERIMENT_PREFIX = 'test-e2e-pipe'                     # custom identifier for organizing experiments
EXPERIMENT_NAME=f'{EXPERIMENT_PREFIX}-{VERSION}'
RUN_NAME = f'run-{time.strftime("%Y%m%d-%H%M%S")}'
# RUN_NAME = 'run-20230308-171530'

print(f"EXPERIMENT_NAME: {EXPERIMENT_NAME}")
print(f"RUN_NAME: {RUN_NAME}")

EXPERIMENT_NAME: test-e2e-pipe-v11
RUN_NAME: run-20230321-100615


## Training Config

### train image

In [25]:
# =================================================
# train image
# =================================================
# Existing image URI or name for image to create

TRAIN_IMAGE_URI = 'gcr.io/hybrid-vertex/train-2212v16-vertex-merlin-tf-2tower-jtv34'

### data source
* TODO: update these variables to point to the GCS location where the processed training data is stored

In [26]:
# gcs bucket
OUTPUT_BUCKET = 'jt-merlin-scaling'

# data and schema from nvtabular pipes
DATA_DIR = 'gs://jt-merlin-scaling/nvt-last5-latest-12/nvt-processed'

TRAIN_DIR = f'{DATA_DIR}/train'
VALID_DIR = f'{DATA_DIR}/valid'

WORKFLOW_DIR = 'gs://jt-merlin-scaling/nvt-last5-latest-12/nvt-processed/workflow'

# Stores pipeline executions for each run
PIPELINE_ROOT_PATH = f'gs://{OUTPUT_BUCKET}/{EXPERIMENT_NAME}/{RUN_NAME}/pipeline_root' # TODO - parametrize

print(f"DATA_DIR: {DATA_DIR}")
print(f"TRAIN_DIR: {TRAIN_DIR}")
print(f"VALID_DIR: {VALID_DIR}")
print(f"WORKFLOW_DIR: {WORKFLOW_DIR}")
print(f"PIPELINE_ROOT_PATH: {PIPELINE_ROOT_PATH}")

DATA_DIR: gs://jt-merlin-scaling/nvt-last5-latest-12/nvt-processed
TRAIN_DIR: gs://jt-merlin-scaling/nvt-last5-latest-12/nvt-processed/train
VALID_DIR: gs://jt-merlin-scaling/nvt-last5-latest-12/nvt-processed/valid
WORKFLOW_DIR: gs://jt-merlin-scaling/nvt-last5-latest-12/nvt-processed/workflow
PIPELINE_ROOT_PATH: gs://jt-merlin-scaling/test-e2e-pipe-v11/run-20230321-100615/pipeline_root


### train params

In [27]:
from utils import train_utils

# data size
train_sample_cnt = 8_205_265 # 8_205_265
valid_samples_cnt = 82_959

# train config
NUM_EPOCHS = 60
BATCH_SIZE = 4096*4 
LEARNING_RATE = 0.001
VALID_FREQUENCY = 20
VALID_STEPS = valid_samples_cnt // BATCH_SIZE
EPOCH_STEPS = train_sample_cnt // BATCH_SIZE
CHECKPOINT_FREQ='epoch'

# model
LAYERS = "[512, 256, 128]"

    
WORKER_CMD = [
    'sh',
    '-euc',
    f'''pip freeze && python -m trainer.train_task \
    --per_gpu_batch_size={BATCH_SIZE} \
    --train_output_bucket={OUTPUT_BUCKET} \
    --train_dir={TRAIN_DIR} \
    --valid_dir={VALID_DIR} \
    --workflow_dir={WORKFLOW_DIR} \
    --num_epochs={NUM_EPOCHS} \
    --learning_rate={LEARNING_RATE} \
    --distribute={DISTRIBUTE_STRATEGY} \
    --experiment_name={EXPERIMENT_NAME} \
    --experiment_run={RUN_NAME} \
    --project={PROJECT_ID} \
    --location={LOCATION} \
    --layer_sizes=\'{LAYERS}\' \
    --valid_frequency={VALID_FREQUENCY} \
    --epoch_steps={EPOCH_STEPS} \
    --valid_steps={VALID_STEPS} \
    --chkpt_freq={CHECKPOINT_FREQ} \
    --write_embeddings'''
    # --profiler \
    # --tb_name={TB_RESOURCE_NAME} \
]
    
WORKER_POOL_SPECS = train_utils.prepare_worker_pool_specs(
    image_uri=TRAIN_IMAGE_URI,
    # args=WORKER_ARGS,
    cmd=WORKER_CMD,
    replica_count=REPLICA_COUNT,
    machine_type=WORKER_MACHINE_TYPE,
    accelerator_count=PER_MACHINE_ACCELERATOR_COUNT,
    accelerator_type=ACCELERATOR_TYPE,
    reduction_server_count=REDUCTION_SERVER_COUNT,
    reduction_server_machine_type=REDUCTION_SERVER_MACHINE_TYPE,
)

from pprint import pprint
pprint(WORKER_POOL_SPECS)

[{'container_spec': {'command': ['sh',
                                 '-euc',
                                 'pip freeze && python -m '
                                 'trainer.train_task     '
                                 '--per_gpu_batch_size=16384     '
                                 '--train_output_bucket=jt-merlin-scaling     '
                                 '--train_dir=gs://jt-merlin-scaling/nvt-last5-latest-12/nvt-processed/train     '
                                 '--valid_dir=gs://jt-merlin-scaling/nvt-last5-latest-12/nvt-processed/valid     '
                                 '--workflow_dir=gs://jt-merlin-scaling/nvt-last5-latest-12/nvt-processed/workflow     '
                                 '--num_epochs=60     '
                                 '--learning_rate=0.001     '
                                 '--distribute=single     '
                                 '--experiment_name=test-e2e-pipe-v11     '
                                 '--experiment_ru

# Build & Compile Pipeline

### pipe configs

In [42]:
PIPELINE_VERSION = 'jtv05' # pipeline code

PIPELINE_TAG = f'merlin-{VERSION}-e2e-pipe-{PIPELINE_VERSION}'

PIPELINE_NAME = f'{PIPELINE_TAG}'.replace('_', '-')

print("PIPELINE_TAG:", PIPELINE_TAG)
print("PIPELINE_NAME:", PIPELINE_NAME)

PIPELINE_TAG: merlin-v11-e2e-pipe-jtv05
PIPELINE_NAME: merlin-v11-e2e-pipe-jtv05


## Build pipeline

In [43]:
from src.train_pipes import (
    train_merlin, build_custom_image, upload_custom_model,
    create_ann_index, create_brute_force_index, create_ann_index_endpoint_vpc,
    create_brute_index_endpoint_vpc, deploy_ann_index, deploy_brute_index, 
    create_tensorboard, pipeline_config, test_deployed_query_model_v7, test_model_index_endpoint_v5
)

@kfp.v2.dsl.pipeline(
    name=f'{PIPELINE_NAME}'.replace('_', '-')
)
def pipeline(
    project: str,
    project_number: str,
    location: str,
    version:str,
    vpc_network_name: str,
    pipe_gcs_path: str,
    train_output_gcs_bucket: str,
    training_image_uri: str,
    serving_image_uri: str,
    train_docker_name: str,
    serving_docker_name: str,
    tb_resource: str,
    batch_size: int,
    train_epochs: int,
    train_dir: str,
    valid_dir: str,
    workflow_dir: str,
    experiment_name: str,
    experiment_run: str,
    service_account: str,
    embeddings_dim: int,
    layer_sizes: str,
    worker_pool_specs: dict,
    test_instance_gcs_blob_name: str,
):
    
    from kfp.v2.components import importer_node
    from google_cloud_pipeline_components.types import artifact_types
    # ========================================================================
    # TODO: data processing steps
    # ========================================================================
    
    
    
    # ========================================================================
    # Build TRAIN Image
    # ========================================================================
    build_custom_train_image_op = (
        build_custom_image.build_custom_image(
            project=project,
            artifact_gcs_path=f'{pipe_gcs_path}',
            app_dir_name='trainer',
            docker_name=train_docker_name,
            custom_image_uri=training_image_uri,
            use_existing='True',
        )
        .set_display_name("Build Train Image")
        .set_caching_options(True)
    )
    
    # ========================================================================
    # Managed TB
    # ========================================================================
    
    create_managed_tensorboard_op = (
        create_tensorboard.create_tensorboard(
            # here
            project=project,
            location=location,
            model_version=version,
            pipeline_version=version,
            # model_name=model_display_name, 
            experiment_name=experiment_name,
            experiment_run=experiment_run,
        )
        .set_display_name("Managed TB")
        .set_caching_options(True)
    )

    
    # ========================================================================
    # Train Merlin Towers
    # ========================================================================
    
    train_merlin_op = (
        train_merlin.train_merlin(
            project=project,
            location=location,
            version=version,
            train_image_uri=build_custom_train_image_op.outputs['custom_image_uri'],
            train_output_gcs_bucket=train_output_gcs_bucket,
            tb_resource=create_managed_tensorboard_op.outputs['tensorboard_resource_name'], #tb_resource,
            batch_size=batch_size,
            train_epochs=train_epochs,
            train_dir=train_dir,
            valid_dir=valid_dir,
            workflow_dir=workflow_dir,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            service_account=service_account,
            worker_pool_specs=worker_pool_specs,
        )
        .set_display_name("Train Merlin Towers")
        .set_caching_options(True)
    )
    
    # ========================================================================
    # Build SERVING Image
    # ========================================================================
    build_custom_serving_image_op = (
        build_custom_image.build_custom_image(
            project=project,
            artifact_gcs_path=f'{pipe_gcs_path}',
            app_dir_name='serving',
            docker_name=serving_docker_name,
            custom_image_uri=serving_image_uri,
            use_existing='True',
        )
        .set_display_name("Build Serving Image")
        # .after(build_custom_train_image_op)
        .set_caching_options(True)
    )
       
    # ========================================================================
    # Import Trained Towers to Pipeline DAG
    # ========================================================================
    import_query_model_task = (
        importer_node.importer(
            artifact_uri=train_merlin_op.outputs['query_tower_gcs_dir'],
            artifact_class=artifact_types.UnmanagedContainerModel,
        )
        .set_display_name("Import Query Tower")
        .set_caching_options(True)
    )
    
    import_candidate_model_task = (
        importer_node.importer(
            artifact_uri=train_merlin_op.outputs['candidate_tower_gcs_uri'],
            artifact_class=artifact_types.UnmanagedContainerModel,
        )
        .set_display_name("Import Candidate Tower")
        .set_caching_options(True)
    )

    # ========================================================================
    # Upload Models to Vertex AI Model Registry
    # ========================================================================
    
    query_model_upload_op = (
        upload_custom_model.upload_custom_model(
            project=project,
            location=location,
            version=version,
            display_name=f'merlin-query-tower-{version}',
            artifact_uri=train_merlin_op.outputs["query_tower_gcs_dir"],
            unmanaged_container_model=import_query_model_task.outputs["artifact"],
            serving_container_image_uri=build_custom_serving_image_op.outputs["custom_image_uri"],
        )
        .set_display_name("Register Query Tower")
        .set_caching_options(True)
    )
    
    from google_cloud_pipeline_components.experimental.evaluation import \
        GetVertexModelOp
    
    model = (
        GetVertexModelOp(
            model_resource_name=query_model_upload_op.outputs['model_resource_name'],
        )
        .set_display_name("Get Vertex Model")
    )
    
    # ========================================================================
    # Deploy Model to Endpoint
    # ========================================================================
    endpoint_create_op = (
        gcc_aip.EndpointCreateOp(
            project=project,
            display_name=f'query-tower-endpoint-{version}'
        )
        .after(query_model_upload_op)
        .set_display_name("Create Query Endpoint")
        .set_caching_options(True)
    )
    
    model_deploy_op = (
        gcc_aip.ModelDeployOp(
            endpoint=endpoint_create_op.outputs['endpoint'],
            model=model.outputs['model'],
            deployed_model_display_name=f'deployed-qtower-{version}',
            dedicated_resources_machine_type="n1-standard-4",
            dedicated_resources_accelerator_type="NVIDIA_TESLA_T4",
            dedicated_resources_accelerator_count=1,
            dedicated_resources_max_replica_count=1,
            dedicated_resources_min_replica_count=1,
            service_account=service_account,
        )
        .set_display_name("Deploy Query Tower")
        .set_caching_options(True)
    )
    
    test_deployed_query_model_op = (
        test_deployed_query_model_v7.test_deployed_query_model(
            project=project,
            location=location,
            version=version,
            deployed_endpoint=model_deploy_op.outputs["gcp_resources"],
            data_dir_bucket_name=train_output_gcs_bucket,
            test_instance_gcs_blob_name=test_instance_gcs_blob_name
        )
        .set_display_name("test query model deploy")
        .set_caching_options(False)
    )
    
    # ========================================================================
    # Create ME indexes
    # ========================================================================
    
    create_ann_index_op = (
        create_ann_index.create_ann_index(
            project=project,
            location=location,
            version=version,
            vpc_network_name=vpc_network_name,
            emb_index_gcs_uri=train_merlin_op.outputs['candidate_embeddings_gcs_uri'],
            dimensions=embeddings_dim,
            ann_index_display_name=f'v1_ann_index_{version}'.replace('-', '_'),
            approximate_neighbors_count=50,
            distance_measure_type="DOT_PRODUCT_DISTANCE",
            leaf_node_embedding_count=500,
            leaf_nodes_to_search_percent=7, 
            ann_index_description="testing ann index for Merlin deployment",
            # ann_index_labels=ann_index_labels,
        )
        .set_display_name("Create ANN Index")
        # .after(XXXX)
        .set_caching_options(True)
    )
    
    create_brute_force_index_op = (
        create_brute_force_index.create_brute_force_index(
            project=project,
            location=location,
            version=version,
            vpc_network_name=vpc_network_name,
            emb_index_gcs_uri=train_merlin_op.outputs['candidate_embeddings_gcs_uri'],
            dimensions=embeddings_dim,
            brute_force_index_display_name=f'v1_bf_index_{version}'.replace('-', '_'),
            approximate_neighbors_count=50,
            distance_measure_type="DOT_PRODUCT_DISTANCE",
            brute_force_index_description="testing bf index for Merlin deployment",
            # brute_force_index_labels=brute_force_index_labels,
        )
        .set_display_name("Create BF Index")
        # .after(XXX)
        .set_caching_options(True)
    )
    
    # ========================================================================
    # Create ME index endpoints
    # ========================================================================
    
    create_ann_index_endpoint_vpc_op = (
        create_ann_index_endpoint_vpc.create_ann_index_endpoint_vpc(
            ann_index_artifact=create_ann_index_op.outputs['ann_index'],
            project=project,
            project_number=project_number,
            version=version,
            location=location,
            vpc_network_name=vpc_network_name,
            ann_index_endpoint_display_name=f'ann-endpoint_{version}'.replace('-', '_'),
            ann_index_endpoint_description='endpoint for ann index',
        )
        .set_display_name("Create ANN Index Endpoint")
        # .after(XXX)
    )
        
    create_brute_index_endpoint_vpc_op = (
        create_brute_index_endpoint_vpc.create_brute_index_endpoint_vpc(
            bf_index_artifact=create_brute_force_index_op.outputs['brute_force_index'],
            project=project,
            project_number=project_number,
            version=version,
            location=location,
            vpc_network_name=vpc_network_name,
            brute_index_endpoint_display_name=f'bf-endpoint_{version}'.replace('-', '_'),
            brute_index_endpoint_description='endpoint for brute force index',
        )
        .set_display_name("Create BF Index Endpoint")
        # .after(XXX)
    )
    
    # ========================================================================
    # Deploy Indexes
    # ========================================================================

    deploy_ann_index_op = (
        deploy_ann_index.deploy_ann_index(
            project=project,
            location=location,
            version=version,
            deployed_ann_index_name=f'deployed_ann_{version}'.replace('-', '_'),
            ann_index_resource_uri=create_ann_index_op.outputs['ann_index_resource_uri'],
            index_endpoint_resource_uri=create_ann_index_endpoint_vpc_op.outputs['ann_index_endpoint_resource_uri'],
        )
        .set_display_name("Deploy ANN Index")
        .set_caching_options(True)
    )

    deploy_brute_index_op = (
        deploy_brute_index.deploy_brute_index(
            project=project,
            location=location,
            version=version,
            deployed_brute_force_index_name=f'deployed_bf_{version}'.replace('-', '_'),
            brute_force_index_resource_uri=create_brute_force_index_op.outputs['brute_force_index_resource_uri'],
            index_endpoint_resource_uri=create_brute_index_endpoint_vpc_op.outputs['brute_index_endpoint_resource_uri'],
        )
        .set_display_name("Deploy BF Index")
        .set_caching_options(True)
    )
    
    # ========================================================================
    # Test deployed indexes
    # ========================================================================
    test_model_index_endpoint_op = (
        test_model_index_endpoint_v5.test_model_index_endpoint(
            project=project,
            location=location,
            version=version,
            data_dir_bucket_name=train_output_gcs_bucket,
            test_instance_gcs_blob_name=test_instance_gcs_blob_name,
            ann_index_endpoint_resource_uri=deploy_ann_index_op.outputs['index_endpoint_resource_uri'],
            brute_index_endpoint_resource_uri=deploy_brute_index_op.outputs['index_endpoint_resource_uri'],
            endpoint=model_deploy_op.outputs['gcp_resources']
        )
        .set_display_name("Test deployed model & indexes")
        .set_caching_options(False)
    )

## Compile Pipeline

In [44]:
PIPELINE_JSON_SPEC_LOCAL = "custom_pipeline_spec.json"

! rm -f $PIPELINE_JSON_SPEC_LOCAL

kfp.v2.compiler.Compiler().compile(
    pipeline_func=pipeline, package_path=PIPELINE_JSON_SPEC_LOCAL,
)

In [45]:
PIPELINES_FILEPATH = f'{PIPELINE_ROOT_PATH}/pipeline_spec.json'
print("PIPELINES_FILEPATH:", PIPELINES_FILEPATH)

!gsutil cp $PIPELINE_JSON_SPEC_LOCAL $PIPELINES_FILEPATH

PIPELINES_FILEPATH: gs://jt-merlin-scaling/test-e2e-pipe-v11/run-20230321-100615/pipeline_root/pipeline_spec.json
Copying file://custom_pipeline_spec.json [Content-Type=application/json]...
/ [1 files][106.1 KiB/106.1 KiB]                                                
Operation completed over 1 objects/106.1 KiB.                                    


## Upload test instances

In [46]:
import pickle as pkl

# query features
LOCAL_INSTANCE_FILE = f'merlin_last5_test_instance_{VERSION}.pkl'

! rm -f $LOCAL_INSTANCE_FILE

In [47]:
TEST_INSTANCE = {
    'collaborative': 'false',
    'album_name_pl': [
        "There's Really A Wolf", 'Late Nights: The Album','American Teen', 'Crazy In Love', 'Pony'
    ], 
    'artist_genres_pl': [
        "'hawaiian hip hop', 'rap'",
       "'chicago rap', 'dance pop', 'pop', 'pop rap', 'r&b', 'southern hip hop', 'trap', 'urban contemporary'",
       "'pop', 'pop r&b'", "'dance pop', 'pop', 'r&b'",
       "'chill r&b', 'pop', 'pop r&b', 'r&b', 'urban contemporary'"
    ], 
    'artist_name_pl': [
        'Russ', 'Jeremih', 'Khalid', 'Beyonc\xc3\xa9','William Singe'
    ], 
    'artist_pop_can': 82.0, 
    'description_pl': '', 
    'duration_ms_songs_pl': [
        237506.0, 217200.0, 219080.0, 226400.0, 121739.0
    ], 
    'n_songs_pl': 8.0, 
    'name': 'Lit Tunes ', 
    'num_albums_pl': 8.0, 
    'num_artists_pl': 8.0, 
    'track_name_pl': [
        'Losin Control', 'Paradise', 'Location','Crazy In Love - Remix', 'Pony'
    ], 
    'track_pop_pl': [
        79.0, 58.0, 83.0, 71.0, 57.0
    ],
    'duration_ms_seed_pl': 51023.1,
    'pid': 1,
    'track_uri_pl': [
        'spotify:track:4cxMGhkinTocPSVVKWIw0d',
        'spotify:track:1wNEBPo3nsbGCZRryI832I',
        'spotify:track:152lZdxL1OR0ZMW6KquMif',
        'spotify:track:2f4IuijXLxYOeBncS60GUD',
        'spotify:track:4Lj8paMFwyKTGfILLELVxt'
    ]
}

INSTANCE_FILE_GCS_OBJ = f'{EXPERIMENT_NAME}/{RUN_NAME}/{LOCAL_INSTANCE_FILE}'

# pickle
filehandler = open(f'{LOCAL_INSTANCE_FILE}', 'wb')
pkl.dump(TEST_INSTANCE, filehandler)
filehandler.close()

# upload to GCS
bucket_client = storage_client.bucket(OUTPUT_BUCKET)
blob = bucket_client.blob(INSTANCE_FILE_GCS_OBJ)
blob.upload_from_filename(LOCAL_INSTANCE_FILE)

print("INSTANCE_FILE_GCS_OBJ:", INSTANCE_FILE_GCS_OBJ)
print(f"INSTANCE_FILE GCS URI: gs://{OUTPUT_BUCKET}/{INSTANCE_FILE_GCS_OBJ}")

INSTANCE_FILE_GCS_OBJ: test-e2e-pipe-v11/run-20230321-100615/merlin_last5_test_instance_v11.pkl
INSTANCE_FILE GCS URI: gs://jt-merlin-scaling/test-e2e-pipe-v11/run-20230321-100615/merlin_last5_test_instance_v11.pkl


### copy artifacts & repo to GCS

> helps with tracking

In [48]:
SERVING_DOCKERNAME='mm-query-serve'
TRAIN_DOCKERNAME='train'
SERVING_SUB_DIR='serving'
TRAIN_SUB_DIR='trainer'

In [49]:
!gsutil cp ./$REPO_DOCKER_PATH_PREFIX/Dockerfile.$SERVING_DOCKERNAME $PIPELINE_ROOT_PATH/
!gsutil cp ./$REPO_DOCKER_PATH_PREFIX/Dockerfile.$TRAIN_DOCKERNAME $PIPELINE_ROOT_PATH/

!gsutil -m cp -r ./$REPO_DOCKER_PATH_PREFIX/$SERVING_SUB_DIR $PIPELINE_ROOT_PATH/
!gsutil -m cp -r ./$REPO_DOCKER_PATH_PREFIX/$TRAIN_SUB_DIR $PIPELINE_ROOT_PATH/

Copying file://./src/Dockerfile.mm-query-serve [Content-Type=application/octet-stream]...
/ [1 files][  862.0 B/  862.0 B]                                                
Operation completed over 1 objects/862.0 B.                                      
Copying file://./src/Dockerfile.train [Content-Type=application/octet-stream]...
/ [1 files][  319.0 B/  319.0 B]                                                
Operation completed over 1 objects/319.0 B.                                      
Copying file://./src/serving/app/requirements.txt [Content-Type=text/plain]...
Copying file://./src/serving/app/dataset_to_tensors.py [Content-Type=text/x-python]...
Copying file://./src/serving/app/__init__.py [Content-Type=text/x-python]...    
Copying file://./src/serving/app/instances.json [Content-Type=application/json]...
Copying file://./src/serving/app/predictor.py [Content-Type=text/x-python]...   
Copying file://./src/serving/app/main.py [Content-Type=text/x-python]...        
Copying fil

In [50]:
!gsutil ls $PIPELINE_ROOT_PATH

gs://jt-merlin-scaling/test-e2e-pipe-v11/run-20230321-100615/pipeline_root/Dockerfile.mm-query-serve
gs://jt-merlin-scaling/test-e2e-pipe-v11/run-20230321-100615/pipeline_root/Dockerfile.train
gs://jt-merlin-scaling/test-e2e-pipe-v11/run-20230321-100615/pipeline_root/pipeline_spec.json
gs://jt-merlin-scaling/test-e2e-pipe-v11/run-20230321-100615/pipeline_root/serving/
gs://jt-merlin-scaling/test-e2e-pipe-v11/run-20230321-100615/pipeline_root/trainer/


## Submit Pipeline Job

## pipe args

In [51]:
# IMAGES
# TRAIN_IMAGE_URI = f'gcr.io/hybrid-vertex/train-2212v15-vertex-merlin-tf-2tower-jtv33'
SERVING_IMAGE_URI='gcr.io/hybrid-vertex/mm2t-vertex-serv-v28'

# DOCKERFILES
TRAIN_DOCKERFILE_NAME = f'Dockerfile.{TRAIN_DOCKERNAME}'
SERVE_DOCKERFILE_NAME = f'Dockerfile.{SERVING_DOCKERNAME}'

print(f"TRAIN_DOCKERFILE_NAME: {TRAIN_DOCKERFILE_NAME}")
print(f"SERVE_DOCKERFILE_NAME: {SERVE_DOCKERFILE_NAME}\n")
print(f"TRAIN_IMAGE_URI: {TRAIN_IMAGE_URI}")
print(f"SERVING_IMAGE_URI: {SERVING_IMAGE_URI}\n")

TRAIN_DOCKERFILE_NAME: Dockerfile.train
SERVE_DOCKERFILE_NAME: Dockerfile.mm-query-serve

TRAIN_IMAGE_URI: gcr.io/hybrid-vertex/train-2212v16-vertex-merlin-tf-2tower-jtv34
SERVING_IMAGE_URI: gcr.io/hybrid-vertex/mm2t-vertex-serv-v28



In [52]:
vpc_network_name = 'ucaip-haystack-vpc-network'
VERTEX_SA = '934903580331-compute@developer.gserviceaccount.com'

job = vertex_ai.PipelineJob(
    display_name=PIPELINE_NAME,
    template_path=PIPELINES_FILEPATH,
    pipeline_root=f'{PIPELINE_ROOT_PATH}',
    failure_policy='fast', # slow | fast
    # enable_caching=False,
    parameter_values={
        # here
        'project': PROJECT_ID,
        'project_number': PROJECT_NUM,
        'location': LOCATION,
        'vpc_network_name': vpc_network_name,
        'version': VERSION,
        'train_output_gcs_bucket':OUTPUT_BUCKET,
        'pipe_gcs_path': PIPELINE_ROOT_PATH,
        'training_image_uri': TRAIN_IMAGE_URI,
        'serving_image_uri': SERVING_IMAGE_URI,
        'train_docker_name': TRAIN_DOCKERFILE_NAME,
        'serving_docker_name': SERVE_DOCKERFILE_NAME,
        'tb_resource': '',
        'batch_size': BATCH_SIZE,
        'train_epochs': NUM_EPOCHS,
        'train_dir': TRAIN_DIR,
        'valid_dir': VALID_DIR,
        'workflow_dir': WORKFLOW_DIR,
        'experiment_name': EXPERIMENT_NAME,
        'experiment_run': RUN_NAME,
        'service_account': VERTEX_SA,
        'embeddings_dim': 128,
        'layer_sizes': LAYERS,
        'worker_pool_specs': WORKER_POOL_SPECS,
        'test_instance_gcs_blob_name':INSTANCE_FILE_GCS_OBJ,
    },
)

job.run(
    sync=False,
    service_account=VERTEX_SA,
    network=f'projects/{PROJECT_NUM}/global/networks/{vpc_network_name}'
)