# Scaling bandit training with Vertex AI 

**prerequisites:**
* build training image in `04b-build-training-image` noteook

In [1]:
! python3 -c "import google.cloud.aiplatform; print('aiplatform SDK version: {}'.format(google.cloud.aiplatform.__version__))"

aiplatform SDK version: 1.26.0


## setup notebook environment

In [2]:
!pwd

/home/jupyter/tf_vertex_agents/04-perarm-features-bandit


### Load env config
* use the prefix from `00-env-setup` notebook

In [3]:
PREFIX = 'mabv1'

In [4]:
# staging GCS
GCP_PROJECTS             = !gcloud config get-value project
PROJECT_ID               = GCP_PROJECTS[0]

# GCS bucket and paths
BUCKET_NAME              = f'{PREFIX}-{PROJECT_ID}-bucket'
BUCKET_URI               = f'gs://{BUCKET_NAME}'

config = !gsutil cat {BUCKET_URI}/config/notebook_env.py
print(config.n)
exec(config.n)


PROJECT_ID               = "hybrid-vertex"
PROJECT_NUM              = "934903580331"
LOCATION                 = "us-central1"

REGION                   = "us-central1"
BQ_LOCATION              = "US"
VPC_NETWORK_NAME         = "ucaip-haystack-vpc-network"

VERTEX_SA                = "934903580331-compute@developer.gserviceaccount.com"

PREFIX                   = "mabv1"
VERSION                  = "v1"

BUCKET_NAME              = "mabv1-hybrid-vertex-bucket"
BUCKET_URI               = "gs://mabv1-hybrid-vertex-bucket"
DATA_GCS_PREFIX          = "data"
DATA_PATH                = "gs://mabv1-hybrid-vertex-bucket/data"
VOCAB_SUBDIR             = "vocabs"
VOCAB_FILENAME           = "vocab_dict.pkl"

VPC_NETWORK_FULL         = "projects/934903580331/global/networks/ucaip-haystack-vpc-network"

BIGQUERY_DATASET_ID      = "hybrid-vertex.movielens_dataset_mabv1"
BIGQUERY_TABLE_ID        = "hybrid-vertex.movielens_dataset_mabv1.training_dataset"

REPO_DOCKER_PATH_PREFIX  = "src"
RL_SUB_DIR     

In [5]:
# ! gsutil iam ch serviceAccount:{VERTEX_SA}:roles/storage.objects.get $BUCKET_URI
# ! gsutil iam ch serviceAccount:{VERTEX_SA}:roles/storage.objects.get $BUCKET_URI

### imports

In [6]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

In [8]:
import json
from datetime import datetime
from time import time
import pandas as pd
import numpy as np

# disable INFO and DEBUG logging everywhere
import logging
import time
from pprint import pprint
import pickle as pkl

logging.disable(logging.WARNING)

from google.cloud import aiplatform as vertex_ai
from google.cloud import storage

In [9]:
import sys
sys.path.append("..")

from src.per_arm_rl import data_utils
from src.per_arm_rl import train_utils
from src.per_arm_rl import data_config

In [10]:
storage_client = storage.Client(project=PROJECT_ID)

vertex_ai.init(project=PROJECT_ID,location=REGION)

# Vertex Training Job

## job compute

Set the variable `TRAIN_COMPUTE` to configure the compute resources for the VMs you will use for for training.

**Machine Type:**
* `n1-standard`: 3.75GB of memory per vCPU.
* `n1-highmem`: 6.5GB of memory per vCPU
* `n1-highcpu`: 0.9 GB of memory per vCPU
* `vCPUs`: number of `[2, 4, 8, 16, 32, 64, 96 ]`

**Note:** The following is not supported for training:

* `standard`: 2 vCPUs
* `highcpu`: 2, 4 and 8 vCPUs

> Note: You may also use n2 and e2 machine types for training and deployment, but they do not support GPUs.

