From ed3e3ffc2c0e903feb1dd7faaadce32aeedd7fa9 Mon Sep 17 00:00:00 2001 From: "Zhicong Huang (Zico)" Date: Tue, 30 Aug 2022 15:47:36 +0800 Subject: [PATCH] Fix compatibility (#864) * fix shape concatenation for tf 1.13.2 * rm tf-big dep Co-authored-by: Zhicong (Zico) Huang --- setup.py | 2 +- tf_encrypted/protocol/aby3/aby3.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index af450797..5ee842d6 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ "tensorflow >=1.12.0, <2", "numpy >=1.14.0", "pyyaml >=5.1", - "tf-big ~=0.1.0", + # "tf-big ~=0.1.0", ], extras_require={ "tf": ["tensorflow>=1.12.0,<2"], diff --git a/tf_encrypted/protocol/aby3/aby3.py b/tf_encrypted/protocol/aby3/aby3.py index 54427474..e0ee91f5 100644 --- a/tf_encrypted/protocol/aby3/aby3.py +++ b/tf_encrypted/protocol/aby3/aby3.py @@ -921,7 +921,10 @@ def _share(self, secret: AbstractTensor, share_type: str, player=None): with tf.name_scope("share"): if share_type == ShareType.ARITHMETIC or share_type == ShareType.BOOLEAN: - randoms = secret.factory.sample_uniform([2] + secret.shape) + secret_shape = secret.shape + if isinstance(secret_shape, tf.TensorShape): + secret_shape = secret_shape.as_list() + randoms = secret.factory.sample_uniform([2] + secret_shape) share0 = randoms[0] share1 = randoms[1] if share_type == ShareType.ARITHMETIC: