# Building embedding models with preprocess global & per-arm features 

**Use this notebook to better understand how the emb preprocessing functions work:**
* the dimensions produced at each step
* working with tensors (e.g., concat)

The preproccesing layers will ultimately feed the two sampling functions described below. These sampling functions will be used to create [trajectories](https://github.com/tensorflow/agents/blob/master/tf_agents/trajectories/trajectory.py#L36) (i.e., the training examples for our model)

`global_context_sampling_fn`: 
* A function that outputs a random 1d array or list of ints or floats
* This output is the global context. Its shape and type must be consistent across calls.

`arm_context_sampling_fn`: 
* A function that outputs a random 1 array or list of ints or floats (same type as the output of `global_context_sampling_fn`). * This output is the per-arm context. Its shape must be consistent across calls.

## Notebook config

In [None]:
PREFIX = 'mabv1'

In [None]:
# staging GCS
GCP_PROJECTS             = !gcloud config get-value project
PROJECT_ID               = GCP_PROJECTS[0]

# GCS bucket and paths
BUCKET_NAME              = f'{PREFIX}-{PROJECT_ID}-bucket'
BUCKET_URI               = f'gs://{BUCKET_NAME}'

config = !gsutil cat {BUCKET_URI}/config/notebook_env.py
print(config.n)
exec(config.n)

## Imports

In [None]:
import functools
from collections import defaultdict
from typing import Callable, Dict, List, Optional, TypeVar
from datetime import datetime
import time
from pprint import pprint
import pickle as pkl
import numpy as np

# google cloud
from google.cloud import aiplatform, storage

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# TF-agents
# from tf_agents.bandits.agents import neural_epsilon_greedy_agent
# from tf_agents.bandits.agents import neural_linucb_agent
# from tf_agents.bandits.networks import global_and_arm_feature_network
from tf_agents.bandits.policies import policy_utilities
from tf_agents.bandits.metrics import tf_metrics as tf_bandit_metrics
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step as ts

from tf_agents.bandits.specs import utils as bandit_spec_utils
from tf_agents.trajectories import trajectory

# GPU
from numba import cuda 
import gc

# tf exceptions and vars
if tf.__version__[0] != "2":
    raise Exception("The trainer only runs with TensorFlow version 2.")

T = TypeVar("T")

In [None]:
# TODO

import sys
sys.path.append("..")

# this repo
# from src.per_arm_rl import data_utils
# from src.per_arm_rl import data_config
# from src.per_arm_rl import train_utils as train_utils

# this repo
# from src.utils import movielens_ds_utils
from src.data import data_utils as data_utils
from src.data import data_config as data_config

# from src.per_arm_rl import train_utils as train_utils
from src import train_utils as train_utils
from src.trainer import eval_perarm as eval_perarm
from src.trainer import train_perarm as train_perarm

# from src.perarm_features import emb_feature as emb_features
from src.networks import encoding_network as emb_features

from src.perarm_features import agent_factory as agent_factory
from src.perarm_features import reward_factory as reward_factory

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
device = cuda.get_current_device()
device.reset()
gc.collect()

In [None]:
# cloud storage client
storage_client = storage.Client(project=PROJECT_ID)

# Vertex client
aiplatform.init(project=PROJECT_ID, location=LOCATION)

## Data 

In [None]:
DATA_SET = "movielens"
DATA_TAG = f"{DATA_SET}/movielens-1m" # movielens-100k | movielens-1m

print(f"DATA_TAG: {DATA_TAG}")

! gsutil ls $DATA_PATH/$DATA_TAG

In [None]:
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO

In [None]:
SPLIT = "train" # "train" | "val"

train_files = []
for blob in storage_client.list_blobs(f"{BUCKET_NAME}", prefix=f'{DATA_GCS_PREFIX}/{DATA_TAG}/{SPLIT}'):
    if '.tfrecord' in blob.name:
        train_files.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))
        
train_files = train_files[:3]
train_files

In [None]:
train_dataset = tf.data.TFRecordDataset(train_files)
train_dataset = train_dataset.map(movielens_ds_utils.parse_tfrecord)

for x in train_dataset.batch(1).take(1):
    pprint(x)

### get vocabulary

In [None]:
GENERATE_VOCABS = False

