Skip to content

Commit

Permalink
Merge pull request #4 from Chen-Junbao/versa
Browse files Browse the repository at this point in the history
feature: verification
  • Loading branch information
Chen-Junbao committed Aug 9, 2022
2 parents a259d30 + 58ab803 commit 2984ff3
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 31 deletions.
55 changes: 39 additions & 16 deletions entities/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def handle(self) -> None:
class MaskingRequestHandler(socketserver.BaseRequestHandler):
U_2_num = 0
masked_gradients_list = []
verification_gradients_list = []
U_3 = []

def handle(self) -> None:
Expand All @@ -71,10 +72,12 @@ def handle(self) -> None:

self.U_3.append(msg[0])
self.masked_gradients_list.append(msg[1])
self.verification_gradients_list.append(msg[2])

received_num = len(self.U_3)

logging.info("[%d/%d] | received user %s's masked gradients", received_num, self.U_2_num, id)
logging.info("[%d/%d] | received user %s's masked gradients and verification gradients",
received_num, self.U_2_num, id)


class ConsistencyRequestHandler(socketserver.BaseRequestHandler):
Expand Down Expand Up @@ -193,16 +196,18 @@ def send(self, msg: bytes, host: str, port: int):

def unmask(self, shape: tuple) -> np.ndarray:
"""Unmasks gradients by reconstructing random vectors and private mask vectors.
Then, generates verification gradients by reconstructing random vectors and private mask vectors.
Args:
shape (tuple): the shape of the raw gradients.
Returns:
np.ndarray: the sum of the raw gradients.
Tuple[np.ndarray, np.ndarray]: the sum of the raw gradients and verification gradients.
"""

# reconstruct random vectors p_v_u
recon_random_vec_list = []
# reconstruct random vectors p_v_u_0 and p_u_v_1
recon_random_vec_0_list = []
recon_random_vec_1_list = []
for u in SecretShareRequestHandler.U_2:
if u not in MaskingRequestHandler.U_3:
# the user drops out, reconstruct its private keys and then generate the corresponding random vectors
Expand All @@ -211,26 +216,44 @@ def unmask(self, shape: tuple) -> np.ndarray:
shared_key = KA.agree(priv_key, SignatureRequestHandler.ka_pub_keys_map[v]["s_pk"])

random.seed(shared_key)
rs = np.random.RandomState(random.randint(0, 2**32 - 1))
s_u_v = random.randint(0, 2**32 - 1)

# expand s_u_v into two random vectors
rs = np.random.RandomState(s_u_v | 0)
p_u_v_0 = rs.random(shape)
rs = np.random.RandomState(s_u_v | 1)
p_u_v_1 = rs.random(shape)

if int(u) > int(v):
recon_random_vec_list.append(rs.random(shape))
recon_random_vec_0_list.append(p_u_v_0)
recon_random_vec_1_list.append(p_u_v_1)
else:
recon_random_vec_list.append(-rs.random(shape))
recon_random_vec_0_list.append(-p_u_v_0)
recon_random_vec_1_list.append(-p_u_v_1)

# reconstruct private mask vectors p_u
recon_priv_vec_list = []
# reconstruct private mask vectors p_u_0 and p_u_1
recon_priv_vec_0_list = []
recon_priv_vec_1_list = []
for u in MaskingRequestHandler.U_3:
random_seed = SS.recon(UnmaskingRequestHandler.random_seed_shares_map[u])
rs = np.random.RandomState(random_seed)
priv_mask_vec = rs.random(shape)
rs = np.random.RandomState(random_seed | 0)
priv_mask_vec_0 = rs.random(shape)
rs = np.random.RandomState(random_seed | 1)
priv_mask_vec_1 = rs.random(shape)

recon_priv_vec_list.append(priv_mask_vec)
recon_priv_vec_0_list.append(priv_mask_vec_0)
recon_priv_vec_1_list.append(priv_mask_vec_1)

masked_gradients = np.sum(np.array(MaskingRequestHandler.masked_gradients_list), axis=0)
recon_priv_vec = np.sum(np.array(recon_priv_vec_list), axis=0)
recon_random_vec = np.sum(np.array(recon_random_vec_list), axis=0)
recon_priv_vec_0 = np.sum(np.array(recon_priv_vec_0_list), axis=0)
recon_random_vec_0 = np.sum(np.array(recon_random_vec_0_list), axis=0)

output = masked_gradients - recon_priv_vec_0 + recon_random_vec_0

verification_gradients = np.sum(np.array(MaskingRequestHandler.verification_gradients_list), axis=0)
recon_priv_vec_1 = np.sum(np.array(recon_priv_vec_1_list), axis=0)
recon_random_vec_1 = np.sum(np.array(recon_random_vec_1_list), axis=0)

output = masked_gradients - recon_priv_vec + recon_random_vec
verification = verification_gradients - recon_priv_vec_1 + recon_random_vec_1

return output
return output, verification
50 changes: 39 additions & 11 deletions entities/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def listen_ciphertexts(self):
sock.close()

