In [1]:
import keras
import keras.backend as K
import numpy as np
from prednet import PredNet

Using TensorFlow backend.


In [2]:

# Model parameters
n_channels, im_height, im_width = (3, 128, 160)
input_shape = (n_channels, im_height, im_width) if K.image_data_format() == 'channels_first' else (im_height, im_width, n_channels)
stack_sizes = (n_channels, 48, 96, 192)
R_stack_sizes = stack_sizes
A_filt_sizes = (3, 3, 3)
Ahat_filt_sizes = (3, 3, 3, 3)
R_filt_sizes = (3, 3, 3, 3)
layer_loss_weights = np.array([1., 0., 0., 0.])  # weighting for each layer in final loss; "L_0" model:  [1, 0, 0, 0], "L_all": [1, 0.1, 0.1, 0.1]
layer_loss_weights = np.expand_dims(layer_loss_weights, 1)
nt = 10  # number of timesteps used for sequences in training
time_loss_weights = 1./ (nt - 1) * np.ones((nt,1))  # equally weight all timesteps except the first
time_loss_weights[0] = 0

In [3]:
prednet = PredNet(stack_sizes, R_stack_sizes,
                  A_filt_sizes, Ahat_filt_sizes, R_filt_sizes,
                  output_mode='error', return_sequences=True)

inputs = keras.layers.Input(shape=(nt,) + input_shape)
errors = prednet(inputs)  # errors will be (batch_size, nt, nb_layers)
errors_by_time = keras.layers.wrappers.TimeDistributed(keras.layers.Dense(1, trainable=False), 
                                                       weights=[layer_loss_weights, np.zeros(1)], trainable=False)(errors)  # calculate weighted error by layer
errors_by_time = keras.layers.Flatten()(errors_by_time)  # will be (batch_size, nt)
final_errors = keras.layers.Dense(1, weights=[time_loss_weights, np.zeros(1)], trainable=False)(errors_by_time)  # weight errors by time

# model
model = keras.models.Model(inputs=inputs, outputs=final_errors)
model.compile(loss='mean_absolute_error', optimizer='adam')
model.summary()


Instructions for updating:
Colocations handled automatically by placer.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 10, 128, 160, 3)   0         
_________________________________________________________________
pred_net_1 (PredNet)         (None, 10, 4)             6915948   
_________________________________________________________________
time_distributed_1 (TimeDist (None, 10, 1)             5         
_________________________________________________________________
flatten_1 (Flatten)          (None, 10)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 11        
Total params: 6,915,964
Trainable params: 6,915,948
Non-trainable params: 16
_________________________________________________________________


In [6]:
model.input

<tf.Tensor 'input_1:0' shape=(?, 10, 128, 160, 3) dtype=float32>

In [7]:
model.output

<tf.Tensor 'dense_2/BiasAdd:0' shape=(?, 1) dtype=float32>