print(f"GENERATE_VOCABS: {GENERATE_VOCABS}")

In [None]:
if not GENERATE_VOCABS:

    EXISTING_VOCAB_FILE = f'gs://{BUCKET_NAME}/{DATA_GCS_PREFIX}/{DATA_TAG}/{VOCAB_FILENAME}'
    print(f"Downloading vocab...")
    
    os.system(f'gsutil -q cp {EXISTING_VOCAB_FILE} .')
    print(f"Downloaded vocab from: {EXISTING_VOCAB_FILE}\n")

    filehandler = open(VOCAB_FILENAME, 'rb')
    vocab_dict = pkl.load(filehandler)
    filehandler.close()
    
    for key in vocab_dict.keys():
        pprint(key)

In [None]:
for i in range(1):
    
    iterator = iter(train_dataset.batch(1))
    data = next(iterator)

data

In [None]:
NUM_OOV_BUCKETS        = 1
GLOBAL_EMBEDDING_SIZE  = 16
MV_EMBEDDING_SIZE      = 32 #32

## check your embedding / encoding network model

In [None]:
embs = emb_features.EmbeddingModel(
    vocab_dict = vocab_dict,
    num_oov_buckets = NUM_OOV_BUCKETS,
    global_emb_size = GLOBAL_EMBEDDING_SIZE,
    mv_emb_size = MV_EMBEDDING_SIZE,
)

embs

In [None]:
data

In [None]:
global_features = embs._get_global_context_features(data)
global_features

In [None]:
arm_features = embs._get_per_arm_features(data)
arm_features

In [None]:
arm_features = train_utils._add_outer_dimension(arm_features)
arm_features

In [None]:
# # padded_batch(2, padded_shapes=5)

# for x in train_dataset.padded_batch(HPARAMS['batch_size']).take(1):
#     print(x)

## Understanding tensor shapes, rank and how to manipulate

### Check the differences in these:

In [None]:
test_globals= [
    0.05226095, -0.04546688, -0.05914654,  0.03443705, -0.04011744,
   -0.05921736,  0.05578206, -0.02147666,  0.00166732,  0.04055796,
    0.06458487, -0.05492309, -0.06472961, -0.00705546, -0.05592869,
   -0.01938318,  0.03898788, -0.04043241, -0.0182637 , -0.0499408 ,
   -0.05968586,  0.06301413,  0.00032848,  0.06395795,  0.01845439,
    0.04108731, -0.05026846,  0.01969895, -0.02506991,  0.02361025,
    0.00762446, -0.00464374,  0.01902852,  0.03852094, -0.04125774,
   -0.04153034, -0.03931752,  0.05585755, -0.03481127, -0.04961544,
   -0.04787084,  0.06189156, -0.04888101, -0.07491934, -0.07062666,
   -0.02748476, -0.01719889, -0.06808205
]

In [None]:
# test_globals
test_list_seq = [test_globals[0].numpy(), test_globals[0].numpy()*2]
test_list_seq

In [None]:
reduce_test_v1 = tf.reduce_mean(test_list_seq, axis=[0,1])
reduce_test_v1

In [None]:
reduce_test_v2 = tf.reduce_mean(test_list_seq, axis=[0])
reduce_test_v2

In [None]:
reduce_test_v3 = tf.reduce_mean(test_list_seq, axis=[1])
reduce_test_v3

### reshape tensors

In [None]:
# a = [[1, 2, 3,4,5,6,7,8,9]]
a = [1, 2, 3,4,5,6,7,8,9]

b = tf.reshape(a, [-1, 9, 1])
b

In [None]:
tf.rank(b)

In [None]:
test_globals= [
    0.05226095, -0.04546688, -0.05914654,  0.03443705, -0.04011744,
   -0.05921736,  0.05578206, -0.02147666,  0.00166732,  0.04055796,
    0.06458487, -0.05492309, -0.06472961, -0.00705546, -0.05592869,
   -0.01938318,  0.03898788, -0.04043241, -0.0182637 , -0.0499408 ,
   -0.05968586,  0.06301413,  0.00032848,  0.06395795,  0.01845439,
    0.04108731, -0.05026846,  0.01969895, -0.02506991,  0.02361025,
    0.00762446, -0.00464374,  0.01902852,  0.03852094, -0.04125774,
   -0.04153034, -0.03931752,  0.05585755, -0.03481127, -0.04961544,
   -0.04787084,  0.06189156, -0.04888101, -0.07491934, -0.07062666,
   -0.02748476, -0.01719889, -0.06808205
]

