In [1]:
# !pip install tf-agents --user -q

In [76]:
!pip freeze | grep tf-agents

tf-agents==0.17.0


In [77]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

In [356]:
import functools
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from tf_agents.bandits.agents import lin_ucb_agent
from tf_agents.bandits.environments import stationary_stochastic_per_arm_py_environment as p_a_env
from tf_agents.bandits.metrics import tf_metrics as tf_bandit_metrics
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import tf_py_environment
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step as ts
import tensorflow_datasets as tfds
from pprint import pprint

nest = tf.nest

### movies data

In [357]:
movies = tfds.load("movielens/100k-movies", split="train")

for x in movies.batch(1).take(1):
    pprint(x)

{'movie_genres': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[4]])>,
 'movie_id': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'1681'], dtype=object)>,
 'movie_title': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'You So Crazy (1994)'], dtype=object)>}


### user and ratings data

In [358]:
ratings = tfds.load("movielens/100k-ratings", split="train")

for x in ratings.batch(1).take(1):
    pprint(x)

{'bucketized_user_age': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([45.], dtype=float32)>,
 'movie_genres': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[7]])>,
 'movie_id': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'357'], dtype=object)>,
 'movie_title': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b"One Flew Over the Cuckoo's Nest (1975)"], dtype=object)>,
 'raw_user_age': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([46.], dtype=float32)>,
 'timestamp': <tf.Tensor: shape=(1,), dtype=int64, numpy=array([879024327])>,
 'user_gender': <tf.Tensor: shape=(1,), dtype=bool, numpy=array([ True])>,
 'user_id': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'138'], dtype=object)>,
 'user_occupation_label': <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>,
 'user_occupation_text': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'doctor'], dtype=object)>,
 'user_rating': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([4.], dtype=float32)>,

#### Let's make this simple and load up movielens that has features
We will only consider for this example
1) The movie genere as an Arm feature (we will concatenate multiple genres)
2) The user occupation and age bucket labels for the overall context features

We need to load the data, get the ratings - light EDA for us to get cardnality of the dataset as well as lookups for the 

In [359]:
# Get the unique movies and users
unique_movie_ids = ratings.map(lambda x: x["movie_id"])
unique_movie_ids = np.unique([x.numpy() for x in unique_movie_ids])
MOVIELENS_NUM_MOVIES = len(unique_movie_ids)


print(f"len(unique_movie_ids) : {len(unique_movie_ids)}")
print(f"unique_movie_ids      : {unique_movie_ids[:2]}")

len(unique_movie_ids) : 1682
unique_movie_ids      : [b'1' b'10']


In [360]:
unique_user_ids = ratings.map(lambda x: x["user_id"])
unique_user_ids = np.unique([x.numpy() for x in unique_user_ids])
MOVIELENS_NUM_USERS = len(unique_user_ids)


print(f"len(unique_user_ids) : {len(unique_user_ids)}")
print(f"unique_user_ids      : {unique_user_ids[:2]}")

len(unique_user_ids) : 943
unique_user_ids      : [b'1' b'10']


In [361]:
## Get the unnique set of user buckets and create a lookup table

In [362]:
from typing import Dict

def get_dictionary_lookup_by_tf_data_key(key: str) -> Dict:
    tensor = ratings.map(lambda x: x[key])
    unique_elems = set()
    for x in tensor:
        val = x.numpy()
        if type(val) is np.ndarray: # if multi dimesnional only grab first one
            val = val[0]
        unique_elems.add(val)
    
    #return a dictionary of keys by integer values for the feature space
    return {val: i for i, val in enumerate(unique_elems)}


In [363]:
user_age_lookup = get_dictionary_lookup_by_tf_data_key('bucketized_user_age')
user_age_dim = len(user_age_lookup)

In [364]:
user_age_lookup

{1.0: 0, 35.0: 1, 45.0: 2, 18.0: 3, 50.0: 4, 56.0: 5, 25.0: 6}

In [365]:
user_occ_lookup = get_dictionary_lookup_by_tf_data_key('user_occupation_text')
user_occ_dim = len(user_occ_lookup)

In [366]:
user_occ_lookup

{b'salesman': 0,
 b'healthcare': 1,
 b'programmer': 2,
 b'lawyer': 3,
 b'marketing': 4,
 b'technician': 5,
 b'engineer': 6,
 b'entertainment': 7,
 b'student': 8,
 b'other': 9,
 b'homemaker': 10,
 b'retired': 11,
 b'administrator': 12,
 b'writer': 13,
 b'executive': 14,
 b'librarian': 15,
 b'scientist': 16,
 b'educator': 17,
 b'none': 18,
 b'artist': 19,
 b'doctor': 20}

