## Private Prediction Client

In [None]:
import numpy as np
import tensorflow as tf
import tf_encrypted as tfe

from tensorflow.keras.datasets import mnist

In [None]:
# input image dimensions
img_rows, img_cols = 28, 28

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

## Set up `tfe.serving.QueueClient`

In [None]:
config = tfe.RemoteConfig.load("/tmp/tfe.config")

tfe.set_config(config)
tfe.set_protocol(tfe.protocol.SecureNN())

In [None]:
input_shape = (1, 784)
output_shape = (1, 10)

In [None]:
client = tfe.serving.QueueClient(
    input_shape=input_shape,
    output_shape=output_shape)

In [None]:
sess = tfe.Session(config=config)

## Query Model

In [None]:
# User inputs
num_tests = 3
images, expected_labels = x_test[:num_tests], y_test[:num_tests]

In [None]:
for image, expected_label in zip(images, expected_labels):
    
    res = client.run(
        sess,
        image.reshape(1, 784))
    
    predicted_label = np.argmax(res)
    
    print("The image had label {} and was {} classified as {}".format(
        expected_label,
        "correctly" if expected_label == predicted_label else "wrongly",
        predicted_label))