In [1]:
import numpy as np
import json
from keras.models import Model
from keras.layers import Input
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, BatchNormalization, Average
from keras import backend as K
from collections import OrderedDict

Using TensorFlow backend.


In [2]:
def format_decimal(arr, places=6):
    return [round(x * 10**places) / 10**places for x in arr]

In [3]:
DATA = OrderedDict()

### graph 2

In [4]:
random_seed = 10002
data_in_shape = (8, 8, 2)

input_layer_0 = Input(shape=data_in_shape)
branch_0 = Conv2D(4, (3,3), activation='relu', padding='valid', strides=(1,1), data_format='channels_last', use_bias=True)(input_layer_0)

input_layer_1 = Input(shape=data_in_shape)
branch_1 = Conv2D(4, (3,3), activation='relu', padding='valid', strides=(1,1), data_format='channels_last', use_bias=True)(input_layer_1)

output_layer = Average()([branch_0, branch_1])
model = Model(inputs=[input_layer_0, input_layer_1], outputs=output_layer)

data_in = []
for i in range(2):
    np.random.seed(random_seed + i)
    data_in.append(np.expand_dims(2 * np.random.random(data_in_shape) - 1, axis=0))

# set weights to random (use seed for reproducibility)
weights = []
for i, w in enumerate(model.get_weights()):
    np.random.seed(random_seed + i)
    weights.append(2 * np.random.random(w.shape) - 1)
model.set_weights(weights)

result = model.predict(data_in)
data_out_shape = result[0].shape
data_in_formatted = [format_decimal(data_in[i].ravel().tolist()) for i in range(2)]
data_out_formatted = format_decimal(result[0].ravel().tolist())

DATA['graph_02'] = {
    'inputs': [{'data': data_in_formatted[i], 'shape': data_in_shape} for i in range(2)],
    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],
    'expected': {'data': data_out_formatted, 'shape': data_out_shape}
}

### export for Keras.js tests

In [5]:
import os

filename = '../../test/data/graph/02.json'
if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))
with open(filename, 'w') as f:
    json.dump(DATA, f)

In [6]:
print(json.dumps(DATA))

{"graph_02": {"inputs": [{"data": [-0.096224, 0.673607, -0.009657, 0.52201, -0.470713, 0.710414, -0.62915, 0.836049, 0.958218, 0.675545, 0.87071, -0.765923, 0.270454, -0.605797, -0.781502, -0.158478, 0.942449, -0.098396, 0.479279, -0.663835, -0.216723, -0.551944, 0.672006, 0.615437, 0.936846, 0.765668, -0.725233, -0.59916, -0.233731, 0.031473, 0.927414, 0.759562, 0.97423, 0.419831, 0.318504, -0.031154, 0.959492, -0.136413, 0.062065, -0.31904, -0.460297, -0.691724, 0.576339, 0.606164, 0.677099, -0.821884, 0.706775, 0.598279, -0.373949, -0.663068, 0.974704, -0.157164, -0.934793, 0.745087, -0.871081, -0.580079, -0.015164, -0.319471, -0.336323, 0.227711, 0.345044, 0.021435, 0.742563, 0.859598, -0.887057, -0.354838, 0.668705, -0.308794, 0.971958, -0.477421, 0.436958, 0.606519, -0.24108, 0.81307, -0.945765, -0.34327, 0.715052, -0.497423, 0.816045, 0.822065, 0.506868, -0.851311, 0.738795, 0.67809, -0.644936, -0.587803, -0.59148, -0.156544, 0.353301, 0.907141, -0.404002, 0.865169, 0.93593, -0.