In [1]:
import numpy as np
import json
from keras.models import Model
from keras.layers import Input
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPooling2D, AveragePooling2D
from keras.layers.normalization import BatchNormalization
from keras import backend as K
from collections import OrderedDict

Using TensorFlow backend.


In [2]:
def format_decimal(arr, places=6):
    return [round(x * 10**places) / 10**places for x in arr]

In [3]:
DATA = OrderedDict()

### pipeline 5

In [4]:
data_in_shape = (8, 8, 2)

conv_0 = Conv2D(4, (3,3), activation='relu', padding='valid', strides=(1,1), data_format='channels_last', use_bias=True)
pool_0 = MaxPooling2D(pool_size=(2,2), strides=None, padding='valid', data_format='channels_last')

input_layer = Input(shape=data_in_shape)
x = conv_0(input_layer)
output_layer = pool_0(x)
model = Model(inputs=input_layer, outputs=output_layer)

np.random.seed(6000)
data_in = 2 * np.random.random(data_in_shape) - 1

# set weights to random (use seed for reproducibility)
weights = []
for i, w in enumerate(model.get_weights()):
    np.random.seed(6000 + i)
    weights.append(2 * np.random.random(w.shape) - 1)
model.set_weights(weights)

result = model.predict(np.array([data_in]))
data_out_shape = result[0].shape
data_in_formatted = format_decimal(data_in.ravel().tolist())
data_out_formatted = format_decimal(result[0].ravel().tolist())

DATA['pipeline_05'] = {
    'input': {'data': data_in_formatted, 'shape': data_in_shape},
    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],
    'expected': {'data': data_out_formatted, 'shape': data_out_shape}
}

### export for Keras.js tests

In [5]:
import os

filename = '../../test/data/pipeline/05.json'
if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))
with open(filename, 'w') as f:
    json.dump(DATA, f)

In [6]:
print(json.dumps(DATA))

{"pipeline_05": {"input": {"data": [-0.337046, -0.76599, 0.43704, -0.590874, 0.523244, 0.08437, -0.46995, -0.1746, -0.787955, 0.479577, 0.35859, -0.194639, -0.042228, 0.439577, 0.330564, 0.176657, 0.032713, 0.965577, -0.145303, -0.428448, -0.989246, -0.657321, 0.964564, 0.539372, 0.031447, 0.174825, -0.019582, 0.356366, 0.039503, 0.291526, -0.869494, 0.479024, -0.113832, -0.671733, -0.276457, -0.562081, -0.38047, -0.894364, 0.767763, 0.607511, 0.083609, 0.202987, -0.654666, 0.054315, 0.290928, 0.742149, 0.557948, -0.011064, 0.056474, -0.440133, 0.533127, 0.757812, 0.911539, 0.940052, -0.177892, -0.017994, -0.280084, 0.527801, -0.047258, 0.605181, 0.609485, 0.078785, -0.768712, -0.06945, 0.570796, -0.518626, 0.855417, 0.330638, 0.613, 0.614581, -0.946572, -0.100723, -0.208673, 0.672064, -0.709129, -0.370401, -0.194798, -0.272388, 0.408158, 0.466825, 0.143578, 0.709094, 0.72072, -0.345373, 0.949003, 0.781196, -0.057882, 0.557938, -0.58276, -0.940891, -0.999619, -0.613315, 0.317712, -0.97