# Waymo Open Dataset Motion Tutorial

- Website: https://waymo.com/open
- GitHub: https://github.com/waymo-research/waymo-open-dataset

This tutorial demonstrates:
- How to decode and interpret the data.
- How to train a simple extrapolation model with Tensorflow.

Visit the [Waymo Open Dataset Website](https://waymo.com/open) to download the full dataset.

To use, open this notebook in [Colab](https://colab.research.google.com).

Uncheck the box "Reset all runtimes before running" if you run this colab directly from the remote kernel. Alternatively, you can make a copy before trying to run it by following "File > Save copy in Drive ...".

In [1]:
import math
import os
import uuid

from matplotlib import cm
import matplotlib.animation as animation

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
import itertools
import tensorflow.compat.v1 as tf

tf.enable_eager_execution()

# Visualize TF Example sample

In [48]:
roadgraph_features = {
    'roadgraph_samples/dir':
        tf.io.FixedLenFeature([20000, 3], tf.float32, default_value=None),
    'roadgraph_samples/id':
        tf.io.FixedLenFeature([20000, 1], tf.int64, default_value=None),
    'roadgraph_samples/type':
        tf.io.FixedLenFeature([20000, 1], tf.int64, default_value=None),
    'roadgraph_samples/valid':
        tf.io.FixedLenFeature([20000, 1], tf.int64, default_value=None),
    'roadgraph_samples/xyz':
        tf.io.FixedLenFeature([20000, 3], tf.float32, default_value=None),
    'roadgraph_segments/dir':
        tf.io.FixedLenFeature([2000, 3], tf.float32, default_value=None),
    'roadgraph_segments/id':
        tf.io.FixedLenFeature([2000, 1], tf.int64, default_value=None),
    'roadgraph_segments/type':
        tf.io.FixedLenFeature([2000, 1], tf.int64, default_value=None),
    'roadgraph_segments/valid':
        tf.io.FixedLenFeature([2000, 1], tf.int64, default_value=None),
    'roadgraph_segments/xyz':
        tf.io.FixedLenFeature([2000, 3], tf.float32, default_value=None),
}

# Features for the autonomous vehicle.
sdc_features = {
    'sdc/id':
        tf.io.FixedLenFeature([1], tf.float32, default_value=None),
    'sdc/current/bbox_yaw':
        tf.io.FixedLenFeature([1], tf.float32, default_value=None),
    'sdc/current/height':
        tf.io.FixedLenFeature([1], tf.float32, default_value=None),
    'sdc/current/length':
        tf.io.FixedLenFeature([1], tf.float32, default_value=None),
    'sdc/current/timestamp_micros':
        tf.io.FixedLenFeature([1], tf.int64, default_value=None),
    'sdc/current/valid':
        tf.io.FixedLenFeature([1], tf.int64, default_value=None),
    'sdc/current/vel_yaw':
        tf.io.FixedLenFeature([1], tf.float32, default_value=None),
    'sdc/current/velocity_x':
        tf.io.FixedLenFeature([1], tf.float32, default_value=None),
    'sdc/current/velocity_y':
        tf.io.FixedLenFeature([1], tf.float32, default_value=None),
    'sdc/current/width':
        tf.io.FixedLenFeature([1], tf.float32, default_value=None),
    'sdc/current/x':
        tf.io.FixedLenFeature([1], tf.float32, default_value=None),
    'sdc/current/y':
        tf.io.FixedLenFeature([1], tf.float32, default_value=None),
    'sdc/current/z':
        tf.io.FixedLenFeature([1], tf.float32, default_value=None),
    'sdc/future/bbox_yaw':
        tf.io.FixedLenFeature([80], tf.float32, default_value=None),
    'sdc/future/height':
        tf.io.FixedLenFeature([80], tf.float32, default_value=None),
    'sdc/future/length':
        tf.io.FixedLenFeature([80], tf.float32, default_value=None),
    'sdc/future/timestamp_micros':
        tf.io.FixedLenFeature([80], tf.int64, default_value=None),
    'sdc/future/valid':
        tf.io.FixedLenFeature([80], tf.int64, default_value=None),
    'sdc/future/vel_yaw':
        tf.io.FixedLenFeature([80], tf.float32, default_value=None),
    'sdc/future/velocity_x':
        tf.io.FixedLenFeature([80], tf.float32, default_value=None),
    'sdc/future/velocity_y':
        tf.io.FixedLenFeature([80], tf.float32, default_value=None),
    'sdc/future/width':
        tf.io.FixedLenFeature([80], tf.float32, default_value=None),
    'sdc/future/x':
        tf.io.FixedLenFeature([80], tf.float32, default_value=None),
    'sdc/future/y':
        tf.io.FixedLenFeature([80], tf.float32, default_value=None),
    'sdc/future/z':
        tf.io.FixedLenFeature([80], tf.float32, default_value=None),
    'sdc/past/bbox_yaw':
        tf.io.FixedLenFeature([10], tf.float32, default_value=None),
    'sdc/past/height':
        tf.io.FixedLenFeature([10], tf.float32, default_value=None),
    'sdc/past/length':
        tf.io.FixedLenFeature([10], tf.float32, default_value=None),
    'sdc/past/timestamp_micros':
        tf.io.FixedLenFeature([10], tf.int64, default_value=None),
    'sdc/past/valid':
        tf.io.FixedLenFeature([10], tf.int64, default_value=None),
    'sdc/past/vel_yaw':
        tf.io.FixedLenFeature([10], tf.float32, default_value=None),
    'sdc/past/velocity_x':
        tf.io.FixedLenFeature([10], tf.float32, default_value=None),
    'sdc/past/velocity_y':
        tf.io.FixedLenFeature([10], tf.float32, default_value=None),
    'sdc/past/width':
        tf.io.FixedLenFeature([10], tf.float32, default_value=None),
    'sdc/past/x':
        tf.io.FixedLenFeature([10], tf.float32, default_value=None),
    'sdc/past/y':
        tf.io.FixedLenFeature([10], tf.float32, default_value=None),
    'sdc/past/z':
        tf.io.FixedLenFeature([10], tf.float32, default_value=None),
}

# Features of other agents.
state_features = {
    'state/id':
        tf.io.FixedLenFeature([128], tf.float32, default_value=None),
    'state/type':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/bbox_yaw':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/height':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/length':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/timestamp_micros':
        tf.io.FixedLenFeature([128, 1], tf.int64, default_value=None),
    'state/current/valid':
        tf.io.FixedLenFeature([128, 1], tf.int64, default_value=None),
    'state/current/vel_yaw':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/velocity_x':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/velocity_y':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/width':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/x':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/y':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/z':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/future/bbox_yaw':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/height':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/length':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/timestamp_micros':
        tf.io.FixedLenFeature([128, 80], tf.int64, default_value=None),
    'state/future/valid':
        tf.io.FixedLenFeature([128, 80], tf.int64, default_value=None),
    'state/future/vel_yaw':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/velocity_x':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/velocity_y':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/width':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/x':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/y':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/z':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/past/bbox_yaw':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/height':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/length':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/timestamp_micros':
        tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None),
    'state/past/valid':
        tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None),
    'state/past/vel_yaw':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/velocity_x':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/velocity_y':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/width':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/x':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/y':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/z':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
}

