In [1]:
import numpy as np
import json
from keras.models import Model
from keras.layers import Input
from keras.layers.convolutional import Conv2D
from keras.layers.normalization import BatchNormalization
from keras import backend as K
from collections import OrderedDict

Using TensorFlow backend.


In [2]:
def format_decimal(arr, places=6):
    return [round(x * 10**places) / 10**places for x in arr]

In [3]:
DATA = OrderedDict()

### pipeline 3

In [4]:
data_in_shape = (8, 8, 2)

conv_0 = Conv2D(4, (3,3), activation='relu', padding='valid', strides=(1,1), data_format='channels_last', use_bias=True)
bn_0 = BatchNormalization(epsilon=1e-05, axis=-1, center=True, scale=True)

input_layer = Input(shape=data_in_shape)
x = conv_0(input_layer)
output_layer = bn_0(x)
model = Model(inputs=input_layer, outputs=output_layer)

np.random.seed(4000)
data_in = 2 * np.random.random(data_in_shape) - 1

# set weights to random (use seed for reproducibility)
weights = []
for i, w in enumerate(model.get_weights()):
    np.random.seed(4000 + i)
    if i == 5:
        # std should be positive
        weights.append(np.random.random(w.shape))
    else:
        weights.append(2 * np.random.random(w.shape) - 1)
model.set_weights(weights)

result = model.predict(np.array([data_in]))
data_out_shape = result[0].shape
data_in_formatted = format_decimal(data_in.ravel().tolist())
data_out_formatted = format_decimal(result[0].ravel().tolist())

DATA['pipeline_03'] = {
    'input': {'data': data_in_formatted, 'shape': data_in_shape},
    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],
    'expected': {'data': data_out_formatted, 'shape': data_out_shape}
}

### export for Keras.js tests

In [5]:
import os

filename = '../../test/data/pipeline/03.json'
if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))
with open(filename, 'w') as f:
    json.dump(DATA, f)

In [6]:
print(json.dumps(DATA))

{"pipeline_03": {"input": {"data": [0.317596, 0.688515, -0.688309, -0.48247, 0.387223, -0.718263, 0.281673, -0.106311, 0.576861, -0.083926, 0.631691, 0.92647, 0.579655, -0.024215, -0.805793, -0.842947, -0.955415, 0.656415, 0.44667, 0.633739, 0.701525, 0.917507, -0.185671, -0.105247, 0.303949, -0.07507, -0.662442, 0.747404, 0.794057, 0.708841, 0.661954, 0.774677, 0.661308, 0.866776, -0.127209, -0.43449, 0.499357, 0.267042, -0.892782, 0.015889, -0.723892, -0.23083, -0.911621, -0.082763, -0.876862, 0.194631, 0.465433, 0.467187, -0.71095, -0.07916, -0.192007, 0.170016, -0.208802, -0.700623, -0.949676, -0.965666, 0.883237, -0.651265, 0.897732, 0.923142, -0.535299, 0.268894, -0.493426, -0.570641, -0.785673, 0.872584, 0.632657, -0.570894, 0.481942, 0.089347, 0.2348, 0.67846, -0.788192, 0.033911, 0.768885, -0.526456, 0.596875, -0.256875, 0.797246, 0.628544, -0.036388, 0.193046, -0.159633, -0.103758, -0.502594, 0.909868, 0.597109, -0.83716, 0.107518, 0.96916, -0.936404, 0.25101, -0.257324, -0.4