relevant docs: 
* [Configure compute resources for training](https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types) for more details
* [Machine series comparison](https://cloud.google.com/compute/docs/machine-resource#machine_type_comparison)

In [11]:
USE_GPU = "t4" # str: "a100" | "t4" | None
USE_GPU = str(USE_GPU)
print(f"USE_GPU: {USE_GPU}")

USE_GPU: t4


In [12]:
if USE_GPU == "a100":
    WORKER_MACHINE_TYPE = 'a2-highgpu-1g'
    REPLICA_COUNT = 1
    ACCELERATOR_TYPE = 'NVIDIA_TESLA_A100'
    PER_MACHINE_ACCELERATOR_COUNT = 1
    REDUCTION_SERVER_COUNT = 0                                                      
    REDUCTION_SERVER_MACHINE_TYPE = "n1-highcpu-16"
    DISTRIBUTE_STRATEGY = 'single'
elif USE_GPU == 't4':
    WORKER_MACHINE_TYPE = 'n1-standard-16'
    REPLICA_COUNT = 1
    ACCELERATOR_TYPE = 'NVIDIA_TESLA_T4' # NVIDIA_TESLA_T4 NVIDIA_TESLA_V100
    PER_MACHINE_ACCELERATOR_COUNT = 1
    DISTRIBUTE_STRATEGY = 'single'
    REDUCTION_SERVER_COUNT = 0                                                      
    REDUCTION_SERVER_MACHINE_TYPE = "n1-highcpu-16"
elif USE_GPU == "False":
    WORKER_MACHINE_TYPE = 'n2-highmem-32' # 'n1-highmem-96'n | 'n2-highmem-92'
    REPLICA_COUNT = 1
    ACCELERATOR_TYPE = None
    PER_MACHINE_ACCELERATOR_COUNT = 0
    DISTRIBUTE_STRATEGY = 'single'
    REDUCTION_SERVER_COUNT = 0                                                      
    REDUCTION_SERVER_MACHINE_TYPE = "n1-highcpu-16"

print(f"WORKER_MACHINE_TYPE            : {WORKER_MACHINE_TYPE}")
print(f"REPLICA_COUNT                  : {REPLICA_COUNT}")
print(f"ACCELERATOR_TYPE               : {ACCELERATOR_TYPE}")
print(f"PER_MACHINE_ACCELERATOR_COUNT  : {PER_MACHINE_ACCELERATOR_COUNT}")
print(f"DISTRIBUTE_STRATEGY            : {DISTRIBUTE_STRATEGY}")
print(f"REDUCTION_SERVER_COUNT         : {REDUCTION_SERVER_COUNT}")
print(f"REDUCTION_SERVER_MACHINE_TYPE  : {REDUCTION_SERVER_MACHINE_TYPE}")

WORKER_MACHINE_TYPE            : n1-standard-16
REPLICA_COUNT                  : 1
ACCELERATOR_TYPE               : NVIDIA_TESLA_T4
PER_MACHINE_ACCELERATOR_COUNT  : 1
DISTRIBUTE_STRATEGY            : single
REDUCTION_SERVER_COUNT         : 0
REDUCTION_SERVER_MACHINE_TYPE  : n1-highcpu-16


## set Vertex AI Experiment

In [13]:
EXPERIMENT_NAME   = f'scale-paf-v2'

# new experiment
invoke_time       = time.strftime("%Y%m%d-%H%M%S")
RUN_NAME          = f'run-{invoke_time}'

BASE_OUTPUT_DIR   = f'{BUCKET_URI}/{EXPERIMENT_NAME}/{RUN_NAME}'
LOG_DIR           = f"{BASE_OUTPUT_DIR}/logs"
ROOT_DIR          = f"{BASE_OUTPUT_DIR}/root"       # Root directory for writing logs/summaries/checkpoints.
ARTIFACTS_DIR     = f"{BASE_OUTPUT_DIR}/artifacts"  # Where the trained model will be saved and restored.

vertex_ai.init(
    project=PROJECT_ID,
    location=REGION,
    experiment=EXPERIMENT_NAME
)

print(f"EXPERIMENT_NAME   : {EXPERIMENT_NAME}")
print(f"RUN_NAME          : {RUN_NAME}\n")
print(f"BASE_OUTPUT_DIR   : {BASE_OUTPUT_DIR}")
print(f"LOG_DIR           : {LOG_DIR}")
print(f"ROOT_DIR          : {ROOT_DIR}")
print(f"ARTIFACTS_DIR     : {ARTIFACTS_DIR}")

EXPERIMENT_NAME   : scale-paf-v2
RUN_NAME          : run-20230823-220300

BASE_OUTPUT_DIR   : gs://mabv1-hybrid-vertex-bucket/scale-paf-v2/run-20230823-220300
LOG_DIR           : gs://mabv1-hybrid-vertex-bucket/scale-paf-v2/run-20230823-220300/logs
ROOT_DIR          : gs://mabv1-hybrid-vertex-bucket/scale-paf-v2/run-20230823-220300/root
ARTIFACTS_DIR     : gs://mabv1-hybrid-vertex-bucket/scale-paf-v2/run-20230823-220300/artifacts


## Create Tensorboard

In [14]:
# # create new TB instance
TENSORBOARD_DISPLAY_NAME=f"{EXPERIMENT_NAME}-{RUN_NAME}"

tensorboard = vertex_ai.Tensorboard.create(
    display_name=TENSORBOARD_DISPLAY_NAME
    , project=PROJECT_ID
    , location=REGION
)

TB_RESOURCE_NAME = tensorboard.resource_name

# use existing TB instance
# TB_RESOURCE_NAME = 'projects/934903580331/locations/us-central1/tensorboards/6924469145035603968'

print(f"TB_RESOURCE_NAME: {TB_RESOURCE_NAME}")
print(f"TB display name: {tensorboard.display_name}")

TB_RESOURCE_NAME: projects/934903580331/locations/us-central1/tensorboards/5580091879707377664
TB display name: scale-paf-v2-run-20230823-220300


## Set training args

In [15]:
print(f"REMOTE_IMAGE_NAME : {REMOTE_IMAGE_NAME}")

REMOTE_IMAGE_NAME : us-central1-docker.pkg.dev/hybrid-vertex/rl-movielens-mabv1/train-perarm-feats-v1


In [16]:
# vocab
# VOCAB_SUBDIR         = "vocabs"
# VOCAB_FILENAME       = "vocab_dict.pkl"

# Set hyperparameters.
BATCH_SIZE           = 128          # Training and prediction batch size.
TRAINING_LOOPS       = 1000         # Number of training iterations.
STEPS_PER_LOOP       = 1            # Number of driver steps per training iteration.
ASYNC_STEPS_PER_LOOP = 1
LOG_INTERVAL         = 50
LR                   = 0.05

CHKPT_INTERVAL       = 1000
EVAL_BATCH_SIZE      = 1  
NUM_EVAL_STEPS       = 1000 #10000

# Set MovieLens simulation environment parameters.
RANK_K               = 10      # Rank for matrix factorization in the MovieLens environment; also the observation dimension.
NUM_ACTIONS          = 2       # Number of actions (movie items) to choose from.
PER_ARM              = True    # Use the non-per-arm version of the MovieLens environment.

# ================================
# Agent
# ================================
AGENT_TYPE          = 'epsGreedy' # 'LinUCB' | 'LinTS |, 'epsGreedy' | 'NeuralLinUCB'
NETWORK_TYPE        = "commontower" # 'commontower' | 'dotproduct'

if AGENT_TYPE == 'NeuralLinUCB':
    NETWORK_TYPE = 'commontower'

TIKHONOV_WEIGHT     = 0.001   # LinUCB Tikhonov regularization weight.
AGENT_ALPHA         = 0.1     # LinUCB exploration parameter that multiplies the confidence intervals.
EPSILON             = 0.01
ENCODING_DIM        = 1
EPS_PHASE_STEPS     = 1000

# ================================
# network params
# ================================
GLOBAL_LAYERS       = [64, 32, 16]
ARM_LAYERS          = [64, 32, 16]
COMMON_LAYERS       = [16, 8]

# ================================
# data config
# ================================
GLOBAL_DIM             = 64       # 16
PER_ARM_DIM            = 64       # 16
NUM_OOV_BUCKETS        = 1
GLOBAL_EMBEDDING_SIZE  = 16
MV_EMBEDDING_SIZE      = 32       # 32
SPLIT                  = "train"  # TODO - remove
RESUME_TRAINING        = None

print(f"VOCAB_SUBDIR           : {VOCAB_SUBDIR}")
print(f"VOCAB_FILENAME         : {VOCAB_FILENAME}")
print(f"BATCH_SIZE             : {BATCH_SIZE}")
print(f"TRAINING_LOOPS         : {TRAINING_LOOPS}")
print(f"STEPS_PER_LOOP         : {STEPS_PER_LOOP}")
print(f"ASYNC_STEPS_PER_LOOP   : {ASYNC_STEPS_PER_LOOP}")
print(f"LOG_INTERVAL           : {LOG_INTERVAL}")
print(f"RANK_K                 : {RANK_K}")
print(f"NUM_ACTIONS            : {NUM_ACTIONS}")
print(f"PER_ARM                : {PER_ARM}")
print(f"AGENT_TYPE             : {AGENT_TYPE}")
print(f"NETWORK_TYPE           : {NETWORK_TYPE}")
print(f"TIKHONOV_WEIGHT        : {TIKHONOV_WEIGHT}")
print(f"AGENT_ALPHA            : {AGENT_ALPHA}")
print(f"GLOBAL_DIM             : {GLOBAL_DIM}")
print(f"PER_ARM_DIM            : {PER_ARM_DIM}")
print(f"SPLIT                  : {SPLIT}")
print(f"RESUME_TRAINING        : {RESUME_TRAINING}")
print(f"NUM_OOV_BUCKETS        : {NUM_OOV_BUCKETS}")
print(f"GLOBAL_EMBEDDING_SIZE  : {GLOBAL_EMBEDDING_SIZE}")
print(f"MV_EMBEDDING_SIZE      : {MV_EMBEDDING_SIZE}")
print(f"AGENT_ALPHA            : {AGENT_ALPHA}")
print(f"GLOBAL_LAYERS          : {GLOBAL_LAYERS}")
print(f"ARM_LAYERS             : {ARM_LAYERS}")
print(f"COMMON_LAYERS          : {COMMON_LAYERS}")
print(f"LR                     : {LR}")
print(f"CHKPT_INTERVAL         : {CHKPT_INTERVAL}")
print(f"EVAL_BATCH_SIZE        : {EVAL_BATCH_SIZE}")
print(f"NUM_EVAL_STEPS         : {NUM_EVAL_STEPS}")
print(f"EPSILON                : {EPSILON}")
print(f"ENCODING_DIM           : {ENCODING_DIM}")
print(f"EPS_PHASE_STEPS        : {EPS_PHASE_STEPS}")

VOCAB_SUBDIR           : vocabs
VOCAB_FILENAME         : vocab_dict.pkl
BATCH_SIZE             : 128
TRAINING_LOOPS         : 1000
STEPS_PER_LOOP         : 1
ASYNC_STEPS_PER_LOOP   : 1
LOG_INTERVAL           : 50
RANK_K                 : 10
NUM_ACTIONS            : 2
PER_ARM                : True
AGENT_TYPE             : epsGreedy
NETWORK_TYPE           : commontower
TIKHONOV_WEIGHT        : 0.001
AGENT_ALPHA            : 0.1
GLOBAL_DIM             : 64
PER_ARM_DIM            : 64
SPLIT                  : train
RESUME_TRAINING        : None
NUM_OOV_BUCKETS        : 1
GLOBAL_EMBEDDING_SIZE  : 16
MV_EMBEDDING_SIZE      : 32
AGENT_ALPHA            : 0.1
GLOBAL_LAYERS          : [64, 32, 16]
ARM_LAYERS             : [64, 32, 16]
COMMON_LAYERS          : [16, 8]
LR                     : 0.05
CHKPT_INTERVAL         : 1000
EVAL_BATCH_SIZE        : 1
NUM_EVAL_STEPS         : 1000
EPSILON                : 0.01
ENCODING_DIM           : 1
EPS_PHASE_STEPS        : 1000


In [17]:
WORKER_ARGS = [
    f"--project={PROJECT_ID}"
    , f"--project_number={PROJECT_NUM}"
    , f"--bucket_name={BUCKET_NAME}"
    , f"--artifacts_dir={ARTIFACTS_DIR}"
    , f"--root_dir={ROOT_DIR}"
    , f"--log_dir={LOG_DIR}"
    , f"--data_dir_prefix_path={DATA_GCS_PREFIX}"
    , f"--vocab_prefix_path={VOCAB_SUBDIR}"
    , f"--vocab_filename={VOCAB_FILENAME}"
    ### job config
    , f"--distribute={DISTRIBUTE_STRATEGY}"
    , f"--experiment_name={EXPERIMENT_NAME}"
    , f"--experiment_run={RUN_NAME}"
    , f"--agent_type={AGENT_TYPE}"
    , f"--network_type={NETWORK_TYPE}"
    ### hparams
    , f"--batch_size={BATCH_SIZE}"
    , f"--eval_batch_size={EVAL_BATCH_SIZE}"
    , f"--training_loops={TRAINING_LOOPS}"
    , f"--steps_per_loop={STEPS_PER_LOOP}"
    , f"--num_eval_steps={NUM_EVAL_STEPS}"
    , f"--rank_k={RANK_K}"
    , f"--num_actions={NUM_ACTIONS}"
    , f"--async_steps_per_loop={ASYNC_STEPS_PER_LOOP}"
    # , f"--resume_training_loops"
    , f"--global_dim={GLOBAL_DIM}"
    , f"--per_arm_dim={PER_ARM_DIM}"
    , f"--split={SPLIT}"
    , f"--log_interval={LOG_INTERVAL}"
    , f"--chkpt_interval={CHKPT_INTERVAL}"
    , f"--num_oov_buckets={NUM_OOV_BUCKETS}"
    , f"--global_emb_size={GLOBAL_EMBEDDING_SIZE}"
    , f"--mv_emb_size={MV_EMBEDDING_SIZE}"
    , f"--agent_alpha={AGENT_ALPHA}"
    , f"--global_layers={GLOBAL_LAYERS}"
    , f"--arm_layers={ARM_LAYERS}"
    , f"--common_layers={COMMON_LAYERS}"
    , f"--learning_rate={LR}"
    , f"--epsilon={EPSILON}"
    , f"--encoding_dim={ENCODING_DIM}"
    , f"--eps_phase_steps={EPS_PHASE_STEPS}"
    ### accelerators & profiling
    , f"--use_gpu"
    # , f"--use_tpu"
    , f"--profiler"
]

import sys
sys.path.append("..")
from src.per_arm_rl import train_utils

WORKER_POOL_SPECS = train_utils.prepare_worker_pool_specs(
    image_uri=f"{REMOTE_IMAGE_NAME}",
    args=WORKER_ARGS,
    replica_count=REPLICA_COUNT,
    machine_type=WORKER_MACHINE_TYPE,
    accelerator_count=PER_MACHINE_ACCELERATOR_COUNT,
    accelerator_type=ACCELERATOR_TYPE,
    reduction_server_count=REDUCTION_SERVER_COUNT,
    reduction_server_machine_type=REDUCTION_SERVER_MACHINE_TYPE,
)

from pprint import pprint
pprint(WORKER_POOL_SPECS)

[{'container_spec': {'args': ['--project=hybrid-vertex',
                              '--project_number=934903580331',
                              '--bucket_name=mabv1-hybrid-vertex-bucket',
                              '--artifacts_dir=gs://mabv1-hybrid-vertex-bucket/scale-paf-v2/run-20230823-220300/artifacts',
                              '--root_dir=gs://mabv1-hybrid-vertex-bucket/scale-paf-v2/run-20230823-220300/root',
                              '--log_dir=gs://mabv1-hybrid-vertex-bucket/scale-paf-v2/run-20230823-220300/logs',
                              '--data_dir_prefix_path=data',
                              '--vocab_prefix_path=vocabs',
                              '--vocab_filename=vocab_dict.pkl',
                              '--distribute=single',
                              '--experiment_name=scale-paf-v2',
                              '--experiment_run=run-20230823-220300',
                              '--agent_type=epsGreedy',
                          

In [18]:
# !pwd

# Submit trainging job

In [19]:
vertex_ai.init(
    project=PROJECT_ID
    , location=REGION
    , experiment=EXPERIMENT_NAME
    # , staging_bucket=ROOT_DIR
)

JOB_NAME = f"paf-bandit-{RUN_NAME}"
print(f"JOB_NAME: {JOB_NAME}")

JOB_NAME: paf-bandit-run-20230823-220300


In [20]:
# Create a CustomJob
my_custom_job = vertex_ai.CustomJob(
    display_name=JOB_NAME
    , project=PROJECT_ID
    , worker_pool_specs=WORKER_POOL_SPECS
    , base_output_dir=BASE_OUTPUT_DIR
    , staging_bucket=ROOT_DIR
    # , location="asia-southeast1" 
)

In [21]:
my_custom_job.run(
    tensorboard=TB_RESOURCE_NAME,
    service_account=VERTEX_SA,
    restart_job_on_worker_restart=False,
    enable_web_access=True,
    sync=False,
)

In [20]:
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf

TB_LOGS_PATH = LOG_DIR
print(f"TB_LOGS_PATH: {TB_LOGS_PATH}")

# TB_LOGS_PATH = f'{BASE_OUTPUT_DIR}/logs' # 
# print(f"TB_LOGS_PATH: {TB_LOGS_PATH}")

TB_LOGS_PATH: gs://mabv1-hybrid-vertex-bucket/scale-paf-v2/run-20230823-153711/logs


In [21]:
%load_ext tensorboard
# %reload_ext tensorboard

In [22]:
%tensorboard --logdir=$TB_LOGS_PATH

# Making predictions

* When a policy is trained, given a new observation request (i.e. a user vector),
* the policy will inference (produce) actions, which are the recommended movies.
* In TF-Agents, observations are abstracted in a named tuple,

```
TimeStep(‘step_type’, ‘discount’, ‘reward’, ‘observation’)
```

> the policy maps time steps to actions

In [21]:
import tensorflow as tf

In [24]:
SPLIT = "val"

val_files = []
for blob in storage_client.list_blobs(f"{BUCKET_NAME}", prefix=f'{DATA_GCS_PREFIX}/{SPLIT}'):
    if '.tfrecord' in blob.name:
        val_files.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))
        
val_dataset = tf.data.TFRecordDataset(val_files)
val_dataset = val_dataset.map(data_utils.parse_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)

# eval dataset
eval_ds = val_dataset.batch(1)

# if NUM_EVAL_STEPS > 0:
#     eval_ds = eval_ds.take(NUM_EVAL_STEPS)

eval_ds

<_BatchDataset element_spec={'bucketized_user_age': TensorSpec(shape=(None,), dtype=tf.float32, name=None), 'movie_genres': TensorSpec(shape=(None, 1), dtype=tf.int64, name=None), 'movie_id': TensorSpec(shape=(None,), dtype=tf.string, name=None), 'timestamp': TensorSpec(shape=(None,), dtype=tf.int64, name=None), 'user_id': TensorSpec(shape=(None,), dtype=tf.string, name=None), 'user_occupation_text': TensorSpec(shape=(None,), dtype=tf.string, name=None), 'user_rating': TensorSpec(shape=(None,), dtype=tf.float32, name=None)}>

In [26]:
# MODEL_DIR = "gs://mabv1-hybrid-vertex-bucket/scale-perarm-hpt/run-20230717-211248/model"

!gsutil ls $ARTIFACTS_DIR

gs://mabv1-hybrid-vertex-bucket/scale-paf-v2/run-20230823-155908/artifacts/
gs://mabv1-hybrid-vertex-bucket/scale-paf-v2/run-20230823-155908/artifacts/fingerprint.pb
gs://mabv1-hybrid-vertex-bucket/scale-paf-v2/run-20230823-155908/artifacts/policy_specs.pbtxt
gs://mabv1-hybrid-vertex-bucket/scale-paf-v2/run-20230823-155908/artifacts/saved_model.pb
gs://mabv1-hybrid-vertex-bucket/scale-paf-v2/run-20230823-155908/artifacts/assets/
gs://mabv1-hybrid-vertex-bucket/scale-paf-v2/run-20230823-155908/artifacts/variables/


In [27]:
trained_policy = tf.saved_model.load(ARTIFACTS_DIR)
trained_policy

<tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject at 0x7ff0b6174ee0>

In [28]:
INFER_SIZE = 1
dummy_arm = tf.zeros([INFER_SIZE, PER_ARM_DIM], dtype=tf.float32)

for x in eval_ds.take(INFER_SIZE):
    # get feature tensors
    global_feat_infer = _get_global_context_features(x)
    arm_feat_infer = _get_per_arm_features(x)
    rewards = _get_rewards(x)
    
    # reshape arm features
    arm_feat_infer = tf.reshape(arm_feat_infer, [HPARAMS['eval_batch_size'], PER_ARM_DIM]) # perarm_dim
    concat_arm = tf.concat([arm_feat_infer, dummy_arm], axis=0)
    
    # flatten global
    flat_global_infer = tf.reshape(global_feat_infer, [GLOBAL_DIM])
    feature = {'global': flat_global_infer, 'per_arm': concat_arm}
    
    # get actual reward
    actual_reward = rewards.numpy()[0]
    
    # build trajectory step
    trajectory_step = train_utils._get_eval_step(feature, actual_reward)
    
    prediction = trained_policy.action(trajectory_step)

NameError: name '_get_global_context_features' is not defined