traffic_light_features = {
    'traffic_light_state/current/state':
        tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),
    'traffic_light_state/current/valid':
        tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),
    'traffic_light_state/current/x':
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    'traffic_light_state/current/y':
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    'traffic_light_state/current/z':
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    'traffic_light_state/past/state':
        tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None),
    'traffic_light_state/past/valid':
        tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None),
    'traffic_light_state/past/x':
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
    'traffic_light_state/past/y':
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
    'traffic_light_state/past/z':
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
}

features_description = {}
features_description.update(roadgraph_features)
features_description.update(sdc_features)
features_description.update(state_features)
features_description.update(traffic_light_features)

In [50]:
# A tfrecord containing tf.Example protos as downloaded from the Waymo dataset
# webpage.

# Replace this path with your own tfrecords.
#FILENAME = '/content/waymo-od-motion/tutorial/.../tfexample.tfrecord'
FILENAME = '/tmp/test_tfexample.tfrecord'


dataset = tf.data.TFRecordDataset(FILENAME, compression_type='')
data = next(dataset.as_numpy_iterator())
parsed = tf.io.parse_single_example(data, features_description)

In [51]:
def create_figure_and_axes(size_pixels):
  """Initializes a unique figure and axes for plotting."""
  fig, ax = plt.subplots(1, 1, num=uuid.uuid4())

  # Sets output image to pixel resolution.
  dpi = 100
  size_inches = size_pixels / dpi
  fig.set_size_inches([size_inches, size_inches])
  fig.set_dpi(dpi)
  fig.set_facecolor('white')
  ax.set_facecolor('white')
  ax.xaxis.label.set_color('black')
  ax.tick_params(axis='x', colors='black')
  ax.yaxis.label.set_color('black')
  ax.tick_params(axis='y', colors='black')
  fig.set_tight_layout(True)
  ax.grid(False)
  return fig, ax


