# Deploy T4Rec model on Triton-based Vertex endpoint

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

## Setup

### get project vars

In [3]:
GCP_PROJECTS = !gcloud config get-value project
PROJECT_ID = GCP_PROJECTS[0]
PROJECT_NUM = !gcloud projects list --filter="$PROJECT_ID" --format="value(PROJECT_NUMBER)"
PROJECT_NUM = PROJECT_NUM[0]
LOCATION = 'us-central1'
REGION = "us-central1"

# VERTEX_SA = '934903580331-compute@developer.gserviceaccount.com'
VERTEX_SA = 'jt-vertex-sa@hybrid-vertex.iam.gserviceaccount.com'

print(f"PROJECT_ID: {PROJECT_ID}")
print(f"PROJECT_NUM: {PROJECT_NUM}")
print(f"LOCATION: {LOCATION}")
print(f"REGION: {REGION}")
print(f"VERTEX_SA: {VERTEX_SA}")

PROJECT_ID: hybrid-vertex
PROJECT_NUM: 934903580331
LOCATION: us-central1
REGION: us-central1
VERTEX_SA: jt-vertex-sa@hybrid-vertex.iam.gserviceaccount.com


### get workspace vars

In [4]:
# INPUT_DATA_DIR = os.environ.get("INPUT_DATA_DIR", "/workspace/data")
# OUTPUT_DIR = os.environ.get("OUTPUT_DIR", f"{INPUT_DATA_DIR}/sessions_by_day")
# model_path= os.environ.get("model_path", f"{INPUT_DATA_DIR}/saved_model")

REPO_WORKSPACE = 'workspace'

DATA_DIR = 'data'
INPUT_DATA_DIR=f'{REPO_WORKSPACE}/{DATA_DIR}'
TRANSFORMED_WORKFLOW=f'{INPUT_DATA_DIR}/processed_nvt'
OUTPUT_DIR=f'{INPUT_DATA_DIR}/sessions_by_day'
MODEL_PATH = f'{INPUT_DATA_DIR}/saved_model'
ENSEMBLE_MODEL_PATH = f'{INPUT_DATA_DIR}/models'

print(f"INPUT_DATA_DIR: {INPUT_DATA_DIR}")
print(f"TRANSFORMED_WORKFLOW: {TRANSFORMED_WORKFLOW}")
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
print(f"MODEL_PATH: {MODEL_PATH}")
print(f"ENSEMBLE_MODEL_PATH: {ENSEMBLE_MODEL_PATH}")

INPUT_DATA_DIR: workspace/data
TRANSFORMED_WORKFLOW: workspace/data/processed_nvt
OUTPUT_DIR: workspace/data/sessions_by_day
MODEL_PATH: workspace/data/saved_model
ENSEMBLE_MODEL_PATH: workspace/data/models


In [5]:
!tree $ENSEMBLE_MODEL_PATH