In [367]:
movie_gen_lookup = get_dictionary_lookup_by_tf_data_key('movie_genres')
movie_gen_dim = len(movie_gen_lookup)

In [368]:
movie_gen_lookup

{0: 0,
 1: 1,
 2: 2,
 3: 3,
 4: 4,
 5: 5,
 6: 6,
 7: 7,
 8: 8,
 9: 9,
 10: 10,
 12: 11,
 13: 12,
 14: 13,
 15: 14,
 16: 15,
 17: 16,
 18: 17,
 19: 18}

In [369]:
# REFACTOR BELOW
 #from https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/environments/dataset_utilities.py#L153
    
# def load_movielens_data(data_file, delimiter=','):
#     """Loads the movielens data and returns the ratings matrix."""
#     ratings_matrix = np.zeros([MOVIELENS_NUM_USERS, MOVIELENS_NUM_MOVIES])
#     with tf.io.gfile.GFile(data_file, 'r') as infile:
#     # The file is a csv with rows containing:
#     # user id | item id | rating | timestamp
#     reader = csv.reader(infile, delimiter=delimiter)
#     for row in reader:
#         user_id, item_id, rating, _ = row
#         ratings_matrix[int(user_id) - 1, int(item_id) - 1] = float(rating)
#     return ratings_matrix



def load_movielens_data(ratings_dataset):
    # ratings = tfds.load("movielens/100k-ratings", split="train")
    ratings_matrix = np.zeros([MOVIELENS_NUM_USERS, MOVIELENS_NUM_MOVIES])
    local_data = ratings_dataset.map(lambda x: {'user_id': x['user_id']
                                                 ,'movie_id':  x['movie_id']
                                                 ,'user_rating':  x['user_rating']
                                                 ,'bucketized_user_age': x['bucketized_user_age']
                                                 ,'user_occupation_text': x['user_occupation_text']
                                                 ,'movie_genres': x['movie_genres'][0]
                                               }
                                                                         )
    user_age_int = []
    user_occ_int = []
    mov_gen_int = []
    for row in local_data:
        ratings_matrix[int(row['user_id'].numpy()) - 1, int(row['movie_id'].numpy()) - 1] = float(row['user_rating'].numpy())
        user_age_int.append(user_age_lookup[row['bucketized_user_age'].numpy()])
        user_occ_int.append(user_occ_lookup[row['user_occupation_text'].numpy()])
        mov_gen_int.append(movie_gen_lookup[row['movie_genres'].numpy()])
    return tf.convert_to_tensor(ratings_matrix, dtype=tf.float32), tf.convert_to_tensor(np.array(user_age_int), dtype=tf.float32), tf.convert_to_tensor(np.array(user_occ_int), dtype=tf.float32), tf.convert_to_tensor(np.array(mov_gen_int), dtype=tf.float32)
    

In [370]:
ratings_matrix, user_age_int, user_occ_int, mov_gen_int = load_movielens_data(ratings)

In [371]:
sampled_user_indices_np = np.random.randint(
            1000, size=8)
sampled_user_indices = tf.convert_to_tensor(sampled_user_indices_np, dtype=tf.int32)
sampled_user_indices = np.expand_dims(sampled_user_indices,axis=-1) #expand out to individual indicies to match sizes for slicing

In [372]:
sampled_user_indices

array([[959],
       [986],
       [489],
       [534],
       [802],
       [395],
       [750],
       [446]], dtype=int32)

In [373]:
sampled_user_ages = tf.gather_nd(indices=sampled_user_indices
                                         , params=user_age_int
                                         , batch_dims=0)

### Now do the same with the movies

In [374]:
sampled_movie_indices_np = np.array([
            random.sample(range(1000), 5)
            for _ in range(8)
        ])
sampled_movie_indices = tf.convert_to_tensor(sampled_movie_indices_np, dtype=tf.int32)
sampled_movie_indices

<tf.Tensor: shape=(8, 5), dtype=int32, numpy=
array([[281, 227, 331, 570,  93],
       [263, 406, 514, 217, 621],
       [983, 473, 308, 602,  31],
       [ 31, 251, 250,  20, 685],
       [744, 456, 195, 947, 526],
       [196, 141, 955, 861, 235],
       [361, 408, 563, 213, 392],
       [776, 393, 782, 622, 287]], dtype=int32)>