# print(test_globals)
print(tf.rank(test_globals))

### dont forget about the batch dimenion

In [None]:
NUM_OOV_BUCKETS = 1
MV_EMBEDDING_SIZE=16

In [None]:
TAG_MAX_LENGTH=10
MAX_VECT_LEN = 10

# vectorize_layer = tf.keras.layers.TextVectorization(
#  max_tokens=max_features,
#  output_mode='int',
#  output_sequence_length=max_len)

mv_tags_input_layer = tf.keras.Input(
    name="movie_tags",
    shape=(TAG_MAX_LENGTH,1),
    # shape=(1,),
    dtype=tf.string,
    # ragged=True
)
mv_tags_text = tf.keras.layers.TextVectorization(
    # max_tokens=max_tokens, 
    ngrams=2, 
    vocabulary=vocab_dict['movie_tags'],
    output_mode='int',
    output_sequence_length=MAX_VECT_LEN,
)(mv_tags_input_layer)
mv_tags_embedding = tf.keras.layers.Embedding(
    # Let's use the explicit vocabulary lookup.
    input_dim=len(vocab_dict['movie_tags']) + NUM_OOV_BUCKETS,
    output_dim=MV_EMBEDDING_SIZE
)(mv_tags_text)

# mv_avg_pooling = tf.reduce_mean(mv_tags_embedding, axis=[-1])
# mv_avg_pooling = tf.reduce_sum(mv_tags_embedding, axis=-2)

mv_tags_pooling_v1 = tf.keras.layers.Reshape([-1, MV_EMBEDDING_SIZE])(mv_tags_embedding)
# mv_avg_pooling = tf.keras.layers.GlobalAveragePooling2D()(mv_tags_pooling_v1)
mv_avg_pooling = tf.keras.layers.GlobalAveragePooling1D()(mv_tags_pooling_v1)

# mv_avg_pooling = tf.keras.layers.GlobalAveragePooling2D()(mv_tags_embedding)
# mv_avg_pooling = tf.keras.layers.GlobalAveragePooling1D()(mv_tags_pooling_v1)

test_mv_tags_model = tf.keras.Model(
    inputs=mv_tags_input_layer, outputs=mv_avg_pooling
)

test_mv_tags_model

In [None]:
mv_tags_text

In [None]:
# BATCH_SIZE_T=1
BATCH_SIZE_T=2

for x in train_dataset.batch(BATCH_SIZE_T).take(1):
    # print(x["movie_tags"])
    # print(test_user_id_model(data["movie_tags"]))
    # reshaped_tensor = tf.reshape(x['movie_tags'], [-1, 10])[0]
    reshaped_tensor = tf.reshape(x['movie_tags'], [-1, 10, 1])
    # reshaped_tensor = x['movie_tags'][0]
    test_value = test_mv_tags_model(reshaped_tensor)

In [None]:
x['movie_tags']

In [None]:
x['movie_tags'][0]

In [None]:
BATCH_SIZE_tt=1
TAG_LENGTH_1=10
# tf.reshape(x['movie_tags'], [BATCH_SIZE_tt, TAG_LENGTH_1, 1])
reshaped_tensor = tf.reshape(x['movie_tags'], [-1, MV_EMBEDDING_SIZE])
reshaped_tensor

In [None]:
test_value

In [None]:
# TENSOR_TEST_T = x['movie_tags']
TENSOR_TEST_T = x['movie_tags'][0]

reshaped_tensor_v2 = tf.reshape(TENSOR_TEST_T, [-1, 10, 1])
reshaped_tensor_v2

In [None]:
test_value

In [None]:
data

In [None]:
data['movie_tags']

In [None]:
# reshaped_tensor = tf.reshape(data['movie_tags'], [-1])
# reshaped_tensor

In [None]:
# reshaped_tensor = tf.reshape(data['movie_tags'], [-1, 10])
# reshaped_tensor

In [None]:
# reshaped_tensor = tf.reshape(data['movie_tags'], [-1, 10, 1])
# reshaped_tensor

