<a href="https://colab.research.google.com/github/self-improving-efms/self-improving-efms.github.io/blob/gpt5-version/pointmass_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
import os
import io
import pickle
from typing import Optional, Any, NamedTuple, Callable, Mapping, Sequence, Dict, Tuple
import dataclasses
import enum
import time
from collections import defaultdict
from pprint import pprint

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import matplotlib.patches as patches
from PIL import Image
import mediapy as media
from IPython.display import clear_output

import jax
import jax.numpy as jnp
import optax
import haiku as hk

import tensorflow as tf
import tensorflow.data as tf_data
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions

import dm_env
from dm_env import specs

from moviepy import ImageSequenceClip, display_in_notebook

# Create the pointmass environment

In [None]:
BOUNDS_X = np.array([-1., 1.], dtype=np.float32)
BOUNDS_Y = np.array([-1., 1.], dtype=np.float32)

DPI = 200
RENDER_HEIGHT_INCHES = 5

class Point2D(dm_env.Environment):
  def __init__(self):
    self._cur_pos = np.zeros(2, dtype=np.float32)
    self._goal_pos = np.zeros(2, dtype=np.float32)
    self._cur_vel = np.zeros(2, dtype=np.float32)
    self._cur_episode_traj = []
    self._physics_substeps = 10
    self._success_radius = 0.15

  def sample_goal(self):
    border_x = (BOUNDS_X[1] - BOUNDS_X[0]) * 0.05
    border_y = (BOUNDS_Y[1] - BOUNDS_Y[0]) * 0.05
    goal_x = np.random.uniform(
        BOUNDS_X[0] + border_x, BOUNDS_X[1] - border_x)
    goal_y = np.random.uniform(
        BOUNDS_Y[0] + border_y, BOUNDS_Y[1] - border_y)
    return np.array([goal_x, goal_y], dtype=np.float32)

  def set_goal(self, goal_pos):
    self._goal_pos = goal_pos

  def reset(self):
    self._goal_pos = self.sample_goal()
    cur_x = np.random.uniform(BOUNDS_X[0], BOUNDS_X[1])
    cur_y = np.random.uniform(BOUNDS_Y[0], BOUNDS_Y[1])
    self._cur_pos = np.array([cur_x, cur_y], dtype=np.float32)
    cur_pos_copy = self._cur_pos.copy()

    self._cur_vel = np.zeros(2, dtype=np.float32)
    cur_vel_copy = self._cur_vel.copy()

    obs = {
        'cur_pos': cur_pos_copy,
        'cur_vel': cur_vel_copy,
        'goal_pos': self._goal_pos.copy()}
    ts = dm_env.TimeStep(
        step_type=dm_env.StepType.FIRST,
        reward=None,
        discount=None,
        observation=obs,)

    self._cur_episode_traj = [cur_pos_copy]
    return ts

  def step(self, action):
    for i in range(self._physics_substeps):
      self._cur_vel += action
      self._cur_pos += self._cur_vel

    cur_pos_copy = self._cur_pos.copy()
    cur_vel_copy = self._cur_vel.copy()
    obs = {
        'cur_pos': cur_pos_copy,
        'cur_vel': cur_vel_copy,
        'goal_pos': self._goal_pos.copy()}

    if self.success():
      step_type = dm_env.StepType.LAST
    else:
      step_type = dm_env.StepType.MID
    ts = dm_env.TimeStep(
        step_type=step_type,
        reward=-1. * np.linalg.norm(self._cur_pos - self._goal_pos),
        discount=1.,
        observation=obs,)

    self._cur_episode_traj.append(cur_pos_copy)
    return ts

  def success(self, waypoint: Optional[np.ndarray] = None):
    if waypoint is not None:
      goal_pos = waypoint
    else:
      goal_pos = self._goal_pos
    return np.linalg.norm(self._cur_pos - goal_pos) < self._success_radius

  def observation_spec(self):
    return {
        'cur_pos': specs.Array((2,), dtype=np.float32),
        'cur_vel': specs.Array((2,), dtype=np.float32),
        'goal_pos': specs.Array((2,), dtype=np.float32),}

  def action_spec(self):
    return specs.Array((2,), dtype=np.float32)

  def render(
      self,
      title: str = '',
      points: Optional[np.ndarray] = None,
      goal_pos: Optional[np.ndarray] = None):
    fig, ax = plt.subplots(
        figsize=(RENDER_HEIGHT_INCHES, RENDER_HEIGHT_INCHES), dpi=DPI)
    ax.set_xlim(BOUNDS_X[0], BOUNDS_X[1])
    ax.set_ylim(BOUNDS_Y[0], BOUNDS_Y[1])
    ax.set_aspect('equal')

    if points is None:
      points = np.array(self._cur_episode_traj)
      cur_pos = self._cur_pos
    else:
      cur_pos = points[-1]

    if goal_pos is None:
      goal_pos = self._goal_pos

    ax.plot(points[:, 0], points[:, 1], marker='.', color='blue', markersize=16, linewidth=4)
    ax.scatter(
        goal_pos[0], goal_pos[1], marker='*', s=200, color='orange', linewidths=8)
    ax.scatter(
        cur_pos[0], cur_pos[1], marker='o', s=100, color='red', linewidths=8)

    # Add a dashed circle around the star
    circle = patches.Circle(
        (goal_pos[0], goal_pos[1]),  # Center of the circle
        self._success_radius,  # Radius of the circle
        edgecolor='green',  # Color of the circle
        linestyle='--',  # Dashed line
        linewidth=4,  # Thickness of the circle line
        fill=False  # Ensure it's just an outline
    )
    ax.add_patch(circle)  # Add the circle to the plot

    # Make the axes lines thicker
    for spine in ax.spines.values():
        spine.set_linewidth(4)  # Adjust the thickness here

    if title != '':
      ax.set_title(title, fontsize=18, fontweight='bold')

    ax.set_xticks([])
    ax.set_yticks([])
    plt.tight_layout()

    # Render the plot using FigureCanvasAgg
    canvas = FigureCanvas(fig)
    canvas.draw()

    # Convert the rendered image to a numpy array
    ##### OLD VERSION (matplotlib3.10 does not support tostring_rgb()) #####
    # width, height = fig.get_size_inches() * fig.get_dpi()
    # image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
    # image = image.reshape(int(height), int(width), 3)
    canvas.draw()  # 必须先 draw
    ##### NEW VERSION #####
    canvas.draw()
    width, height = canvas.get_width_height()
    image = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
    image = image.reshape(height, width, 4)
    image = image[:, :, :3]

    plt.close(fig)
    return image


# Create the PD controller for data generation

In [None]:
def pd_controller(cur_pos, cur_vel, goal_pos):
  Kp = 0.0002
  Kd = 0.0125
  act = Kp * (goal_pos - cur_pos) + Kd * (-1. * cur_vel)
  return act

# Visualize Trajctories from the PD controller

In [None]:
env = Point2D()

imgs = []
ts = env.reset()
for eps in range(7):
  while not env.success():
    cur_pos = ts.observation['cur_pos']
    cur_vel = ts.observation['cur_vel']
    goal_pos = ts.observation['goal_pos']
    act = pd_controller(cur_pos, cur_vel, goal_pos)
    ts = env.step(act)
    imgs.append(env.render(title=""))

  env.set_goal(env.sample_goal())
imgs.append(env.render(title=""))

clip = ImageSequenceClip(imgs, fps=10)
display_in_notebook(clip, fps=10, width=512)

# Generate a Dataset
In the dataset, in each episode, the pointmass will go to various random waypoints, and eventually go to the desired goal position for the episode.

In [None]:
NestedArray = Any

num_episodes = 10000 #@param {type: "number"}
num_waypoints_per_episode = 5 #@param {type: "number"}
episode_len_discard_thresh = 10 #@param {type: "number"}

# This dataset has the format (obs, act, time_to_success, next_obs)
# It's inefficent to store next obs but makes life easier
# Will save two versions: 1) trajectories kept separate,
# 2) just tuples