def fig_canvas_image(fig):
  """Returns a [H, W, 3] uint8 np.array image from fig.canvas.tostring_rgb()."""
  # Just enough margin in the figure to display xticks and yticks.
  fig.subplots_adjust(
      left=0.08, bottom=0.08, right=0.98, top=0.98, wspace=0.0, hspace=0.0)
  fig.canvas.draw()
  data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
  return data.reshape(fig.canvas.get_width_height()[::-1] + (3,))


def get_colormap(num_agents):
  """Compute a color map array of shape [num_agents, 4]."""
  colors = cm.get_cmap('jet', num_agents)
  colors = colors(range(num_agents))
  np.random.shuffle(colors)
  return colors


def get_viewport(all_states, all_states_mask):
  """Gets the region containing the data.

  Args:
    all_states: states of agents as an array of shape [num_agents, num_steps,
      2].
    all_states_mask: binary mask of shape [num_agents, num_steps] for
      `all_states`.

  Returns:
    center_y: float. y coordinate for center of data.
    center_x: float. x coordinate for center of data.
    width: float. Width of data.
  """
  valid_states = all_states[all_states_mask]
  all_y = valid_states[..., 1]
  all_x = valid_states[..., 0]

  center_y = (np.max(all_y) + np.min(all_y)) / 2
  center_x = (np.max(all_x) + np.min(all_x)) / 2

  range_y = np.ptp(all_y)
  range_x = np.ptp(all_x)

  width = max(range_y, range_x)

  return center_y, center_x, width


def visualize_one_step(states,
                       mask,
                       roadgraph,
                       title,
                       center_y,
                       center_x,
                       width,
                       color_map,
                       size_pixels=1000):
  """Generate visualization for a single step."""

  # Create figure and axes.
  fig, ax = create_figure_and_axes(size_pixels=size_pixels)

  # Plot roadgraph.
  rg_pts = roadgraph[:, :2].T
  ax.plot(rg_pts[0, :], rg_pts[1, :], 'k.', alpha=1, ms=2)

  masked_x = states[:, 0][mask]
  masked_y = states[:, 1][mask]
  colors = color_map[mask]

  # Plot agent current position.
  ax.scatter(
      masked_x,
      masked_y,
      marker='o',
      linewidths=3,
      color=colors,
  )

  # Title.
  ax.set_title(title)

  # Set axes.  Should be at least 10m on a side and cover 160% of agents.
  size = max(10, width * 1.0)
  ax.axis([
      -size / 2 + center_x, size / 2 + center_x, -size / 2 + center_y,
      size / 2 + center_y
  ])
  ax.set_aspect('equal')

  image = fig_canvas_image(fig)
  plt.close(fig)
  return image


