In [2]:
import numpy as np
import pandas as pd
import plotly.express as px
from pprint import pprint
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import
import inspect
import abc
import numpy as np

import tensorflow as tf

from tf_agents.agents import tf_agent
from tf_agents.drivers import driver
from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.policies import tf_policy
from tf_agents.specs import array_spec
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step as ts
from tf_agents.trajectories import trajectory
from tf_agents.trajectories import policy_step
# Imports for example.
from tf_agents.bandits.agents import lin_ucb_agent
from tf_agents.bandits.environments import stationary_stochastic_py_environment as sspe
from tf_agents.bandits.metrics import tf_metrics
from tf_agents.drivers import dynamic_step_driver
from tf_agents.replay_buffers import tf_uniform_replay_buffer

nest = tf.nest
tf.compat.v1.enable_v2_behavior()
sess = tf.compat.v1.InteractiveSession()



## Environment

In [6]:
batch_size = 2 # @param
arm0_param = [-3, 0, 1, -2] # @param
arm1_param = [1, -2, 3, 0] # @param
arm2_param = [0, 0, 1, 1] # @param
def context_sampling_fn(batch_size):
  """Contexts from [-10, 10]^4."""
  def _context_sampling_fn():
    return np.random.randint(-10, 10, [batch_size, 4]).astype(np.float32)
  return _context_sampling_fn

class LinearNormalReward(object):
  """A class that acts as linear reward function when called."""
  def __init__(self, theta, sigma):
    self.theta = theta
    self.sigma = sigma
  def __call__(self, x):
    mu = np.dot(x, self.theta)
    return np.random.normal(mu, self.sigma)

arm0_reward_fn = LinearNormalReward(arm0_param, 1)
arm1_reward_fn = LinearNormalReward(arm1_param, 1)
arm2_reward_fn = LinearNormalReward(arm2_param, 1)

environment = tf_py_environment.TFPyEnvironment(
    sspe.StationaryStochasticPyEnvironment(
        context_sampling_fn(batch_size),
        [arm0_reward_fn, arm1_reward_fn, arm2_reward_fn],
        batch_size=batch_size))


observation_spec = tensor_spec.TensorSpec([4], tf.float32)
time_step_spec = ts.time_step_spec(observation_spec)
action_spec = tensor_spec.BoundedTensorSpec(
    dtype=tf.int32, shape=(), minimum=0, maximum=2)

agent = lin_ucb_agent.LinearUCBAgent(time_step_spec=time_step_spec,
                                     action_spec=action_spec)

def compute_optimal_reward(observation):
  expected_reward_for_arms = [
      tf.linalg.matvec(observation, tf.cast(arm0_param, dtype=tf.float32)),
      tf.linalg.matvec(observation, tf.cast(arm1_param, dtype=tf.float32)),
      tf.linalg.matvec(observation, tf.cast(arm2_param, dtype=tf.float32))]
  optimal_action_reward = tf.reduce_max(expected_reward_for_arms, axis=0)
  return optimal_action_reward

regret_metric = tf_metrics.RegretMetric(compute_optimal_reward)

num_iterations = 90 # @param
steps_per_loop = 1 # @param

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.policy.trajectory_spec,
    batch_size=batch_size,
    max_length=steps_per_loop)

observers = [replay_buffer.add_batch, regret_metric]

driver = dynamic_step_driver.DynamicStepDriver(
    env=environment,
    policy=agent.collect_policy,
    num_steps=steps_per_loop * batch_size,
    observers=observers)

regret_values = []

for _ in range(num_iterations):
  driver.run()
  loss_info = agent.train(replay_buffer.gather_all())
  replay_buffer.clear()
  regret_values.append(regret_metric.result())



## Agent

In [7]:
from tf_agents.bandits.agents import bernoulli_thompson_sampling_agent as bern_ts_agent

agent = bern_ts_agent.BernoulliThompsonSamplingAgent(
    time_step_spec=environment.time_step_spec(),
    action_spec=environment.action_spec(),
    dtype=tf.float64,
    batch_size=batch_size,
)

## Observers (Reward)

In [8]:
print(environment.pyenv._means)

[0.1, 0.2, 0.3, 0.45, 0.5]


In [9]:
from tf_agents.metrics import tf_metrics
from tf_agents.bandits.metrics import tf_metrics as tf_bandit_metrics

def optimal_reward_fn(unused_observation):
    return np.max(environment.pyenv._means)

def optimal_action_fn(unused_observation):
    return np.int32(np.argmax(environment.pyenv._means))


observers = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.AverageEpisodeLengthMetric(batch_size=environment.batch_size),
    tf_metrics.AverageReturnMetric(batch_size=environment.batch_size),
    tf_bandit_metrics.RegretMetric(optimal_reward_fn),
    tf_bandit_metrics.SuboptimalArmsMetric(optimal_action_fn)
  ]

## Driver

In [10]:
from tf_agents.drivers import dynamic_step_driver
steps_per_loop = 1
driver = dynamic_step_driver.DynamicStepDriver(
  env=environment,
  policy=agent.collect_policy,
  num_steps=steps_per_loop * environment.batch_size,
  observers=observers,
)
data_spec = agent.policy.trajectory_spec

In [11]:
pprint(driver.env.)

SyntaxError: invalid syntax (3215892952.py, line 1)

## Replay buffer

In [12]:
from tf_agents.bandits.replay_buffers import bandit_replay_buffer
replay_buffer = bandit_replay_buffer.BanditReplayBuffer(
      data_spec=data_spec,
      batch_size=batch_size,
      max_length=steps_per_loop
  )

## Training

In [13]:
from tf_agents.eval import metric_utils
from tf_agents.metrics import export_utils
from io import StringIO
import logging
log_stream = StringIO()    
logging.basicConfig(stream=log_stream, level=logging.NOTSET)

def _export_metrics_and_summaries(step, metrics):
    """Exports metrics and tf summaries."""
    metric_utils.log_metrics(metrics)
    export_utils.export_metrics(step=step, metrics=metrics)
    for metric in metrics:
      metric.tf_summaries(train_step=step)

In [14]:
starting_loop = 0
training_loops = 1000

In [15]:
driver.run()
dataset_it = iter(
    replay_buffer.as_dataset(
        sample_batch_size=batch_size,
        num_steps=100,
        single_deterministic_pass=True,
    )
)



In [24]:
driver.observers[1].tf_summaries(train_step=0)[0].numpy()

False

In [100]:
log_stream.

<_io.StringIO at 0x142ffa050>

In [None]:
def training_loop(train_step, metrics):
    """Returns a function that runs a single training loop and logs metrics."""
    driver.run()
    _export_metrics_and_summaries(
      step=train_step, metrics=metrics
    )
    batch_size = driver.env.batch_size
    dataset_it = iter(
        replay_buffer.as_dataset(
            sample_batch_size=batch_size,
            num_steps=100,
            single_deterministic_pass=True,
        )
    )
    experience, unused_buffer_info = dataset_it.get_next()
    set_expected_shape(experience, steps)
    loss_info = agent.train(experience)
    export_utils.export_metrics(
      step=train_step * async_steps_per_loop + batch_id,
      metrics=[],
      loss_info=loss_info,
    )
    
    replay_buffer.clear()