# Building embedding models with preprocess global & per-arm features 

**Use this notebook to better understand how the emb preprocessing functions work:**
* the dimensions produced at each step
* working with tensors (e.g., concat)

The preproccesing layers will ultimately feed the two sampling functions described below. These sampling functions will be used to create [trajectories](https://github.com/tensorflow/agents/blob/master/tf_agents/trajectories/trajectory.py#L36) (i.e., the training examples for our model)

`global_context_sampling_fn`: 
* A function that outputs a random 1d array or list of ints or floats
* This output is the global context. Its shape and type must be consistent across calls.

`arm_context_sampling_fn`: 
* A function that outputs a random 1 array or list of ints or floats (same type as the output of `global_context_sampling_fn`). * This output is the per-arm context. Its shape must be consistent across calls.

## Notebook config

In [None]:
PREFIX = 'mabv1'

In [None]:
# staging GCS
GCP_PROJECTS             = !gcloud config get-value project
PROJECT_ID               = GCP_PROJECTS[0]

# GCS bucket and paths
BUCKET_NAME              = f'{PREFIX}-{PROJECT_ID}-bucket'
BUCKET_URI               = f'gs://{BUCKET_NAME}'

config = !gsutil cat {BUCKET_URI}/config/notebook_env.py
print(config.n)
exec(config.n)

## Imports

In [None]:
import functools
from collections import defaultdict
from typing import Callable, Dict, List, Optional, TypeVar
from datetime import datetime
import time
from pprint import pprint
import pickle as pkl
import numpy as np

# google cloud
from google.cloud import aiplatform, storage

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# TF-agents
# from tf_agents.bandits.agents import neural_epsilon_greedy_agent
# from tf_agents.bandits.agents import neural_linucb_agent
# from tf_agents.bandits.networks import global_and_arm_feature_network
from tf_agents.bandits.policies import policy_utilities
from tf_agents.bandits.metrics import tf_metrics as tf_bandit_metrics
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step as ts

from tf_agents.bandits.specs import utils as bandit_spec_utils
from tf_agents.trajectories import trajectory

# GPU
from numba import cuda 
import gc

# tf exceptions and vars
if tf.__version__[0] != "2":
    raise Exception("The trainer only runs with TensorFlow version 2.")

T = TypeVar("T")

In [None]:
import sys
sys.path.append("..")

# this repo
from src.per_arm_rl import data_utils
from src.per_arm_rl import data_config
from src.per_arm_rl import train_utils as train_utils

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
device = cuda.get_current_device()
device.reset()
gc.collect()

In [None]:
# cloud storage client
storage_client = storage.Client(project=PROJECT_ID)

# Vertex client
aiplatform.init(project=PROJECT_ID, location=LOCATION)

## Data 

In [None]:
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO

In [None]:
SPLIT = "val"

val_files = []
for blob in storage_client.list_blobs(f"{BUCKET_NAME}", prefix=f'{DATA_GCS_PREFIX}/{SPLIT}'):
    if '.tfrecord' in blob.name:
        val_files.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))
        
val_dataset = tf.data.TFRecordDataset(val_files)
val_dataset = val_dataset.map(data_utils.parse_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)

### get vocabulary

In [None]:
EXISTING_VOCAB_FILE = f'gs://{BUCKET_NAME}/{VOCAB_SUBDIR}/{VOCAB_FILENAME}'
print(f"Downloading vocab...")

os.system(f'gsutil -q cp {EXISTING_VOCAB_FILE} .')
print(f"Downloaded vocab from: {EXISTING_VOCAB_FILE}\n")

filehandler = open(VOCAB_FILENAME, 'rb')
vocab_dict = pkl.load(filehandler)
filehandler.close()

for key in vocab_dict.keys():
    pprint(key)

In [None]:
for i in range(1):
    
    iterator = iter(train_dataset.batch(1))
    data = next(iterator)

data

In [None]:
NUM_OOV_BUCKETS        = 1
GLOBAL_EMBEDDING_SIZE  = 16
MV_EMBEDDING_SIZE      = 32 #32

## global context (user) features

#### user ID

In [None]:
user_id_input_layer = tf.keras.Input(
    name="user_id",
    shape=(1,),
    dtype=tf.string
)

user_id_lookup = tf.keras.layers.StringLookup(
    max_tokens=len(vocab_dict['user_id']) + NUM_OOV_BUCKETS,
    num_oov_indices=NUM_OOV_BUCKETS,
    mask_token=None,
    vocabulary=vocab_dict['user_id'],
)(user_id_input_layer)

user_id_embedding = tf.keras.layers.Embedding(
    # Let's use the explicit vocabulary lookup.
    input_dim=len(vocab_dict['user_id']) + NUM_OOV_BUCKETS,
    output_dim=GLOBAL_EMBEDDING_SIZE
)(user_id_lookup)

user_id_embedding = tf.reduce_sum(user_id_embedding, axis=-2)