In [None]:
# BATCH_SIZE_T=1
BATCH_SIZE_T=2

for x in train_dataset.batch(BATCH_SIZE_T).take(1):
    # print(x["movie_tags"])
    # print(test_user_id_model(data["movie_tags"]))
    BATCH_SIZE_tt = x["movie_tags"].shape[0]
    TAG_LENGTH_1 = x["movie_tags"].shape[1]
    print(f"BATCH_SIZE_tt : {BATCH_SIZE_tt}")
    print(f"TAG_LENGTH_1  : {TAG_LENGTH_1}")
    # reshaped_tensor = x['movie_tags']
    # reshaped_tensor = tf.reshape(x['movie_tags'], [-1])
    # reshaped_tensor = tf.reshape(x['movie_tags'], [-1, 10])[0]
    reshaped_tensor = tf.reshape(x['movie_tags'], [BATCH_SIZE_tt, TAG_LENGTH_1, 1])
    # reshaped_tensor = tf.reshape(x['movie_tags'], [BATCH_SIZE_tt, 4, 1])
    test_value = test_mv_tags_model(reshaped_tensor)
    
# test_mv_tags_model(data["movie_tags"])
# test_mv_tags_model(reshaped_tensor)

test_value

In [None]:
reshaped_tensor

In [None]:
# TENSOR_TEST_T = x['movie_tags']
TENSOR_TEST_T = x['movie_tags'][0]

reshaped_tensor_v2 = tf.reshape(TENSOR_TEST_T, [-1, 10, 1])
reshaped_tensor_v2

In [None]:
test_list_seq = [test_globals[0].numpy(), test_globals[0].numpy()]

test_list_seq

In [None]:
reshape_tensor_test = tf.reshape(test_list_seq, [-1, 48, 1])
reshape_tensor_test

In [None]:
# padded_inputs = tf.keras.utils.pad_sequences(test_list_seq, maxlen=20,truncating='post')
# padded_inputs

In [None]:
for x in train_dataset.padded_batch(HPARAMS['batch_size']).take(1):
    print(x)

#### tmp - delete START

In [None]:
class MovieModel(tf.keras.Model):

  def __init__(self):
    super().__init__()

    max_tokens = 10_000_00

    self.title_embedding = tf.keras.Sequential([
      tf.keras.layers.StringLookup(
          vocabulary=unique_movie_titles,mask_token=None),
      tf.keras.layers.Embedding(len(unique_movie_titles) + 1, 32)
    ])

    self.title_vectorizer = tf.keras.layers.TextVectorization(
        max_tokens=max_tokens,output_sequence_length = 4)

    self.title_text_embedding = tf.keras.Sequential(
        [
            self.title_vectorizer,
            tf.keras.layers.Embedding(max_tokens, 32, mask_zero=True),
        ]
    )
    self.title_vectorizer.adapt(movies)

  def call(self, titles, pool_size):
    avg_layer = tf.keras.layers.AveragePooling2D(
        pool_size=pool_size,strides=1,padding='valid',
    )
    len_titles=tf.shape(titles)[0]
    
    # return avg_layer(self.title_text_embedding(titles))
    return tf.concat(
        [
            self.title_embedding(tf.reshape(titles,[len_titles,5,1])),
            
            avg_layer(
                self.title_text_embedding(
                    tf.reshape(
                        titles,[len_titles,5,1]
                    )
                )
            ),
        ], 
        axis=3
    )