class DataTuple(NamedTuple):
  observation: NestedArray
  action: NestedArray
  time_to_success: NestedArray
  reward: NestedArray
  discount: NestedArray
  next_observation: NestedArray


episodes = []
while len(episodes) < num_episodes:
  ep_num = len(episodes)
  traj = []

  ts = env.reset()
  cur_obs = ts.observation
  succ = env.success()

  waypoint_idx = 0
  if num_waypoints_per_episode == 0:
    cur_waypoint = cur_obs['goal_pos']
  else:
    cur_waypoint = env.sample_goal()
  waypoint_succ = env.success(waypoint=cur_waypoint)

  while not succ:
    if waypoint_succ:
      waypoint_idx += 1
      waypoint_idx = min(waypoint_idx, num_waypoints_per_episode)
      if waypoint_idx == num_waypoints_per_episode:
        cur_waypoint = cur_obs['goal_pos']
      else:
        cur_waypoint = env.sample_goal()

    act = pd_controller(
        cur_obs['cur_pos'], cur_obs['cur_vel'], cur_waypoint)
    ts = env.step(act)
    next_obs = ts.observation

    data_tuple = DataTuple(
        observation=cur_obs,
        action=act,
        time_to_success=0.,
        reward=ts.reward if ts.reward is not None else 0.,
        discount=1.,
        next_observation=next_obs,)
    traj.append(data_tuple)

    cur_obs = next_obs
    succ = env.success()
    waypoint_succ = env.success(waypoint=cur_waypoint)

  # add the last timestep
  act = pd_controller(
      cur_obs['cur_pos'], cur_obs['cur_vel'], cur_waypoint)
  data_tuple = DataTuple(
      observation=cur_obs,
      action=act,
      time_to_success=0.,
      reward=ts.reward if ts.reward is not None else 0.,
      discount=0.,
      next_observation=cur_obs,)
  traj.append(data_tuple)

  # discard if episode is too short
  traj_len = len(traj)
  if traj_len < episode_len_discard_thresh:
    continue

  # stack the traj arrays
  new_traj = jax.tree.map(
      lambda *xs: np.stack(xs, dtype=np.float32), *traj)

  # label traj arrays with time to success
  new_traj = new_traj._replace(
      time_to_success=np.arange(
          traj_len - 1, -1, -1, dtype=np.float32))
  episodes.append(new_traj)

# concat all the trajs
all_tuples = jax.tree.map(
    lambda *xs: np.concatenate(xs, dtype=np.float32), *episodes)

# print stats
print('\nEpisodes 0-2:')
pprint(jax.tree.map(lambda x: (x.shape, x.dtype), episodes[:2]))

print('\nData Tuples:')
pprint(jax.tree.map(lambda x: (x.shape, x.dtype), all_tuples))

episode_lens = np.array(
    list(map(lambda x: x.observation['cur_pos'].shape[0], episodes)))
print('\n')
print(f'Num Episodes: {episode_lens.shape[0]}')
print(f'Episode Lens: {np.mean(episode_lens)} +/- {np.std(episode_lens)}')
print(f'Max Episode Len: {np.max(episode_lens)}')
print(f'Min Episode Len: {np.min(episode_lens)}')

In [None]:
plt.figure()
plt.hist(episode_lens, bins=30)
plt.title('Histogram of Episode Lengths in the Dataset')

In [None]:
# plot some episodes to make sure things look good
episode_num = 7
debug_episode = episodes[episode_num]
points = debug_episode.observation['cur_pos']
goal_pos = debug_episode.observation['goal_pos'][0]
imgs = []
for i in range(points.shape[0]):
  imgs.append(env.render(title=f'Dataset Episode {episode_num}', points=points[:i+1], goal_pos=goal_pos))

clip = ImageSequenceClip(imgs, fps=10)
display_in_notebook(clip, fps=10, width=512)

In [None]:
save_datasets = True #@param {type: "boolean"}
save_path = 'pointmass_dataset.pkl' #@param {type: "string"}

if save_datasets:
  # save version with trajs
  head, tail = os.path.splitext(save_path)
  with open('{}_trajs{}'.format(head, tail), 'wb') as fp:
    pickle.dump(episodes, fp)

  # save version with just data tuples
  with open('{}_tuples{}'.format(head, tail), 'wb') as fp:
    pickle.dump(all_tuples, fp)

# Create the networks

In [None]:
Params = Any
PRNGKey = Any
NetworkOutput = Any
Entropy = Any
ActDistParams = Params
FeedForwardPolicyWithExtra = Any
LogProbFn = Any
SampleFn = Any
Observation = Any
Action = Any
DistanceToSuccessDistParams = Params
EntropyFn = Callable[
    [Params, PRNGKey], Entropy]


@dataclasses.dataclass
class FeedForwardNetwork:
  """Holds a pair of pure functions defining a feed-forward network.

  Attributes:
    init: A Jax pure function: ``params = init(rng, *a, **k)``
    apply: A Jax pure function: ``out = apply(params, rng, *a, **k)``
  """
  # Initializes and returns the networks parameters.
  init: Callable[..., Params]
  # Computes and returns the outputs of a forward pass.
  apply: Callable[..., NetworkOutput]

In [None]:
MIN_ACT_SCALE = 1e-2 #@param {type: "number"}


class MVNDiagParams(NamedTuple):
  """Parameters for a diagonal multi-variate normal distribution."""
  loc: jnp.ndarray
  scale_diag: jnp.ndarray


class CategoricalParams(NamedTuple):
  """Parameters for a categorical distribution."""
  logits: jnp.ndarray


class TIMERNetworkOutput(NamedTuple):
  act_dist_params: ActDistParams
  dist_to_succ_dist_params: DistanceToSuccessDistParams


@dataclasses.dataclass
class TIMERNetworks:
  """Network and pure functions for the TIMER agent.

  network: outputs TIMERNetworkOutputs
  act_log_prob: log probability of an action
  act_entropy: optional method for entropy of an action distribution
  sample_act: samples an action given [ActDistParams, PRNGKey]
  sample_act_mode: optional separate action sampling procedure
  dist_log_prob: log probability of a distance
  dist_entropy: optional method for entropy of a distance distribution
  sample_dist: samples a distance given [DistanceToSuccessDistParams, PRNGKey]
  sample_dist_mode: optional separate distance sampling procedure
  """
  network: FeedForwardNetwork
  act_log_prob: LogProbFn
  sample_act: SampleFn
  dist_log_prob: LogProbFn
  sample_dist: SampleFn
  act_entropy: Optional[EntropyFn] = None
  sample_act_mode: Optional[SampleFn] = None
  dist_entropy: Optional[EntropyFn] = None
  sample_dist_mode: Optional[SampleFn] = None


def make_policy_fn(
    timer_networks: TIMERNetworks,
    evaluation: bool) -> FeedForwardPolicyWithExtra:
  """Returns a policy function for the TIMER agent."""

  def _policy_fn(
      params: Params,
      key: PRNGKey,
      observations: Observation,
  ):
    timer_network_output: TIMERNetworkOutput = timer_networks.network.apply(
        params, observations)
    if evaluation:
      actions = timer_networks.sample_act_eval(
          timer_network_output.act_dist_params, key)
    else:
      actions = timer_networks.sample_act(
          timer_network_output.act_dist_params, key)
    return actions, {}

  return _policy_fn