def mask_gradients(self, gradients: np.ndarray, host: str, port: int):
"""Masks user's own gradients and sends them to the server.
"""Masks user's own gradients and generates corresponding verification gradients. Then, sends them to the server.
Args:
gradients (np.ndarray): user's raw gradients.
Expand All @@ -168,12 +168,16 @@ def mask_gradients(self, gradients: np.ndarray, host: str, port: int):

U_2 = list(self.ciphertexts.keys())

# generate user's own private mask vector p_u
rs = np.random.RandomState(self.__random_seed)
priv_mask_vec = rs.random(gradients.shape)
# generate user's own private mask vector p_u_0 and p_u_1
rs = np.random.RandomState(self.__random_seed | 0)
priv_mask_vec_0 = rs.random(gradients.shape)
rs = np.random.RandomState(self.__random_seed | 1)
priv_mask_vec_1 = rs.random(gradients.shape)

# generate random vectors p_u_v for each user
random_vec_list = []
# generate random vectors p_u_v_0 and p_u_v_1 for each user
random_vec_0_list = []
random_vec_1_list = []
alpha = 0
for v in U_2:
if v == self.id:
continue
Expand All @@ -182,15 +186,34 @@ def mask_gradients(self, gradients: np.ndarray, host: str, port: int):
shared_key = KA.agree(self.__s_sk, v_s_pk)

random.seed(shared_key)
rs = np.random.RandomState(random.randint(0, 2**32 - 1))
s_u_v = random.randint(0, 2**32 - 1)
alpha = (alpha + s_u_v) % (2 ** 32)

# expand s_u_v into two random vectors
rs = np.random.RandomState(s_u_v | 0)
p_u_v_0 = rs.random(gradients.shape)
rs = np.random.RandomState(s_u_v | 1)
p_u_v_1 = rs.random(gradients.shape)
if int(self.id) > int(v):
random_vec_list.append(rs.random(gradients.shape))
random_vec_0_list.append(p_u_v_0)
random_vec_1_list.append(p_u_v_1)
else:
random_vec_list.append(-rs.random(gradients.shape))
random_vec_0_list.append(-p_u_v_0)
random_vec_1_list.append(-p_u_v_1)

masked_gradients = gradients + priv_mask_vec + np.sum(np.array(random_vec_list), axis=0)
# expand α into two random vectors
alpha = 10000
rs = np.random.RandomState(alpha | 0)
self.__a = rs.random(gradients.shape)
rs = np.random.RandomState(alpha | 1)
self.__b = rs.random(gradients.shape)

msg = pickle.dumps([self.id, masked_gradients])
verification_code = self.__a * gradients + self.__b

masked_gradients = gradients + priv_mask_vec_0 + np.sum(np.array(random_vec_0_list), axis=0)
verification_gradients = verification_code + priv_mask_vec_1 + np.sum(np.array(random_vec_1_list), axis=0)

msg = pickle.dumps([self.id, masked_gradients, verification_gradients])

# send the masked gradients to the server
self.send(msg, host, port)
Expand Down Expand Up @@ -261,3 +284,8 @@ def unmask_gradients(self, host: str, port: str):
msg = pickle.dumps([self.id, priv_key_shares_map, random_seed_shares_map])

self.send(msg, host, port)

def verify(self, output_gradients, verification_gradients, num_U_3):
gradients_prime = self.__a * output_gradients + num_U_3 * self.__b

return ((gradients_prime - verification_gradients) < np.full(output_gradients.shape, 1e-6)).all()
15 changes: 11 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def unmasking(shape: tuple) -> np.ndarray:
shape (tuple): the shape of the raw gradients.
Returns:
np.ndarray: the sum of the raw gradients.
Tuple[np.ndarray, np.ndarray]: the sum of the raw gradients and verification gradients.
"""

server = entities["server"]
Expand All @@ -303,9 +303,9 @@ def unmasking(shape: tuple) -> np.ndarray:
if len(UnmaskingRequestHandler.U_5) >= t:
logging.info("{} users have sent shares".format(len(UnmaskingRequestHandler.U_5)))

output = server.unmask(shape)
output, verification = server.unmask(shape)

return output
return output, verification

else:
# the number of the received messages is less than the threshold value for SecretSharing, abort
Expand Down Expand Up @@ -384,7 +384,7 @@ def unmasking(shape: tuple) -> np.ndarray:

print("{:=^80s}".format("Finish Consistency Check"))

output = unmasking(shape)
output, verification = unmasking(shape)
if output is None:
logging.error("insufficient shares received by the server!")

Expand All @@ -395,3 +395,10 @@ def unmasking(shape: tuple) -> np.ndarray:
assert ((np.sum(np.array(input_gradients), axis=0) - output) < np.full(shape, 1e-6)).all()

print("{:=^80s}".format("Finish Secure Aggregation"))

for u in U_3:
if not entities[u].verify(output, verification, len(U_3)):
logging.error("verification failed!")
sys.exit(1)

print("{:=^80s}".format("Finish Verification"))

0 comments on commit 2984ff3

Please sign in to comment.