class MovielensModel(tfrs.models.Model):

  def __init__(self, layer_sizes):
    super().__init__()
    self.query_model = QueryModel(layer_sizes)
    self.candidate_model = CandidateModel(layer_sizes)
    self.rating_model = tf.keras.Sequential([
        tf.keras.layers.Dense(256, activation="relu"),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dense(1),
    ])
    # self.rating_task: tf.keras.layers.Layer = tfrs.tasks.Ranking(
    #     loss=tf.keras.losses.MeanSquaredError(),
    #     metrics=[tf.keras.metrics.RootMeanSquaredError()],
    # )
    self.rating_task: tf.keras.layers.Layer = tfrs.tasks.Ranking(
        loss=tfr.keras.losses.ListMLELoss(),
        metrics=[tfr.keras.metrics.NDCGMetric(name="ndcg_metric"),
                 tf.keras.metrics.RootMeanSquaredError()],
    )

  def call(self, features: Dict[Text, tf.Tensor]) -> tf.Tensor:
    # We pick out the user features and pass them into the user model.
    # print(len(features["movie_title"]))
    # for x in features["movie_title"]:
    #   print(x)
    self.query_embeddings = self.query_model({
        "user_id": features["user_id"],
        # "timestamp": features["timestamp"],
    })
    self.movie_embeddings = self.candidate_model(features["movie_title"],pool_size=(1,4))

    list_length = features["movie_title"].shape[1]
    
    self.query_embeddings_repeated = tf.repeat(
        tf.expand_dims(
            tf.expand_dims(
                self.query_embeddings, 1
            ), 1
        ), [list_length], axis=1
    )
    self.embd_concat=tf.concat(
        [
            self.query_embeddings_repeated, 
            self.movie_embeddings
        ], 3
    )
    return (
        self.query_embeddings,
        self.movie_embeddings,
        # We apply the multi-layered rating model to a concatentation of
        # user and movie embeddings.
        self.rating_model(
            self.embd_concat
        ),
    )

In [None]:
    len_titles=tf.shape(titles)[0]
    
    # return avg_layer(self.title_text_embedding(titles))
    return tf.concat(
        [
            self.title_embedding(tf.reshape(titles,[len_titles,5,1])),
            
            avg_layer(
                self.title_text_embedding(
                    tf.reshape(
                        titles,[len_titles,5,1]
                    )
                )
            ),
        ], 
        axis=3
    )

In [None]:
    list_length = features["movie_title"].shape[1]
    
    self.query_embeddings_repeated = tf.repeat(
        tf.expand_dims(
            tf.expand_dims(
                self.query_embeddings, 1
            ), 1
        ), [list_length], axis=1
    )
    self.embd_concat=tf.concat(
        [
            self.query_embeddings_repeated, 
            self.movie_embeddings
        ], 3
    )

#### tmp - delete END

## global context (user) features

#### user ID

In [None]:
user_id_input_layer = tf.keras.Input(
    name="user_id",
    shape=(1,),
    dtype=tf.string
)

user_id_lookup = tf.keras.layers.StringLookup(
    max_tokens=len(vocab_dict['user_id']) + NUM_OOV_BUCKETS,
    num_oov_indices=NUM_OOV_BUCKETS,
    mask_token=None,
    vocabulary=vocab_dict['user_id'],
)(user_id_input_layer)

user_id_embedding = tf.keras.layers.Embedding(
    # Let's use the explicit vocabulary lookup.
    input_dim=len(vocab_dict['user_id']) + NUM_OOV_BUCKETS,
    output_dim=GLOBAL_EMBEDDING_SIZE
)(user_id_lookup)

user_id_embedding = tf.reduce_sum(user_id_embedding, axis=-2)

# global_inputs.append(user_id_input_layer)
# global_features.append(user_id_embedding)

In [None]:
test_user_id_model = tf.keras.Model(inputs=user_id_input_layer, outputs=user_id_embedding)

# for x in train_dataset.batch(1).take(1):
#     print(x["user_id"])
#     print(test_user_id_model(x["user_id"]))

#### user AGE

In [None]:
# user_age_input_layer = tf.keras.Input(
#     name="bucketized_user_age",
#     shape=(1,),
#     dtype=tf.float32
# )

# user_age_lookup = tf.keras.layers.IntegerLookup(
#     vocabulary=vocab_dict['bucketized_user_age'],
#     num_oov_indices=NUM_OOV_BUCKETS,
#     oov_value=0,
# )(user_age_input_layer)

# user_age_embedding = tf.keras.layers.Embedding(
#     # Let's use the explicit vocabulary lookup.
#     input_dim=len(vocab_dict['bucketized_user_age']) + NUM_OOV_BUCKETS,
#     output_dim=GLOBAL_EMBEDDING_SIZE
# )(user_age_lookup)

# user_age_embedding = tf.reduce_sum(user_age_embedding, axis=-2)

# # global_inputs.append(user_age_input_layer)
# # global_features.append(user_age_embedding)

In [None]:
# test_user_age_model = tf.keras.Model(inputs=user_age_input_layer, outputs=user_age_embedding)

