# MPII Human Pose with Tensorflow Datasets

This notebook looks at exploring the [MPII Human Pose](http://human-pose.mpi-inf.mpg.de) dataset using `tensorflow_datasets`.

In [0]:
!pip install tensorflow_datasets matplotlib scipy
# scipy is required for dataset preprocessing

In [0]:
%matplotlib inline
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_datasets.human_pose.mpii import skeleton
import matplotlib.pyplot as plt
import matplotlib.patches as patches

tf.compat.v1.enable_eager_execution()

## Loading the dataset

`tensorflow_datasets` will automatically download the relevant files, serialize them to disk and gather statistics. For `mpii_human_pose`, this is a releatively long process, though it is a one-time cost. After running to completion successfully, the relevant files will be stored and accessed immediately in future runs.

Progress may appear to hang displaying `0 examples [00:00, ? examples/s]`. It may take several minutes.

Due to the space requirements (the images are >12gb compressed, and the serialized form will take up about as much space again), you may have problems running this in a standard colab notebook. Check out [this post](https://research.google.com/colaboratory/local-runtimes.html) for instructions on how to run locally.

In [0]:
datasets, info = tfds.load('mpii_human_pose', with_info=True)

The `info` object contains helpful information including (but not limited to) an overview of features.

In [0]:
print(info.features)

Next we'll set up some basic visualization functions. Feel free to skip the details and refer back when necessary.

Box color codes:
* Blue: rough bounding boxes based on center and scale
* Red: head boxes with corresponding joint annotations
* Green (rare): head boxes without corresponding annotations

In [0]:
parent_indices = skeleton.s16.parent_indices
num_joints = skeleton.s16.num_joints
joint_names = skeleton.s16.joints
colors = [
    'red' if n.startswith('r_') else 'blue' if n.startswith('l_') else 'black'
     for n in joint_names]


def setup_fig(n=3):
  fig, axes = plt.subplots(1, n, figsize=(n*6, 6),
                           subplot_kw=dict(xticks=[], yticks=[]))
  if n == 1:
    axes = axes,
  return fig, axes


def label(key, value):
  return 'unknown' if value == -1 else info.features[key].int2str(value)
  

def vis_skeleton(single_joints, single_visible, ax):
  for child_index in range(num_joints):
    parent_index = parent_indices[child_index]
    if parent_index is None:
      # root joint
      continue
    if single_visible[child_index] and single_visible[parent_index]:
      child = single_joints[child_index]
      parent = single_joints[parent_index]
      ax.plot(
          [child[0], parent[0]], [child[1], parent[1]],
          color=colors[child_index])

      
def add_box(ymin, xmin, ymax, xmax, ax, **kwargs):
  width = xmax - xmin
  height = ymax - ymin
  rect = patches.Rectangle(
    (xmin, ymin), width, height,
    linewidth=1, facecolor='none', **kwargs)
  ax.add_patch(rect)


def example_vis(
    example, ax, show_heads=True, show_joints=True, show_scale_boxes=True,
    multiple_targets=True):
  
  if not multiple_targets:
    # add a leading dimension to simulate multiple targets
    for k in 'head_box',:
      if k in example:
        example[k] = tf.expand_dims(example[k], axis=0)
    targets = example['targets']
    for k in 'joints', 'visible', 'center', 'scale':
      targets[k] = tf.expand_dims(targets[k], axis=0)
  
  # print some metadata
  if 'category' in example:
    cat = example['category']
    print('category: %s' % label('category', example['category']))
    print('activity: %s' % label('activity', example['activity']))
  if 'youtube_id' in example:
    youtube_id = example['youtube_id'].numpy()
    if youtube_id == "UNK":
      youtube_link = youtube_id
    else:
      youtube_link = (
          'https://www.youtube.com/watch?v=%s' % example['youtube_id'].numpy())
    print('youtube_link: %s' % youtube_link)
  if 'frame_sec' in example:
    print('frame_sec: %d' % example['frame_sec'])

  image = example['image'].numpy()
  targets = example['targets']
  joints = targets['joints']
  visible = targets['visible']  # not all joints are visible
  
  visible_joints = tf.boolean_mask(joints, visible)
  joints = joints.numpy()
  visible = visible.numpy()
  visible_joints = visible_joints.numpy()
  ax.imshow(image)
  
  if show_joints:
    visible_joints = joints[visible]
    ax.scatter(visible_joints[..., 0], visible_joints[..., 1])
    for single_joints, single_vis in zip(joints, visible):
      vis_skeleton(single_joints, single_vis, ax)
  
  if show_heads:
    head_boxes = example['head_box'].numpy()
    num_skeletons = joints.shape[0]
    # most examples have a head box for each skeleton, but some have more.
    # heads with missing bodies are always at the end of the array
    for box in head_boxes[:num_skeletons]:
      add_box(*box, ax=ax, edgecolor='green')  # these heads have bodies

    for box in head_boxes[num_skeletons:]:
      add_box(*box, ax=ax, edgecolor='red')   # these don't - there may not be any
  
  if show_scale_boxes:
    centers = targets['center'].numpy()
    scales = targets['scale'].numpy()
    ax.scatter(centers[:, 0], centers[:, 1], marker='+')
  
    # scale is defined rather arbitrarily, albeit consistently
    for center, scale in zip(centers, scales):
      xmin, ymin = center - 100 * scale
      xmax, ymax = center + 100 * scale
      add_box(ymin, xmin, ymax, xmax, ax=ax, edgecolor='blue')


Let's have a quick look at what the basic dataset contains. Note the training dataset is shuffled by default during `load` so you can re-run the following cell for random results each time.

In [0]:
fig, axes = setup_fig()

for example, ax in zip(datasets['train'], axes):
  example_vis(example, ax)

The test dataset also contains center/scale annotations, but no joint/visibility/head boxes.

In [0]:
fig, axes = setup_fig()

for example, ax in zip(datasets['test'], axes):
  example_vis(example, ax, show_heads=False, show_joints=False)

## Single Target Inference

Many models focus on single-target human pose estimation. Rather than reprocess the dataset (which will take time and up even more space), we can use [tf.data.Dataset.flat_map](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#flat_map). In this case, we'll only consider sufficiently separated individuals. We'll [filter](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#filter) out those examples without any separated individuals before applying the `flat_map`.

Alternatively, we could randomly sample  features associated with skeletons using [tf.data.Dataset.map](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map). This would lead to smaller epochs, as each image would only be used once. During training this would be fine, but this would make out test set incomplete.

In [0]:
def crop_and_resize(example, target_width):
  image = example['image']
  targets = example['targets']

  joints = targets['joints']
  visible = targets['visible']
  visible = tf.logical_or(visible, tf.reduce_all(tf.equal(joints, -1), axis=-1))
  center = targets['center']
  scale = targets['scale']

  num_targets = tf.shape(joints)[0]

  indices = example['separated_individuals']
  num_targets = tf.size(indices)
  joints = tf.gather(joints, indices)
  visible = tf.gather(visible, indices)
  center = tf.gather(center, indices)
  scale = tf.gather(scale, indices)

  radius = scale * 100
  center =  tf.cast(center, tf.float32)

  orig_dims = tf.cast(tf.shape(image)[:2], tf.float32)
  expanded_radius = tf.expand_dims(radius, axis=-1)

  center_yx = tf.reverse(center, [-1])

  mins = center_yx - expanded_radius
  maxs = center_yx + expanded_radius

  boxes = tf.stack([mins, maxs], axis=-2)
  boxes = boxes / orig_dims
  boxes = tf.reshape(boxes, (-1, 4))

  box_ind = tf.zeros(shape=(num_targets,), dtype=tf.int32)
  image = tf.expand_dims(image, axis=0)

  target_shape = tf.constant(
      [target_width, target_width], dtype=tf.int32)

  image = tf.image.crop_and_resize(
    image, boxes, box_ind, target_shape)

  new_center = tf.tile([[target_width // 2]], (num_targets, 2))
  new_scale = tf.fill((num_targets,), target_width / 200)

  target_width_f = tf.cast(target_width, tf.float32)
  resize_factor = target_width_f / (2 * radius)

  # Don't forget to adjust joints!
  joints = tf.cast(joints, tf.float32)
  joints = joints - tf.expand_dims(center, axis=-2)
  joints = joints * tf.reshape(resize_factor, (-1, 1, 1))
  joints = joints + tf.expand_dims(tf.cast(new_center, tf.float32), axis=-2)

  return dict(
    image=image,
    targets=dict(
      joints=joints,
      visible=visible,
      scale=new_scale,
      center=new_center,
    ),
  )


def flat_map_fn(example):
  return tf.data.Dataset.from_tensor_slices(crop_and_resize(example, 128))


filtered_dataset = datasets['train'].filter(
    lambda x: tf.size(x['separated_individuals']) > 0)

single_dataset = filtered_dataset.flat_map(flat_map_fn)

fig, axes = setup_fig()

for example, ax in zip(single_dataset, axes):
  example['image'] /= 255
  example_vis(example, ax, multiple_targets=False, show_heads=False)

## Data Augmentation

Finally, let's do some random color adjustments to augment our dataset and normalize the color values using [tf.data.Dataset.map](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map).

In [0]:
def preprocess_image(example, perturb=True):
  image = example['image']
  
  if perturb:
    # augment dataset by perturbing colors
    image = tf.image.random_brightness(image, max_delta=63. / 255.)
    image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
    image = tf.image.random_contrast(image, lower=0.2, upper=1.8)

  example['image'] = tf.image.per_image_standardization(image)
  return example


fig, axes = setup_fig()
  
single_dataset = single_dataset.map(
    preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
for example, ax in zip(single_dataset, axes):
  image = example['image']
#   renormalize for visualization only
  image = image - tf.reduce_min(image, axis=(0, 1))
  image = image / tf.reduce_max(image, axis=(0, 1))
  example_vis(
      example, ax, multiple_targets=False, show_heads=False)

Exercise: random data augmentation can improve network performance. Above we have randomly modified image properties. Play with the above `crop_and_resize_and_augment` function to randomly perturb the center/scale. If you're feeling brave, you can also try:
* flipping the image left-right. You'll also have to change the `x`-values of your joints and reorder indices (left hands look different to right hands!). Check out `tensorflow_datasets.human_pos.skeleton.Skeleton.flip_left_right_indices`);
* use a third-party image manipulation library like `cv2` along with `tf.py_function` to randomly _rotate_ your examples by a small angle. You'll also have to rotate your joints.
* Some models target heatmaps rather than single points. Transform `joints` into a heatmap by setting values on a grid equal to a Gaussian centered at each joint position.
 

## Make Something Awesome!

Once you're happy with your input examples, use standard `tf.data.Dataset` tools, wire up a model, choose a loss function and away you go! Note `tfds.load` automatically applies `tf.data.Dataset.prefetch`, so no need to call it again.

In [0]:
single_dataset = single_dataset.repeat().shuffle(1024).batch(32)
single_dataset = single_dataset.prefetch(tf.data.experimental.AUTOTUNE)

# model = make_model(...)
# model.fit(...)

In [0]:
# The complete pipeline for single-target inference for convenience
datasets = tfds.load('mpii_human_pose')
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_dataset = datasets['train'].flat_map(
    flat_map_fn).map(preprocess_image, num_parallel_calls=AUTOTUNE)
train_dataset = train_dataset.repeat().shuffle(1024).batch(32).prefetch(
    AUTOTUNE)
    
test_dataset = datasets['test'].flat_map(flat_map_fn).map(
    lambda image: preprocess_image(image, perturb=False),
    num_parallel_calls=AUTOTUNE).batch(32)

In [0]:
plt.show()  # for users who download the `.py` file