Skip to content
Cannot retrieve contributors at this time
# coding=utf-8
# Copyright 2021 The Tensor2Tensor Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""SV2P: Stochastic Variational Video Prediction.
based on the following paper:
by Mohammad Babaeizadeh, Chelsea Finn, Dumitru Erhan,
Roy H. Campbell and Sergey Levine
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.layers import common_layers
from tensor2tensor.layers import common_video
from tensor2tensor.layers import discretization
from import base
from import base_vae
from tensor2tensor.utils import contrib
from tensor2tensor.utils import registry
import tensorflow.compat.v1 as tf
tfl = tf.layers
tfcl = contrib.layers()
class NextFrameSv2p(base.NextFrameBase, base_vae.NextFrameBaseVae):
"""Stochastic Variational Video Prediction From Basic Model!"""
def is_recurrent_model(self):
return True
def tinyify(self, array):
return common_video.tinyify(
array, self.hparams.tiny_mode, self.hparams.small_mode)
def bottom_part_tower(self, input_image, input_reward, action, latent,
lstm_state, lstm_size, conv_size, concat_latent=False):
"""The bottom part of predictive towers.
With the current (early) design, the main prediction tower and
the reward prediction tower share the same arcitecture. TF Scope can be
adjusted as required to either share or not share the weights between
the two towers.
input_image: the current image.
input_reward: the current reward.
action: the action taken by the agent.
latent: the latent vector.
lstm_state: the current internal states of conv lstms.
lstm_size: the size of lstms.
conv_size: the size of convolutions.
concat_latent: whether or not to concatenate the latent at every step.
- the output of the partial network.
- intermidate outputs for skip connections.
lstm_func = common_video.conv_lstm_2d
tile_and_concat = common_video.tile_and_concat
input_image = common_layers.make_even_size(input_image)
concat_input_image = tile_and_concat(
input_image, latent, concat_latent=concat_latent)
layer_id = 0
enc0 = tfl.conv2d(
conv_size[0], [5, 5],
strides=(2, 2),
enc0 = tfcl.layer_norm(enc0, scope="layer_norm1")
hidden1, lstm_state[layer_id] = lstm_func(
enc0, lstm_state[layer_id], lstm_size[layer_id], name="state1")
hidden1 = tile_and_concat(hidden1, latent, concat_latent=concat_latent)
hidden1 = tfcl.layer_norm(hidden1, scope="layer_norm2")
layer_id += 1
hidden2, lstm_state[layer_id] = lstm_func(
hidden1, lstm_state[layer_id], lstm_size[layer_id], name="state2")
hidden2 = tfcl.layer_norm(hidden2, scope="layer_norm3")
hidden2 = common_layers.make_even_size(hidden2)
enc1 = tfl.conv2d(hidden2, hidden2.get_shape()[3], [3, 3], strides=(2, 2),
padding="SAME", activation=tf.nn.relu, name="conv2")
enc1 = tile_and_concat(enc1, latent, concat_latent=concat_latent)
layer_id += 1
if self.hparams.small_mode:
hidden4, enc2 = hidden2, enc1
hidden3, lstm_state[layer_id] = lstm_func(
enc1, lstm_state[layer_id], lstm_size[layer_id], name="state3")
hidden3 = tile_and_concat(hidden3, latent, concat_latent=concat_latent)
hidden3 = tfcl.layer_norm(hidden3, scope="layer_norm4")
layer_id += 1
hidden4, lstm_state[layer_id] = lstm_func(
hidden3, lstm_state[layer_id], lstm_size[layer_id], name="state4")
hidden4 = tile_and_concat(hidden4, latent, concat_latent=concat_latent)
hidden4 = tfcl.layer_norm(hidden4, scope="layer_norm5")
hidden4 = common_layers.make_even_size(hidden4)
enc2 = tfl.conv2d(hidden4, hidden4.get_shape()[3], [3, 3], strides=(2, 2),
padding="SAME", activation=tf.nn.relu, name="conv3")
layer_id += 1
if action is not None:
enc2 = common_video.inject_additional_input(
enc2, action, "action_enc", self.hparams.action_injection)
if input_reward is not None:
enc2 = common_video.inject_additional_input(
enc2, input_reward, "reward_enc")
if latent is not None and not concat_latent:
with tf.control_dependencies([latent]):
enc2 = tf.concat([enc2, latent], axis=3)
enc3 = tfl.conv2d(enc2, hidden4.get_shape()[3], [1, 1], strides=(1, 1),
padding="SAME", activation=tf.nn.relu, name="conv4")
hidden5, lstm_state[layer_id] = lstm_func(
enc3, lstm_state[layer_id], lstm_size[layer_id], name="state5")
hidden5 = tfcl.layer_norm(hidden5, scope="layer_norm6")
hidden5 = tile_and_concat(hidden5, latent, concat_latent=concat_latent)
layer_id += 1
return hidden5, (enc0, enc1), layer_id
def reward_prediction(self, *args, **kwargs):
model = self.hparams.reward_model
if model == "basic":
return self.reward_prediction_basic(*args, **kwargs)
elif model == "big":
return self.reward_prediction_big(*args, **kwargs)
elif model == "mid":
return self.reward_prediction_mid(*args, **kwargs)
raise ValueError("Unknown reward model %s" % model)
def reward_prediction_basic(
self, input_images, input_reward, action, latent, mid_outputs):
del input_reward, action, latent, mid_outputs
x = input_images
x = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
x = tfl.dense(x, 128, activation=tf.nn.relu, name="reward_pred")
x = tf.expand_dims(x, axis=3)
return x
def reward_prediction_mid(
self, input_images, input_reward, action, latent, mid_outputs):
"""Builds a reward prediction network from intermediate layers."""
encoded = []
for i, output in enumerate(mid_outputs):
enc = output
enc = tfl.conv2d(enc, 64, [3, 3], strides=(1, 1), activation=tf.nn.relu)
enc = tfl.conv2d(enc, 32, [3, 3], strides=(2, 2), activation=tf.nn.relu)
enc = tfl.conv2d(enc, 16, [3, 3], strides=(2, 2), activation=tf.nn.relu)
enc = tfl.flatten(enc)
enc = tfl.dense(enc, 64, activation=tf.nn.relu, name="rew_enc_%d" % i)
x = encoded
x = tf.stack(x, axis=1)
x = tfl.flatten(x)
x = tfl.dense(x, 256, activation=tf.nn.relu, name="rew_dense1")
x = tfl.dense(x, 128, activation=tf.nn.relu, name="rew_dense2")
return x
def reward_prediction_big(
self, input_images, input_reward, action, latent, mid_outputs):
"""Builds a reward prediction network."""
del mid_outputs
conv_size = self.tinyify([32, 32, 16, 8])
with tf.variable_scope("reward_pred", reuse=tf.AUTO_REUSE):
x = tf.concat(input_images, axis=3)
x = tfcl.layer_norm(x)
if not self.hparams.small_mode:
x = tfl.conv2d(x, conv_size[1], [3, 3], strides=(2, 2),
activation=tf.nn.relu, name="reward_conv1")
x = tfcl.layer_norm(x)
# Inject additional inputs
if action is not None:
x = common_video.inject_additional_input(
x, action, "action_enc", self.hparams.action_injection)
if input_reward is not None:
x = common_video.inject_additional_input(x, input_reward, "reward_enc")
if latent is not None:
latent = tfl.flatten(latent)
latent = tf.expand_dims(latent, axis=1)
latent = tf.expand_dims(latent, axis=1)
x = common_video.inject_additional_input(x, latent, "latent_enc")
x = tfl.conv2d(x, conv_size[2], [3, 3], strides=(2, 2),
activation=tf.nn.relu, name="reward_conv2")
x = tfcl.layer_norm(x)
x = tfl.conv2d(x, conv_size[3], [3, 3], strides=(2, 2),
activation=tf.nn.relu, name="reward_conv3")
return x
def get_extra_loss(self,
latent_means=None, latent_stds=None,
true_frames=None, gen_frames=None):
"""Losses in addition to the default modality losses."""
del true_frames, gen_frames
return self.get_kl_loss(latent_means, latent_stds)
def construct_predictive_tower(
self, input_image, input_reward, action, lstm_state, latent,
# Main tower
lstm_func = common_video.conv_lstm_2d
frame_shape = common_layers.shape_list(input_image)
batch_size, img_height, img_width, color_channels = frame_shape
# the number of different pixel motion predictions
# and the number of masks for each of those predictions
num_masks = self.hparams.num_masks
upsample_method = self.hparams.upsample_method
tile_and_concat = common_video.tile_and_concat
lstm_size = self.tinyify([32, 32, 64, 64, 128, 64, 32])
conv_size = self.tinyify([32])
with tf.variable_scope("main", reuse=tf.AUTO_REUSE):
hidden5, skips, layer_id = self.bottom_part_tower(
input_image, input_reward, action, latent,
lstm_state, lstm_size, conv_size, concat_latent=concat_latent)
enc0, enc1 = skips
with tf.variable_scope("upsample1", reuse=tf.AUTO_REUSE):
enc4 = common_layers.cyclegan_upsample(
hidden5, num_outputs=hidden5.shape.as_list()[-1],
stride=[2, 2], method=upsample_method)
enc1_shape = common_layers.shape_list(enc1)
enc4 = enc4[:, :enc1_shape[1], :enc1_shape[2], :] # Cut to shape.
enc4 = tile_and_concat(enc4, latent, concat_latent=concat_latent)
hidden6, lstm_state[layer_id] = lstm_func(
enc4, lstm_state[layer_id], lstm_size[5], name="state6",
spatial_dims=enc1_shape[1:-1]) # 16x16
hidden6 = tile_and_concat(hidden6, latent, concat_latent=concat_latent)
hidden6 = tfcl.layer_norm(hidden6, scope="layer_norm7")
# Skip connection.
hidden6 = tf.concat(axis=3, values=[hidden6, enc1]) # both 16x16
layer_id += 1
with tf.variable_scope("upsample2", reuse=tf.AUTO_REUSE):
enc5 = common_layers.cyclegan_upsample(
hidden6, num_outputs=hidden6.shape.as_list()[-1],
stride=[2, 2], method=upsample_method)
enc0_shape = common_layers.shape_list(enc0)
enc5 = enc5[:, :enc0_shape[1], :enc0_shape[2], :] # Cut to shape.
enc5 = tile_and_concat(enc5, latent, concat_latent=concat_latent)
hidden7, lstm_state[layer_id] = lstm_func(
enc5, lstm_state[layer_id], lstm_size[6], name="state7",
spatial_dims=enc0_shape[1:-1]) # 32x32
hidden7 = tfcl.layer_norm(hidden7, scope="layer_norm8")
layer_id += 1
# Skip connection.
hidden7 = tf.concat(axis=3, values=[hidden7, enc0]) # both 32x32
with tf.variable_scope("upsample3", reuse=tf.AUTO_REUSE):
enc6 = common_layers.cyclegan_upsample(
hidden7, num_outputs=hidden7.shape.as_list()[-1],
stride=[2, 2], method=upsample_method)
enc6 = tfcl.layer_norm(enc6, scope="layer_norm9")
enc6 = tile_and_concat(enc6, latent, concat_latent=concat_latent)
if self.hparams.model_options == "DNA":
# Using largest hidden state for predicting untied conv kernels.
enc7 = tfl.conv2d_transpose(
[1, 1],
strides=(1, 1),
# Using largest hidden state for predicting a new image layer.
enc7 = tfl.conv2d_transpose(
[1, 1],
strides=(1, 1),
# This allows the network to also generate one image from scratch,
# which is useful when regions of the image become unoccluded.
transformed = [tf.nn.sigmoid(enc7)]
if self.hparams.model_options == "CDNA":
# cdna_input = tf.reshape(hidden5, [int(batch_size), -1])
cdna_input = tfl.flatten(hidden5)
transformed += common_video.cdna_transformation(
input_image, cdna_input, num_masks, int(color_channels),
self.hparams.dna_kernel_size, self.hparams.relu_shift)
elif self.hparams.model_options == "DNA":
# Only one mask is supported (more should be unnecessary).
if num_masks != 1:
raise ValueError("Only one mask is supported for DNA model.")
transformed = [
input_image, enc7,
self.hparams.dna_kernel_size, self.hparams.relu_shift)]
masks = tfl.conv2d(
enc6, filters=num_masks + 1, kernel_size=[1, 1],
strides=(1, 1), name="convt7", padding="SAME")
masks = masks[:, :img_height, :img_width, ...]
masks = tf.reshape(
tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])),
int(img_width), num_masks + 1])
mask_list = tf.split(
axis=3, num_or_size_splits=num_masks + 1, value=masks)
output = mask_list[0] * input_image
for layer, mask in zip(transformed, mask_list[1:]):
# TODO(mbz): take another look at this logic and verify.
output = output[:, :img_height, :img_width, :]
layer = layer[:, :img_height, :img_width, :]
output += layer * mask
# Map to softmax digits
if self.is_per_pixel_softmax:
output = tf.layers.dense(
output, self.hparams.problem.num_channels * 256, name="logits")
mid_outputs = [enc0, enc1, enc4, enc5, enc6]
return output, lstm_state, mid_outputs
def video_features(
self, all_frames, all_actions, all_rewards, all_raw_frames):
"""Video wide latent."""
del all_actions, all_rewards, all_raw_frames
if not self.hparams.stochastic_model:
return None, None, None
frames = tf.stack(all_frames, axis=1)
mean, std = self.construct_latent_tower(frames, time_axis=1)
latent = common_video.get_gaussian_tensor(mean, std)
return [latent, mean, std]
def next_frame(self, frames, actions, rewards, target_frame,
internal_states, video_features):
del target_frame
if self.has_policies or self.has_values:
raise NotImplementedError("Parameter sharing with policy not supported.")
latent, latent_mean, latent_std = video_features
frames, actions, rewards = frames[0], actions[0], rewards[0]
extra_loss = 0.0
if internal_states is None:
internal_states = [None] * (5 if self.hparams.small_mode else 7)
if latent_mean is not None:
extra_loss = self.get_extra_loss([latent_mean], [latent_std])
pred_image, internal_states, mid_outputs = self.construct_predictive_tower(
frames, None, actions, internal_states, latent)
if not self.has_rewards:
return pred_image, None, None, None, extra_loss, internal_states
pred_reward = self.reward_prediction(
pred_image, actions, rewards, latent, mid_outputs)
return pred_image, pred_reward, None, None, extra_loss, internal_states
class NextFrameSv2pDiscrete(NextFrameSv2p):
"""SV2P with discrete latent."""
def video_features(
self, all_frames, all_actions, all_rewards, all_raw_frames):
"""No video wide latent."""
del all_frames, all_actions, all_rewards, all_raw_frames
return None
def basic_conv_net(self, images, conv_size, scope):
"""Simple multi conv ln relu."""
conv_size = self.tinyify(conv_size)
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
x = images
for i, c in enumerate(conv_size):
if i > 0:
x = tf.nn.relu(x)
x = common_layers.make_even_size(x)
x = tfl.conv2d(x, c, [3, 3], strides=(2, 2),
activation=None, padding="SAME", name="conv%d" % i)
x = tfcl.layer_norm(x)
return x
def simple_discrete_latent_tower(self, input_image, target_image):
hparams = self.hparams
if self.is_predicting:
batch_size = common_layers.shape_list(input_image)[0]
rand = tf.random_uniform([batch_size, hparams.bottleneck_bits])
bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
return bits
conv_size = self.tinyify([64, 32, 32, 1])
pair = tf.concat([input_image, target_image], axis=-1)
posterior_enc = self.basic_conv_net(pair, conv_size, "posterior_enc")
posterior_enc = tfl.flatten(posterior_enc)
bits, _ = discretization.tanh_discrete_bottleneck(
return bits
def next_frame(self, frames, actions, rewards, target_frame,
internal_states, video_features):
del video_features
if self.has_pred_actions or self.has_values:
raise NotImplementedError("Parameter sharing with policy not supported.")
frames, actions, rewards = frames[0], actions[0], rewards[0]
if internal_states is None:
internal_states = [None] * (5 if self.hparams.small_mode else 7)
extra_loss = 0.0
latent = self.simple_discrete_latent_tower(frames, target_frame)
pred_image, internal_states, _ = self.construct_predictive_tower(
frames, None, actions, internal_states, latent, True)
if not self.has_rewards:
return pred_image, None, extra_loss, internal_states
pred_reward = self.reward_prediction(
pred_image, actions, rewards, latent)
return pred_image, pred_reward, None, None, extra_loss, internal_states
class NextFrameSv2pAtari(NextFrameSv2p):
"""SV2P with specific changes for atari pipeline."""
def init_internal_states(self):
# Hardcoded LSTM-CONV shapes.
# These sizes are calculated based on original atari frames.
# TODO(mbz): find a cleaner way of doing this maybe?!
batch_size = self.hparams.batch_size
shapes = [(batch_size, 53, 40, 8),
(batch_size, 53, 40, 8),
(batch_size, 27, 20, 16),
(batch_size, 27, 20, 16),
(batch_size, 53, 40, 8)]
with tf.variable_scope("clean_scope"):
# Initialize conv-lstm states with zeros
init = tf.zeros_initializer()
states = []
for i, shape in enumerate(shapes):
# every lstm-conv state has two variables named c and h.
c = tf.get_variable("c%d" % i, shape, trainable=False, initializer=init)
h = tf.get_variable("h%d" % i, shape, trainable=False, initializer=init)
states.append((c, h))
return states
def reset_internal_states_ops(self):
zeros = [(tf.zeros_like(c), tf.zeros_like(h))
for c, h in self.internal_states]
return self.save_internal_states_ops(zeros)
def load_internal_states_ops(self):
ops = [(c.read_value(), h.read_value()) for c, h in self.internal_states]
return ops
def save_internal_states_ops(self, internal_states):
ops = [[tf.assign(x[0], y[0]), tf.assign(x[1], y[1])]
for x, y in zip(self.internal_states, internal_states)]
return ops
class NextFrameSv2pLegacy(NextFrameSv2p):
"""Old SV2P code. Only for legacy reasons."""
def visualize_predictions(self, real_frames, gen_frames, actions=None):
def concat_on_y_axis(x):
x = tf.unstack(x, axis=1)
x = tf.concat(x, axis=1)
return x
frames_gd = common_video.swap_time_and_batch_axes(real_frames)
frames_pd = common_video.swap_time_and_batch_axes(gen_frames)
if actions is not None:
actions = common_video.swap_time_and_batch_axes(actions)
if self.is_per_pixel_softmax:
frames_pd_shape = common_layers.shape_list(frames_pd)
frames_pd = tf.reshape(frames_pd, [-1, 256])
frames_pd = tf.to_float(tf.argmax(frames_pd, axis=-1))
frames_pd = tf.reshape(frames_pd, frames_pd_shape[:-1] + [3])
frames_gd = concat_on_y_axis(frames_gd)
frames_pd = concat_on_y_axis(frames_pd)
if actions is not None:
actions = tf.clip_by_value(actions, 0, 1)
summary("action_vid", tf.cast(actions * 255, tf.uint8))
actions = concat_on_y_axis(actions)
side_by_side_video = tf.concat([frames_gd, frames_pd, actions], axis=2)
side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2)
tf.summary.image("full_video", side_by_side_video)
def get_input_if_exists(self, features, key, batch_size, num_frames):
if key in features:
x = features[key]
x = tf.zeros((batch_size, num_frames, 1, self.hparams.hidden_size))
return common_video.swap_time_and_batch_axes(x)
def construct_model(self,
"""Build convolutional lstm video predictor using CDNA, or DNA.
images: list of tensors of ground truth image sequences
there should be a 4D image ?xWxHxC for each timestep
actions: list of action tensors
each action should be in the shape ?x1xZ
rewards: list of reward tensors
each reward should be in the shape ?x1xZ
gen_images: predicted future image frames
gen_rewards: predicted future rewards
latent_mean: mean of approximated posterior
latent_std: std of approximated posterior
ValueError: if more than 1 mask specified for DNA model.
context_frames = self.hparams.video_num_input_frames
buffer_size = self.hparams.reward_prediction_buffer_size
if buffer_size == 0:
buffer_size = context_frames
if buffer_size > context_frames:
raise ValueError("Buffer size is bigger than context frames %d %d." %
(buffer_size, context_frames))
batch_size = common_layers.shape_list(images[0])[0]
ss_func = self.get_scheduled_sample_func(batch_size)
def process_single_frame(prev_outputs, inputs):
"""Process a single frame of the video."""
cur_image, input_reward, action = inputs
time_step, prev_image, prev_reward, frame_buf, lstm_states = prev_outputs
# sample from softmax (by argmax). this is noop for non-softmax loss.
prev_image = self.get_sampled_frame(prev_image)
generated_items = [prev_image]
groundtruth_items = [cur_image]
done_warm_start = tf.greater(time_step, context_frames - 1)
input_image, = self.get_scheduled_sample_inputs(
done_warm_start, groundtruth_items, generated_items, ss_func)
# Prediction
pred_image, lstm_states, _ = self.construct_predictive_tower(
input_image, None, action, lstm_states, latent)
if self.hparams.reward_prediction:
reward_input_image = self.get_sampled_frame(pred_image)
if self.hparams.reward_prediction_stop_gradient:
reward_input_image = tf.stop_gradient(reward_input_image)
with tf.control_dependencies([time_step]):
frame_buf = [reward_input_image] + frame_buf[:-1]
pred_reward = self.reward_prediction(frame_buf, None, action, latent)
pred_reward = common_video.decode_to_shape(
pred_reward, common_layers.shape_list(input_reward), "reward_dec")
pred_reward = prev_reward
time_step += 1
outputs = (time_step, pred_image, pred_reward, frame_buf, lstm_states)
return outputs
# Latent tower
latent = None
if self.hparams.stochastic_model:
latent_mean, latent_std = self.construct_latent_tower(images, time_axis=0)
latent = common_video.get_gaussian_tensor(latent_mean, latent_std)
# HACK: Do first step outside to initialize all the variables
lstm_states = [None] * (5 if self.hparams.small_mode else 7)
frame_buffer = [tf.zeros_like(images[0])] * buffer_size
inputs = images[0], rewards[0], actions[0]
init_image_shape = common_layers.shape_list(images[0])
if self.is_per_pixel_softmax:
init_image_shape[-1] *= 256
init_image = tf.zeros(init_image_shape, dtype=images.dtype)
prev_outputs = (tf.constant(0),
initializers = process_single_frame(prev_outputs, inputs)
first_gen_images = tf.expand_dims(initializers[1], axis=0)
first_gen_rewards = tf.expand_dims(initializers[2], axis=0)
inputs = (images[1:-1], rewards[1:-1], actions[1:-1])
outputs = tf.scan(process_single_frame, inputs, initializers)
gen_images, gen_rewards = outputs[1:3]
gen_images = tf.concat((first_gen_images, gen_images), axis=0)
gen_rewards = tf.concat((first_gen_rewards, gen_rewards), axis=0)
if self.hparams.stochastic_model:
return gen_images, gen_rewards, [latent_mean], [latent_std]
return gen_images, gen_rewards, None, None
def infer(self, features, *args, **kwargs):
"""Produce predictions from the model by running it."""
del args, kwargs
if "targets" not in features:
if "infer_targets" in features:
targets_shape = common_layers.shape_list(features["infer_targets"])
elif "inputs" in features:
targets_shape = common_layers.shape_list(features["inputs"])
targets_shape[1] = self.hparams.video_num_target_frames
raise ValueError("no inputs are given.")
features["targets"] = tf.zeros(targets_shape, dtype=tf.float32)
output, _ = self(features) # pylint: disable=not-callable
if not isinstance(output, dict):
output = {"targets": output}
x = output["targets"]
if self.is_per_pixel_softmax:
x_shape = common_layers.shape_list(x)
x = tf.reshape(x, [-1, x_shape[-1]])
x = tf.argmax(x, axis=-1)
x = tf.reshape(x, x_shape[:-1])
x = tf.squeeze(x, axis=-1)
x = tf.to_int64(tf.round(x))
output["targets"] = x
if self.hparams.reward_prediction:
output["target_reward"] = tf.argmax(output["target_reward"], axis=-1)
# only required for decoding.
output["outputs"] = output["targets"]
output["scores"] = output["targets"]
return output
def body(self, features):
hparams = self.hparams
batch_size = common_layers.shape_list(features["inputs"])[0]
# Swap time and batch axes.
input_frames = common_video.swap_time_and_batch_axes(features["inputs"])
target_frames = common_video.swap_time_and_batch_axes(features["targets"])
# Get actions if exist otherwise use zeros
input_actions = self.get_input_if_exists(
features, "input_action", batch_size, hparams.video_num_input_frames)
target_actions = self.get_input_if_exists(
features, "target_action", batch_size, hparams.video_num_target_frames)
# Get rewards if exist otherwise use zeros
input_rewards = self.get_input_if_exists(
features, "input_reward", batch_size, hparams.video_num_input_frames)
target_rewards = self.get_input_if_exists(
features, "target_reward", batch_size, hparams.video_num_target_frames)
all_actions = tf.concat([input_actions, target_actions], axis=0)
all_rewards = tf.concat([input_rewards, target_rewards], axis=0)
all_frames = tf.concat([input_frames, target_frames], axis=0)
# Each image is being used twice, in latent tower and main tower.
# This is to make sure we are using the *same* image for both, ...
# ... given how TF queues work.
# NOT sure if this is required at all. Doesn"t hurt though! :)
all_frames = tf.identity(all_frames)
gen_images, gen_rewards, latent_means, latent_stds = self.construct_model(
extra_loss = self.get_extra_loss(
# Visualize predictions in Tensorboard
if self.is_training:
self.visualize_predictions(all_frames[1:], gen_images)
# Ignore the predictions from the input frames.
# This is NOT the same as original paper/implementation.
predictions = gen_images[hparams.video_num_input_frames-1:]
reward_pred = gen_rewards[hparams.video_num_input_frames-1:]
reward_pred = tf.squeeze(reward_pred, axis=2) # Remove extra dimension.
# Swap back time and batch axes.
predictions = common_video.swap_time_and_batch_axes(predictions)
reward_pred = common_video.swap_time_and_batch_axes(reward_pred)
if self.is_training and hparams.internal_loss:
# add the loss for input frames as well.
extra_gts = all_frames[1:hparams.video_num_input_frames]
extra_gts = common_video.swap_time_and_batch_axes(extra_gts)
extra_pds = gen_images[:hparams.video_num_input_frames-1]
extra_pds = common_video.swap_time_and_batch_axes(extra_pds)
extra_raw_gts = features["inputs_raw"][:, 1:]
recon_loss = self.get_extra_internal_loss(
extra_raw_gts, extra_gts, extra_pds)
extra_loss += recon_loss
return_targets = predictions
if hparams.reward_prediction:
return_targets = {"targets": predictions, "target_reward": reward_pred}
return return_targets, extra_loss
class NextFrameSv2pTwoFrames(NextFrameSv2pLegacy):
"""Stochastic next-frame model with 2 frames posterior."""
def construct_model(self, images, actions, rewards):
images = tf.unstack(images, axis=0)
actions = tf.unstack(actions, axis=0)
rewards = tf.unstack(rewards, axis=0)
batch_size = common_layers.shape_list(images[0])[0]
context_frames = self.hparams.video_num_input_frames
# Predicted images and rewards.
gen_rewards, gen_images, latent_means, latent_stds = [], [], [], []
# LSTM states.
lstm_state = [None] * 7
# Create scheduled sampling function
ss_func = self.get_scheduled_sample_func(batch_size)
pred_image = tf.zeros_like(images[0])
pred_reward = tf.zeros_like(rewards[0])
latent = None
for timestep, image, action, reward in zip(
range(len(images)-1), images[:-1], actions[:-1], rewards[:-1]):
# Scheduled Sampling
done_warm_start = timestep > context_frames - 1
groundtruth_items = [image, reward]
generated_items = [pred_image, pred_reward]
input_image, input_reward = self.get_scheduled_sample_inputs(
done_warm_start, groundtruth_items, generated_items, ss_func)
# Latent
# TODO(mbz): should we use input_image iunstead of image?
latent_images = tf.stack([image, images[timestep+1]], axis=0)
latent_mean, latent_std = self.construct_latent_tower(
latent_images, time_axis=0)
latent = common_video.get_gaussian_tensor(latent_mean, latent_std)
# Prediction
pred_image, lstm_state, _ = self.construct_predictive_tower(
input_image, input_reward, action, lstm_state, latent)
if self.hparams.reward_prediction:
pred_reward = self.reward_prediction(
pred_image, input_reward, action, latent)
pred_reward = common_video.decode_to_shape(
pred_reward, common_layers.shape_list(input_reward), "reward_dec")
pred_reward = input_reward
gen_images = tf.stack(gen_images, axis=0)
gen_rewards = tf.stack(gen_rewards, axis=0)
return gen_images, gen_rewards, latent_means, latent_stds