# # for x in train_dataset.batch(1).take(1):
# #     print(x["bucketized_user_age"])
# #     print(test_user_age_model(x["bucketized_user_age"]))

#### user OCC

In [None]:
# user_occ_input_layer = tf.keras.Input(
#     name="user_occupation_text",
#     shape=(1,),
#     dtype=tf.string
# )

# user_occ_lookup = tf.keras.layers.StringLookup(
#     max_tokens=len(vocab_dict['user_occupation_text']) + NUM_OOV_BUCKETS,
#     num_oov_indices=NUM_OOV_BUCKETS,
#     mask_token=None,
#     vocabulary=vocab_dict['user_occupation_text'],
# )(user_occ_input_layer)

# user_occ_embedding = tf.keras.layers.Embedding(
#     # Let's use the explicit vocabulary lookup.
#     input_dim=len(vocab_dict['user_occupation_text']) + NUM_OOV_BUCKETS,
#     output_dim=GLOBAL_EMBEDDING_SIZE
# )(user_occ_lookup)

# user_occ_embedding = tf.reduce_sum(user_occ_embedding, axis=-2)

# # global_inputs.append(user_occ_input_layer)
# # global_features.append(user_occ_embedding)

In [None]:
# test_user_occ_model = tf.keras.Model(inputs=user_occ_input_layer, outputs=user_occ_embedding)

# # for x in train_dataset.batch(1).take(1):
# #     print(x["user_occupation_text"])
# #     print(test_user_occ_model(x["user_occupation_text"]))

#### user Timestamp

In [None]:
# user_ts_input_layer = tf.keras.Input(
#     name="timestamp",
#     shape=(1,),
#     dtype=tf.int64
# )

# user_ts_lookup = tf.keras.layers.Discretization(
#     vocab_dict['timestamp_buckets'].tolist()
# )(user_ts_input_layer)

# user_ts_embedding = tf.keras.layers.Embedding(
#     # Let's use the explicit vocabulary lookup.
#     input_dim=len(vocab_dict['timestamp_buckets'].tolist()) + NUM_OOV_BUCKETS,
#     output_dim=GLOBAL_EMBEDDING_SIZE
# )(user_ts_lookup)

# user_ts_embedding = tf.reduce_sum(user_ts_embedding, axis=-2)

# # global_inputs.append(user_ts_input_layer)
# # global_features.append(user_ts_embedding)

In [None]:
# test_user_ts_model = tf.keras.Model(inputs=user_ts_input_layer, outputs=user_ts_embedding)

# # for x in train_dataset.batch(1).take(1):
# #     print(x["timestamp"])
# #     print(test_user_ts_model(x["timestamp"]))

### define global sampling function

In [None]:
# def _get_global_context_features(x):
#     """
#     This function generates a single global observation vector.
#     """
#     user_id_value = x['user_id']
#     user_age_value = x['bucketized_user_age']
#     user_occ_value = x['user_occupation_text']
#     user_ts_value = x['timestamp']

#     _id = test_user_id_model(user_id_value) # input_tensor=tf.Tensor(shape=(4,), dtype=float32)
#     _age = test_user_age_model(user_age_value)
#     _occ = test_user_occ_model(user_occ_value)
#     _ts = test_user_ts_model(user_ts_value)

#     # # tmp - insepct numpy() values
#     # print(_id.numpy()) #[0])
#     # print(_age.numpy()) #[0])
#     # print(_occ.numpy()) #[0])
#     # print(_ts.numpy()) #[0])

#     # to numpy array
#     _id = np.array(_id.numpy())
#     _age = np.array(_age.numpy())
#     _occ = np.array(_occ.numpy())
#     _ts = np.array(_ts.numpy())

#     concat = np.concatenate(
#         [_id, _age, _occ, _ts], axis=-1 # -1
#     ).astype(np.float32)

#     return concat

In [None]:
GLOBAL_DIM = _get_global_context_features(data).shape[1]
print(f"GLOBAL_DIM: {GLOBAL_DIM}")

## arm preprocessing layers

#### movie ID

In [None]:
# mv_id_input_layer = tf.keras.Input(
#     name="movie_id",
#     shape=(1,),
#     dtype=tf.string
# )

