Skip to content
Branch: master
Find file History

Private Predictions for MNIST

This example illustrates how TF Encrypted can be used to perform private predictions using a simple neural network on the MNIST data set. It also shows how to integrate with ordinary TensorFlow, seamlessly linking local computations with secure computations.

Our scenario is split into two phases.

In the first phase, a model owner trains a model locally and sends encryptions of the resulting weights to three compute servers as used by the default Pond secure computation protocol. The training is done using an ordinary TensorFlow computation and the encrypted weights are cached on the servers for repeated use.

In the second phase, a prediction client sends an encryption of its input to the servers, who perform a secure computation over the weights and input to arrive at an encrypted prediction, which is finally sent back to the client and decrypted. The client also uses ordinary TensorFlow computations to apply pre- and post-processing.

The goal of the example is to show that the computations done by the servers can be performed entirely on encrypted data, at no point being able to decrypt any of the values. For this reason we can see the weights as a private input from the model owner and the prediction input as a private input from the prediction client.


The code is structured around a ModelOwner and PredictionClient class.

ModelOwner builds a data pipeline for training data using the TensorFlow Dataset API and performs training using TensorFlow's built in components.

class ModelOwner:

    def provide_input(self) -> tf.Tensor:
        # training
        training_data = self._build_data_pipeline()
        weights = self._build_training_graph(training_data)
        return weights

PredictionClient likewise builds a data pipeline for prediction inputs but also a post-processing computation that applies an argmax on the decrypted result before printing to the screen.

class PredictionClient:

    def provide_input(self) -> tf.Tensor:
        """Prepare input data for prediction."""
        prediction_input, expected_result = self._build_data_pipeline().get_next()
        prediction_input = tf.reshape(
            prediction_input, shape=(self.BATCH_SIZE, ModelOwner.FLATTENED_DIM))
        return prediction_input

    def receive_output(self, logits: tf.Tensor) -> tf.Operation:
        prediction = tf.argmax(logits, axis=1)
        op = tf.print("Result", prediction, summarize=self.BATCH_SIZE)
        return op

Instances of these are then linked together in a secure computation performing a prediction, treating both the weights and the prediction input as private values.

  model_owner = ModelOwner(player_name="model-owner")
  prediction_client = PredictionClient(player_name="prediction-client")

  # get model weights from model owner
  params = model_owner.provide_weights()
  # get prediction input from client
  x = prediction_client.provide_input()

  with tfe.protocol.SecureNN():
    model = tfe.keras.Sequential()
    model.add(tfe.keras.layers.Dense(512, batch_input_shape=x.shape))

    logits = model(x)

  # send prediction output back to client
  prediction_op = prediction_client.receive_output(logits)

Finally, the computation is executed using a tfe.Session following the typical TensorFlow pattern

with tfe.Session() as sess:


Make sure to have the training and test data sets downloaded before running the example:

python3 examples/mnist/

which will place the converted files in the ./data subdirectory.

To then run locally use:

python3 examples/mnist/

or remotely using:

python3 examples/mnist/ config.json

See more details in the documentation.

You can’t perform that action at this time.