This notebook shows how runtime performance of the `tfe.protocol.pond` protocol can be improved by splitting the computation into an offline and an online phase. It assumes some understanding of how the protocol works at a cryptographic layer, at the very least an understanding of the notion of triples, and is intended for intermediate to advanced users.

# Setup

In [None]:
import tensorflow as tf
import tf_encrypted as tfe

# Computation

As a first step in defining our computation we have to select which triple source we want to use for the Pond protocol. By default this is `OnlineTripleSource` but here we want to used the queued version instead. Below we leave the option to select either.

In [None]:
# triple_source = tfe.protocol.pond.OnlineTripleSource("server2")
# tfe.set_protocol(tfe.protocol.Pond("server0", "server1", triple_source=triple_source))

In [None]:
triple_source = tfe.protocol.pond.QueuedOnlineTripleSource("server0", "server1", "server2", capacity=10)
tfe.set_protocol(tfe.protocol.Pond("server0", "server1", triple_source=triple_source))

For our actual computation we are interested in two values, `y` and `w`, as defined next.

In [None]:
c = tfe.define_private_input("coefficients-provider", lambda: tf.constant([1,2,3,4,5,6,7,8,9,10], shape=[10, 1]))

x = tfe.define_private_input("data-provider", lambda: tf.fill([1, 10], 1))
y = tfe.matmul(x, c).reveal()

v = tfe.define_private_input("data-provider", lambda: tf.fill([1, 10], 2))
w = tfe.matmul(v, c).reveal()

# Offline and Online Execution

In [None]:
def print_triple_status():

    if not hasattr(triple_source, "queues"):
        return

    for queue in triple_source.queues:
        print("{:40}: {:>2} / {:>2}".format(
            queue.name,
            queue.size(),
            triple_source.capacity)
        )

With the queued triple source the computation is split into two phases:

- an *offline* phase where triples are generated and distributed by a trusted third-party

- an *online* phase where the actual values are computed by consuming triples previously generated, and which does not involve the trusted third-party

The most important thing it to always maintain a correspondance between these two phases, so that when we run an online computation we are consuming triples that where made to match for every node in the subgraph that defines it.

To get started let us first update our cache. To generate triples for this we simply ask the triple source, passing in the fetch we intend to run later. This in itself does not run anything, it simply figures out how to generate triples for the particular fetch.

In [None]:
print_triple_status()

Next we then run the new fetch to actually generate triples.

In [None]:
print_triple_status()

And finally we run our online computation, consuming what we just generated.

In [None]:
print_triple_status()

One way of maintaining the correspondance between the offline and online phase is to alternate between them as we have just done, first running the offline computation and then immediately running the online computation.

But we can also work with larger sequences as shown next, generating triples for several fetches before running any online computation. Be careful not to overdo this though, as TensorFlow will block when the queues fill up and exceeds their capacity.

In [None]:
y_triples = triple_source.generate_triples(y)
w_triples = triple_source.generate_triples(w)

In [None]:
sess.run(y_triples, tag='y_triples')
sess.run(w_triples, tag='w_triples')

print_triple_status()

In [None]:
res_y = sess.run(y, tag='y')
res_w = sess.run(w, tag='w')
print(res_y, res_w)

print_triple_status()

Instead of running a sequence of fetches as above, one might be tempted to use either `sess.run([y_triples, w_triples])` or `sess.run([y, w])`. This is fine as long as it does not introduce non-determinism into the computation, which could break the correspondance we have to maintain.

For instance, if we had used `tfe.cache(c)` instead of `tfe.cache(tfe.mask(c))` then there is an overlap in the computation of `y` and `w` that can cause different evaluation orders to break the correspondance and give wrong results. To force this to happen you can modify the computation earlier and change the evaluation order of `y` and `w` as done below: offline is for sequence `[y, w]` but online is for `[w, y]`.

In [None]:
# sess.run(y_triples)
# sess.run(w_triples)

# res_w = sess.run(w)
# res_y = sess.run(y)
# print(res_y, res_w)

However, we can safely streamline our process and generate triples for the next run while executing the current. The reason this works is because of the `tf.queue.FIFOQueue`s used by the triple source.

In [None]:
sess.run(y_triples)

res, _ = sess.run([y, y_triples], tag='streamlined')
print(res)

res, _ = sess.run([y, y_triples], tag='streamlined')
print(res)

res = sess.run(y)
print(res)