# Neural Network On PPU

>  如果你还没有阅读《Logistic Regression On PPU》，请先阅读该tutorial。

在《Logistic Regression On PPU》中，我们已经展示如何利用Secretflow/PPU将一个明文机器学习训练任务转化为隐私保护任务。

在这个tutorial中，我们将继续展示对于一个神经网络模型做类似的转换。

我们继续使用相同的数据集[Breast Cancer](https://archive.ics.uci.edu/ml/datasets/breast+cancer+wisconsin+(diagnostic))和相同的预处理函数。

第一步，我们依然先构建一个明文的NN模型。

## 明文NN模型

### 产生数据

这里用的函数和《Logistic Regression On PPU》是相同的。

In [1]:
import numpy as np
from sklearn.datasets import load_breast_cancer


def load_dataset(return_label=False) -> (np.ndarray, np.ndarray):
    features, label = load_breast_cancer(return_X_y=True)

    if return_label:
        return features[:, 15:], label
    else:
        return features[:, :15], None

from sklearn.preprocessing import StandardScaler

def transform(data):
    scaler = StandardScaler()
    return scaler.fit_transform(data)

### 模型训练

#### 模型定义

我们这里选用是一个[MLP](https://en.wikipedia.org/wiki/Multilayer_perceptron)模型。

In [2]:
from typing import Sequence
import flax.linen as nn

class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        return nn.Dense(self.features[-1])(x)

#### 模型训练

以下为模型训练函数，我们将会训练一个三层的MLP模型。

In [3]:
from typing import Sequence
import flax.linen as nn
import jax
import jax.numpy as jnp

def fit_auto_grad(x1,x2,y, n_epochs=10, n_batch=10, step_size=0.01):
    x = jax.numpy.concatenate((x1, x2), axis=1)
    n_iters=10
    xs = jnp.array_split(x, n_iters, axis=0)
    ys = jnp.array_split(y, n_iters, axis=0)
    features=[30, 8, 1]
    model = MLP(features)

    params = model.init(jax.random.PRNGKey(1),
                                      jax.numpy.ones((n_batch, features[0])))

    def loss_func(params, x, y):
        pred = model.apply(params, x)

        def mse(y, pred):
            def squared_error(y, y_pred):
                return jnp.multiply(y - y_pred, y - y_pred) / 2.0
            return jnp.mean(squared_error(y, pred))

        return mse(y, pred)

    def body_fun(_, state):
        params = state
        for (x, y) in zip(xs, ys):
            _, grads = jax.value_and_grad(loss_func)(params, x, y)
            params = jax.tree_multimap(lambda p, g: p - step_size * g,
                                        params, grads)
        return params

    params = jax.lax.fori_loop(0, n_epochs, body_fun, params)
    return params


#### 验证模型

我们将计算模型的loss。

In [4]:
def compute_loss(params, x1, x2, y):
    x= jax.numpy.concatenate((x1, x2), axis=1)
    mlp=MLP([30, 8, 1])
    pred = mlp.apply(params, x)

    def mse(y, pred):
        def squared_error(y, y_pred):
            return jnp.multiply(y - y_pred, y - y_pred) / 2.0

        return jnp.mean(squared_error(y, pred))

    return mse(y, pred)

### BUILD Together

以下是完整的模型训练过程

In [8]:
import jax

x1,_ = load_dataset(False)
x2, y = load_dataset(True)

x1 = transform(x1)
x2 = transform(x2)

params = fit_auto_grad(x1, x2, y)

print(compute_loss(params, x1, x2, y))


0.14214484


第二步，我们将以上训练任务转化为隐私保护任务。

## 隐私保护NN模型

In [14]:
import secretflow as sf

sf.init(['alice', 'bob'], num_cpus=8, log_to_driver=True)

alice, bob = sf.PYU('alice'), sf.PYU('bob')
ppu = sf.PPU(sf.utils.testing.cluster_def(['alice', 'bob']))

x1_private, x2_private, y_private = sf.to(alice, x1), sf.to(bob, x2), sf.to(bob, y)

x1_ppu = x1_private.to(ppu)
x2_ppu = x2_private.to(ppu)
y_ppu = y_private.to(ppu)

params_ppu = ppu(fit_auto_grad)(x1_ppu, x2_ppu, y_ppu)

[2m[36m(PPURuntime pid=111570)[0m I0303 17:03:40.176662 111570 external/com_github_brpc_brpc/src/brpc/server.cpp:1046] Server[ppu::link::internal::ReceiverServiceImpl] is serving on port=28788.
[2m[36m(PPURuntime pid=111570)[0m I0303 17:03:40.176756 111570 external/com_github_brpc_brpc/src/brpc/server.cpp:1049] Check out http://k69b13338.eu95sqa:28788 in web browser.
[2m[36m(PPURuntime pid=111568)[0m I0303 17:03:40.200518 111568 external/com_github_brpc_brpc/src/brpc/server.cpp:1046] Server[ppu::link::internal::ReceiverServiceImpl] is serving on port=53280.
[2m[36m(PPURuntime pid=111568)[0m I0303 17:03:40.200621 111568 external/com_github_brpc_brpc/src/brpc/server.cpp:1049] Check out http://k69b13338.eu95sqa:53280 in web browser.
[2m[36m(PPURuntime pid=111570)[0m I0303 17:03:40.277556 111980 external/com_github_brpc_brpc/src/brpc/socket.cpp:2202] Checking Socket{id=0 addr=127.0.0.1:53280} (0x7f8e8d9bdb80)
[2m[36m(PPURuntime pid=111570)[0m I0303 17:03:40.277735 112010 

[2m[36m(PPURuntime pid=111570)[0m [2022-03-03 17:03:40.176] [info] [context.cc:58] connecting to mesh, id=root, self=0
[2m[36m(PPURuntime pid=111570)[0m [2022-03-03 17:03:40.193] [info] [context.cc:83] try_connect to rank 1 not succeed, sleep_for 1000ms and retry.
[2m[36m(PPURuntime pid=111568)[0m [2022-03-03 17:03:40.200] [info] [context.cc:58] connecting to mesh, id=root, self=1
[2m[36m(PPURuntime pid=111570)[0m [2022-03-03 17:03:41.194] [info] [context.cc:111] connected to mesh, id=root, self=0
[2m[36m(PPURuntime pid=111568)[0m [2022-03-03 17:03:41.194] [info] [context.cc:111] connected to mesh, id=root, self=1


我们来检验一下MPC训练出来的params。

In [16]:
loss = compute_loss(sf.reveal(params_ppu),x1, x2,y)

print(loss)

0.14270556


以上为tutorial的全部内容，你可以发现你只需要很少的改动就可以将一个普通的NN模型训练任务转化为隐私保护版本的训练任务。

In [13]:
import ray

ray.shutdown()