# global_inputs.append(user_id_input_layer)
# global_features.append(user_id_embedding)

In [None]:
test_user_id_model = tf.keras.Model(inputs=user_id_input_layer, outputs=user_id_embedding)

# for x in train_dataset.batch(1).take(1):
#     print(x["user_id"])
#     print(test_user_id_model(x["user_id"]))

#### user AGE

In [None]:
# user_age_input_layer = tf.keras.Input(
#     name="bucketized_user_age",
#     shape=(1,),
#     dtype=tf.float32
# )

# user_age_lookup = tf.keras.layers.IntegerLookup(
#     vocabulary=vocab_dict['bucketized_user_age'],
#     num_oov_indices=NUM_OOV_BUCKETS,
#     oov_value=0,
# )(user_age_input_layer)

# user_age_embedding = tf.keras.layers.Embedding(
#     # Let's use the explicit vocabulary lookup.
#     input_dim=len(vocab_dict['bucketized_user_age']) + NUM_OOV_BUCKETS,
#     output_dim=GLOBAL_EMBEDDING_SIZE
# )(user_age_lookup)

# user_age_embedding = tf.reduce_sum(user_age_embedding, axis=-2)

# # global_inputs.append(user_age_input_layer)
# # global_features.append(user_age_embedding)

In [None]:
# test_user_age_model = tf.keras.Model(inputs=user_age_input_layer, outputs=user_age_embedding)

# # for x in train_dataset.batch(1).take(1):
# #     print(x["bucketized_user_age"])
# #     print(test_user_age_model(x["bucketized_user_age"]))

#### user OCC

In [None]:
# user_occ_input_layer = tf.keras.Input(
#     name="user_occupation_text",
#     shape=(1,),
#     dtype=tf.string
# )

# user_occ_lookup = tf.keras.layers.StringLookup(
#     max_tokens=len(vocab_dict['user_occupation_text']) + NUM_OOV_BUCKETS,
#     num_oov_indices=NUM_OOV_BUCKETS,
#     mask_token=None,
#     vocabulary=vocab_dict['user_occupation_text'],
# )(user_occ_input_layer)

# user_occ_embedding = tf.keras.layers.Embedding(
#     # Let's use the explicit vocabulary lookup.
#     input_dim=len(vocab_dict['user_occupation_text']) + NUM_OOV_BUCKETS,
#     output_dim=GLOBAL_EMBEDDING_SIZE
# )(user_occ_lookup)

# user_occ_embedding = tf.reduce_sum(user_occ_embedding, axis=-2)

# # global_inputs.append(user_occ_input_layer)
# # global_features.append(user_occ_embedding)

In [None]:
# test_user_occ_model = tf.keras.Model(inputs=user_occ_input_layer, outputs=user_occ_embedding)

# # for x in train_dataset.batch(1).take(1):
# #     print(x["user_occupation_text"])
# #     print(test_user_occ_model(x["user_occupation_text"]))

#### user Timestamp

In [None]:
# user_ts_input_layer = tf.keras.Input(
#     name="timestamp",
#     shape=(1,),
#     dtype=tf.int64
# )

# user_ts_lookup = tf.keras.layers.Discretization(
#     vocab_dict['timestamp_buckets'].tolist()
# )(user_ts_input_layer)

# user_ts_embedding = tf.keras.layers.Embedding(
#     # Let's use the explicit vocabulary lookup.
#     input_dim=len(vocab_dict['timestamp_buckets'].tolist()) + NUM_OOV_BUCKETS,
#     output_dim=GLOBAL_EMBEDDING_SIZE
# )(user_ts_lookup)

# user_ts_embedding = tf.reduce_sum(user_ts_embedding, axis=-2)

# # global_inputs.append(user_ts_input_layer)
# # global_features.append(user_ts_embedding)

In [None]:
# test_user_ts_model = tf.keras.Model(inputs=user_ts_input_layer, outputs=user_ts_embedding)

# # for x in train_dataset.batch(1).take(1):
# #     print(x["timestamp"])
# #     print(test_user_ts_model(x["timestamp"]))

### define global sampling function

In [None]:
# def _get_global_context_features(x):
#     """
#     This function generates a single global observation vector.
#     """
#     user_id_value = x['user_id']
#     user_age_value = x['bucketized_user_age']
#     user_occ_value = x['user_occupation_text']
#     user_ts_value = x['timestamp']

#     _id = test_user_id_model(user_id_value) # input_tensor=tf.Tensor(shape=(4,), dtype=float32)
#     _age = test_user_age_model(user_age_value)
#     _occ = test_user_occ_model(user_occ_value)
#     _ts = test_user_ts_model(user_ts_value)

#     # # tmp - insepct numpy() values
#     # print(_id.numpy()) #[0])
#     # print(_age.numpy()) #[0])
#     # print(_occ.numpy()) #[0])
#     # print(_ts.numpy()) #[0])

