# Deploying trained Agent to Vertex Endpoint

> **TODO**

## Create custom prediction container

As with training, create a custom prediction container. This container handles the TF-Agents specific logic that is different from a regular TensorFlow Model. Specifically, it finds the predicted action using a trained policy. The associated source code is in `src/prediction/`.
See other options for Vertex AI predictions [here](https://cloud.google.com/vertex-ai/docs/predictions/getting-predictions).

#### Serve predictions:
- Use [`tensorflow.saved_model.load`](https://www.tensorflow.org/agents/api_docs/python/tf_agents/policies/PolicySaver#usage), instead of [`tf_agents.policies.policy_loader.load`](https://github.com/tensorflow/agents/blob/r0.8.0/tf_agents/policies/policy_loader.py#L26), to load the trained policy, because the latter produces an object of type [`SavedModelPyTFEagerPolicy`](https://github.com/tensorflow/agents/blob/402b8aa81ca1b578ec1f687725d4ccb4115386d2/tf_agents/policies/py_tf_eager_policy.py#L137) whose `action()` is not compatible for use here.
- Note that prediction requests contain only observation data but not reward. This is because: The prediction task is a standalone request that doesn't require prior knowledge of the system state. Meanwhile, end users only know what they observe at the moment. Reward is a piece of information that comes after the action has been made, so the end users would not have knowledge of said reward. In handling prediction requests, you create a [`TimeStep`](https://www.tensorflow.org/agents/api_docs/python/tf_agents/trajectories/TimeStep) object (consisting of `observation`, `reward`, `discount`, `step_type`) using the [`restart()`](https://www.tensorflow.org/agents/api_docs/python/tf_agents/trajectories/restart) function which takes in an `observation`. This function creates the *first* TimeStep in a trajectory of steps, where reward is 0, discount is 1 and step_type is marked as the first timestep. In other words, each prediction request forms the first `TimeStep` in a brand new trajectory.
- For the prediction response, avoid using NumPy-typed values; instead, convert them to native Python values using methods such as [`tolist()`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.tolist.html) as opposed to `list()`.
- There exists a prestart script in `src/prediction`. FastAPI executes this script before starting up the server. The `PORT` environment variable is set to equal `AIP_HTTP_PORT` in order to run FastAPI on the same port expected by Vertex AI.

In [None]:
# PRED_SUBFOLDER = 'prediction'

In [None]:
# Make the training subfolder
# ! rm -rf {REPO_DOCKER_PATH_PREFIX}/{PRED_SUBFOLDER}
# ! mkdir {REPO_DOCKER_PATH_PREFIX}/{PRED_SUBFOLDER}

In [None]:
# %%writefile {REPO_DOCKER_PATH_PREFIX}/{PRED_SUBFOLDER}/main.py
# # Copyright 2021 Google LLC
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# #      http://www.apache.org/licenses/LICENSE-2.0
# #
# # Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.

# """Prediction server that uses a trained policy to give predicted actions."""
# import os

# from fastapi import FastAPI
# from fastapi import Request

# import tensorflow as tf
# import tf_agents


# app = FastAPI()
# _model = tf.compat.v2.saved_model.load(os.environ["AIP_STORAGE_URI"])


# @app.get(os.environ["AIP_HEALTH_ROUTE"], status_code=200)
# def health():
#     """
#     Handles server health check requests.

#     Returns:
#       An empty dict.
#     """
#     return {}


# @app.post(os.environ["AIP_PREDICT_ROUTE"])
# async def predict(request: Request):
#     """
#     Handles prediction requests.

#     Unpacks observations in prediction requests and queries the trained policy for
#     predicted actions.

#     Args:
#       request: Incoming prediction requests that contain observations.

#     Returns:
#       A dict with the key `predictions` mapping to a list of predicted actions
#       corresponding to each observation in the prediction request.
#     """
#     body = await request.json()
#     instances = body["instances"]

#     predictions = []
#     for index, instance in enumerate(instances):
#         # Unpack request body and reconstruct TimeStep. Rewards default to 0.
#         batch_size = len(instance["observation"])
        
#         time_step = tf_agents.trajectories.restart(
#             observation=instance["observation"]
#             , batch_size=tf.convert_to_tensor([batch_size])
#         )
#         policy_step = _model.action(time_step)

#         predictions.append(
#             {f"PolicyStep {index}": policy_step.action.numpy().tolist()}
#         )

#     return {
#         "predictions": predictions
#     }

In [None]:
# %%writefile {REPO_DOCKER_PATH_PREFIX}/{PRED_SUBFOLDER}/prestart.sh
# #!/bin/bash
# export PORT=$AIP_HTTP_PORT

In [None]:
# %%writefile pred_requirements.txt
# tf-agents==0.17.0
# tensorflow==2.12.0
# numpy
# six
# typing-extensions
# pillow

In [None]:
# DOCKERNAME = 'pred'

In [None]:
# %%writefile Dockerfile_{DOCKERNAME}

# FROM tiangolo/uvicorn-gunicorn-fastapi:python3.10

# COPY src/prediction /app
# COPY pred_requirements.txt /app/requirements.txt

# RUN pip3 install -r /app/requirements.txt

In [None]:
# PREDICTION_CONTAINER = "prediction-custom-container"

# # Docker definitions for training
# PRED_IMAGE_URI = f'gcr.io/{PROJECT_ID}/{PREDICTION_CONTAINER}'
# MACHINE_TYPE ='e2-highcpu-32'
# FILE_LOCATION = './'

# print(f"export DOCKERNAME={DOCKERNAME}")
# print(f"export PRED_IMAGE_URI={PRED_IMAGE_URI}")
# print(f"export FILE_LOCATION={FILE_LOCATION}")
# print(f"export MACHINE_TYPE={MACHINE_TYPE}")
# print(f"export ARTIFACTS_DIR={ARTIFACTS_DIR}")

In [None]:
# ! gcloud builds submit --config cloudbuild.yaml \
#     --substitutions _DOCKERNAME=$DOCKERNAME,_IMAGE_URI=$PRED_IMAGE_URI,_FILE_LOCATION=$FILE_LOCATION,_ARTIFACTS_DIR=$ARTIFACTS_DIR \
#     --timeout=2h \
#     --machine-type=$MACHINE_TYPE