In [375]:
# movie_index_vector = tf.reshape(sampled_movie_indices, shape=[-1])
movie_index_vector = tf.convert_to_tensor(sampled_movie_indices, dtype=tf.int32)
movie_index_vector = tf.expand_dims(tf.reshape(movie_index_vector, shape=[-1]), axis=-1)
# flat_genre_list = self._mov_gen_int[movie_index_vector] #shape of 1
movie_index_vector

<tf.Tensor: shape=(40, 1), dtype=int32, numpy=
array([[281],
       [227],
       [331],
       [570],
       [ 93],
       [263],
       [406],
       [514],
       [217],
       [621],
       [983],
       [473],
       [308],
       [602],
       [ 31],
       [ 31],
       [251],
       [250],
       [ 20],
       [685],
       [744],
       [456],
       [195],
       [947],
       [526],
       [196],
       [141],
       [955],
       [861],
       [235],
       [361],
       [408],
       [563],
       [213],
       [392],
       [776],
       [393],
       [782],
       [622],
       [287]], dtype=int32)>

In [376]:
mov_gen_int

<tf.Tensor: shape=(100000,), dtype=float32, numpy=array([ 7.,  4.,  4., ..., 10.,  0.,  4.], dtype=float32)>

In [377]:
flat_genre_list = tf.gather_nd(indices=movie_index_vector
                               , params=mov_gen_int
                               , batch_dims=0) #shape of 1
flat_genre_list

<tf.Tensor: shape=(40,), dtype=float32, numpy=
array([10.,  7.,  0.,  4.,  7.,  3.,  7.,  3.,  1.,  4.,  5.,  2.,  0.,
        0.,  0.,  0.,  0.,  0.,  5.,  5.,  0.,  1.,  7.,  2.,  2., 12.,
        4.,  2.,  9.,  4.,  7.,  7.,  9.,  0.,  4.,  0.,  0.,  7.,  4.,
        4.], dtype=float32)>

In [378]:
#### Tf SVD

In [379]:
s, u, vh = tf.linalg.svd(ratings_matrix, full_matrices=False)

rank_k = 4

# Keep only the largest singular values.
u_hat = u[:, :rank_k]
s_hat = s[:rank_k]
v_hat = tf.transpose(vh[:rank_k])

## Replicate an agent using the above data

https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/environments/movielens_per_arm_py_environment.py

Create an arm spec from this utility function
https://www.tensorflow.org/agents/api_docs/python/tf_agents/specs/bandit_spec_utils/create_per_arm_observation_spec

#### NOT Used but helpful to create an obs spec:

```python
# Example observation spec from above
# There are 20 user occupations and 7 age buckets. This makes our global dimension 27
# There are 19 genres, and that will be the arm dimension for this example

from tf_agents.specs.bandit_spec_utils import create_per_arm_observation_spec as create_obs_spec
create_obs_spec(
    global_dim = 1,
    per_arm_dim = 2,
    max_num_actions = 10,
    add_num_actions_feature = False
) 
```

In [392]:
"""Class implementation of the per-arm MovieLens Bandit environment."""
from __future__ import absolute_import

import random
from typing import Optional, Text
import gin
import numpy as np

from tf_agents.bandits.environments import bandit_py_environment
from tf_agents.bandits.environments import dataset_utilities
from tf_agents.bandits.specs import utils as bandit_spec_utils
from tf_agents.specs import array_spec
from tf_agents.trajectories import time_step as ts


GLOBAL_KEY = bandit_spec_utils.GLOBAL_FEATURE_KEY
PER_ARM_KEY = bandit_spec_utils.PER_ARM_FEATURE_KEY