def build_continuous_act_discrete_dist_v0(
    layer_sizes: Sequence[int],
    act_dim: int,
    num_dist_bins: int,
    dummy_input,
) -> TIMERNetworks:
  """"Builds TIMERNetworks for continuous action and discrete distance."""

  def _network(
      x: Observation) -> TIMERNetworkOutput:
    #### Build the action part
    h_act = hk.nets.MLP(
        output_sizes=layer_sizes,
        activation=jax.nn.relu,
        activate_final=True,)(x)
    act_loc = hk.Linear(
        act_dim,
        w_init=hk.initializers.VarianceScaling(1e-4),
        b_init=hk.initializers.Constant(0.))(h_act)
    act_scale = hk.Linear(
        act_dim,
        w_init=hk.initializers.VarianceScaling(1e-4),
        b_init=hk.initializers.Constant(0.))(h_act)
    act_scale = jax.nn.softplus(act_scale) + MIN_ACT_SCALE
    act_dist = MVNDiagParams(loc=act_loc, scale_diag=act_scale)

    #### Build the distance part
    h_dist = hk.nets.MLP(
        output_sizes=layer_sizes,
        activation=jax.nn.relu,
        activate_final=True,)(x)
    dist_logits = hk.Linear(num_dist_bins, with_bias=False)(h_dist)
    distance_dist = CategoricalParams(logits=dist_logits)

    return TIMERNetworkOutput(
        act_dist_params=act_dist,
        dist_to_succ_dist_params=distance_dist,)

  transformed_network = hk.without_apply_rng(hk.transform(_network))
  def init_closure(rng: PRNGKey):
    return transformed_network.init(rng, dummy_input)
  network = FeedForwardNetwork(
      init=init_closure,
      apply=transformed_network.apply,)

  def act_log_prob(params: MVNDiagParams, action):
    return tfd.MultivariateNormalDiag(
        loc=params.loc, scale_diag=params.scale_diag).log_prob(action)

  def act_entropy(
      params: MVNDiagParams, key: PRNGKey
  ) -> Entropy:
    del key
    return tfd.MultivariateNormalDiag(
        loc=params.loc, scale_diag=params.scale_diag).entropy()

  def sample_act(params: MVNDiagParams, key: PRNGKey):
    return tfd.MultivariateNormalDiag(
        loc=params.loc, scale_diag=params.scale_diag).sample(seed=key)

  def sample_act_mode(params: MVNDiagParams, key: PRNGKey):
    del key
    return tfd.MultivariateNormalDiag(
        loc=params.loc, scale_diag=params.scale_diag).mode()

  def dist_log_prob(params: CategoricalParams, action):
    return tfd.Categorical(logits=params.logits).log_prob(action)

  def dist_entropy(
      params: CategoricalParams, key: PRNGKey
  ) -> Entropy:
    del key
    return tfd.Categorical(logits=params.logits).entropy()

  def sample_dist(params: CategoricalParams, key: PRNGKey):
    return tfd.Categorical(logits=params.logits).sample(seed=key)

  def sample_dist_mode(params: CategoricalParams, key: PRNGKey):
    del key
    return tfd.Categorical(logits=params.logits).mode()

  return TIMERNetworks(
      network=network,
      act_log_prob=act_log_prob,
      sample_act=sample_act,
      dist_log_prob=dist_log_prob,
      sample_dist=sample_dist,
      act_entropy=act_entropy,
      sample_act_mode=sample_act_mode,
      dist_entropy=dist_entropy,
      sample_dist_mode=sample_dist_mode,
  )


In [None]:
# sanity checks
timer_networks = build_continuous_act_discrete_dist_v0((64, 64), 2, 10, np.ones((4, 6), dtype=np.float32))
params = timer_networks.network.init(jax.random.PRNGKey(42))
pprint(jax.tree.map(lambda x: x.shape, params))
x = timer_networks.network.apply(params, np.ones((4, 6), dtype=np.float32))
pprint(jax.tree.map(lambda x: x.shape, x))
print(timer_networks.act_log_prob(x.act_dist_params, np.ones((4, 2), dtype=np.float32)).shape)
print(timer_networks.sample_act(x.act_dist_params, jax.random.PRNGKey(42)).shape)
print(timer_networks.dist_log_prob(x.dist_to_succ_dist_params, np.array([3, 1, 5, 0], dtype=np.int32)).shape)
print(timer_networks.sample_dist(x.dist_to_succ_dist_params, jax.random.PRNGKey(42)).shape)
print(timer_networks.act_entropy(x.act_dist_params, jax.random.PRNGKey(42)).shape)
print(timer_networks.sample_act_mode(x.act_dist_params, jax.random.PRNGKey(42)).shape)
print(timer_networks.dist_entropy(x.dist_to_succ_dist_params, jax.random.PRNGKey(42)).shape)
print(timer_networks.sample_dist_mode(x.dist_to_succ_dist_params, jax.random.PRNGKey(42)).shape)

# Timestep prediction converters between discrete predictions and continuous values

In [None]:
class DistanceConverters(NamedTuple):
  distance_to_network_format: Callable[
      [jnp.ndarray], NetworkOutput]
  network_format_to_distance: Callable[
      [NetworkOutput], jnp.ndarray]


def build_discrete_distance_converter(
    min_distance: float,
    max_distance: float,
    num_bins: int = 100) -> DistanceConverters:

  bin_size = (max_distance - min_distance) / num_bins

  def _distance_to_network_format(d: float):
    d = jnp.clip(d, min_distance, max_distance - bin_size / 2.)
    bin_index = jnp.floor_divide(d - min_distance, bin_size)
    return bin_index

  distance_to_network_format = jax.vmap(_distance_to_network_format)

  dist_vals = jnp.linspace(
      min_distance,
      max_distance,
      num_bins + 1,
      endpoint=True, dtype=jnp.float32)
  dist_vals = dist_vals[:-1]

  def _network_format_to_distance(logits: NetworkOutput):
    dist = jnp.sum(dist_vals * jax.nn.softmax(logits))
    return dist

  network_format_to_distance = jax.vmap(_network_format_to_distance)

  return DistanceConverters(
      distance_to_network_format,
      network_format_to_distance,)


In [None]:
# sanity check
dc = build_discrete_distance_converter(0, 55, 50)
print(dc.distance_to_network_format(np.arange(60, dtype=np.float32)))
print(dc.network_format_to_distance(np.ones((4, 50))))

# Implementations for Stage 1: Supervised Fine-Tuning

In [None]:
def get_from_first_device(x, as_numpy=False):
  if as_numpy:
    return jax.device_get(jax.tree.map(lambda x: x[0], x))
  else:
    return jax.tree.map(lambda x: x[0], x)

In [None]:
TIMERParams = Params


class TrainingState(NamedTuple):
  """Training state for the TIMER learner."""
  params: TIMERParams
  opt_state: optax.OptState
  random_key: PRNGKey


