# Baseline ranking example

> see how the pieces fit together with synthetic data

## trainer.py

In [2]:
"""Generic TF-Agents training function for bandits."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

from typing import Callable, Dict, List, Optional, TypeVar

from absl import logging
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import
from tf_agents.bandits.replay_buffers import bandit_replay_buffer
from tf_agents.drivers import dynamic_step_driver
from tf_agents.eval import metric_utils
# from tf_agents.google.metrics import export_utils
from tf_agents.metrics import export_utils
from tf_agents.metrics import tf_metrics
from tf_agents.policies import policy_saver

tf = tf.compat.v2

AGENT_CHECKPOINT_NAME = 'agent'
STEP_CHECKPOINT_NAME = 'step'
CHECKPOINT_FILE_PREFIX = 'ckpt'

# GPU
from numba import cuda 
import gc

# logging
import logging
logging.disable(logging.WARNING)

import warnings
warnings.filterwarnings('ignore')

# tf exceptions and vars
if tf.__version__[0] != "2":
    raise Exception("The trainer only runs with TensorFlow version 2.")

T = TypeVar("T")

In [3]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


In [5]:
device = cuda.get_current_device()
device.reset()
gc.collect()

14

In [6]:
def _get_replay_buffer(
    data_spec, batch_size, steps_per_loop, async_steps_per_loop
):
    """Return a `TFUniformReplayBuffer` for the given `agent`."""
    return bandit_replay_buffer.BanditReplayBuffer(
        data_spec=data_spec,
        batch_size=batch_size,
        max_length=steps_per_loop * async_steps_per_loop,
    )


def set_expected_shape(experience, num_steps):
    """Sets expected shape."""

    def set_time_dim(input_tensor, steps):
        tensor_shape = input_tensor.shape.as_list()
        if len(tensor_shape) < 2:
            raise ValueError(
                'input_tensor is expected to be of rank-2, but found otherwise: '
                f'input_tensor={input_tensor}, tensor_shape={tensor_shape}'
            )
        tensor_shape[1] = steps
        input_tensor.set_shape(tensor_shape)

    tf.nest.map_structure(lambda t: set_time_dim(t, num_steps), experience)


def _get_training_loop(
    driver, replay_buffer, agent, steps, async_steps_per_loop
):
    """Returns a `tf.function` that runs the driver and training loops.

    Args:
    driver: an instance of `Driver`.
    replay_buffer: an instance of `ReplayBuffer`.
    agent: an instance of `TFAgent`.
    steps: an integer indicating how many driver steps should be executed and
      presented to the trainer during each training loop.
    async_steps_per_loop: an integer. In each training loop, the driver runs
      this many times, and then the agent gets asynchronously trained over this
      many batches sampled from the replay buffer.
    """

    def _export_metrics_and_summaries(step, metrics):
        """Exports metrics and tf summaries."""
        metric_utils.log_metrics(metrics)
        export_utils.export_metrics(step=step, metrics=metrics)
        for metric in metrics:
            metric.tf_summaries(train_step=step)

    def training_loop(train_step, metrics):
        """Returns a function that runs a single training loop and logs metrics."""
        for batch_id in range(async_steps_per_loop):
            driver.run()
            _export_metrics_and_summaries(
                step=train_step * async_steps_per_loop + batch_id, metrics=metrics
            )
        batch_size = driver.env.batch_size
        dataset_it = iter(
            replay_buffer.as_dataset(
                sample_batch_size=batch_size,
                num_steps=steps,
                single_deterministic_pass=True,
            )
        )
        for batch_id in range(async_steps_per_loop):
            experience, unused_buffer_info = dataset_it.get_next()
            set_expected_shape(experience, steps)
            loss_info = agent.train(experience)
            export_utils.export_metrics(
                step=train_step * async_steps_per_loop + batch_id,
                metrics=[],
                loss_info=loss_info,
            )
            if train_step % 100 == 0:
                print(
                    f'step = {train_step}: train loss = {round(loss_info.loss.numpy(), 2)}'
                )

        replay_buffer.clear()

    return training_loop

In [7]:
def restore_and_get_checkpoint_manager(root_dir, agent, metrics, step_metric):
    """Restores from `root_dir` and returns a function that writes checkpoints."""
    trackable_objects = {metric.name: metric for metric in metrics}
    trackable_objects[AGENT_CHECKPOINT_NAME] = agent
    trackable_objects[STEP_CHECKPOINT_NAME] = step_metric
    checkpoint = tf.train.Checkpoint(**trackable_objects)
    checkpoint_manager = tf.train.CheckpointManager(
        checkpoint=checkpoint, directory=root_dir, max_to_keep=5
    )
    latest = checkpoint_manager.latest_checkpoint
    if latest is not None:
        logging.info('Restoring checkpoint from %s.', latest)
        checkpoint.restore(latest)
        logging.info('Successfully restored to step %s.', step_metric.result())
    else:
        logging.info(
            'Did not find a pre-existing checkpoint. Starting from scratch.'
        )
    return checkpoint_manager

In [8]:
def train(
    root_dir,
    agent,
    environment,
    training_loops,
    steps_per_loop,
    async_steps_per_loop=None,
    additional_metrics=(),
    get_replay_buffer_fn=None,
    get_training_loop_fn=None,
    training_data_spec_transformation_fn=None,
    save_policy=True,
    resume_training_loops=False,
):
    """Perform `training_loops` iterations of training.

    Checkpoint results.

    If one or more baseline_reward_fns are provided, the regret is computed
    against each one of them. Here is example baseline_reward_fn:

    def baseline_reward_fn(observation, per_action_reward_fns):
    rewards = ... # compute reward for each arm
    optimal_action_reward = ... # take the maximum reward
    return optimal_action_reward

    Args:
    root_dir: path to the directory where checkpoints and metrics will be
      written.
    agent: an instance of `TFAgent`.
    environment: an instance of `TFEnvironment`.
    training_loops: an integer indicating how many training loops should be run.
    steps_per_loop: an integer indicating how many driver steps should be
      executed in a single driver run.
    async_steps_per_loop: an optional integer for simulating offline or
      asynchronous training: In each training loop iteration, the driver runs
      this many times, each executing `steps_per_loop` driver steps, and then
      the agent gets asynchronously trained over this many batches sampled from
      the replay buffer. When unset or set to 1, the function performs
      synchronous training, where the agent gets trained on a single batch
      immediately after the driver runs.
    additional_metrics: Tuple of metric objects to log, in addition to default
      metrics `NumberOfEpisodes`, `AverageReturnMetric`, and
      `AverageEpisodeLengthMetric`.
    get_replay_buffer_fn: An optional function that creates a replay buffer by
      taking a data_spec, batch size, the number of driver steps per loop, and
      the number of asynchronous training steps per loop. Note that the returned
      replay buffer will be passed to `get_training_loop_fn` below to generate a
      traininig loop function. If `None`, the `get_replay_buffer` function
      defined in this module will be used.
    get_training_loop_fn: An optional function that constructs the traininig
      loop function executing a single train step. This function takes a driver,
      a replay buffer, an agent, the number of driver steps per loop, and the
      number of asynchronous training steps per loop. If `None`, the
      `get_training_loop` function defined in this module will be used.
    training_data_spec_transformation_fn: Optional function that transforms the
      data items before they get to the replay buffer.
    save_policy: (bool) whether to save the policy or not.
    resume_training_loops: A boolean flag indicating whether `training_loops`
      should be enforced relatively to the initial (True) or the last (False)
      checkpoint.
    """

    # TODO(b/127641485): create evaluation loop with configurable metrics.
    if training_data_spec_transformation_fn is None:
        data_spec = agent.policy.trajectory_spec
    else:
        data_spec = training_data_spec_transformation_fn(
            agent.policy.trajectory_spec
        )
    if async_steps_per_loop is None:
        async_steps_per_loop = 1
    if get_replay_buffer_fn is None:
        get_replay_buffer_fn = _get_replay_buffer
    replay_buffer = get_replay_buffer_fn(
        data_spec, environment.batch_size, steps_per_loop, async_steps_per_loop
    )

    # `step_metric` records the number of individual rounds of bandit interaction;
    # that is, (number of trajectories) * batch_size.
    step_metric = tf_metrics.EnvironmentSteps()
    metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.AverageEpisodeLengthMetric(batch_size=environment.batch_size),
    ] + list(additional_metrics)

    # If the reward anything else than a single scalar, we're adding multimetric
    # average reward.
    if isinstance(
        environment.reward_spec(), dict
    ) or environment.reward_spec().shape != tf.TensorShape(()):
        metrics += [
            tf_metrics.AverageReturnMultiMetric(
                reward_spec=environment.reward_spec(),
                batch_size=environment.batch_size,
            )
        ]
    if not isinstance(environment.reward_spec(), dict):
        metrics += [
            tf_metrics.AverageReturnMetric(batch_size=environment.batch_size)
        ]

    if training_data_spec_transformation_fn is not None:
        add_batch_fn = lambda data: replay_buffer.add_batch(  # pylint: disable=g-long-lambda
        training_data_spec_transformation_fn(data)
    )
    else:
        add_batch_fn = replay_buffer.add_batch

    observers = [add_batch_fn, step_metric] + metrics

    driver = dynamic_step_driver.DynamicStepDriver(
        env=environment,
        policy=agent.collect_policy,
        num_steps=steps_per_loop * environment.batch_size,
        observers=observers,
    )

    if get_training_loop_fn is None:
        get_training_loop_fn = _get_training_loop
    training_loop = get_training_loop_fn(
        driver, replay_buffer, agent, steps_per_loop, async_steps_per_loop
    )
    checkpoint_manager = restore_and_get_checkpoint_manager(
        root_dir, agent, metrics, step_metric
    )
    train_step_counter = tf.compat.v1.train.get_or_create_global_step()
    if save_policy:
        saver = policy_saver.PolicySaver(
            agent.policy, train_step=train_step_counter
        )

    summary_writer = tf.summary.create_file_writer(root_dir)
    summary_writer.set_as_default()

    if resume_training_loops:
        train_step_count_per_loop = (
            steps_per_loop * environment.batch_size * async_steps_per_loop
        )
        last_checkpointed_step = step_metric.result().numpy()
        if last_checkpointed_step % train_step_count_per_loop != 0:
            raise ValueError(
                'Last checkpointed step is expected to be a multiple of '
                'steps_per_loop * batch_size * async_steps_per_loop, but found '
                f'otherwise: last checkpointed step: {last_checkpointed_step}, '
                f'steps_per_loop: {steps_per_loop}, batch_size: '
                f'{environment.batch_size}, async_steps_per_loop: '
                f'{async_steps_per_loop}'
            )
        starting_loop = last_checkpointed_step // train_step_count_per_loop
    else:
        starting_loop = 0

    for i in range(starting_loop, training_loops):
        training_loop(train_step=i, metrics=metrics)
        checkpoint_manager.save()
        if save_policy & (i % 100 == 0):
            saver.save(os.path.join(root_dir, 'policy_%d' % step_metric.result()))

## train_eval_ranking

In [9]:
"""End-to-end test for ranking."""

import os
from absl import app
from absl import flags
import numpy as np
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import
from tf_agents.bandits.agents import ranking_agent
from tf_agents.bandits.agents.examples.v2 import trainer
from tf_agents.bandits.environments import ranking_environment
from tf_agents.bandits.networks import global_and_arm_feature_network
from tf_agents.environments import tf_py_environment
from tf_agents.specs import bandit_spec_utils
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import trajectory

In [10]:


### CHANGE ME

root_dir           = "gs://rec-bandits-v1-hybrid-vertex-bucket/tmp-example-rootdir-v3"


policy_type        = "cosine_distance"
ff_feedback_model  = "cascading"                  # "score_vector" | "cascading"
click_model        = "ghost_actions"
distance_threshold = 5.0
env_type           = "base"
bias_type          = ""
bias_severity      = 1.0
bias_positive_only = False

In [14]:
# Environment and driver parameters.

BATCH_SIZE = 64
NUM_ITEMS = 20
NUM_SLOTS = 3
GLOBAL_DIM = 32
ITEM_DIM = 64

TRAINING_LOOPS = 300
STEPS_PER_LOOP = 2

LR = 0.05

In [15]:
min_dim = min(GLOBAL_DIM, ITEM_DIM)
min_dim

32

In [16]:
def _global_sampling_fn():
    return np.random.randint(-1, 1, [GLOBAL_DIM]).astype(np.float32)

def _item_sampling_fn():
    unnormalized = np.random.randint(-2, 3, [ITEM_DIM]).astype(np.float32)
    return unnormalized / np.linalg.norm(unnormalized)

def _relevance_fn(global_obs, item_obs):
    min_dim = min(GLOBAL_DIM, ITEM_DIM)
    dot_prod = np.dot(global_obs[:min_dim], item_obs[:min_dim]).astype(
        np.float32
    )
    return 1 / (1 + np.exp(-dot_prod))

In [17]:
global_obs_test = _global_sampling_fn()

# GLOBAL_DIM = global_obs_test[0]
print(f"global_obs_test.shape: {global_obs_test.shape}")

global_obs_test

global_obs_test.shape: (32,)


array([-1., -1.,  0.,  0., -1., -1.,  0.,  0., -1., -1.,  0.,  0., -1.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -1., -1.,  0.,
       -1.,  0.,  0., -1.,  0., -1.], dtype=float32)

In [18]:
item_obs_test = _item_sampling_fn()

# PER_ARM_DIM = item_obs_test.shape
print(f"item_obs_test shape: {item_obs_test.shape}")

item_obs_test

item_obs_test shape: (64,)


array([ 0.18898225, -0.09449112, -0.09449112,  0.        ,  0.09449112,
        0.        ,  0.09449112,  0.09449112, -0.09449112, -0.18898225,
       -0.18898225,  0.09449112,  0.        , -0.09449112,  0.        ,
        0.        ,  0.18898225,  0.18898225,  0.18898225,  0.18898225,
        0.18898225, -0.18898225,  0.18898225, -0.09449112,  0.09449112,
       -0.09449112,  0.09449112,  0.        , -0.18898225,  0.        ,
        0.09449112,  0.09449112, -0.18898225,  0.        ,  0.18898225,
       -0.09449112,  0.18898225,  0.        ,  0.        , -0.18898225,
        0.        ,  0.09449112,  0.09449112,  0.        ,  0.        ,
        0.09449112, -0.09449112, -0.09449112, -0.09449112, -0.18898225,
       -0.09449112, -0.09449112,  0.18898225, -0.18898225,  0.09449112,
        0.09449112, -0.18898225,  0.18898225,  0.        , -0.18898225,
        0.        ,  0.09449112,  0.09449112,  0.09449112], dtype=float32)

In [19]:
rel_test = _relevance_fn(global_obs_test, item_obs_test)
rel_test

0.4763947549717827

In [20]:
click_model

'ghost_actions'

In [21]:
ranking_agent.FeedbackModel.SCORE_VECTOR

<FeedbackModel.SCORE_VECTOR: 2>

In [22]:
if env_type == 'exp_pos_bias':
    positional_biases = list(1.0 / np.arange(1, NUM_SLOTS + 1) ** 1.3)
    
    env = ranking_environment.ExplicitPositionalBiasRankingEnvironment(
        _global_sampling_fn,
        _item_sampling_fn,
        _relevance_fn,
        NUM_ITEMS,
        positional_biases,
        batch_size=BATCH_SIZE,
    )
    
    feedback_model = ranking_agent.FeedbackModel.SCORE_VECTOR

elif env_type == 'base':
    # Inner product with the excess dimensions ignored.
    scores_weight_matrix = np.eye(ITEM_DIM, GLOBAL_DIM, dtype=np.float32)

    feedback_model = ranking_agent.FeedbackModel.SCORE_VECTOR
    
    if ff_feedback_model == 'cascading':
        feedback_model = ranking_agent.FeedbackModel.CASCADING
    else:
        raise NotImplementedError(
            'Feedback model {} not implemented'.format(feedback_model)
        )
        
    if click_model == 'ghost_actions':
        click_model = ranking_environment.ClickModel.GHOST_ACTIONS
    elif click_model == 'distance_based':
        click_model = ranking_environment.ClickModel.DISTANCE_BASED
    else:
        raise NotImplementedError(
            'Diversity mode {} not implemented'.format(click_model)
        )
        
    tf_env = ranking_environment.RankingPyEnvironment(
        _global_sampling_fn,
        _item_sampling_fn,
        num_items=NUM_ITEMS,
        num_slots=NUM_SLOTS,
        scores_weight_matrix=scores_weight_matrix,
        # TODO(b/247995883): Merge the two feedback model enums from the agent
        # and the enviroment.
        feedback_model=feedback_model.value,
        click_model=click_model,
        distance_threshold=distance_threshold,
        batch_size=BATCH_SIZE,
    )

In [23]:
tf_env.name

'ranking_environment'

In [24]:
print(f"scores_weight_matrix.shape: {scores_weight_matrix.shape}")
scores_weight_matrix

scores_weight_matrix.shape: (64, 32)


array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [25]:
environment = tf_py_environment.TFPyEnvironment(tf_env)
environment

<tf_agents.environments.tf_py_environment.TFPyEnvironment at 0x7fa73d3d0730>

In [26]:
obs_spec = environment.observation_spec()
obs_spec

{'global': TensorSpec(shape=(32,), dtype=tf.float32, name=None),
 'per_arm': TensorSpec(shape=(20, 64), dtype=tf.float32, name=None)}

In [27]:
environment.action_spec()

BoundedTensorSpec(shape=(3,), dtype=tf.int32, name='action', minimum=array(0, dtype=int32), maximum=array(19, dtype=int32))

In [28]:
environment.reward_spec()

{'chosen_index': BoundedTensorSpec(shape=(), dtype=tf.int32, name='chosen_index', minimum=array(0, dtype=int32), maximum=array(3, dtype=int32)),
 'chosen_value': TensorSpec(shape=(), dtype=tf.float32, name='chosen_value')}

In [29]:
environment.time_step_spec()

TimeStep(
{'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
 'observation': {'global': TensorSpec(shape=(32,), dtype=tf.float32, name=None),
                 'per_arm': TensorSpec(shape=(20, 64), dtype=tf.float32, name=None)},
 'reward': {'chosen_index': BoundedTensorSpec(shape=(), dtype=tf.int32, name='chosen_index', minimum=array(0, dtype=int32), maximum=array(3, dtype=int32)),
            'chosen_value': TensorSpec(shape=(), dtype=tf.float32, name='chosen_value')},
 'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type')})

In [30]:
# GLOBAL_LAYERS   = [GLOBAL_DIM, int(GLOBAL_DIM / 2)]
# ARM_LAYERS      = [PER_ARM_DIM, int(PER_ARM_DIM / 2), int(PER_ARM_DIM / 4)]
# COMMON_LAYERS   = [16, 8]


network = (
  global_and_arm_feature_network.create_feed_forward_common_tower_network(
      observation_spec = obs_spec, 
      global_layers = (20, 10), 
      arm_layers = (20, 10),
      common_layers = (10, 5),
  )
)
network

<tf_agents.bandits.networks.global_and_arm_feature_network.GlobalAndArmCommonTowerNetwork at 0x7fa7e6225c30>

In [31]:
policy_type

'cosine_distance'

In [32]:
if policy_type == 'cosine_distance':
    policy_type = ranking_agent.RankingPolicyType.COSINE_DISTANCE
elif policy_type == 'no_penalty':
    policy_type = ranking_agent.RankingPolicyType.NO_PENALTY
elif policy_type == 'descending_scores':
    policy_type = ranking_agent.RankingPolicyType.DESCENDING_SCORES
else:
    raise NotImplementedError(
        'Policy type {} is not implemented'.format(policy_type)
    )
    
policy_type

<RankingPolicyType.COSINE_DISTANCE: 1>

In [33]:
bias_type

''

In [34]:
positional_bias_type = None # bias_type or None
agent = ranking_agent.RankingAgent(
    time_step_spec=environment.time_step_spec(),
    action_spec=environment.action_spec(),
    scoring_network=network,
    optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=LR),
    policy_type=policy_type,
    feedback_model=feedback_model,
    positional_bias_type=positional_bias_type,
    positional_bias_severity=bias_severity,
    positional_bias_positive_only=bias_positive_only,
    summarize_grads_and_vars=True,
)
agent.name

'ranking_agent'

In [35]:
agent.action_spec

BoundedTensorSpec(shape=(3,), dtype=tf.int32, name='action', minimum=array(0, dtype=int32), maximum=array(19, dtype=int32))

In [36]:
agent.time_step_spec

_TupleWrapper(TimeStep(
{'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
 'observation': DictWrapper({'global': TensorSpec(shape=(32,), dtype=tf.float32, name=None), 'per_arm': TensorSpec(shape=(20, 64), dtype=tf.float32, name=None)}),
 'reward': DictWrapper({'chosen_index': BoundedTensorSpec(shape=(), dtype=tf.int32, name='chosen_index', minimum=array(0, dtype=int32), maximum=array(3, dtype=int32)), 'chosen_value': TensorSpec(shape=(), dtype=tf.float32, name='chosen_value')}),
 'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type')}))

In [37]:
agent.training_data_spec

_TupleWrapper(Trajectory(
{'action': (),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
 'next_step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'observation': DictWrapper({'global': TensorSpec(shape=(32,), dtype=tf.float32, name=None), 'per_arm': TensorSpec(shape=(3, 64), dtype=tf.float32, name=None)}),
 'policy_info': (),
 'reward': DictWrapper({'chosen_index': BoundedTensorSpec(shape=(), dtype=tf.int32, name='chosen_index', minimum=array(0, dtype=int32), maximum=array(3, dtype=int32)), 'chosen_value': TensorSpec(shape=(), dtype=tf.float32, name='chosen_value')}),
 'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type')}))

In [38]:
def order_items_from_action_fn(orig_trajectory):
    """Puts the features of the selected items in the recommendation order.

    This function is used to make sure that at training the item observation is
    filled with features of items selected by the policy, in the order of the
    selection. Features of unselected items are discarded.

    Args:
      orig_trajectory: The trajectory as output by the policy

    Returns:
      The modified trajectory that contains slotted item features.
    """
    item_obs = orig_trajectory.observation[
        bandit_spec_utils.PER_ARM_FEATURE_KEY
    ]
    action = orig_trajectory.action
    if isinstance(
        orig_trajectory.observation[bandit_spec_utils.PER_ARM_FEATURE_KEY],
        tensor_spec.TensorSpec,
    ):
        dtype = orig_trajectory.observation[
            bandit_spec_utils.PER_ARM_FEATURE_KEY
        ].dtype
        shape = [
            NUM_SLOTS,
            orig_trajectory.observation[
                bandit_spec_utils.PER_ARM_FEATURE_KEY
            ].shape[-1],
        ]
        new_observation = {
            bandit_spec_utils.GLOBAL_FEATURE_KEY: orig_trajectory.observation[
                bandit_spec_utils.GLOBAL_FEATURE_KEY
            ],
            bandit_spec_utils.PER_ARM_FEATURE_KEY: tensor_spec.TensorSpec(
                dtype=dtype, shape=shape
            ),
        }
    else:
        slotted_items = tf.gather(item_obs, action, batch_dims=1)
        new_observation = {
            bandit_spec_utils.GLOBAL_FEATURE_KEY: orig_trajectory.observation[
                bandit_spec_utils.GLOBAL_FEATURE_KEY
            ],
            bandit_spec_utils.PER_ARM_FEATURE_KEY: slotted_items,
        }
    
    return trajectory.Trajectory(
        step_type=orig_trajectory.step_type,
        observation=new_observation,
        action=(),
        policy_info=(),
        next_step_type=orig_trajectory.next_step_type,
        reward=orig_trajectory.reward,
        discount=orig_trajectory.discount,
    )

In [39]:
agent.policy.trajectory_spec

_TupleWrapper(Trajectory(
{'action': BoundedTensorSpec(shape=(3,), dtype=tf.int32, name=None, minimum=array(0, dtype=int32), maximum=array(19, dtype=int32)),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
 'next_step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'observation': DictWrapper({'global': TensorSpec(shape=(32,), dtype=tf.float32, name=None), 'per_arm': TensorSpec(shape=(20, 64), dtype=tf.float32, name=None)}),
 'policy_info': PolicyInfo(log_probability=(), predicted_rewards_mean=TensorSpec(shape=(3,), dtype=tf.float32, name=None), multiobjective_scalarized_predicted_rewards_mean=(), predicted_rewards_optimistic=(), predicted_rewards_sampled=(), bandit_policy_type=()),
 'reward': DictWrapper({'chosen_index': BoundedTensorSpec(shape=(), dtype=tf.int32, name='chosen_index', minimum=array(0, dtype=int32), maximum=array(3, dtype=int32)), 'chosen_value': TensorSpec(s

In [40]:
data_spec = order_items_from_action_fn(
    agent.policy.trajectory_spec
)
data_spec

Trajectory(
{'action': (),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
 'next_step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'observation': {'global': TensorSpec(shape=(32,), dtype=tf.float32, name=None),
                 'per_arm': TensorSpec(shape=(3, 64), dtype=tf.float32, name=None)},
 'policy_info': (),
 'reward': DictWrapper({'chosen_index': BoundedTensorSpec(shape=(), dtype=tf.int32, name='chosen_index', minimum=array(0, dtype=int32), maximum=array(3, dtype=int32)), 'chosen_value': TensorSpec(shape=(), dtype=tf.float32, name='chosen_value')}),
 'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type')})

## Trainer

In [41]:
import time

print(f"BATCH_SIZE     : {BATCH_SIZE}")
print(f"NUM_ITEMS      : {NUM_ITEMS}")
print(f"NUM_SLOTS      : {NUM_SLOTS}")
print(f"GLOBAL_DIM     : {GLOBAL_DIM}")
print(f"ITEM_DIM       : {ITEM_DIM}")
print(f"TRAINING_LOOPS : {TRAINING_LOOPS}")
print(f"STEPS_PER_LOOP : {STEPS_PER_LOOP}")
print(f"LR             : {LR}")

BATCH_SIZE     : 64
NUM_ITEMS      : 20
NUM_SLOTS      : 3
GLOBAL_DIM     : 32
ITEM_DIM       : 64
TRAINING_LOOPS : 300
STEPS_PER_LOOP : 2
LR             : 0.05


In [42]:
start_time = time.time()

train(
    root_dir=root_dir,
    agent=agent,
    environment=environment,
    training_loops=TRAINING_LOOPS,
    steps_per_loop=STEPS_PER_LOOP,
    training_data_spec_transformation_fn=order_items_from_action_fn,
    save_policy=False,
)

runtime_mins = int((time.time() - start_time) / 60)
print(f"train runtime_mins: {runtime_mins}")

step = 0: train loss = 0.6299999952316284
step = 100: train loss = 0.03999999910593033
step = 200: train loss = 0.05000000074505806
train runtime_mins: 23


In [43]:
# history

In [44]:
LOG_DIR = "gs://rec-bandits-v1-hybrid-vertex-bucket/tmp-example-rootdir-v3"

! gsutil ls $LOG_DIR

gs://rec-bandits-v1-hybrid-vertex-bucket/tmp-example-rootdir-v3/
gs://rec-bandits-v1-hybrid-vertex-bucket/tmp-example-rootdir-v3/checkpoint
gs://rec-bandits-v1-hybrid-vertex-bucket/tmp-example-rootdir-v3/ckpt-296.data-00000-of-00001
gs://rec-bandits-v1-hybrid-vertex-bucket/tmp-example-rootdir-v3/ckpt-296.index
gs://rec-bandits-v1-hybrid-vertex-bucket/tmp-example-rootdir-v3/ckpt-297.data-00000-of-00001
gs://rec-bandits-v1-hybrid-vertex-bucket/tmp-example-rootdir-v3/ckpt-297.index
gs://rec-bandits-v1-hybrid-vertex-bucket/tmp-example-rootdir-v3/ckpt-298.data-00000-of-00001
gs://rec-bandits-v1-hybrid-vertex-bucket/tmp-example-rootdir-v3/ckpt-298.index
gs://rec-bandits-v1-hybrid-vertex-bucket/tmp-example-rootdir-v3/ckpt-299.data-00000-of-00001
gs://rec-bandits-v1-hybrid-vertex-bucket/tmp-example-rootdir-v3/ckpt-299.index
gs://rec-bandits-v1-hybrid-vertex-bucket/tmp-example-rootdir-v3/ckpt-300.data-00000-of-00001
gs://rec-bandits-v1-hybrid-vertex-bucket/tmp-example-rootdir-v3/ckpt-300.index


In [45]:
# %load_ext tensorboard
%reload_ext tensorboard

In [46]:
%tensorboard --logdir=$LOG_DIR 

# Finished