In [1]:
'''
Evaluate trained PredNet on KITTI sequences.
Calculates mean-squared error and plots predictions.
'''

import os
import numpy as np
from six.moves import cPickle
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib inline

from keras import backend as K
from keras.models import Model, model_from_json
from keras.layers import Input, Dense, Flatten, UpSampling3D

from prednet import PredNet
from data_utils import SequenceGenerator
from kitti_settings import *


n_plot = 40
batch_size = 10
nt = 10

weights_file = os.path.join(WEIGHTS_DIR, 'tensorflow_weights/prednet_kitti_weights.hdf5')
json_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_model.json')
test_file = os.path.join(DATA_DIR, 'X_test.hkl')
test_sources = os.path.join(DATA_DIR, 'sources_test.hkl')

# Load trained model
f = open(json_file, 'r')
json_string = f.read()
f.close()
train_model = model_from_json(json_string, custom_objects = {'PredNet': PredNet})
train_model.load_weights(weights_file)

Using TensorFlow backend.


Instructions for updating:
Colocations handled automatically by placer.


In [2]:
# Create testing model (to output predictions)
layer_config = train_model.layers[1].get_config()
#layer_config['output_mode'] = 'prediction'
data_format = layer_config['data_format'] if 'data_format' in layer_config else layer_config['dim_ordering']
test_prednet = PredNet(weights=train_model.layers[1].get_weights(), **layer_config)
input_shape = list(train_model.layers[0].batch_input_shape[1:])


In [3]:
input_shape

[10, 3, 128, 160]

In [4]:
inputs = Input(shape = (1,3,128,160))    # dim_1, channels, dim_2, dim_3
copied_inputs = UpSampling3D(size = (10,1,1), data_format="channels_last")(inputs)

In [5]:
predictions = test_prednet(copied_inputs)
test_model = Model(inputs=inputs, outputs=predictions)

In [6]:
test_model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 1, 3, 128, 160)    0         
_________________________________________________________________
up_sampling3d_1 (UpSampling3 (None, 10, 3, 128, 160)   0         
_________________________________________________________________
prednet_1 (PredNet)          (None, 10, 4)             6915948   
Total params: 6,915,948
Trainable params: 6,915,948
Non-trainable params: 0
_________________________________________________________________


In [None]:
Loss = model.outputs[0,0,0]