# 安全聚合

安全聚合（secure aggregation）可以表述为：多个参与方各自拥有数据，在不泄露各自私有数据的前提下，合作完成聚合值（比如求和）的计算。

安全聚合是联邦学习中的一项重要概念，学术界已经有较多研究，隐语已经在水平联邦梯度/权重聚合、数据统计（比如数据探查、预处理）中使用了安全聚合。

下面将开始介绍隐语使用的安全聚合方案。


## 准备

初始化隐语

In [1]:
import secretflow as sf

sf.init(['alice', 'bob'], num_cpus=8, log_to_driver=True)

E0308 17:16:17.239683816   26319 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
E0308 17:16:17.251297248   26319 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
E0308 17:16:17.260081403   26319 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


准备一些待测试数据

In [2]:
import numpy as np

arr0, arr1 = np.random.rand(2, 3), np.random.rand(2, 3)
print('arr0:\n', arr0, '\narr1:\n', arr1)

print('Sum:\n', np.sum([arr0, arr1], axis=0))
print('Average:\n', np.average([arr0, arr1], axis=0))
print('Min:\n', np.min([arr0, arr1], axis=0))
print('Max:\n', np.max([arr0, arr1], axis=0))

arr0:
 [[0.64818372 0.33600055 0.43811926]
 [0.7835069  0.25554061 0.71970086]] 
arr1:
 [[0.11274985 0.99172087 0.56520836]
 [0.48992353 0.22364359 0.40473672]]
Sum:
 [[0.76093357 1.32772142 1.00332762]
 [1.27343043 0.47918421 1.12443758]]
Average:
 [[0.38046678 0.66386071 0.50166381]
 [0.63671522 0.2395921  0.56221879]]
Min:
 [[0.11274985 0.33600055 0.43811926]
 [0.48992353 0.22364359 0.40473672]]
Max:
 [[0.64818372 0.99172087 0.56520836]
 [0.7835069  0.25554061 0.71970086]]


构建参与方用于后续的演示。

In [3]:
alice, bob = sf.PYU('alice'), sf.PYU('bob')

## 聚合操作

隐语提供了多种```Aggregator```供用户选择，每种```Aggregator```都提供了求和(sum)/求平均(average)的功能。

### 基于PPU的安全聚合

