Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mortendahl committed Jul 11, 2019
1 parent ba89373 commit 700d973
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 21 deletions.
2 changes: 1 addition & 1 deletion examples/federated-learning/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def update_model(self, *grads):
print_loss = tf.print("loss", loss)
print_expected = tf.print("expect", y, summarize=50)
print_result = tf.print("result", y_hat, summarize=50)
return print_loss, print_expected, print_result
return tf.group(print_loss, print_expected, print_result)


class DataOwner:
Expand Down
1 change: 0 additions & 1 deletion examples/logistic/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Provide classes to perform private training and private prediction with
logistic regression"""
import tensorflow as tf
# pylint: disable=redefined-outer-name
import tf_encrypted as tfe


Expand Down
2 changes: 1 addition & 1 deletion examples/logistic/prediction_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
result_receiver = prediction_client_0

x_0 = prediction_client_0.provide_input()
x_1 = prediction_client_0.provide_input()
x_1 = prediction_client_1.provide_input()
x = tfe.concat([x_0, x_1], axis=1)

y = model.forward(x)
Expand Down
13 changes: 9 additions & 4 deletions examples/simple-average/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,26 @@
tfe.set_config(config)
tfe.set_protocol(tfe.protocol.Pond())

@tfe.local_computation
@tfe.local_computation(name_scope='provide_input')
def provide_input() -> tf.Tensor:
# pick random tensor to be averaged
return tf.random_normal(shape=(10,))

@tfe.local_computation('result-receiver')
@tfe.local_computation('result-receiver', name_scope='receive_output')
def receive_output(average: tf.Tensor) -> tf.Operation:
# simply print average
return tf.print("Average:", average)


if __name__ == '__main__':
# get input from inputters as private values
inputs = [provide_input(player_name="inputter-{}".format(i)) # pylint: disable=unexpected-keyword-arg
for i in range(5)]
inputs = [
provide_input(player_name='inputter-0'), # pylint: disable=unexpected-keyword-arg
provide_input(player_name='inputter-1'), # pylint: disable=unexpected-keyword-arg
provide_input(player_name='inputter-2'), # pylint: disable=unexpected-keyword-arg
provide_input(player_name='inputter-3'), # pylint: disable=unexpected-keyword-arg
provide_input(player_name='inputter-4'), # pylint: disable=unexpected-keyword-arg
]

# sum all inputs and divide by count
result = tfe.add_n(inputs) / len(inputs)
Expand Down
27 changes: 13 additions & 14 deletions tf_encrypted/protocol/pond/pond.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,8 @@ def helper(v: tf.Tensor) -> "PondPublicTensor":
def local_computation(
self,
player_name=None,
**kwargs):
**kwargs
):
"""Annotate a function `compute_func` for local computation.
This decorator can be used to pin a function's code to a specific player's
Expand Down Expand Up @@ -558,7 +559,7 @@ def define_local_computation(
computation_fn,
arguments=None,
apply_scaling=True,
name=None,
name_scope=None,
masked=False,
factory=None,
):
Expand All @@ -568,7 +569,7 @@ def define_local_computation(
:param player: Who performs the computation and gets to see the values in
plaintext.
:param apply_scaling: Whether or not to scale the outputs.
:param name: Optional name to give to this node in the graph.
:param name_scope: Optional name to give to this node in the graph.
:param masked: Whether or not to produce masked outputs.
:param factory: Backing tensor type to use for outputs.
""" # noqa:E501
Expand All @@ -582,7 +583,7 @@ def define_local_computation(
def share_output(v: tf.Tensor):
assert v.shape.is_fully_defined(), ("Shape of return value '{}' on '{}' "
"not fully defined").format(
name if name else "",
name_scope if name_scope else "",
player.name,
)

Expand Down Expand Up @@ -624,7 +625,7 @@ def reconstruct_input(x):
raise TypeError(("Don't know how to process input argument "
"of type {}").format(type(x)))

with tf.name_scope(name if name else "local-computation"):
with tf.name_scope(name_scope if name_scope else "local-computation"):

with tf.device(player.device_name):
if arguments is None:
Expand Down Expand Up @@ -655,7 +656,7 @@ def define_private_input(
player,
inputter_fn,
apply_scaling: bool = True,
name: Optional[str] = None,
name_scope: Optional[str] = None,
masked: bool = False,
factory: Optional[AbstractFactory] = None,
):
Expand All @@ -667,19 +668,17 @@ def define_private_input(
:param Union[str,Player] player: Which player owns this input.
:param bool apply_scaling: Whether or not to scale the value.
:param str name: What name to give to this node in the graph.
:param str name_scope: What name to give to this node in the graph.
:param bool masked: Whether or not to mask the input.
:param AbstractFactory factory: Which backing type to use for this input
(e.g. `int100` or `int64`).
"""
suffix = "-" + name if name else ""

return self.define_local_computation(
player=player,
computation_fn=inputter_fn,
arguments=[],
apply_scaling=apply_scaling,
name="private-input{}".format(suffix),
name_scope=name_scope if name_scope else "private-input",
masked=masked,
factory=factory,
)
Expand All @@ -689,7 +688,7 @@ def define_output(
player,
arguments,
outputter_fn,
name=None,
name_scope=None,
):
"""
Define an output for this graph.
Expand All @@ -699,15 +698,15 @@ def define_output(

def result_wrapper(*args):
op = outputter_fn(*args)
# wrap in tf.group to prevent sending back any tensors (which might hence
# be leaked)
# wrap in tf.group to prevent sending back any tensors
# (which might hence be leaked)
return tf.group(op)

return self.define_local_computation(
player=player,
computation_fn=result_wrapper,
arguments=arguments,
name="output{}".format("-" + name if name else ""),
name_scope=name_scope if name_scope else "output",
)

@property
Expand Down

0 comments on commit 700d973

Please sign in to comment.