# mv_id_lookup = tf.keras.layers.StringLookup(
#     max_tokens=len(vocab_dict['movie_id']) + NUM_OOV_BUCKETS,
#     num_oov_indices=NUM_OOV_BUCKETS,
#     mask_token=None,
#     vocabulary=vocab_dict['movie_id'],
# )(mv_id_input_layer)

# mv_id_embedding = tf.keras.layers.Embedding(
#     # Let's use the explicit vocabulary lookup.
#     input_dim=len(vocab_dict['movie_id']) + NUM_OOV_BUCKETS,
#     output_dim=MV_EMBEDDING_SIZE
# )(mv_id_lookup)

# mv_id_embedding = tf.reduce_sum(mv_id_embedding, axis=-2)

# # arm_inputs.append(mv_id_input_layer)
# # arm_features.append(mv_id_embedding)

In [None]:
# test_mv_id_model = tf.keras.Model(inputs=mv_id_input_layer, outputs=mv_id_embedding)

# # for x in train_dataset.batch(1).take(1):
# #     print(x["movie_id"])
# #     print(test_mv_id_model(x["movie_id"]))

#### movie genre

In [None]:
# mv_genre_input_layer = tf.keras.Input(
#     name="movie_genres",
#     shape=(1,),
#     dtype=tf.float32
# )

# mv_genre_lookup = tf.keras.layers.IntegerLookup(
#     vocabulary=vocab_dict['movie_genres'],
#     num_oov_indices=NUM_OOV_BUCKETS,
#     oov_value=0,
# )(mv_genre_input_layer)

# mv_genre_embedding = tf.keras.layers.Embedding(
#     # Let's use the explicit vocabulary lookup.
#     input_dim=len(vocab_dict['movie_genres']) + NUM_OOV_BUCKETS,
#     output_dim=MV_EMBEDDING_SIZE
# )(mv_genre_lookup)

# mv_genre_embedding = tf.reduce_sum(mv_genre_embedding, axis=-2)

# # arm_inputs.append(mv_genre_input_layer)
# # arm_features.append(mv_genre_embedding)

In [None]:
# test_mv_gen_model = tf.keras.Model(inputs=mv_genre_input_layer, outputs=mv_genre_embedding)

# for x in train_dataset.batch(1).take(1):
#     print(x["movie_genres"])
#     print(test_mv_gen_model(x["movie_genres"]))

### define sampling function

In [None]:
# def _get_per_arm_features(x):
#     """
#     This function generates a single per-arm observation vector
#     """
#     mv_id_value = x['movie_id']
#     mv_gen_value = x['movie_genres']

#     _mid = test_mv_id_model(mv_id_value)
#     _mgen = test_mv_gen_model(mv_gen_value)

#     # to numpy array
#     _mid = np.array(_mid.numpy())
#     _mgen = np.array(_mgen.numpy())


#     concat = np.concatenate(
#         [_mid, _mgen], axis=-1 # -1
#     ).astype(np.float32)
#     # concat = tf.concat([_mid, _mgen], axis=-1).astype(np.float32)

#     return concat #this is special to this example - there is only one action dimensions

In [None]:
PER_ARM_DIM = _get_per_arm_features(data).shape[1] #shape checks out at batchdim, nactions, arm feats
print(f"PER_ARM_DIM: {PER_ARM_DIM}")

## Global & Per-Arm feature embedding models 

> all these dimensions should match the class output below

In [None]:
from src.perarm_features import emb_features as emb_features

embs = emb_features.EmbeddingModel(
    vocab_dict = vocab_dict,
    num_oov_buckets = NUM_OOV_BUCKETS,
    global_emb_size = GLOBAL_EMBEDDING_SIZE,
    mv_emb_size = MV_EMBEDDING_SIZE,
)

embs

In [None]:
test_globals = embs._get_global_context_features(data)

GLOBAL_DIM = test_globals.shape[1]            
# shape checks out at batch_dim, nactions, arm feats
print(f"GLOBAL_DIM: {GLOBAL_DIM}")

test_globals

In [None]:
test_arms = embs._get_per_arm_features(data)

PER_ARM_DIM = test_arms.shape[1]            
# shape checks out at batch_dim, nactions, arm feats
print(f"PER_ARM_DIM: {PER_ARM_DIM}")

test_arms