[PPU](../development/ppu.md)是隐语中的一种安全设备，其底层原理为[MPC](https://en.wikipedia.org/wiki/Secure_multi-party_computation)。隐语实现了基于PPU的安全聚合，下面将展示如何使用。

In [4]:
# 创建一个PPU设备
ppu = sf.PPU(sf.utils.testing.cluster_def(['alice', 'bob']))

# 使用该PPU创建一个聚合实例。
ppu_aggr = sf.security.aggregation.PPUAggregator(ppu)

[2m[36m(PPURuntime pid=26652)[0m I0308 17:16:27.179066 26652 external/com_github_brpc_brpc/src/brpc/server.cpp:1046] Server[ppu::link::internal::ReceiverServiceImpl] is serving on port=33991.
[2m[36m(PPURuntime pid=26652)[0m I0308 17:16:27.179144 26652 external/com_github_brpc_brpc/src/brpc/server.cpp:1049] Check out http://sgx-stable:33991 in web browser.
[2m[36m(PPURuntime pid=26657)[0m I0308 17:16:27.105597 26657 external/com_github_brpc_brpc/src/brpc/server.cpp:1046] Server[ppu::link::internal::ReceiverServiceImpl] is serving on port=49969.
[2m[36m(PPURuntime pid=26657)[0m I0308 17:16:27.105670 26657 external/com_github_brpc_brpc/src/brpc/server.cpp:1049] Check out http://sgx-stable:49969 in web browser.
[2m[36m(PPURuntime pid=26657)[0m I0308 17:16:27.206267 27076 external/com_github_brpc_brpc/src/brpc/socket.cpp:2202] Checking Socket{id=0 addr=127.0.0.1:33991} (0x55aa8c859e00)
[2m[36m(PPURuntime pid=26657)[0m I0308 17:16:27.206395 27076 external/com_github_brpc_b

[2m[36m(PPURuntime pid=26652)[0m [2022-03-08 17:16:27.179] [info] [context.cc:58] connecting to mesh, id=root, self=1
[2m[36m(PPURuntime pid=26657)[0m [2022-03-08 17:16:27.105] [info] [context.cc:58] connecting to mesh, id=root, self=0
[2m[36m(PPURuntime pid=26657)[0m [2022-03-08 17:16:27.124] [info] [context.cc:83] try_connect to rank 1 not succeed, sleep_for 1000ms and retry.
[2m[36m(PPURuntime pid=26652)[0m [2022-03-08 17:16:28.124] [info] [context.cc:111] connected to mesh, id=root, self=1
[2m[36m(PPURuntime pid=26657)[0m [2022-03-08 17:16:28.124] [info] [context.cc:111] connected to mesh, id=root, self=0


In [5]:
# 模拟alice、bob分别持有数据
a = alice(lambda: arr0)()
b = bob(lambda: arr1)()

In [6]:
# 求和
ppu_aggr.sum([a, b], axis=0)



array([[0.7609336, 1.3277214, 1.0033276],
       [1.2734303, 0.4791842, 1.1244376]], dtype=float32)

[2m[36m(PPURuntime pid=26652)[0m 17:16:54 TRACE: [Profiling] PPU execution completed, input processing took 0.000570814s, execution took 7.9804e-05s, output processing took 1.3843e-05s, total time 0.000664461s.
[2m[36m(PPURuntime pid=26657)[0m 17:16:54 TRACE: [Profiling] PPU execution completed, input processing took 0.000588775s, execution took 7.787e-05s, output processing took 1.4256e-05s, total time 0.000680901s.


In [7]:
# 求均值
ppu_aggr.average([a, b], axis=0)

array([[0.38046676, 0.6638607 , 0.5016638 ],
       [0.6367152 , 0.23959209, 0.5622188 ]], dtype=float32)

[2m[36m(PPURuntime pid=26652)[0m 17:16:55 TRACE: [Profiling] PPU execution completed, input processing took 0.000130693s, execution took 0.000126421s, output processing took 1.0489e-05s, total time 0.000267603s.
[2m[36m(PPURuntime pid=26657)[0m 17:16:55 TRACE: [Profiling] PPU execution completed, input processing took 0.000125816s, execution took 0.000132241s, output processing took 9.444e-06s, total time 0.000267501s.


### Masking with One-Time Pads

`Masking with One-Time Pads`的思路为每个参与方和其他参与方协商秘密，然后使用秘密对其输入$x$进行隐藏，每个参与方输出:

$$ y_u = x_u + \sum_{u < v}s_{u,v} - \sum_{u > v}s_{u,v}\ mod\ R  $$

，聚合之后秘密被互相抵消从而得到正确的结果:

$$ \sum y = \sum x $$


比如参与方Alice、Bob、Carol各自拥有$x_1,x_2,x_3$，协商出秘密$s_{a,b}, s_{a,c}, s_{b,c}$，然后分别输出:  
$y_1 = x_1 + s_{a,b} + s_{a,c}$  
$y_2 = x_2 - s_{a,b} + s_{b,c}$  
$y_3 = x_3 - s_{a,c} - s_{b,c}$  
则容易得到  $$ y_1 + y_2 + y_3 = x_1 + s_{a,b} + s_{a,c} + x_2 - s_{a,b} + s_{b,c} + x_3 - s_{a,c} - s_{b,c} = x_1 + x_2 + x_3 $$

注意，`Masking with One-Time Pads`基于半诚实（semi-honest）假设，且不支持掉线。更多内容可以参考[Practical Secure Aggregation
for Privacy-Preserving Machine Learning](https://eprint.iacr.org/2017/281.pdf)

In [8]:
# 创建一个安全聚合实例: alice、bob作为参与方，其中alice负责执行聚合计算操作。
secure_aggr = sf.security.aggregation.SecureAggregator(device=alice, participants=[alice, bob])

In [9]:
# 求和
secure_aggr.sum([a, b], axis=0)

array([[0.7609335, 1.3277213, 1.0033275],
       [1.2734303, 0.4791841, 1.1244375]])

In [10]:
# 求平均
secure_aggr.average([a, b], axis=0)

array([[0.38046675, 0.66386065, 0.50166375],
       [0.63671515, 0.23959205, 0.56221875]])

### 明文聚合（不推荐在生产场景使用）

为了方便本地模拟，隐语还提供了明文聚合器。

In [11]:
# 创建一个明文聚合实例：由alice负责执行聚合。
plain_aggr = sf.security.aggregation.PlainAggregator(ppu)

In [12]:
# 求和
plain_aggr.sum([a, b], axis=0)

array([[0.7609336, 1.3277214, 1.0033276],
       [1.2734303, 0.4791842, 1.1244376]], dtype=float32)

[2m[36m(PPURuntime pid=26652)[0m 17:17:05 TRACE: [Profiling] PPU execution completed, input processing took 6.5205e-05s, execution took 4.9977e-05s, output processing took 7.034e-06s, total time 0.000122216s.
[2m[36m(PPURuntime pid=26657)[0m 17:17:05 TRACE: [Profiling] PPU execution completed, input processing took 8.8797e-05s, execution took 4.3523e-05s, output processing took 1.1556e-05s, total time 0.000143876s.


In [13]:
# 求平均
plain_aggr.average([a, b], axis=0)

array([[0.38046676, 0.6638607 , 0.5016638 ],
       [0.6367152 , 0.23959209, 0.5622188 ]], dtype=float32)

[2m[36m(PPURuntime pid=26652)[0m 17:17:06 TRACE: [Profiling] PPU execution completed, input processing took 0.000131015s, execution took 0.000141545s, output processing took 1.0408e-05s, total time 0.000282968s.
[2m[36m(PPURuntime pid=26657)[0m 17:17:06 TRACE: [Profiling] PPU execution completed, input processing took 0.000161021s, execution took 0.00016127s, output processing took 1.0446e-05s, total time 0.000332737s.


## 比较操作。

上面我们介绍了```Aggregator```，提供的主要是求和/平均等聚合操作。
除此之外，隐语还提供了多种```Comparator```，提供诸如最大(max)/最小(min)操作。
比如在数据水平切分场景，可以通过安全比较来得到全局值而不需要暴露参与方的私有信息。


### 基于PPU的安全比较

隐语实现了基于PPU的安全聚合，下面将展示如何使用。

In [16]:
# 创建一个安全聚合实例。
ppu_com = sf.security.compare.PPUComparator(ppu)

In [17]:
# 求最小值
sf.reveal(ppu_com.min([a, b], axis=0))

array([[0.11274984, 0.33600056, 0.43811923],
       [0.48992354, 0.22364359, 0.4047367 ]], dtype=float32)

[2m[36m(PPURuntime pid=26652)[0m 17:17:18 TRACE: [Profiling] PPU execution completed, input processing took 6.5205e-05s, execution took 0.001691144s, output processing took 2.2081e-05s, total time 0.00177843s.
[2m[36m(PPURuntime pid=26657)[0m 17:17:18 TRACE: [Profiling] PPU execution completed, input processing took 7.9183e-05s, execution took 0.00170288s, output processing took 2.0364e-05s, total time 0.001802427s.


In [18]:
# 求最大值
ppu_com.max([a, b], axis=0)

array([[0.6481837 , 0.99172086, 0.5652083 ],
       [0.7835069 , 0.2555406 , 0.7197009 ]], dtype=float32)

[2m[36m(PPURuntime pid=26652)[0m 17:17:19 TRACE: [Profiling] PPU execution completed, input processing took 8.4679e-05s, execution took 0.001681185s, output processing took 7.882e-06s, total time 0.001773746s.
[2m[36m(PPURuntime pid=26657)[0m 17:17:19 TRACE: [Profiling] PPU execution completed, input processing took 0.000116569s, execution took 0.001736532s, output processing took 9.873e-06s, total time 0.001862974s.


### 明文比较（不推荐生产使用）

为了方便本地模拟，隐语还提供了明文比较。

In [19]:
# 创建一个明文比较对象：由alice负责执行比较。
plain_com = sf.security.compare.PlainComparator(alice)

In [20]:
# 求最小值
plain_com.min([a, b], axis=0)

array([[0.11274985, 0.33600056, 0.43811926],
       [0.48992354, 0.22364359, 0.40473673]], dtype=float32)

[2m[36m(_run pid=26650)[0m 2022-03-08 17:17:23,840,840 DEBUG [dispatch.py:log_elapsed_time:184] Finished tracing + transforming prim_fun for jit in 0.0005772113800048828 sec
[2m[36m(_run pid=26650)[0m 2022-03-08 17:17:23,840,840 DEBUG [xla_bridge.py:_init_backend:259] Initializing backend 'interpreter'
[2m[36m(_run pid=26650)[0m 2022-03-08 17:17:23,842,842 DEBUG [xla_bridge.py:_init_backend:271] Backend 'interpreter' initialized
[2m[36m(_run pid=26650)[0m 2022-03-08 17:17:23,842,842 DEBUG [xla_bridge.py:_init_backend:259] Initializing backend 'cpu'
[2m[36m(_run pid=26650)[0m 2022-03-08 17:17:23,843,843 DEBUG [xla_bridge.py:_init_backend:271] Backend 'cpu' initialized
[2m[36m(_run pid=26650)[0m 2022-03-08 17:17:23,843,843 DEBUG [xla_bridge.py:_init_backend:259] Initializing backend 'tpu_driver'
[2m[36m(_run pid=26650)[0m 2022-03-08 17:17:23,843,843 INFO [xla_bridge.py:backends:244] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in regist

In [21]:
# 求最大值
plain_com.max([a, b], axis=0)

array([[0.6481837 , 0.99172086, 0.5652084 ],
       [0.7835069 , 0.2555406 , 0.7197009 ]], dtype=float32)

[2m[36m(_run pid=26650)[0m 2022-03-08 17:17:24,583,583 DEBUG [dispatch.py:log_elapsed_time:184] Finished tracing + transforming _reduce_max for jit in 0.0007154941558837891 sec
[2m[36m(_run pid=26650)[0m 2022-03-08 17:17:24,583,583 DEBUG [dispatch.py:lower_xla_callable:229] Compiling _reduce_max (139887474748032 for args (ShapedArray(float32[2,2,3]),).
[2m[36m(_run pid=26650)[0m 2022-03-08 17:17:24,593,593 DEBUG [dispatch.py:log_elapsed_time:184] Finished XLA compilation of _reduce_max in 0.008640050888061523 sec


## 总结

本篇介绍了隐语的安全聚合方案，隐语提供了多种安全聚合方案，用户可以根据自己的需求施行不同的安全策略。
对于明文聚合方案，建议在生产环境谨慎使用。