# @gin.configurable
class MovieLensPerArmPyEnvironment(bandit_py_environment.BanditPyEnvironment):
    """Implements the per-arm version of the MovieLens Bandit environment.

    This environment implements the MovieLens 100K dataset, available at:
    https://www.kaggle.com/prajitdatta/movielens-100k-dataset

    This dataset contains 100K ratings from 943 users on 1682 items.
    This csv list of:
    user id | item id | rating | timestamp.
    This environment computes a low-rank matrix factorization (using SVD) of the
    data matrix `A`, such that: `A ~= U * Sigma * V^T`.

    The environment uses the rows of `U` as global (or user) features, and the
    rows of `V` as per-arm (or movie) features.

    The reward of recommending movie `v` to user `u` is `u * Sigma * v^T`.
    """

    def __init__(self,
               dataset = ratings,
               rank_k: int = 2,
               batch_size: int = 10,
               num_actions: int = 100,
               name: Optional[Text] = 'movielens_per_arm'):
        """Initializes the Per-arm MovieLens Bandit environment.

        Args:
          data_dir: (string) Directory where the data lies (in text form).
          rank_k : (int) Which rank to use in the matrix factorization. This will
            also be the feature dimension of both the user and the movie features.
          batch_size: (int) Number of observations generated per call.
          num_actions: (int) How many movies to choose from per round.
          csv_delimiter: (string) The delimiter to use in loading the data csv file.
          name: (string) The name of this environment instance.
        """
        self._batch_size = batch_size
        self._num_actions = num_actions
        self.rank_k = rank_k

        # Compute the matrix factorization.
        # self._data_matrix = dataset_utilities.load_movielens_data(
        #     data_dir, delimiter=csv_delimiter)

        self._data_matrix, self._user_age_int, self._user_occ_int, self._mov_gen_int = load_movielens_data(ratings)
        self._num_users, self._num_movies = self._data_matrix.shape

        # Compute the SVD.
        s, u, vh = tf.linalg.svd(self._data_matrix, full_matrices=False)

        # Keep only the largest singular values.
        self._u_hat = u[:, :rank_k]#.astype(tf.float32)
        self._s_hat = s[:rank_k]#.astype(tf.float32)
        self._v_hat = vh[:, :rank_k]#.astype(tf.float32)

        self._approx_ratings_matrix = tf.matmul(self._u_hat * self._s_hat,
                                                tf.transpose(self._v_hat))

        self._action_spec = array_spec.BoundedArraySpec(
            shape=(),
            dtype=np.int32,
            minimum=0,
            maximum=num_actions - 1,
            name='action')
        observation_spec = {
            GLOBAL_KEY:
                array_spec.ArraySpec(shape=[rank_k+2], dtype=np.float32), #creating +space for user age and occupation
            PER_ARM_KEY:
                array_spec.ArraySpec(
                    shape=[num_actions, rank_k+1], dtype=np.float32), #creating +1 space for movie genre
        }
        self._time_step_spec = ts.time_step_spec(observation_spec)

        self._current_user_indices = tf.zeros(batch_size, dtype=np.int32)
        self._previous_user_indices = tf.zeros(batch_size, dtype=np.int32)

        self._current_movie_indices = tf.zeros([batch_size, num_actions],
                                               dtype=np.int32)
        self._previous_movie_indices = tf.zeros([batch_size, num_actions],
                                                dtype=np.int32)

        self._observation = {
            GLOBAL_KEY:
                tf.zeros([batch_size, rank_k+2], dtype=np.int32), #making space like above for dimensions
            PER_ARM_KEY:
                tf.zeros([batch_size, num_actions, rank_k+1], dtype=np.int32),
        }

        super(MovieLensPerArmPyEnvironment, self).__init__(
            observation_spec, self._action_spec, name=name)

    @property
    def batch_size(self):
        return self._batch_size

    @property
    def batched(self):
        return True

    def _observe(self):
        
        #user section - random sample users
        sampled_user_indices_np = np.random.randint(
            self._num_users, size=self._batch_size)
        sampled_user_indices_1d = tf.convert_to_tensor(sampled_user_indices_np, dtype=tf.int32)
        sampled_user_indices = tf.expand_dims(sampled_user_indices_1d, axis=-1)
        self._previous_user_indices = self._current_user_indices
        self._current_user_indices = sampled_user_indices
        
        #sample feature values
        sampled_user_ages = tf.gather_nd(indices=sampled_user_indices, params=self._user_age_int)
        sampled_user_occ = tf.gather_nd(indices=sampled_user_indices, params=self._user_occ_int)
        latent_user_features = tf.gather_nd(indices=sampled_user_indices, params=self._u_hat)
        combined_user_features = tf.concat([latent_user_features
                                                 , tf.expand_dims(sampled_user_ages, axis=-1)
                                                 , tf.expand_dims(sampled_user_occ, axis=-1)], axis=1)
        ###movie section

        sampled_movie_indices_np = np.array([
            random.sample(range(self._num_movies), self._num_actions)
            for _ in range(self._batch_size)
        ])
        sampled_movie_indices = tf.convert_to_tensor(sampled_movie_indices_np, dtype=tf.int32)
        print(sampled_movie_indices)
        
        
        
        self.movie_index_vector = tf.expand_dims(tf.reshape(sampled_movie_indices, shape=[-1]), axis=-1)
        flat_genre_list = tf.gather_nd(indices=self.movie_index_vector, params=self._mov_gen_int) #shape of 1
        reshaped_genre_features = tf.reshape(flat_genre_list, shape = [self._batch_size, self._num_actions])
        latent_movie_features = tf.gather_nd(indices=movie_index_vector, params=self._v_hat) #shape of 2
        # shape[0] = [40,20]
        latent_movie_features = tf.reshape(latent_movie_features, shape=[self._batch_size, self._num_actions])
        combined_movie_features = tf.concat([latent_movie_features
                                             , reshaped_genre_features], axis=1)
        current_movies = tf.reshape(combined_movie_features
                                    , shape = [self._batch_size, self._num_actions, self.rank_k+1]
                                   )

        self._previous_movie_indices = self._current_movie_indices
        self._current_movie_indices = sampled_movie_indices

        batched_observations = {
            GLOBAL_KEY:
                combined_user_features,
            PER_ARM_KEY:
                current_movies,
        }
        return batched_observations

    def _apply_action(self, action):
        chosen_arm_indices = tf.gather_nd(indices=self._current_movie_indices, params=action)
        return self._approx_ratings_matrix[self._current_user_indices,
                                           chosen_arm_indices]

    def _rewards_for_all_actions(self):
        rewards_matrix = self._approx_ratings_matrix[
            tf.expand_dims(self._previous_user_indices, axis=-1),
            self._previous_movie_indices]
        return rewards_matrix

    def compute_optimal_action(self):
        return np.argmax(self._rewards_for_all_actions(), axis=-1)

    def compute_optimal_reward(self):
        return np.max(self._rewards_for_all_actions(), axis=-1)