class PretrainLearner():
  def __init__(
      self,
      timer_networks: TIMERNetworks,
      distance_converters: DistanceConverters,
      optimizer: optax.GradientTransformation,
      random_key: PRNGKey,
      global_minibatch_size: int,
      num_minibatches: int,):

    self.local_learner_devices = jax.local_devices()
    self.num_local_learner_devices = jax.local_device_count()
    self.learner_devices = jax.devices()
    per_device_minibatch_size = (
        global_minibatch_size // len(self.learner_devices))

    self._num_full_update_steps = 0
    self.global_minibatch_size = global_minibatch_size
    self.num_minibatches = num_minibatches

    def pretrain_loss(params, minibatch: DataTuple):
      obs = minibatch.observation
      obs = jnp.concatenate(
          [obs['cur_pos'], obs['cur_vel'], obs['goal_pos']], axis=-1)
      acts = minibatch.action
      dist_idx = dc.distance_to_network_format(minibatch.time_to_success)

      preds = timer_networks.network.apply(params, obs)
      act_dist_params = preds.act_dist_params
      dist_to_succ_dist_params = preds.dist_to_succ_dist_params

      # bc loss
      act_log_prob = jnp.mean(
          timer_networks.act_log_prob(act_dist_params, acts))
      bc_loss = -1.0 * act_log_prob

      # Distance to success loss
      dist_log_prob = jnp.mean(timer_networks.dist_log_prob(
          dist_to_succ_dist_params, minibatch.time_to_success))
      dist_loss = -1.0 * dist_log_prob

      total_loss = bc_loss + dist_loss

      return total_loss, {
          'pretrain_loss': total_loss,
          'act_log_prob': act_log_prob,
          'dist_log_prob': dist_log_prob,}

    pretrain_grad = jax.grad(pretrain_loss, has_aux=True)

    def per_device_pretrain_step(
        state: TrainingState,
        minibatch: DataTuple,
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
      pretrain_loss_grad, metrics = pretrain_grad(state.params, minibatch)
      pretrain_loss_grad = jax.lax.pmean(
          pretrain_loss_grad, axis_name='devices'
      )
      updates, new_opt_state = optimizer.update(
          pretrain_loss_grad, state.opt_state
      )
      new_params = optax.apply_updates(state.params, updates)
      state = state._replace(params=new_params, opt_state=new_opt_state)
      return state, metrics

    def scanned_per_device_pretrain_step(
        state: TrainingState, batch: DataTuple):
      def reshape_for_scan(x):
        new_shape = [
            num_minibatches,
            per_device_minibatch_size,
        ] + list(x.shape[1:])
        return jnp.reshape(x, new_shape)

      minibatches = jax.tree.map(reshape_for_scan, batch)
      state, metrics = jax.lax.scan(
          per_device_pretrain_step, state, minibatches, length=num_minibatches)
      metrics = jax.tree.map(jnp.mean, metrics)

      return state, metrics

    self._pmapped_scanned_pretrain_step = jax.pmap(
        scanned_per_device_pretrain_step,
        axis_name='devices',
        devices=self.learner_devices)

    def per_device_compute_loss(
        state: TrainingState,
        minibatch: DataTuple,
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
      _, metrics = pretrain_loss(state.params, minibatch)
      return state, metrics

    def scanned_per_device_compute_loss(
        state: TrainingState, batch: DataTuple):
      def reshape_for_scan(x):
        new_shape = [
            num_minibatches,
            per_device_minibatch_size,
        ] + list(x.shape[1:])
        return jnp.reshape(x, new_shape)

      minibatches = jax.tree.map(reshape_for_scan, batch)
      state, metrics = jax.lax.scan(
          per_device_compute_loss, state, minibatches, length=num_minibatches)
      metrics = jax.tree.map(jnp.mean, metrics)

      return state, metrics

    self._pmapped_scanned_compute_loss = jax.pmap(
        scanned_per_device_compute_loss,
        axis_name='devices',
        devices=self.learner_devices)

    def make_initial_state(random_key: PRNGKey) -> TrainingState:
      all_keys = jax.random.split(
          random_key, num=self.num_local_learner_devices + 1)
      key_init, key_state = all_keys[0], all_keys[1:]
      key_state = [key_state[i] for i in range(self.num_local_learner_devices)]
      key_state = jax.device_put_sharded(key_state, self.local_learner_devices)

      initial_params = timer_networks.network.init(key_init)
      initial_opt_state = optimizer.init(initial_params)

      initial_params = jax.device_put_replicated(initial_params,
                                                 self.local_learner_devices)
      initial_opt_state = jax.device_put_replicated(initial_opt_state,
                                                    self.local_learner_devices)

      return TrainingState(
          params=initial_params,
          opt_state=initial_opt_state,
          random_key=key_state,)

    # Initialise training state (parameters and optimizer state).
    self._state = make_initial_state(random_key)

  def step(self, batch: DataTuple):
    self._state, results = self._pmapped_scanned_pretrain_step(
        self._state, batch)

    self._num_full_update_steps += self.num_minibatches

    results = jax.tree.map(jnp.mean, results)
    return results

  def compute_loss(self, batch: DataTuple):
    _, results = self._pmapped_scanned_pretrain_step(
        self._state, batch)
    results = jax.tree.map(jnp.mean, results)
    return results

  def get_state(self):
    return get_from_first_device(self._state, as_numpy=True)

  def restore(self, state: TrainingState):
    random_key = state.random_key
    random_key = jax.random.split(
        random_key, num=self.num_local_learner_devices)
    random_key = jax.device_put_sharded(
        [random_key[i] for i in range(self.num_local_learner_devices)],
        self.local_learner_devices)

    state = jax.device_put_replicated(state, self.local_learner_devices)
    state = state._replace(random_key=random_key)
    self._state = state


# sanity checks
optimizer = optax.chain(
    optax.clip_by_global_norm(0.5),
    optax.scale_by_adam(eps=1e-7),
    optax.scale(-3e-4))
learner = PretrainLearner(
    timer_networks,
    dc,
    optimizer,
    jax.random.PRNGKey(0),
    64,
    4,)
pprint(jax.tree.map(lambda x: x.shape, learner.get_state()))
batch = jax.tree.map(lambda x: x[:256], all_tuples)
def reshape_for_pmap(x):
  new_shape = [
      jax.device_count(),
      (64 // jax.device_count()) * 4,
  ] + list(x.shape[1:])
  return jnp.reshape(x, new_shape)
batch = jax.tree.map(lambda x: reshape_for_pmap(x), batch)
print(learner.step(batch))
print(learner.compute_loss(batch))

In [None]:
print(jnp.mean(all_tuples.observation['cur_pos'], axis=0))
print(jnp.std(all_tuples.observation['cur_pos'], axis=0))
print(jnp.mean(all_tuples.observation['cur_vel'], axis=0))
print(jnp.std(all_tuples.observation['cur_vel'], axis=0))
print(jnp.mean(all_tuples.observation['goal_pos'], axis=0))
print(jnp.std(all_tuples.observation['goal_pos'], axis=0))
print(jnp.mean(all_tuples.action, axis=0))
print(jnp.std(all_tuples.action, axis=0))

In [None]:
cur_pos_mean = jnp.mean(all_tuples.observation['cur_pos'], axis=0, keepdims=True)
cur_pos_std = jnp.std(all_tuples.observation['cur_pos'], axis=0, keepdims=True)
cur_vel_mean = jnp.mean(all_tuples.observation['cur_vel'], axis=0, keepdims=True)
cur_vel_std = jnp.std(all_tuples.observation['cur_vel'], axis=0, keepdims=True)
act_mean = jnp.mean(all_tuples.action, axis=0, keepdims=True)
act_std = jnp.std(all_tuples.action, axis=0, keepdims=True)

def normalize_obs(obs):
  normalized_obs = {
      'cur_pos': (obs['cur_pos'] - cur_pos_mean) / cur_pos_std,
      'cur_vel': (obs['cur_vel'] - cur_vel_mean) / cur_vel_std,
      'goal_pos': (obs['goal_pos'] - cur_pos_mean) / cur_pos_std,}
  return normalized_obs

def normalize_action(action):
  return (action - act_mean) / act_std

def unnormalize_action(action):
  return action * act_std + act_mean

normalized_all_tuples = all_tuples._replace(
    observation=normalize_obs(all_tuples.observation),
    action=normalize_action(all_tuples.action))

In [None]:
print(jnp.mean(normalized_all_tuples.observation['cur_pos'], axis=0))
print(jnp.std(normalized_all_tuples.observation['cur_pos'], axis=0))
print(jnp.mean(normalized_all_tuples.observation['cur_vel'], axis=0))
print(jnp.std(normalized_all_tuples.observation['cur_vel'], axis=0))
print(jnp.mean(normalized_all_tuples.observation['goal_pos'], axis=0))
print(jnp.std(normalized_all_tuples.observation['goal_pos'], axis=0))
print(jnp.mean(normalized_all_tuples.action, axis=0))
print(jnp.std(normalized_all_tuples.action, axis=0))

# Train Stage 1: Supervised Fine-Tuning

In [None]:
# Make the data loader
global_minibatch_size = 256 #@param {type: "number"}
# num_minibatches is the is the number of SGD update steps we do per call to the learner
# the updates steps are scanned using jax.lax.scan for efficiency.
num_minibatches = 128 #@param {type: "number"}

batch_size = global_minibatch_size * num_minibatches

num_learners = jax.device_count()
def reshape_for_pmap(x):
  new_shape = [
      num_learners,
      (global_minibatch_size * num_minibatches) // num_learners,
  ] + list(x.shape[1:])
  return tf.reshape(x, new_shape)

def make_dataset_from_tuples(data_tuples):
  dataset = tf_data.Dataset.from_tensor_slices(data_tuples).cache()
  dataset = dataset.shuffle(
      all_tuples.observation['cur_pos'].shape[0], reshuffle_each_iteration=True)
  dataset = dataset.repeat().batch(batch_size, drop_remainder=True)
  dataset = dataset.map(lambda x: jax.tree.map(reshape_for_pmap, x))
  dataset = dataset.prefetch(tf_data.experimental.AUTOTUNE)
  dataset = dataset.as_numpy_iterator()
  return dataset

normalized_all_tuples_size = normalized_all_tuples.observation['cur_pos'].shape[0]
train_set_ratio = 0.9
train_set_size = int(normalized_all_tuples_size * train_set_ratio)
train_dataset = make_dataset_from_tuples(
    jax.tree.map(lambda x: np.array(x[:train_set_size]), normalized_all_tuples))
val_dataset = make_dataset_from_tuples(
    jax.tree.map(lambda x: np.array(x[train_set_size:]), normalized_all_tuples))

print(jax.tree.map(lambda x: x.shape, next(train_dataset)))
print(jax.tree.map(lambda x: x.shape, next(val_dataset)))

In [None]:
# Make the distance converters
min_distance = 0 #@param {type: "number"}
max_distance = 140 #@param {type: "number"}
num_distance_bins = 50 #@param {type: "number"}

distance_converter = build_discrete_distance_converter(
    min_distance, max_distance, num_distance_bins)

In [None]:
# Make the networks
layer_sizes = (256, 256, 256) #@param

timer_networks = build_continuous_act_discrete_dist_v0(
    layer_sizes,
    env.action_spec().shape[0],
    num_distance_bins,
    np.ones((4, 6), dtype=np.float32))

In [None]:
# Make the optimizer
learning_rate = 3e-4 #@param {type: "number"}
global_norm_clip = 1.0 #@param {type: "number"}

optimizer = optax.chain(
    # optax.clip_by_global_norm(global_norm_clip),
    optax.scale_by_adam(eps=1e-7),
    optax.scale(-1. * learning_rate))

In [None]:
# Make the learner
learner = PretrainLearner(
    timer_networks,
    distance_converter,
    optimizer,
    jax.random.PRNGKey(42),
    global_minibatch_size,
    num_minibatches,)

losses = []
act_log_probs = []
dist_log_probs = []

val_losses = []
val_act_log_probs = []
val_dist_log_probs = []

In [None]:
# Stage 1 Supervised Fine-Tuning Train Loop
# You can keep rerunning this cell if you would like to continue the training
# for more iterations.

# Number of SGD steps to perform
num_steps = 32768 #@param
assert num_steps % num_minibatches == 0
display_rate = 1

for i in range(num_steps // num_minibatches):
  batch = next(train_dataset)
  results = learner.step(batch)
  cur_loss = results['pretrain_loss'].item()
  losses.append(cur_loss)
  act_log_probs.append(results['act_log_prob'].item())
  dist_log_probs.append(results['dist_log_prob'].item())

  if i % display_rate == 0:
    val_batch = next(val_dataset)
    val_results = learner.compute_loss(val_batch)
    val_cur_loss = val_results['pretrain_loss'].item()
    val_losses.append(val_cur_loss)
    val_act_log_probs.append(val_results['act_log_prob'].item())
    val_dist_log_probs.append(val_results['dist_log_prob'].item())

    clear_output(wait=True)
    plt.figure(figsize=[4*3,4*1])

    plt.subplot(1, 3, 1)
    plt.title('Stage 1 SFT Loss')
    plt.plot(
        np.arange(len(losses)) * num_minibatches,
        losses,
        color='blue',
        label='train',)
    plt.plot(
        np.arange(len(val_losses)) * num_minibatches * display_rate,
        val_losses,
        color='red',
        label='val',)
    plt.legend()

    plt.subplot(1, 3, 2)
    plt.title('Action Log Prob')
    plt.plot(
        np.arange(len(losses)) * num_minibatches,
        act_log_probs,
        color='blue',
        label='train',)
    plt.plot(
        np.arange(len(val_losses)) * num_minibatches * display_rate,
        val_act_log_probs,
        color='red',
        label='val',)
    plt.legend()

    plt.subplot(1, 3, 3)
    plt.title('Timestep Log Prob')
    plt.plot(
        np.arange(len(losses)) * num_minibatches,
        dist_log_probs,
        color='blue',
        label='train',)
    plt.plot(
        np.arange(len(val_losses)) * num_minibatches * display_rate,
        val_dist_log_probs,
        color='red',
        label='val',)
    plt.legend()

    plt.show()

pretrain_cpu_state = learner.get_state()
print(f'\nTotal Steps: {len(losses) * num_minibatches}')


In [None]:
batch = next(train_dataset)
print(jnp.mean(batch.observation['cur_pos'], axis=(0, 1)))
print(jnp.std(batch.observation['cur_pos'], axis=(0, 1)))
print(jnp.mean(batch.observation['cur_vel'], axis=(0, 1)))
print(jnp.std(batch.observation['cur_vel'], axis=(0, 1)))
print(jnp.mean(batch.observation['goal_pos'], axis=(0, 1)))
print(jnp.std(batch.observation['goal_pos'], axis=(0, 1)))
print(jnp.mean(batch.action, axis=(0, 1)))
print(jnp.std(batch.action, axis=(0, 1)))

# Visualize Policies after the Stage 1 Supervised-Finetuning process

In [None]:
# Get the learner state and compile a CPU policy
cpu_state = learner.get_state()
cpu_params = cpu_state.params

def _policy(obs, rng):
  obs = jax.tree.map(lambda x: x[None], obs)
  obs = normalize_obs(obs)
  obs = jnp.concatenate(
      [obs['cur_pos'], obs['cur_vel'], obs['goal_pos']], axis=-1)
  pred = timer_networks.network.apply(cpu_params, obs)
  act = timer_networks.sample_act(pred.act_dist_params, rng)
  # print(act)
  act = unnormalize_action(act)
  # print(act)
  dist_to_succ = distance_converter.network_format_to_distance(
      pred.dist_to_succ_dist_params.logits)
  extras = {
      'pred_dist_to_succ': dist_to_succ,
      'pred_dist_to_succ_dist_params': pred.dist_to_succ_dist_params,}
  return act[0], extras

policy = jax.jit(_policy, backend='cpu')

In [None]:
def get_distance_plot(distances, logits):
  fig, ax = plt.subplots(
      figsize=(2 * RENDER_HEIGHT_INCHES, RENDER_HEIGHT_INCHES), dpi=DPI)
  plt.clf()

  plt.subplot(1, 2, 1)
  plt.plot(distances, color='blue', linewidth=4)
  plt.ylim(min_distance, max_distance)
  plt.title('E[steps-to-go]', fontsize=18, fontweight='bold')

  ax = plt.gca()
  # Make the tick marks and their labels thick
  ax.tick_params(axis='both', which='major', width=2, length=6, labelsize=14)
  ax.tick_params(axis='both', which='minor', width=1.5, length=4)

  for label in ax.get_xticklabels() + ax.get_yticklabels():
      label.set_fontweight('bold')

  # Set the axes labels with thicker font
  ax.set_xlabel('Episode Step', fontsize=16, fontweight='bold')
  ax.set_ylabel('E[steps-to-go]', fontsize=16, fontweight='bold')
  # Make the axes lines thicker
  for spine in ax.spines.values():
      spine.set_linewidth(4)  # Adjust the thickness here

  plt.subplot(1, 2, 2)
  probs = jax.nn.softmax(logits, axis=-1)
  plt.bar(
      np.linspace(
          min_distance,
          max_distance,
          probs.shape[0] + 1,
          endpoint=True)[:-1],
      probs,
      width=(max_distance - min_distance) / probs.shape[0] * 0.8,
      color='blue')
  plt.ylim(0., 1.)
  plt.title(r'p(steps-to-go) at Curr Frame', fontsize=18, fontweight='bold')

  ax = plt.gca()
  # Make the tick marks and their labels thick
  ax.tick_params(axis='both', which='major', width=2, length=6, labelsize=14)
  ax.tick_params(axis='both', which='minor', width=1.5, length=4)

  for label in ax.get_xticklabels() + ax.get_yticklabels():
      label.set_fontweight('bold')

  # Set the axes labels with thicker font
  ax.set_xlabel('Step-to-go', fontsize=16, fontweight='bold')
  ax.set_ylabel('Probability', fontsize=16, fontweight='bold')
  # Make the axes lines thicker
  for spine in ax.spines.values():
      spine.set_linewidth(4)  # Adjust the thickness here

  plt.tight_layout()

  # Render the plot using FigureCanvasAgg
  canvas = FigureCanvas(fig)
  canvas.draw()

  # Convert the rendered image to a numpy array
  ##### OLD VERSION (matplotlib3.10 does not support tostring_rgb()) #####
  # width, height = fig.get_size_inches() * fig.get_dpi()
  # image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
  # image = image.reshape(int(height), int(width), 3)
  ##### NEW VERSION #####
  canvas.draw()
  width, height = canvas.get_width_height()
  image = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
  image = image.reshape(height, width, 4)
  image = image[:, :, :3]

  plt.close(fig)

  return image

def generate_distance_plots(all_extras):
  imgs = []
  for i in range(all_extras['pred_dist_to_succ'].shape[0]):
    dist_preds = all_extras['pred_dist_to_succ'][:i+1]
    logits = all_extras['pred_dist_to_succ_dist_params'].logits[i]
    img = get_distance_plot(dist_preds, logits)
    imgs.append(img)
  return imgs

def generate_policy_traj(policy, title):
  imgs = []
  all_extras = []
  ts = env.reset()
  t = 0
  succ = env.success()
  key = jax.random.PRNGKey(42)
  imgs.append(env.render(title=title))

  while (not succ) and t < max_distance:
    obs = ts.observation
    key, sub_key = jax.random.split(key)
    act, extras = policy(obs, sub_key)
    all_extras.append(jax.tree.map(lambda x: x[0], extras))
    ts = env.step(act)
    imgs.append(env.render(title=title))
    succ = env.success()
    t += 1

  # repeat the last all_extras to match imgs len
  all_extras.append(all_extras[-1])

  all_extras = jax.tree.map(lambda *xs: np.stack(xs), *all_extras)

  return imgs, all_extras

def get_trajectory_visualization(policy, title, return_sub_images: bool = False):
  imgs, all_extras = generate_policy_traj(policy, title)
  plot_imgs = generate_distance_plots(all_extras)
  video_imgs = []
  for (x, y) in zip(imgs, plot_imgs):
    video_imgs.append(np.concatenate([x, y], axis=1))

  if return_sub_images:
    return video_imgs, (imgs, plot_imgs)
  return video_imgs


num_sft_trajs_vis = 10
sft_trajs = []
for i in range(num_sft_trajs_vis):
  print(f'\nGenerating traj {i + 1} of {num_sft_trajs_vis}...')
  sft_trajs.extend(get_trajectory_visualization(policy, 'Stage 1 SFT Policy'))

print('Creating video...')
clip = ImageSequenceClip(sft_trajs, fps=10)
display_in_notebook(clip, fps=10, width=3 * 512, maxduration=1_000_000)

# RL Utilities

In [None]:
def _timer_rollout_policy(params, normalized_obs, rng):
  normalized_obs = jnp.concatenate(
      [normalized_obs['cur_pos'], normalized_obs['cur_vel'], normalized_obs['goal_pos']], axis=-1)
  preds = timer_networks.network.apply(params, normalized_obs)
  normalized_act = timer_networks.sample_act(preds.act_dist_params, rng)
  dist_pred = distance_converter.network_format_to_distance(
      preds.dist_to_succ_dist_params.logits)
  return normalized_act, {'pred_dist_to_succ': dist_pred}

timer_rollout_policy = jax.jit(_timer_rollout_policy, backend='cpu')
# timer_rollout_policy = _timer_rollout_policy

def evaluate_timer_rollout_policy(env, params, num_episodes):
  stats = []
  for eps_num in range(num_episodes):
    timesteps = []
    ts = env.reset()
    ts = ts._replace(reward=0.)
    timesteps.append(ts)

    key = jax.random.PRNGKey(42)

    while (not env.success()) and len(timesteps) < max_distance:
      key, sub_key = jax.random.split(key)
      cur_obs = ts.observation
      cur_obs = jax.tree.map(lambda x: x[None], cur_obs)
      cur_obs = normalize_obs(cur_obs)  # 1 x dims
      norm_act, _ = timer_rollout_policy(params, cur_obs, sub_key)
      unnorm_act = unnormalize_action(norm_act)
      ts = env.step(unnorm_act[0])
      timesteps.append(ts)

    episode_stats = {}
    episode_stats['success'] = env.success()
    episode_stats['return'] = sum(x.reward for x in timesteps)
    episode_stats['len'] = len(timesteps)
    stats.append(episode_stats)

  stats = jax.tree.map(lambda *xs: np.stack(xs), *stats)
  print(f'Success Rate: {np.mean(stats["success"])}')
  print(f'Returns: {np.mean(stats["return"]):.2f} +/- {np.std(stats["return"]):.2f}')
  print(f'Episode Lengths: {np.mean(stats["len"]):.2f} +/- {np.std(stats["len"]):.2f}')
  print(f'Max Episode Length: {np.max(stats["len"])}')
  print(f'Min Episode Length: {np.min(stats["len"])}')

evaluate_timer_rollout_policy(env, cpu_params, num_episodes=25)

In [None]:
class REINFORCETuple(NamedTuple):
  observation: NestedArray
  action: NestedArray
  weight: NestedArray

def generate_timer_reinforce_dataset(env, params, num_steps, gamma):
  total_steps = 0
  data_tuples = []
  key = jax.random.PRNGKey(42)

  data_stats = []

  while total_steps < num_steps:
    traj_obs = []
    traj_acts = []
    traj_dist_preds = []
    episode_stats = {}

    episode_steps = 0
    episode_return = 0.
    ts = env.reset()
    cur_obs = ts.observation
    cur_obs = jax.tree.map(lambda x: x[None], cur_obs)
    cur_obs = normalize_obs(cur_obs)  # 1 x dims
    traj_obs.append(cur_obs)

    while (not env.success()) and episode_steps < max_distance:
      sub_key, key = jax.random.split(key)
      norm_act, extras = timer_rollout_policy(params, cur_obs, sub_key)
      unnorm_act = unnormalize_action(norm_act)
      traj_acts.append(norm_act)
      traj_dist_preds.append(extras['pred_dist_to_succ'])

      ts = env.step(unnorm_act[0])
      episode_return += ts.reward
      cur_obs = ts.observation
      cur_obs = jax.tree.map(lambda x: x[None], cur_obs)
      cur_obs = normalize_obs(cur_obs)  # 1 x dims
      traj_obs.append(cur_obs)

      episode_steps += 1

    if episode_steps < 1:
      continue

    episode_stats['success'] = env.success()
    episode_stats['return'] = episode_return
    episode_stats['len'] = episode_steps

    total_steps += len(traj_acts)

    sub_key, key = jax.random.split(key)
    norm_act, extras = timer_rollout_policy(params, cur_obs, sub_key)
    traj_dist_preds.append(extras['pred_dist_to_succ'])

    traj_obs = jax.tree.map(lambda *xs: np.concatenate(xs), *traj_obs)
    traj_acts = jax.tree.map(
        lambda *xs: np.concatenate(xs), *traj_acts)
    traj_dist_preds = jax.tree.map(
        lambda *xs: np.concatenate(xs), *traj_dist_preds)

    rews = -1. * (traj_dist_preds[1:] - traj_dist_preds[:-1])
    weights = []
    temp = 0.
    for i in range(rews.shape[0] - 1, -1, -1):
      weights.append(rews[i] + gamma * temp)
      temp = weights[-1]
    weights = np.array(weights[::-1], dtype=np.float32)


    traj_tuples = REINFORCETuple(
        observation=jax.tree.map(lambda x: x[:-1], traj_obs),
        action=traj_acts,
        weight=weights,
    )
    data_tuples.append(traj_tuples)

    data_stats.append(episode_stats)

  data_tuples = jax.tree.map(
      lambda *xs: np.concatenate(xs), *data_tuples)
  data_stats = jax.tree.map(lambda *xs: np.stack(xs), *data_stats)
  return data_tuples, data_stats


# sanity check
tick = time.time()
reinforce_data, data_stats = generate_timer_reinforce_dataset(
    env, cpu_params, num_steps=2048, gamma=0.9)
print(f'Took {time.time() - tick:.2f} seconds')
pprint(jax.tree.map(lambda x: x.shape, reinforce_data))
print(jnp.mean(reinforce_data.observation['cur_pos'], axis=0))
print(jnp.std(reinforce_data.observation['cur_pos'], axis=0))
print(jnp.mean(reinforce_data.observation['cur_vel'], axis=0))
print(jnp.std(reinforce_data.observation['cur_vel'], axis=0))
print(jnp.mean(reinforce_data.observation['goal_pos'], axis=0))
print(jnp.std(reinforce_data.observation['goal_pos'], axis=0))
print(jnp.mean(reinforce_data.action, axis=0))
print(jnp.std(reinforce_data.action, axis=0))
print(jnp.mean(reinforce_data.weight, axis=0))
print(jnp.std(reinforce_data.weight, axis=0))

stats = data_stats
print(f'Success Rate: {np.mean(stats["success"])}')
print(f'Returns: {np.mean(stats["return"]):.2f} +/- {np.std(stats["return"]):.2f}')
print(f'Episode Lengths: {np.mean(stats["len"]):.2f} +/- {np.std(stats["len"]):.2f}')
print(f'Max Episode Length: {np.max(stats["len"])}')
print(f'Min Episode Length: {np.min(stats["len"])}')

# Train Stage 2 Self-Improvement

In [None]:
reinforce_global_minibatch_size = 64  # @param {type: "number"}
reinforce_global_batch_size = 2048  # @param {type: "number"}
reinforce_num_minibatches = (
    reinforce_global_batch_size // reinforce_global_minibatch_size)
per_device_reinforce_minibatch_size = reinforce_global_minibatch_size // jax.device_count()
per_device_reinforce_batch_size = reinforce_global_batch_size // jax.device_count()
num_devices = jax.device_count()

def reinforce_loss(params, minibatch: REINFORCETuple):
  obs = minibatch.observation
  obs = jnp.concatenate(
      [obs['cur_pos'], obs['cur_vel'], obs['goal_pos']], axis=-1)
  preds = timer_networks.network.apply(params, obs)
  act_log_probs = timer_networks.act_log_prob(
      preds.act_dist_params, minibatch.action)
  weights = minibatch.weight / float(max_distance)
  loss = -1. * jnp.mean((weights * act_log_probs))
  return loss, {'reinforce_loss': loss}

reinforce_loss_grad = jax.grad(reinforce_loss, has_aux=True)

def per_device_reinforce_step(
    state: TrainingState,
    minibatch: REINFORCETuple,) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
    minibatch_loss_grad, metrics = reinforce_loss_grad(state.params, minibatch)
    minibatch_loss_grad = jax.lax.pmean(
        minibatch_loss_grad, axis_name='devices'
    )
    updates, new_opt_state = optimizer.update(
        minibatch_loss_grad, state.opt_state
    )
    new_params = optax.apply_updates(state.params, updates)
    state = state._replace(params=new_params, opt_state=new_opt_state)
    return state, metrics

def scanned_per_device_reinforce_step(state: TrainingState, batch: DataTuple):
  def reshape_for_scan(x):
    new_shape = [
        reinforce_num_minibatches,
        per_device_reinforce_minibatch_size,
    ] + list(x.shape[1:])
    return jnp.reshape(x, new_shape)
  minibatches = jax.tree.map(reshape_for_scan, batch)
  state, metrics = jax.lax.scan(
      per_device_reinforce_step, state, minibatches, length=reinforce_num_minibatches)
  metrics = jax.tree.map(jnp.mean, metrics)
  return state, metrics

pmapped_scanned_per_device_reinforce_step = jax.pmap(
    scanned_per_device_reinforce_step,
    axis_name='devices',
    devices=jax.devices())

def _full_reinforce_step(state: TrainingState, batch: REINFORCETuple):
  def reshape_for_pmap(x):
    new_shape = [
        num_devices,
        reinforce_global_batch_size // num_devices,
    ] + list(x.shape[1:])
    return jnp.reshape(x, new_shape)
  batch = jax.tree.map(lambda x: reshape_for_pmap(x), batch)
  state, metrics = pmapped_scanned_per_device_reinforce_step(state, batch)
  metrics = jax.tree.map(jnp.mean, metrics)
  return state, metrics

full_reinforce_step = _full_reinforce_step

def restore_reinforce_state(state: TrainingState):
    random_key = state.random_key
    device_count = jax.device_count()
    random_key = jax.random.split(
        random_key, num=device_count)
    random_key = jax.device_put_sharded(
        [random_key[i] for i in range(device_count)],
        jax.devices())

    state = jax.device_put_replicated(state, jax.devices())
    state = state._replace(random_key=random_key)
    return state

# sanity checks
reinforce_state = restore_reinforce_state(pretrain_cpu_state)
reinforce_state, metrics = full_reinforce_step(
    reinforce_state,
    jax.tree.map(lambda x: x[:reinforce_global_batch_size], reinforce_data))

In [None]:
reinforce_state = restore_reinforce_state(pretrain_cpu_state)
reinforce_metrics = {}
reinforce_metrics_list_form = defaultdict(list)

In [None]:
# Stage 2 Self-Improvement Train Loop
# You can keep rerunning this cell if you would like to continue the training
# for more iterations.

# Number of REINFORCE steps to perform
num_reinforce_sgd_steps = 2048  # @param {type: "number"}
assert num_reinforce_sgd_steps % reinforce_num_minibatches == 0

env_steps_per_batch = reinforce_global_batch_size


for i in range(num_reinforce_sgd_steps // reinforce_num_minibatches):
  reinforce_data, data_stats = generate_timer_reinforce_dataset(
      env,
      get_from_first_device(reinforce_state, as_numpy=True).params,
      num_steps=env_steps_per_batch,
      gamma=0.9)  # default that we use in our work
      # gamma=1.)  # does not work, as expected
      # gamma=0.)  # works, but we know this is not a good idea because it's almost equivalent to single-step RL
  reinforce_state, metrics = full_reinforce_step(
      reinforce_state,
      jax.tree.map(lambda x: x[:reinforce_global_batch_size], reinforce_data))
  reinforce_metrics_list_form['reinforce_loss'].append(metrics['reinforce_loss'].item())
  reinforce_metrics_list_form['success_rate'].append(np.mean(data_stats['success']))
  reinforce_metrics_list_form['return_mean'].append(np.mean(data_stats['return']))
  reinforce_metrics_list_form['return_std'].append(np.std(data_stats['return']))
  reinforce_metrics_list_form['len_mean'].append(np.mean(data_stats['len']))
  reinforce_metrics_list_form['len_std'].append(np.std(data_stats['len']))
  reinforce_metrics_list_form['max_len'].append(np.max(data_stats['len']))
  reinforce_metrics_list_form['min_len'].append(np.min(data_stats['len']))
  reinforce_metrics = {}
  for k, v in reinforce_metrics_list_form.items():
    reinforce_metrics[k] = np.array(v)
  # pprint(reinforce_metrics)

  clear_output(wait=True)

  fig, axs = plt.subplots(1, 4, figsize=(20,5))  # creating 4 subplots

  X = np.arange(len(reinforce_metrics['reinforce_loss'])) * reinforce_num_minibatches

  # Plotting reinforce_loss
  axs[0].plot(X, reinforce_metrics['reinforce_loss'])
  axs[0].set_title('REINFORCE Loss')
  axs[0].set_xlabel('REINFORCE Steps')

  # Plotting success_rate
  axs[1].plot(X, reinforce_metrics['success_rate'])
  axs[1].set_title('Success Rate')
  axs[1].set_xlabel('REINFORCE Steps')

  # Plotting return_mean and return_std
  axs[2].plot(X, reinforce_metrics['return_mean'], label='Mean')
  axs[2].fill_between(X,
                      reinforce_metrics['return_mean'] - reinforce_metrics['return_std'],
                      reinforce_metrics['return_mean'] + reinforce_metrics['return_std'],
                      color='b', alpha=.1, label='Std deviation')
  axs[2].set_title('Return Mean and Std')
  axs[2].set_xlabel('REINFORCE Steps')

  # Plotting len_mean and len_std
  axs[3].plot(X, reinforce_metrics['len_mean'], label='Mean')
  axs[3].fill_between(X,
                      reinforce_metrics['len_mean'] - reinforce_metrics['len_std'],
                      reinforce_metrics['len_mean'] + reinforce_metrics['len_std'],
                      color='b', alpha=.1, label='Std')
  axs[3].set_title('Episode Length Mean and Std')
  axs[3].set_xlabel('REINFORCE Steps')

  # Adding legend to the plots that need it
  axs[2].legend()
  axs[3].legend()

  plt.show()


# Visualize policies after the Stage 2 Self-Improvement process

In [None]:
# Had to do this weird code duplication to reuse some other code
reinforce_cpu_params = get_from_first_device(
    reinforce_state, as_numpy=True).params

def _timer_eval_policy(obs, rng):
  obs = jax.tree.map(lambda x: x[None], obs)
  obs = normalize_obs(obs)
  obs = jnp.concatenate(
      [obs['cur_pos'], obs['cur_vel'], obs['goal_pos']], axis=-1)
  pred = timer_networks.network.apply(reinforce_cpu_params, obs)
  act = timer_networks.sample_act(pred.act_dist_params, rng)
  # print(act)
  act = unnormalize_action(act)
  # print(act)
  dist_to_succ = distance_converter.network_format_to_distance(
      pred.dist_to_succ_dist_params.logits)
  extras = {
      'pred_dist_to_succ': dist_to_succ,
      'pred_dist_to_succ_dist_params': pred.dist_to_succ_dist_params,}
  return act[0], extras

timer_eval_policy = jax.jit(_timer_eval_policy, backend='cpu')

num_self_improvement_trajs_vis = 5
self_improvement_trajs = []
for i in range(num_self_improvement_trajs_vis):
  print(f'\nGenerating traj {i + 1} of {num_self_improvement_trajs_vis}...')
  self_improvement_trajs.extend(get_trajectory_visualization(timer_eval_policy, 'Stage 2 Self-Improvement Policy'))

print('Creating video...')
clip = ImageSequenceClip(self_improvement_trajs, fps=10)
display_in_notebook(clip, fps=10, width=3 * 512, maxduration=1_000_000)


# Generating Paper Figures and Website Videos

In [None]:
num_figure_episodes = 10

In [None]:
sft_full_imgs = []
sft_env_imgs = []
sft_plot_imgs = []
for _ in range(num_figure_episodes):
  full_img, (env_img, plot_img) = get_trajectory_visualization(policy, '', return_sub_images=True)
  sft_full_imgs.append(full_img)
  sft_env_imgs.append(env_img)
  sft_plot_imgs.append(plot_img)

In [None]:
eps = []
for ep in sft_env_imgs:
  eps.extend(ep)
clip = ImageSequenceClip(eps, fps=10)
display_in_notebook(clip, fps=10, width=512, maxduration=1_000_000)

In [None]:
# Concatenate all the env images horizontally
figure_sft_env_imgs = []
# sft_env_imgs_to_use = [sft_env_imgs[i] for i in [0, 1, 2, 4, 5, 7, 8]]
sft_env_imgs_to_use = sft_env_imgs

for i, eps in enumerate(sft_env_imgs_to_use):
  img = eps[-1]
  if i == 0:
    img = img[:, :-30]
  elif i == len(sft_env_imgs_to_use) - 1:
    img = img[:, 30:]
  else:
    img = img[:, 30:-30]
  figure_sft_env_imgs.append(img)

figure_sft_env_imgs = np.concatenate(figure_sft_env_imgs, axis=1)
print(figure_sft_env_imgs.shape)

# Save image at full resolution
plt.imsave('ten_tight_pointmass_sft_env_imgs.png', figure_sft_env_imgs)

In [None]:
img.shape

In [None]:
self_improvement_full_imgs = []
self_improvement_env_imgs = []
self_improvement_plot_imgs = []
for _ in range(num_figure_episodes):
  full_imgs, (env_img, plot_img) = get_trajectory_visualization(timer_eval_policy, '', return_sub_images=True)
  self_improvement_full_imgs.append(full_imgs)
  self_improvement_env_imgs.append(env_img)
  self_improvement_plot_imgs.append(plot_img)

In [None]:
eps = []
for ep in self_improvement_env_imgs:
  eps.extend(ep)
clip = ImageSequenceClip(eps, fps=10)
display_in_notebook(clip, fps=10, width=512, maxduration=1_000_000)

In [None]:
# Concatenate all the env images horizontally
figure_self_improvement_env_imgs = []
# self_improvement_env_imgs_to_use = [self_improvement_env_imgs[i] for i in [0, 1, 3, 5, 6, 7, 8]]
self_improvement_env_imgs_to_use = self_improvement_env_imgs

for i, eps in enumerate(self_improvement_env_imgs_to_use):
  img = eps[-1]
  if i == 0:
    img = img[:, :-30]
  elif i == len(self_improvement_env_imgs_to_use) - 1:
    img = img[:, 30:]
  else:
    img = img[:, 30:-30]
  figure_self_improvement_env_imgs.append(img)

figure_self_improvement_env_imgs = np.concatenate(figure_self_improvement_env_imgs, axis=1)
print(figure_self_improvement_env_imgs.shape)

# Save image at full resolution
plt.imsave('ten_tight_pointmass_self_improvement_env_imgs.png', figure_self_improvement_env_imgs)

In [None]:
data_env_imgs = []
inds = np.random.choice(len(episodes), size=num_figure_episodes, replace=False)

for i in inds:
  points = episodes[i].observation['cur_pos']
  goal_pos = episodes[i].observation['goal_pos'][0]
  imgs = []
  for t in range(points.shape[0]):
    imgs.append(env.render(title=f'', points=points[:t+1], goal_pos=goal_pos))

  data_env_imgs.append(imgs)

In [None]:
eps = []
for ep in data_env_imgs:
  eps.extend(ep)
clip = ImageSequenceClip(eps, fps=10)
display_in_notebook(clip, fps=10, width=512, maxduration=1_000_000)

In [None]:
# Concatenate all the env images horizontally
figure_data_env_imgs = []
# data_env_imgs_to_use = [data_env_imgs[i] for i in [0, 1, 5, 6, 7, 8, 9]]
data_env_imgs_to_use = data_env_imgs

for i, eps in enumerate(data_env_imgs_to_use):
  img = eps[-1]
  if i == 0:
    img = img[:, :-30]
  elif i == len(data_env_imgs_to_use) - 1:
    img = img[:, 30:]
  else:
    img = img[:, 30:-30]
  figure_data_env_imgs.append(img)

figure_data_env_imgs = np.concatenate(figure_data_env_imgs, axis=1)
print(figure_data_env_imgs.shape)

# Save image at full resolution
plt.imsave('ten_tight_pointmass_data_env_imgs.png', figure_data_env_imgs)

In [None]:
# Stack the images for data, sft, and self-improvement on top of each other into a single image
buffer = 255 * np.ones((300, figure_data_env_imgs.shape[1], 3), dtype=figure_data_env_imgs.dtype)
imgs_to_concat = [buffer, figure_data_env_imgs, buffer, figure_sft_env_imgs, buffer, figure_self_improvement_env_imgs]
figure_env_imgs = np.concatenate(imgs_to_concat, axis=0)
print(figure_env_imgs.shape)

# Save image at full resolution
plt.imsave('ten_tight_pointmass_figure_env_imgs.png', figure_env_imgs)