In [1]:
from tensorflow.keras.layers import Input, Conv2D, AveragePooling2D, Flatten, Softmax, Dense, Lambda, BatchNormalization
from tensorflow.keras import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam, SGD
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

In [2]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

In [3]:
# Convert y_train into one-hot format
temp = []
for i in range(len(y_train)):
    temp.append(to_categorical(y_train[i], num_classes=10))
y_train = np.array(temp)
# Convert y_test into one-hot format
temp = []
for i in range(len(y_test)):    
    temp.append(to_categorical(y_test[i], num_classes=10))
y_test = np.array(temp)

In [4]:
#reshaping
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)

In [5]:
inputs = Input(shape=(28,28,1))
out = Lambda(lambda x: x/100)(inputs)
out = Conv2D(4, 3, use_bias=False)(out)
out = BatchNormalization()(out)
out = Lambda(lambda x: x**2+x)(out)
out = AveragePooling2D()(out)
# out = Lambda(lambda x: x*4)(out)
out = Conv2D(8, 3, use_bias=False)(out)
out = BatchNormalization()(out)
out = Lambda(lambda x: x**2+x)(out)
out = AveragePooling2D()(out)
# out = Lambda(lambda x: x*4)(out)
out = Flatten()(out)
out = Dense(10, activation=None)(out)
out = Softmax()(out)
model = Model(inputs, out)

In [6]:
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
lambda (Lambda)              (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 26, 26, 4)         36        
_________________________________________________________________
batch_normalization (BatchNo (None, 26, 26, 4)         16        
_________________________________________________________________
lambda_1 (Lambda)            (None, 26, 26, 4)         0         
_________________________________________________________________
average_pooling2d (AveragePo (None, 13, 13, 4)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 11, 11, 8)         288   

In [7]:
model.compile(
    loss='categorical_crossentropy',
    optimizer=SGD(learning_rate=0.01, momentum=0.9),
    metrics=['acc']
    )

In [8]:
model.fit(X_train, y_train, epochs=15, batch_size=32, validation_data=(X_test, y_test))

Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


<tensorflow.python.keras.callbacks.History at 0x2889cf1c0>

In [9]:
X = X_test[0]
X.shape, X.min(), X.max()

((28, 28, 1), 0, 255)

In [10]:
model2 = Model(model.input, model.layers[-2].output)

In [11]:
y = model2.predict(X_test[[0]]) - model.layers[-2].weights[1].numpy()
y

array([[ -2.9271724,  -5.1743965,   4.3266096,   8.861395 , -10.59499  ,
         -6.1087027, -14.946421 ,  20.252651 ,   1.7481048,   6.408617 ]],
      dtype=float32)

In [12]:
print(len(model.weights))
for weights in model.weights:
    print(weights.shape)

12
(3, 3, 1, 4)
(4,)
(4,)
(4,)
(4,)
(3, 3, 4, 8)
(8,)
(8,)
(8,)
(8,)
(200, 10)
(10,)


In [13]:
gamma = model.layers[3].weights[0].numpy()
beta = model.layers[3].weights[1].numpy()
moving_mean = model.layers[3].weights[2].numpy()
moving_var = model.layers[3].weights[3].numpy()
epsilon = model.layers[3].epsilon

In [14]:
a1 = gamma/(moving_var+epsilon)**.5
b1 = beta-gamma*moving_mean/(moving_var+epsilon)**.5
a1, b1

(array([0.7930184 , 0.86564696, 0.80535203, 0.7487028 ], dtype=float32),
 array([-0.6174586,  2.336494 , -0.046653 , -1.3925927], dtype=float32))

In [15]:
gamma = model.layers[7].weights[0].numpy()
beta = model.layers[7].weights[1].numpy()
moving_mean = model.layers[7].weights[2].numpy()
moving_var = model.layers[7].weights[3].numpy()
epsilon = model.layers[7].epsilon

In [16]:
a2 = gamma/(moving_var+epsilon)**.5
b2 = beta-gamma*moving_mean/(moving_var+epsilon)**.5
a2, b2

(array([0.10420316, 0.1040957 , 0.09440491, 0.12571144, 0.12500723,
        0.11510107, 0.10298571, 0.11865856], dtype=float32),
 array([-0.14729536,  0.84251314, -0.10289041,  0.4494599 , -0.4984228 ,
        -0.4116003 ,  0.7156993 ,  0.08575855], dtype=float32))

In [17]:
in_json = {
    "in": X.astype(int).flatten().tolist(), # X is already 100 times to begin with
    "conv2d_1_weights": (model.layers[2].weights[0].numpy()*(10**2)).round().astype(int).flatten().tolist(),
    "conv2d_1_bias": (np.array([0]*4)*(10**2)**2).round().astype(int).flatten().tolist(),
    "bn_1_a": (a1*(10**2)).round().astype(int).flatten().tolist(),
    "bn_1_b": (b1*(10**2)**3).round().astype(int).flatten().tolist(),
    # poly layer would be (10**2)**3=10**6 times as well
    # average pooling will scale another 10**2 times
    "conv2d_2_weights": (model.layers[6].weights[0].numpy()*(10**2)).round().astype(int).flatten().tolist(),
    "conv2d_2_bias": (np.array([0]*8)*((10**2)**8)).round().astype(int).flatten().tolist(),
    "bn_2_a": (a2*(10**2)).round().astype(int).flatten().tolist(),
    "bn_2_b": (b2*(10**2)**9).round().astype(int).flatten().tolist(),
    # poly layer would be (10**2)**9=10**18 times as well
    # average pooling will scale another 10**2 times
    "dense_weights":(model.layers[11].weights[0].numpy()*(10**2)).round().astype(int).flatten().tolist(),
    "dense_bias": np.zeros(model.layers[11].weights[1].numpy().shape).tolist() # zero because we are not doing softmax in circom, just argmax
}

In [18]:
out_json = {
    "scale": 10**-40,
    "out": y.flatten().tolist(),
    "label": int(y.argmax())
}
out_json

{'scale': 1e-40,
 'out': [-2.9271724224090576,
  -5.174396514892578,
  4.3266096115112305,
  8.861394882202148,
  -10.594989776611328,
  -6.108702659606934,
  -14.946420669555664,
  20.25265121459961,
  1.7481048107147217,
  6.40861701965332],
 'label': 7}

In [19]:
import json

In [20]:
with open("mnist_latest_input.json", "w") as f:
    json.dump(in_json, f)

In [22]:
with open("mnist_latest_output.json", "w") as f:
    json.dump(out_json, f)