In [393]:
# env = MovieLensPerArmPyEnvironment()

In [394]:
# print('observation spec: ', env.observation_spec())
# print('\nAn observation: ', env.reset().observation)

### Now that the environment is created, let's optimize

Taken from here
https://github.com/tensorflow/agents/blob/5e5915b0a3650a15e82e77af6e37f41a6c744689/tf_agents/bandits/agents/examples/v2/train_eval_movielens.py#L84

In [395]:
import functools
import os
from absl import app
from absl import flags

import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import
from tf_agents.bandits.agents import dropout_thompson_sampling_agent as dropout_ts_agent
from tf_agents.bandits.agents import lin_ucb_agent
from tf_agents.bandits.agents import linear_thompson_sampling_agent as lin_ts_agent
from tf_agents.bandits.agents import neural_epsilon_greedy_agent as eps_greedy_agent
from tf_agents.bandits.agents.examples.v2 import trainer
from tf_agents.bandits.environments import environment_utilities
from tf_agents.bandits.environments import movielens_per_arm_py_environment
from tf_agents.bandits.environments import movielens_py_environment
from tf_agents.bandits.metrics import tf_metrics as tf_bandit_metrics
from tf_agents.bandits.networks import global_and_arm_feature_network
from tf_agents.environments import tf_py_environment
from tf_agents.networks import q_network

BATCH_SIZE = 8
TRAINING_LOOPS = 20000
STEPS_PER_LOOP = 2

RANK_K = 20
NUM_ACTIONS = 20

# LinUCB agent constants.

AGENT_ALPHA = 10.0

# epsilon Greedy constants.

EPSILON = 0.05
LAYERS = (50, 50, 50)
LR = 0.005

# Dropout TS constants.
DROPOUT_RATE = 0.2

In [396]:
tf.compat.v1.enable_v2_behavior()

In [397]:
env = MovieLensPerArmPyEnvironment(
        rank_k=RANK_K,
        batch_size=BATCH_SIZE,
        num_actions=NUM_ACTIONS,
)
environment = tf_py_environment.TFPyEnvironment(env)

### Note we will be using the reward function with this utility function

```python
@gin.configurable
def compute_optimal_reward_with_movielens_environment(observation, environment):
  """Helper function for gin configurable Regret metric."""
  del observation
  return tf.py_function(environment.compute_optimal_reward, [], tf.float32)

@gin.configurable
def compute_optimal_action_with_movielens_environment(observation,
                                                      environment,
                                                      action_dtype=tf.int32):
  """Helper function for gin configurable SuboptimalArms metric."""
  del observation
  return tf.py_function(environment.compute_optimal_action, [], action_dtype)
```

