Private Predictions for MNIST
This example illustrates how TF Encrypted can be used to perform private predictions using a simple neural network on the MNIST data set. It also shows how to integrate with ordinary TensorFlow, seamlessly linking local computations with secure computations.
Our scenario is split into two phases.
In the first phase, a model owner trains a model locally and sends encryptions of the resulting weights to three compute servers as used by the default Pond secure computation protocol. The training is done using an ordinary TensorFlow computation and the encrypted weights are cached on the servers for repeated use.
In the second phase, a prediction client sends an encryption of its input to the servers, who perform a secure computation over the weights and input to arrive at an encrypted prediction, which is finally sent back to the client and decrypted. The client also uses ordinary TensorFlow computations to apply pre- and post-processing.
The goal of the example is to show that the computations done by the servers can be performed entirely on encrypted data, at no point being able to decrypt any of the values. For this reason we can see the weights as a private input from the model owner and the prediction input as a private input from the prediction client.
The code is structured around a
ModelOwner builds a data pipeline for training data using the TensorFlow Dataset API and performs training using TensorFlow's built in components.
class ModelOwner: @tfe.local_computation def provide_input(self) -> tf.Tensor: # training training_data = self._build_data_pipeline() weights = self._build_training_graph(training_data) return weights
PredictionClient likewise builds a data pipeline for prediction inputs but also a post-processing computation that applies an argmax on the decrypted result before printing to the screen.
class PredictionClient: @tfe.local_computation def provide_input(self) -> tf.Tensor: """Prepare input data for prediction.""" prediction_input, expected_result = self._build_data_pipeline().get_next() prediction_input = tf.reshape( prediction_input, shape=(self.BATCH_SIZE, ModelOwner.FLATTENED_DIM)) return prediction_input @tfe.local_computation def receive_output(self, logits: tf.Tensor) -> tf.Operation: prediction = tf.argmax(logits, axis=1) op = tf.print("Result", prediction, summarize=self.BATCH_SIZE) return op
Instances of these are then linked together in a secure computation performing a prediction, treating both the weights and the prediction input as private values.
model_owner = ModelOwner(player_name="model-owner") prediction_client = PredictionClient(player_name="prediction-client") # get model weights from model owner params = model_owner.provide_weights() # get prediction input from client x = prediction_client.provide_input() with tfe.protocol.SecureNN(): model = tfe.keras.Sequential() model.add(tfe.keras.layers.Dense(512, batch_input_shape=x.shape)) model.add(tfe.keras.layers.Activation('relu')) model.add(tfe.keras.layers.Dense(10)) model.set_weights(params) logits = model(x) # send prediction output back to client prediction_op = prediction_client.receive_output(logits)
Finally, the computation is executed using a
tfe.Session following the typical TensorFlow pattern
with tfe.Session() as sess: sess.run(prediction_op)
Make sure to have the training and test data sets downloaded before running the example:
which will place the converted files in the
To then run locally use:
or remotely using:
python3 examples/mnist/run.py config.json
See more details in the documentation.