def visualize_all_agents_smooth(
    decoded_example,
    size_pixels=1000,
):
  """Visualizes all agent predicted trajectories in a serie of images.

  Args:
    decoded_example: Dictionary containing agent info about all modeled agents.
    size_pixels: The size in pixels of the output image.

  Returns:
    T of [H, W, 3] uint8 np.arrays of the drawn matplotlib's figure canvas.
  """
  # [num_agents, num_past_steps, 2] float32.
  past_states = tf.stack(
      [decoded_example['state/past/x'], decoded_example['state/past/y']],
      -1).numpy()
  past_states_mask = decoded_example['state/past/valid'].numpy() > 0.0

  # [num_agents, 1, 2] float32.
  current_states = tf.stack(
      [decoded_example['state/current/x'], decoded_example['state/current/y']],
      -1).numpy()
  current_states_mask = decoded_example['state/current/valid'].numpy() > 0.0

  # [num_agents, num_future_steps, 2] float32.
  future_states = tf.stack(
      [decoded_example['state/future/x'], decoded_example['state/future/y']],
      -1).numpy()
  future_states_mask = decoded_example['state/future/valid'].numpy() > 0.0

  # [num_points, 3] float32.
  roadgraph_xyz = decoded_example['roadgraph_samples/xyz'].numpy()

  num_agents, num_past_steps, _ = past_states.shape
  num_future_steps = future_states.shape[1]

  color_map = get_colormap(num_agents)

  # [num_agens, num_past_steps + 1 + num_future_steps, depth] float32.
  all_states = np.concatenate([past_states, current_states, future_states], 1)

  # [num_agens, num_past_steps + 1 + num_future_steps] float32.
  all_states_mask = np.concatenate(
      [past_states_mask, current_states_mask, future_states_mask], 1)

  center_y, center_x, width = get_viewport(all_states, all_states_mask)

  images = []

  # Generate images from past time steps.
  for i, (s, m) in enumerate(
      zip(
          np.split(past_states, num_past_steps, 1),
          np.split(past_states_mask, num_past_steps, 1))):
    im = visualize_one_step(s[:, 0], m[:, 0], roadgraph_xyz,
                            'past: %d' % (num_past_steps - i), center_y,
                            center_x, width, color_map, size_pixels)
    images.append(im)

  # Generate one image for the current time step.
  s = current_states
  m = current_states_mask

  im = visualize_one_step(s[:, 0], m[:, 0], roadgraph_xyz, 'current', center_y,
                          center_x, width, color_map, size_pixels)
  images.append(im)

  # Generate images from future time steps.
  for i, (s, m) in enumerate(
      zip(
          np.split(future_states, num_future_steps, 1),
          np.split(future_states_mask, num_future_steps, 1))):
    im = visualize_one_step(s[:, 0], m[:, 0], roadgraph_xyz,
                            'future: %d' % (i + 1), center_y, center_x, width,
                            color_map, size_pixels)
    images.append(im)

  return images


images = visualize_all_agents_smooth(parsed)

In [56]:
def create_animation(images):
  """ Creates a Matplotlib animation of the given images.

  Args:
    images: A list of numpy arrays representing the images.

  Returns:
    A matplotlib.animation.Animation.

  Usage:
    anim = create_animation(images)
    anim.save('/tmp/animation.avi')
    HTML(anim.to_html5_video())
  """

  plt.ioff()
  fig, ax = plt.subplots()
  dpi = 100
  size_inches = 1000 / dpi
  fig.set_size_inches([size_inches, size_inches])
  plt.ion()

  def animate_func(i):
    ax.imshow(images[i])
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid('off')

  anim = animation.FuncAnimation(
      fig, animate_func, frames=len(images) // 2, interval=100)
  plt.close(fig)
  return anim


anim = create_animation(images)
HTML(anim.to_html5_video())

# Train a simple extrapolator with TF

In [None]:
def _parse(value):
  decoded_example = tf.io.parse_single_example(value, features_description)

  past_states = tf.stack(
      [decoded_example['state/past/x'], decoded_example['state/past/y']], -1)
  cur_states = tf.stack(
      [decoded_example['state/current/x'], decoded_example['state/current/y']],
      -1)

  states = tf.concat([past_states, cur_states], 1)

  future_states = tf.stack(
      [decoded_example['state/future/x'], decoded_example['state/future/y']],
      -1)

  inputs = {
      'states': states,
      'future_states': future_states,
  }
  return inputs


class ExtrapolationModel(tf.keras.Model):

  def __init__(self, num_agents, num_states_steps, num_future_steps):
    super(ExtrapolationModel, self).__init__()
    self._num_agents = num_agents
    self._num_states_steps = num_states_steps
    self._num_future_steps = num_future_steps
    self.regressor = tf.keras.layers.Dense(num_future_steps * 2)

  def call(self, inputs):
    states = inputs['states']
    print(states)
    states = tf.reshape(states, (-1, self._num_states_steps * 2))
    pred = self.regressor(states)
    pred = tf.reshape(pred, (-1, self._num_agents, self._num_future_steps, 2))

    loss = tf.keras.losses.MeanSquaredError()(pred, inputs['future_states'])
    self.add_loss(loss)
    return pred


dataset = tf.data.TFRecordDataset(FILENAME)
dataset = dataset.map(_parse)
batch = dataset.batch(1)

model = ExtrapolationModel(32, 10, 60)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

model.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
model.fit(batch, epochs=2, batch_size=64)

# TODO(chaiy): add eval once metrics are ready.