[01;34mworkspace/data/models[00m
├── [01;34m0_predictpytorchtriton[00m
│   ├── [01;34m1[00m
│   │   └── model.pt
│   └── config.pbtxt
└── [01;34mensemble_model[00m
    ├── [01;34m1[00m
    └── config.pbtxt

4 directories, 3 files


### set deployment version

In [6]:
VERSION='jvt02'
MODEL_VERSION='v02'

In [11]:
BUCKET_NAME=f'merlin-transformers4rec-{VERSION}'
BUCKET_URI=f'gs://{BUCKET_NAME}'

MODEL_ARTIFACTS_REPO_GCS = f"{BUCKET_URI}/{MODEL_VERSION}/workspace/data/models"
WORKFLOW_REPO_GCS = f"{BUCKET_URI}/{MODEL_VERSION}/workspace/data/workflow_etl"

print(f"BUCKET_URI: {BUCKET_URI}")
print(f"MODEL_ARTIFACTS_REPO_GCS: {MODEL_ARTIFACTS_REPO_GCS}")
print(f"WORKFLOW_REPO_GCS: {WORKFLOW_REPO_GCS}")

MODEL_ARTIFACTS_REPO_GCS: gs://merlin-transformers4rec-jvt02/v02/workspace/data/models
WORKFLOW_REPO_GCS: gs://merlin-transformers4rec-jvt02/v02/workspace/data/workflow_etl


### triton credentials

* see [model_repository user guide](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_repository.md#cloud-storage-with-environment-variables) re: GCS environment variables and `TRITON_CLOUD_CREDENTIAL_PATH`

In [None]:
from google.oauth2 import service_account

# t4rec-nvidia-docs/credentials.json
credentials = service_account.Credentials.from_service_account_file('credentials.json')

# Test Endpoint

In [14]:
ENDPOINT_ID='7747893403077050368'
endpoint_name = f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{ENDPOINT_ID}"
endpoint_name

'projects/hybrid-vertex/locations/us-central1/endpoints/7747893403077050368'

In [15]:
import json
import os
from pathlib import Path

import numpy as np
import requests
from google.api import httpbody_pb2
from google.cloud import aiplatform_v1 as gapic

import nvtabular as nvt

from merlin.io import Dataset
from merlin.core.dispatch import make_df  # noqa
from transformers4rec.torch.utils.data_utils import MerlinDataLoader

import tritonclient.http as triton_http # client
import nvtabular.inference.triton as nvt_triton
import tritonclient.grpc as grpcclient

import torch 

  from .autonotebook import tqdm as notebook_tqdm


## get nvtabular schema

In [16]:
from merlin_standard_lib import Schema
# SCHEMA_PATH = os.environ.get("INPUT_SCHEMA_PATH", "/workspace/data/processed_nvt/schema.pbtxt")

SCHEMA_PATH = f'{TRANSFORMED_WORKFLOW}/schema.pbtxt'
schema = Schema().from_proto_text(SCHEMA_PATH)

In [232]:
schema.column_names

['age_days-list', 'weekday_sin-list', 'item_id-list', 'category-list']

### get test data

In [18]:
TEST_PATHS=f'{OUTPUT_DIR}/1/test.parquet'
dataset = Dataset(TEST_PATHS)

In [19]:
sparse_max = {
    'age_days-list': 20,
    'weekday_sin-list': 20,
    'item_id-list': 20,
    'category-list': 20
}

from transformers4rec.torch.utils.data_utils import MerlinDataLoader

def generate_dataloader(schema, dataset, batch_size=128, seq_length=20):
    loader = MerlinDataLoader.from_schema(
            schema,
            dataset,
            batch_size=batch_size,
            max_sequence_length=seq_length,
            shuffle=False,
            sparse_as_dense=True,
            sparse_max=sparse_max
        )
    return loader

loader = generate_dataloader(schema, dataset)
test_dict = next(iter(loader))
# test_dict[0]

### load model locally

In [20]:
import cloudpickle
loaded_model = cloudpickle.load(
    open(os.path.join(MODEL_PATH, "t4rec_model_class.pkl"), "rb")
)

model = loaded_model.cuda()
# model.eval()

In [21]:
traced_model = torch.jit.trace(model, test_dict[0], strict=True)

input_schema = model.input_schema
output_schema = model.output_schema



## create test instances payload

**TODO**
* update and parameterize

In [None]:
# projects/934903580331/locations/us-central1/endpoints/7747893403077050368
# v03 'projects/hybrid-vertex/locations/us-central1/endpoints/7043924486323699712'

In [23]:
from pprint import pprint

a = test_dict[0]['age_days-list'].cuda().cpu().clone().numpy()
b = test_dict[0]['weekday_sin-list'].cuda().cpu().clone().numpy()
c = test_dict[0]['category-list'].cuda().cpu().clone().numpy()
d = test_dict[0]['item_id-list'].cuda().cpu().clone().numpy()

# payload_dict = {}
# payload_dict['age_days-list']=a
# payload_dict['weekday_sin-list']=b
# payload_dict['category-list']=c
# payload_dict['item_id-list']=d

a2 = a[0:32]
b2 = b[0:32]
c2 = c[0:32]
d2 = d[0:32]

In [45]:
len(a2[0])

20

### small payload

In [27]:
SINGLE_PAYLOAD_LOCAL_FILENAME='single_t4rec_payload.json'

smaller_payload = {
    # 'id': '1',
    "inputs":[
        {
            "name": "age_days-list",
            "shape": [1,20],
            "datatype": 'FP32', #"TYPE_FP32",
            "data":a2[0].tolist()
        },
        {
            "name": "weekday_sin-list",
            "shape": [1,20],
            "datatype": 'FP32', #"TYPE_FP32",
            "data":b2[0].tolist() #b[0].tolist() #[float(s) for s in b[0].tolist()]
        },
        {
            "name": "item_id-list",
            "shape": [1,20],
            "datatype": 'INT64', #"TYPE_INT64",
            "data":c2[0].tolist() #c[0].tolist() #[int(s) for s in c[0].tolist()]
        },
        {
            "name": "category-list",
            "shape": [1,20],
            "datatype": 'INT64',# "TYPE_INT64",
            "data":d2[0].tolist() #d[0].tolist() #[int(s) for s in d[0].tolist()]
        },
    ],
    "outputs": [
        {
            "name": "next-item",
            "parameters":{"binary_data":False}
        }
    ]
}

with open(SINGLE_PAYLOAD_LOCAL_FILENAME, 'w') as f:
    json.dump(smaller_payload, f)

In [29]:
! cat ./$SINGLE_PAYLOAD_LOCAL_FILENAME

{"inputs": [{"name": "age_days-list", "shape": [1, 20], "datatype": "FP32", "data": [0.7424173355102539, 0.7332683205604553, 0.09815418720245361, 0.4467461109161377, 0.6173868775367737, 0.11097356677055359, 0.4968095123767853, 0.4523400366306305, 0.3517908453941345, 0.29413312673568726, 0.7820895314216614, 0.21377669274806976, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}, {"name": "weekday_sin-list", "shape": [1, 20], "datatype": "FP32", "data": [0.27244293689727783, 0.469150185585022, 0.7858142256736755, 0.45554742217063904, 0.49801334738731384, 0.9712715744972229, 0.16286295652389526, 0.6930072903633118, 0.6746281385421753, 0.18748827278614044, 0.36374929547309875, 0.34085774421691895, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}, {"name": "item_id-list", "shape": [1, 20], "datatype": "INT64", "data": [7, 2, 4, 1, 4, 2, 10, 25, 2, 2, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0]}, {"name": "category-list", "shape": [1, 20], "datatype": "INT64", "data": [26, 9, 16, 3, 15, 7, 47, 117, 10, 10, 1, 7, 0, 0, 0, 0

### larger payload

> this doesnt work, figure out best way to send *batches* of prediction requests

In [49]:
# BATCH_PAYLOAD_LOCAL_FILENAME='t4rec_payload.json'

# payload = {
#     # 'id': '1',
#     "inputs":[
#         {
#             "name": "age_days-list",
#             "shape": [32,20],
#             "datatype": 'FP32', #"TYPE_FP32",
#             "data":a2.tolist()
#         },
#         {
#             "name": "weekday_sin-list",
#             "shape": [32,20],
#             "datatype": 'FP32', #"TYPE_FP32",
#             "data":b2.tolist() #b[0].tolist() #[float(s) for s in b[0].tolist()]
#         },
#         {
#             "name": "item_id-list",
#             "shape": [32,20],
#             "datatype": 'INT64', #"TYPE_INT64",
#             "data":c2.tolist() #c[0].tolist() #[int(s) for s in c[0].tolist()]
#         },
#         {
#             "name": "category-list",
#             "shape": [32,20],
#             "datatype": 'INT64',# "TYPE_INT64",
#             "data":d2.tolist() #d[0].tolist() #[int(s) for s in d[0].tolist()]
#         },
#     ],
#     "outputs": [
#         {
#             "name": "next-item",
#             "parameters":{"binary_data":False}
#         }
#     ]
# }

# with open(BATCH_PAYLOAD_LOCAL_FILENAME, 'w') as f:
#     json.dump(payload, f)

In [32]:
# ! cat ./$BATCH_PAYLOAD_LOCAL_FILENAME

## create endpoint prediction requests

In [66]:
import time

def get_triton_prediction_vertex(
    model_name, 
    endpoint_name,
    schema,
    instances_dict,
    local_payload_json,
    api_endpoint=f"{REGION}-aiplatform.googleapis.com",
    headers=None,
    seq_length=20,
    eval_batch_size=32,
):
    # set up vertex ai prediction client
    client_options = {"api_endpoint": api_endpoint}
    gapic_client = gapic.PredictionServiceClient(client_options=client_options)
    
    # to provide json of instances
    with open(local_payload_json) as f: # single_t4rec_payload | t4rec_payload
        http_body = httpbody_pb2.HttpBody(
            data=f.read().encode("utf-8"),
            content_type="application/json",
        )
    
    # print(f"http_body: {http_body}")
    # print(f"data: {data}")
    
    # submit inference request
    request = gapic.RawPredictRequest(
        endpoint=endpoint_name,
        http_body=http_body
    )
    
    start=time.time()
    # print(f"request: {request}")
    response = gapic_client.raw_predict(
        request=request, 
        metadata=headers,
        # endpoint=endpoint_name
    )
    # capture elapsed time
    end = time.time()
    elapsed = end - start
    elapsed = round(elapsed, 4)
    print(f'inference latency: {elapsed} seconds')
    
    # get result as json
    result_http = json.loads(response.data.decode('utf-8'))
    print(f"response: {result_http['outputs'][0]['data']}")

In [67]:
# ENDPOINT_ID='7747893403077050368' # v02 
# ENDPOINT_ID='7043924486323699712' # v03
ENDPOINT_ID='2776763839390154752' # v03-v2
endpoint_name = f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{ENDPOINT_ID}"
endpoint_name

'projects/hybrid-vertex/locations/us-central1/endpoints/2776763839390154752'

In [68]:
for model_name in ["0_predictpytorchtriton"]:
    print(f"Predictions from model: {model_name}")
    headers =  {"x-vertex-ai-triton-redirect": f"v2/models/{model_name}/infer"}
    get_triton_prediction_vertex(
        model_name, 
        endpoint_name,
        schema,
        test_dict,
        local_payload_json=SINGLE_PAYLOAD_LOCAL_FILENAME, #BATCH_PAYLOAD_LOCAL_FILENAME
        headers=tuple(
            headers.items()
        )
    )
    print("-"*5)

Predictions from model: 0_predictpytorchtriton
inference latency: 0.1161 seconds
response: [-9.414483070373535, -3.3946754932403564, -3.5153329372406006, -3.325957775115967, -3.4902701377868652, -3.414525032043457, -3.5283923149108887, -3.5137453079223633, -3.4412429332733154, -3.624793529510498, -3.677654504776001, -3.5000057220458984, -3.650045871734619, -3.683887481689453, -3.886939764022827, -3.709595203399658, -3.735921621322632, -3.9285502433776855, -4.010989189147949, -3.936932325363159, -4.215701580047607, -4.093976020812988, -4.144990921020508, -3.9316093921661377, -4.249298095703125, -4.173276901245117, -4.337774276733398, -4.512059211730957, -4.223840236663818, -4.338855266571045, -4.346428394317627, -4.538280963897705, -4.724963188171387, -4.690615653991699, -4.761604309082031, -4.527646064758301, -4.78475284576416, -4.612557411193848, -4.800440788269043, -4.953672885894775, -4.790022850036621, -5.047253608703613, -4.920400619506836, -4.958305835723877, -4.863715648651123, 

# WIP

### and local testing

In [69]:
# WORKFLOW_DIR = 'gs://merlin-transformers4rec-jvt02/v02/workspace/data/workflow_etl'
# workflow = nvt.Workflow.load(WORKFLOW_DIR)
# workflow.output_schema

#### [1] calling `rawPredict` using Vertex AI SDK

In [55]:
# initialize service client
client_options = {"api_endpoint": f"{REGION}-aiplatform.googleapis.com"}
prediction_client = gapic.PredictionServiceClient(client_options=client_options)

In [61]:
# TRAIN_PATHS=f'{OUTPUT_DIR}/1/train.parquet'
# dataset = Dataset(TRAIN_PATHS)

In [66]:
# import cloudpickle
# loaded_model = cloudpickle.load(
#     open(os.path.join(MODEL_PATH, "t4rec_model_class.pkl"), "rb")
# )

In [68]:
# model = loaded_model.cuda()
# model.eval()

In [70]:
# traced_model = torch.jit.trace(model, train_dict[0], strict=True)

In [71]:
# input_schema = model.input_schema
# output_schema = model.output_schema

In [72]:
# input_schema

In [73]:
# df_cols = {}
# for name, tensor in train_dict[0].items():
#     if name in input_schema.column_names:
#         df_cols[name] = tensor.cpu().numpy()
#         if len(tensor.shape) > 1:
#             df_cols[name] = list(df_cols[name])
            
# df = make_df(df_cols)
# print(df.shape)
# df.head()

In [74]:
# batch = df.iloc[:50,:]
# batch.head()

In [76]:
# from merlin.systems.triton.utils import send_triton_request
# response = send_triton_request(input_schema, df[input_schema.column_names], output_schema.column_names)
# print(response)

{'next-item': array([[ -9.414465 ,  -3.3946402,  -3.5153313, ...,  -9.280334 ,
         -9.484825 , -10.178284 ],
       [ -9.414481 ,  -3.3946645,  -3.5153518, ...,  -9.280359 ,
         -9.484777 , -10.178347 ],
       [ -9.414499 ,  -3.3946452,  -3.5154536, ...,  -9.280445 ,
         -9.484086 , -10.179337 ],
       ...,
       [ -9.414593 ,  -3.394685 ,  -3.5153685, ...,  -9.280357 ,
         -9.484976 , -10.178085 ],
       [ -9.414569 ,  -3.3946962,  -3.515378 , ...,  -9.280379 ,
         -9.484813 , -10.178307 ],
       [ -9.414568 ,  -3.3946936,  -3.5153277, ...,  -9.280327 ,
         -9.485282 , -10.177661 ]], dtype=float32)}


In [198]:
# inputs = nvt_triton.convert_df_to_triton_input(df.columns, df, grpcclient.InferInput)
# inputs

[<tritonclient.grpc.InferInput at 0x7fe0bd7b61f0>,
 <tritonclient.grpc.InferInput at 0x7fe0bc98e1c0>,
 <tritonclient.grpc.InferInput at 0x7fe0b57beaf0>,
 <tritonclient.grpc.InferInput at 0x7fe0b57c69a0>,
 <tritonclient.grpc.InferInput at 0x7fe0b57c6d00>,
 <tritonclient.grpc.InferInput at 0x7fe0b57c6bb0>,
 <tritonclient.grpc.InferInput at 0x7fe0b57fec70>,
 <tritonclient.grpc.InferInput at 0x7fe0b577b3a0>]

In [None]:
# # format payload
# http_body = httpbody_pb2.HttpBody(
#     data=open(payload_file).read().encode("utf-8"),
#     content_type="application/json",
# )

# # Initialize request argument(s)
# request = gapic.RawPredictRequest(endpoint=endpoint_name, http_body=http_body)

## curl

In [81]:
# INPUT_FILE="./instances.json"

In [82]:
! curl \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json" \
--data-binary @$INPUT_FILE \
https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint.name}:rawPredict

{"error":"failed to parse the request JSON buffer: Invalid value. at 19125"}

In [75]:
uri = f'https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint.name}:rawPredict'

! curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json"  \
{uri} \
-d @t4rec_payload.json

## payload

* TODO: fix

In [302]:
# a = train_dict[0]['age_days-list'].cuda().cpu().clone().numpy()
# b = train_dict[0]['weekday_sin-list'].cuda().cpu().clone().numpy()
# c = train_dict[0]['category-list'].cuda().cpu().clone().numpy()
# d = train_dict[0]['item_id-list'].cuda().cpu().clone().numpy()

In [109]:
# from pprint import pprint

# # pprint(train_dict[0])

# payload_dict = {}
# payload_dict['age_days-list']=a
# payload_dict['weekday_sin-list']=b
# payload_dict['category-list']=c
# payload_dict['item_id-list']=d

# a2 = a[0:32]
# b2 = b[0:32]
# c2 = c[0:32]
# d2 = d[0:32]

In [110]:
payload_dict

{'age_days-list': array([[0.48491085, 0.86484003, 0.13697912, 0.62527317, 0.6846858 ,
         0.05689086, 0.95180523, 0.31664136, 0.81627715, 0.12000247,
         0.5980053 , 0.3159475 , 0.45607874, 0.9074921 , 0.        ,
         0.        , 0.        , 0.        , 0.        , 0.        ]],
       dtype=float32),
 'weekday_sin-list': array([[0.01403622, 0.3720214 , 0.5400261 , 0.6242002 , 0.8294813 ,
         0.93150216, 0.04658586, 0.79152614, 0.5543418 , 0.45621887,
         0.56755877, 0.2791969 , 0.8535643 , 0.56864387, 0.        ,
         0.        , 0.        , 0.        , 0.        , 0.        ]],
       dtype=float32),
 'category-list': array([[10,  8, 26,  3, 10,  4,  1, 23, 20,  6,  5,  4,  1,  4,  0,  0,
          0,  0,  0,  0]]),
 'item_id-list': array([[ 49,  36, 130,  20,  48,  15,   3, 126,  96,  31,  24,  17,   1,
          18,   0,   0,   0,   0,   0,   0]])}

#### generate some payloads

In [180]:
payload_file = "payload_bash_ensemble.json"

In [181]:
%%bash -s $payload_file

payload_file=$1
echo "Writing $payload_file"

cat <<EOF > $payload_file
{
    'id': '1',
    "inputs":[
        {
            "name": "age_days-list",
            "shape": [32,20],
            "datatype": 'FP32', #"TYPE_FP32",
            "data":a2.tolist()
        },
        {
            "name": "weekday_sin-list",
            "shape": [32,20],
            "datatype": 'FP32', #"TYPE_FP32",
            "data":b2.tolist() #b[0].tolist() #[float(s) for s in b[0].tolist()]
        },
        {
            "name": "item_id-list",
            "shape": [32,20],
            "datatype": 'INT64', #"TYPE_INT64",
            "data":c2.tolist() #c[0].tolist() #[int(s) for s in c[0].tolist()]
        },
        {
            "name": "category-list",
            "shape": [32,20],
            "datatype": 'INT64',# "TYPE_INT64",
            "data":d2.tolist() #d[0].tolist() #[int(s) for s in d[0].tolist()]
        }
    ]
}
EOF

Writing payload_ensemble.json


In [182]:
# format payload
http_body = httpbody_pb2.HttpBody(
    data=open(payload_file).read().encode("utf-8"),
    content_type="application/json",
)

print(http_body)

# Initialize request argument(s)
request = gapic.RawPredictRequest(endpoint=endpoint_name, http_body=http_body)

content_type: "application/json"
data: "{\n    \'id\': \'1\',\n    \"inputs\":[\n        {\n            \"name\": \"age_days-list\",\n            \"shape\": [32,20],\n            \"datatype\": \'FP32\', #\"TYPE_FP32\",\n            \"data\":a2.tolist()\n        },\n        {\n            \"name\": \"weekday_sin-list\",\n            \"shape\": [32,20],\n            \"datatype\": \'FP32\', #\"TYPE_FP32\",\n            \"data\":b2.tolist() #b[0].tolist() #[float(s) for s in b[0].tolist()]\n        },\n        {\n            \"name\": \"item_id-list\",\n            \"shape\": [32,20],\n            \"datatype\": \'INT64\', #\"TYPE_INT64\",\n            \"data\":c2.tolist() #c[0].tolist() #[int(s) for s in c[0].tolist()]\n        },\n        {\n            \"name\": \"category-list\",\n            \"shape\": [32,20],\n            \"datatype\": \'INT64\',# \"TYPE_INT64\",\n            \"data\":d2.tolist() #d[0].tolist() #[int(s) for s in d[0].tolist()]\n        }\n    ]\n}\n"



In [42]:
# # Make the prediction request
# response = prediction_client.raw_predict(request=request)
# result = json.loads(response.data)

# print(result)

In [None]:
def get_triton_prediction_vertex(
    model_name,
    endpoint_name,
    api_endpoint=f"{REGION}-aiplatform.googleapis.com",
    headers=None,
):
    # set up vertex ai prediction client
    client_options = {"api_endpoint": api_endpoint}
    gapic_client = gapic.PredictionServiceClient(client_options=client_options)

    # generate example data to classify
    features = 4
    samples = 1
    data = np.random.rand(samples, features).astype("float32")

    # payload configuration defining input and output names
    payload_config = {
        "0_predictpytorchtriton": {
            "input": "age_days-list",
            "input": "weekday_sin-list",
            "input": "item_id-list",
            "input": "category-list",
            "output": "next-item"
        },
    }

    # get input and output names based on model name
    input_name = payload_config[model_name]["input"]
    output_name = payload_config[model_name]["output"]

    # set up Triton input and output objects for HTTP
    triton_input_http = triton_http.InferInput(input_name, (samples, features), "FP32")
    triton_input_http.set_data_from_numpy(data, binary_data=False)

    if isinstance(output_name, list):
        triton_output_http = [
            triton_http.InferRequestedOutput(output, binary_data=False)
            for output in output_name
        ]

    else:
        triton_output_http = [
            triton_http.InferRequestedOutput(output_name, binary_data=False)
        ]

    # create inference request
    _data, _ = triton_http._get_inference_request(
        inputs=[triton_input_http],
        outputs=triton_output_http,
        request_id="",
        sequence_id=0,
        sequence_start=False,
        sequence_end=False,
        priority=0,
        timeout=None,
    )
    http_body = httpbody_pb2.HttpBody(
        data=_data.encode("utf-8"), content_type="application/json"
    )
    print(f"request: {data}")
    # submit inference request
    request = gapic.RawPredictRequest(endpoint=endpoint_name, http_body=http_body)
    response = gapic_client.raw_predict(request=request, metadata=headers)
    # get result as json
    result_http = json.loads(response.data.decode("utf-8"))
    print(f"response: {result_http['outputs'][0]['data']}")