In [398]:
optimal_reward_fn = functools.partial(
      environment_utilities.compute_optimal_reward_with_movielens_environment,
      environment=environment)

optimal_action_fn = functools.partial(
  environment_utilities.compute_optimal_action_with_movielens_environment,
  environment=environment)

### Below we will try different agents by selecting one of the enumerated types:

```python
flags.DEFINE_enum(
    'agent', 'LinUCB', ['LinUCB', 'LinTS', 'epsGreedy', 'DropoutTS'],
    'Which agent to use. Possible values: `LinUCB`, `LinTS`, `epsGreedy`,'
    ' `DropoutTS`.')
```

In [399]:
AGENT_TYPE = 'LinUCB'

In [400]:
if AGENT_TYPE == 'LinUCB':
    agent = lin_ucb_agent.LinearUCBAgent(
        time_step_spec=environment.time_step_spec(),
        action_spec=environment.action_spec(),
        tikhonov_weight=0.001,
        alpha=AGENT_ALPHA,
        dtype=tf.float32,
        accepts_per_arm_features=True)

elif AGENT_TYPE == 'LinTS':
    agent = lin_ts_agent.LinearThompsonSamplingAgent(
        time_step_spec=environment.time_step_spec(),
        action_spec=environment.action_spec(),
        dtype=tf.float32,
        accepts_per_arm_features=True)

elif AGENT_TYPE == 'epsGreedy':
    network = (
      global_and_arm_feature_network
      .create_feed_forward_dot_product_network(
          environment.time_step_spec().observation,
          global_layers=LAYERS,
          arm_layers=LAYERS))

    agent = eps_greedy_agent.NeuralEpsilonGreedyAgent(
        time_step_spec=environment.time_step_spec(),
        action_spec=environment.action_spec(),
        reward_network=network,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=LR),
        epsilon=EPSILON,
        emit_policy_info='predicted_rewards_mean',
        info_fields_to_inherit_from_greedy=['predicted_rewards_mean'])

elif AGENT_TYPE == 'DropoutTS':
    train_step_counter = tf.compat.v1.train.get_or_create_global_step()

    def dropout_fn():
        return tf.math.maximum(
          tf.math.reciprocal_no_nan(1.01 +
                                    tf.cast(train_step_counter, tf.float32)),
          0.0003)

    agent = dropout_ts_agent.DropoutThompsonSamplingAgent(
        time_step_spec=environment.time_step_spec(),
        action_spec=environment.action_spec(),
        dropout_rate=dropout_fn,
        network_layers=LAYERS,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=LR))

regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward_fn)
suboptimal_arms_metric = tf_bandit_metrics.SuboptimalArmsMetric(
  optimal_action_fn)

### Now train the MAB Agent

Create a local checkpoint folder if you already have not
!mkdir checkpoint

In [401]:
# !mkdir checkpoint

In [402]:
trainer.train(
      root_dir='checkpoint',
      agent=agent,
      environment=environment,
      training_loops=TRAINING_LOOPS,
      steps_per_loop=STEPS_PER_LOOP,
      additional_metrics=[regret_metric, suboptimal_arms_metric])

tf.Tensor(
[[ 991 1142  200 1605 1215  434  208 1051  233  305  942  941 1086 1586
   871 1041 1645  483 1506  180]
 [ 721  182 1240 1114  866  915  332  208 1434  683 1514  587 1259 1665
  1604  426  281 1506 1439  454]
 [ 518 1350 1420   96  618  134  629 1448  896  829  470  600 1465  836
  1559  185 1388 1214 1580 1213]
 [ 493  919  876  632  735  624  850  617  154 1107  492 1621  662   17
   353  717  309  149 1624    4]
 [ 424  230   17 1135  250 1608 1046 1256  305  529  800  475 1457   83
   809  982 1199 1140  147  759]
 [1016  768 1127 1009 1346 1458  541  333 1289 1430  616 1347 1569 1277
  1210  721 1371 1097 1002  562]
 [  91 1419  412  339  844  365  635  988 1189 1467 1293 1590  967  858
   175  483   61    1  244  188]
 [1282 1271  164 1192  844  737  793  463  964   45 1400  801  196  947
   897  152 1174 1156 1007  541]], shape=(8, 20), dtype=int32)


InvalidArgumentError: {{function_node __wrapped__Reshape_device_/job:localhost/replica:0/task:0/device:CPU:0}} Input to reshape is a tensor with 800 values, but the requested shape has 160 [Op:Reshape]