#     # to numpy array
#     _id = np.array(_id.numpy())
#     _age = np.array(_age.numpy())
#     _occ = np.array(_occ.numpy())
#     _ts = np.array(_ts.numpy())

#     concat = np.concatenate(
#         [_id, _age, _occ, _ts], axis=-1 # -1
#     ).astype(np.float32)

#     return concat

In [None]:
GLOBAL_DIM = _get_global_context_features(data).shape[1]
print(f"GLOBAL_DIM: {GLOBAL_DIM}")

## arm preprocessing layers

#### movie ID

In [None]:
# mv_id_input_layer = tf.keras.Input(
#     name="movie_id",
#     shape=(1,),
#     dtype=tf.string
# )

# mv_id_lookup = tf.keras.layers.StringLookup(
#     max_tokens=len(vocab_dict['movie_id']) + NUM_OOV_BUCKETS,
#     num_oov_indices=NUM_OOV_BUCKETS,
#     mask_token=None,
#     vocabulary=vocab_dict['movie_id'],
# )(mv_id_input_layer)

# mv_id_embedding = tf.keras.layers.Embedding(
#     # Let's use the explicit vocabulary lookup.
#     input_dim=len(vocab_dict['movie_id']) + NUM_OOV_BUCKETS,
#     output_dim=MV_EMBEDDING_SIZE
# )(mv_id_lookup)

# mv_id_embedding = tf.reduce_sum(mv_id_embedding, axis=-2)

# # arm_inputs.append(mv_id_input_layer)
# # arm_features.append(mv_id_embedding)

In [None]:
# test_mv_id_model = tf.keras.Model(inputs=mv_id_input_layer, outputs=mv_id_embedding)

# # for x in train_dataset.batch(1).take(1):
# #     print(x["movie_id"])
# #     print(test_mv_id_model(x["movie_id"]))

#### movie genre

In [None]:
# mv_genre_input_layer = tf.keras.Input(
#     name="movie_genres",
#     shape=(1,),
#     dtype=tf.float32
# )

# mv_genre_lookup = tf.keras.layers.IntegerLookup(
#     vocabulary=vocab_dict['movie_genres'],
#     num_oov_indices=NUM_OOV_BUCKETS,
#     oov_value=0,
# )(mv_genre_input_layer)

# mv_genre_embedding = tf.keras.layers.Embedding(
#     # Let's use the explicit vocabulary lookup.
#     input_dim=len(vocab_dict['movie_genres']) + NUM_OOV_BUCKETS,
#     output_dim=MV_EMBEDDING_SIZE
# )(mv_genre_lookup)

# mv_genre_embedding = tf.reduce_sum(mv_genre_embedding, axis=-2)

# # arm_inputs.append(mv_genre_input_layer)
# # arm_features.append(mv_genre_embedding)

In [None]:
# test_mv_gen_model = tf.keras.Model(inputs=mv_genre_input_layer, outputs=mv_genre_embedding)

# for x in train_dataset.batch(1).take(1):
#     print(x["movie_genres"])
#     print(test_mv_gen_model(x["movie_genres"]))

### define sampling function

In [None]:
# def _get_per_arm_features(x):
#     """
#     This function generates a single per-arm observation vector
#     """
#     mv_id_value = x['movie_id']
#     mv_gen_value = x['movie_genres']

#     _mid = test_mv_id_model(mv_id_value)
#     _mgen = test_mv_gen_model(mv_gen_value)

#     # to numpy array
#     _mid = np.array(_mid.numpy())
#     _mgen = np.array(_mgen.numpy())


#     concat = np.concatenate(
#         [_mid, _mgen], axis=-1 # -1
#     ).astype(np.float32)
#     # concat = tf.concat([_mid, _mgen], axis=-1).astype(np.float32)

#     return concat #this is special to this example - there is only one action dimensions

In [None]:
PER_ARM_DIM = _get_per_arm_features(data).shape[1] #shape checks out at batchdim, nactions, arm feats
print(f"PER_ARM_DIM: {PER_ARM_DIM}")

## Global & Per-Arm feature embedding models 

> all these dimensions should match the class output below

In [None]:
from src.perarm_features import emb_features as emb_features

embs = emb_features.EmbeddingModel(
    vocab_dict = vocab_dict,
    num_oov_buckets = NUM_OOV_BUCKETS,
    global_emb_size = GLOBAL_EMBEDDING_SIZE,
    mv_emb_size = MV_EMBEDDING_SIZE,
)

embs

In [None]:
test_globals = embs._get_global_context_features(data)

GLOBAL_DIM = test_globals.shape[1]            
# shape checks out at batch_dim, nactions, arm feats
print(f"GLOBAL_DIM: {GLOBAL_DIM}")

test_globals

In [None]:
test_arms = embs._get_per_arm_features(data)

PER_ARM_DIM = test_arms.shape[1]            
# shape checks out at batch_dim, nactions, arm feats
print(f"PER_ARM_DIM: {PER_ARM_DIM}")

test_arms