In [None]:
 import numpy as np
from tensorflow.keras.models import Sequential,Model,load_model
from tensorflow.keras.layers import Input, Activation, Flatten, Dense, Dropout,Reshape
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import SGD
import time

In [None]:
from tensorflow.keras.datasets import mnist

In [None]:
(img_train, label_train), (img_test, label_test) = mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
x_train = img_train.reshape(60000,28*28)
x_test = img_test.reshape(10000,28*28)

In [None]:
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

In [None]:
y_train = to_categorical(label_train)
y_test = to_categorical(label_test)

In [None]:
img_size = 28

In [None]:
model = Sequential()
img_input = Input(shape=(784,))

model.add(Dense(500,activation='relu', input_shape=img_input.shape[1:]))
model.add(Dense(784, activation='relu'))
model.add(Reshape(target_shape=(28,28,1)))
model.add(Conv2D(32, (3,3),activation='relu', padding='valid'))
model.add(MaxPooling2D(pool_size=(2,2), padding='valid'))
model.add(Dropout(0.25))

model.add(Conv2D(16, (3,3), activation='relu',padding='same'))
model.add(MaxPooling2D(pool_size=(2,2), padding='same'))
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(32, activation='relu'))
model.add(Dense(10, activation='softmax'))

In [None]:
sgdx = SGD(lr=0.01, decay =1e-6, momentum=0.9, nesterov=True)
ntrain = 5

In [None]:
s = time.time()
model.compile(optimizer=sgdx,
              loss='categorical_crossentropy',
              metrics=['categorical_accuracy'])
H = model.fit(x_train, y_train, epochs=ntrain)
print('time: ',time.time() - s)

In [None]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 500)               392500    
_________________________________________________________________
dense_1 (Dense)              (None, 784)               392784    
_________________________________________________________________
reshape (Reshape)            (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
_________________________________________________________________
dropout (Dropout)            (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 13, 13, 16)        4

##Server

In [None]:
s = time.time()
w0, b0 = model.layers[0].get_weights()
model_ = Sequential()
for layer in model.layers[1:]:
  model_.add(layer)
model_.build(input_shape = model.layers[0].output_shape)
w_reshape = w0.T
listw = []
for i in np.arange(w_reshape.shape[0]):
    w_plus = [np.random.rand() for i in range(w_reshape.shape[-1])]
    listw.append([w_reshape[i],w_plus])
w = np.array(listw)
k = np.array([[np.random.rand(),np.random.rand()],
              [np.random.rand(),np.random.rand()]])
kd = np.linalg.inv(k)
listKW = []
for i in np.arange(w.shape[0]):
    listKW.append(k.dot(w[i]))
kw = np.array(listKW)
print('\nTime: ',time.time() - s)


Time:  0.27098965644836426


In [None]:
def Relu(x):
    return np.maximum(x,0)

##Client

In [None]:
x_ = x_test[100:101]
x_.shape

(1, 784)

In [None]:
s = time.time()
listWXX = []
for xx in np.arange(x_.shape[0]):
  x_plus = [np.random.rand() for i in range(x_.shape[-1])]
  x = np.array([x_[xx],x_plus])
  listWX = []
  for i in np.arange(kw.shape[0]):
      listWX.append(kw[i].dot(x.T))
  
  listWXX.append(listWX)
wx = np.array(listWXX)
print('\nTime: ',time.time() - s)


Time:  0.0036079883575439453


In [None]:
wx.shape

(1, 500, 2, 2)

##Server
###Sau kh nhận được wx từ Client

In [None]:
s = time.time()
listKWX = []
for ix in np.arange(wx.shape[0]):
  listKWXX = []
  for i in np.arange(wx.shape[1]):
      w1 = kd.dot(wx[ix,i])
      listKWXX.append(w1[0][0])
  listKWX.append(listKWXX)
wx_encode = np.array(listKWX)
WX_relu = Relu(wx_encode + b0)
model_.predict(WX_relu)
print('\nTime: ',time.time() - s)


Time:  0.04667329788208008


In [None]:
s = time.time()
model.predict(x_)
print('\nTime: ',time.time() - s)


Time:  0.04452943801879883
