# Develop Your First MPC-Application

> The following codes are demos only. It’s NOT for production due to system security concerns, please DO NOT use it directly in production.

This is an introductory secretflow tutorial that contains:

* Implement a simple algorithm and run it in plaintext as baseline.
* Use simulator to check the **precision loss** and try to fix it.
* Run elaborated emulations to give reports on both **efficiency and correctness**.

We **highly recommend** the reader to read [spu-quickstart](../tutorials/quick_start.ipynb) before continuing read this tutorial, which you can learn some basic usage of Device, DeviceObject and how to run program in secret.

## Part 0: Prepare the environment and dataset
1. Environment: To run this tutorial, you should have spu installed in your environment(if not, you can refer to [this](https://www.secretflow.org.cn/docs/spu/en/getting_started/install.html)).
2. Dataset: We use the breast cancer wisconsin dataset, which can be obtained from sklearn. And we just do simple minmax transform for preprocessing

In [1]:
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import MinMaxScaler
import pandas as pd

In [2]:
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
# normally, LR works only when the features have been normalized!
scalar = MinMaxScaler(feature_range=(-2, 2))
cols = X.columns
X = scalar.fit_transform(X)
X = pd.DataFrame(X, columns=cols)

In [3]:
X.head()

Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,...,worst radius,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension
0,0.08415,-1.909368,0.183954,-0.545069,0.375011,1.168149,0.812559,0.924453,0.745455,0.422072,...,0.483102,-1.433902,0.673241,-0.197208,0.404543,0.477166,0.274441,1.64811,0.39385,-0.324544
1,0.572578,-0.909706,0.463133,0.006363,-0.84048,-1.272928,-1.185567,-0.60497,-0.480808,-1.434709,...,0.427606,-0.785714,0.159271,-0.259143,-0.609787,-1.381747,-1.228115,0.556701,-1.065642,-1.108487
2,0.405982,-0.438958,0.382973,-0.202333,0.057236,-0.275934,-0.149953,0.542744,0.038384,-1.155013,...,0.225543,-0.559701,0.033767,-0.501966,-0.065641,-0.458499,-0.561022,1.340206,-0.385176,-1.146268
3,-1.159638,-0.556645,-1.065994,-1.588378,1.245283,1.245445,0.262418,0.091451,1.105051,2.0,...,-1.006759,-0.45629,-1.034613,-1.623968,1.66189,1.256047,0.194569,1.539519,2.0,1.094845
4,0.51957,-1.37369,0.523944,-0.042842,-0.278595,-0.608429,-0.14433,0.073559,-0.486869,-1.252738,...,0.078975,-1.504264,0.02779,-0.6337,-0.250545,-1.310339,-0.722045,0.233677,-1.369998,-1.429621


## Part 1: Implement algorithm in plaintext

[SGD](https://en.wikipedia.org/wiki/Stochastic_gradient_descent)(Stochastic Gradient Descent) is a simple but effective optimization algorithm, so in MPC settings, it's common to use it to optimize the model.

[LR](https://en.wikipedia.org/wiki/Logistic_regression)(Logistic Regression) is a widely used linear model especially in financial industry. So in this tutorial, as an example, we will implement LR with a modified SGD, called [policy-sgd](../development/policy_sgd_insight.rst), which can accelerate the speeds of training in most scenery.



Here, we just list some important equations used in policy-sgd:
- LR compute gradient with(`n` is batch_size):
$$ grad = \frac{1}{n} \sum_{i} (sigmoid(w^T x_i) - y_i) x_i $$
- Policy-sgd compute dk in first epoch with(`p` is number of features):
$$ d_k = \frac{1}{\sqrt{\sum_j^{p} grad_j^2} + \epsilon} $$
- Then, update weights with(`i` means i-th epoch, `k` means k-th iter):
$$ w_{i,k} = w_{i, k-1} -  d_k * \alpha *  grad $$

In this part, we first forget the MPC setting(data split, protocol...) and  implement the algorithm in plaintext. Secretflow recommends user to do this with [Jax](https://jax.readthedocs.io/en/latest/), which `jax.numpy` provides a familiar NumPy-style API for ease of adoption. If you are familiar with Numpy, you can go through [this blog](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) and gets some caveats and then write jax-code just like numpy-code.

In [4]:
# import some basic library
# use jnp just like np
import jax.numpy as jnp
import jax.lax

import numpy as np
from functools import partial

The original response function for LR is sigmoid function, which contains time-consuming ops like exp and division in MPC. So it's common to approximate sigmoid function with other MPC-friendly function. Here we give two method, i.e. first-order Taylor and square root approximation.

In [5]:
def sigmoid_t1(x, limit: bool = True):
    '''
    taylor series referenced from:
    https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/
    '''
    T0 = 1.0 / 2
    T1 = 1.0 / 4
    ret = T0 + x * T1
    if limit:
        return jnp.select([ret < 0, ret > 1], [0, 1], ret)
    else:
        return ret


def sigmoid_sr(x):
    """
    https://en.wikipedia.org/wiki/Sigmoid_function#Examples
    Square Root approximation functions:
    F(x) = 0.5 * ( x / ( 1 + x^2 )^0.5 ) + 0.5
    sigmoid_sr almost perfect fit to sigmoid if x out of range [-3,3]
    highly recommended use this appr as GDBT's default sigmoid method.
    """
    return 0.5 * (x / jnp.sqrt(1 + jnp.power(x, 2))) + 0.5


def sigmoid(x, method='t1'):
    if method == 't1':
        return sigmoid_t1(x)
    else:
        return sigmoid_sr(x)

policy-sgd needs scale learning rate in first epoch.

In [6]:
# Note: we leave a method param in this function for next part, in plaintext, we won't invoke low-level op in most conditions.
def compute_dk_func(x, eps=1e-6, method='norm'):
    # Same as Adam, need add small eps to avoid zero-division error
    if method == 'norm':
        return 1 / (jnp.linalg.norm(x) + eps)
    else:
        # invoke low-level rsqrt op by hand
        return jax.lax.rsqrt(jnp.sum(jnp.square(x)) + eps)

Then, we give a brief implementation of LR with policy-sgd, and have similar interface(but less) with sklearn.

**Note**: for simplicity, we will always fit intercept in LR model and omit regularization and other techniques. For full version of SSLR,  can refer to `SSRegression` in secretflow.

In [7]:
class SSLRSGDClassifier:
    def __init__(
        self,
        epochs: int,
        learning_rate: float,
        batch_size: int,
        sig_type: str = 't1',
        eps: float = 1e-6,  # eps is the small number for computing dk
        dk_method: str = 'norm',  # method to compute dk, default is use jnp.linalg.norm function
    ):
        # parameter check.
        assert epochs > 0, f"epochs should >0"
        assert learning_rate > 0, f"learning_rate should >0"
        assert batch_size > 0, f"batch_size should >0"
        assert sig_type in ['t1', 'sr'], f"sig_type should one of ['t1', 'sr']"
        assert eps > 0, f"eps should >0"
        assert dk_method in [
            'norm',
            'rsqrt',
        ], f"dk_method should one of ['norm', 'rsqrt']"

        self._epochs = epochs
        self._learning_rate = learning_rate
        self._batch_size = batch_size
        self._sig_type = sig_type
        self._eps = eps
        self._dk_method = dk_method

        self._weights = jnp.zeros(())

    def _update_weights(
        self,
        x,  # array-like
        y,  # array-like
        w,  # array-like
        total_batch: int,
        batch_size: int,
        dk_arr,  # array-like
    ):
        num_feat = x.shape[1]
        assert w.shape[0] == num_feat + 1, "w shape is mismatch to x"
        assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
        w = w.reshape((w.shape[0], 1))

        compute_dk = False
        if dk_arr is None:
            compute_dk = True
            dk_arr = []

        for idx in range(total_batch):
            begin = idx * batch_size
            end = min((idx + 1) * batch_size, x.shape[0])
            rows = end - begin
            # padding one col for bias in w
            x_slice = jnp.concatenate((x[begin:end, :], jnp.ones((rows, 1))), axis=1)
            y_slice = y[begin:end, :]

            pred = jnp.matmul(x_slice, w)
            pred = sigmoid(pred, method=self._sig_type)

            err = pred - y_slice
            grad = jnp.matmul(jnp.transpose(x_slice), err) / rows

            if compute_dk:
                dk = compute_dk_func(grad, self._eps, self._dk_method)
                dk_arr.append(dk)
            else:
                dk = dk_arr[idx]

            step = self._learning_rate * grad * dk

            w = w - step

        if compute_dk:
            dk_arr = jnp.array(dk_arr)

        return w, dk_arr

    def fit(self, x, y):
        """Fit LR with policy-sgd.

        Parameters
        ----------
        X : {array-like}, shape (n_samples, n_features)
            Training data.

        y : ndarray of shape (n_samples, 1)
            Target values.

        """
        assert len(x.shape) == 2, f"expect x to be 2 dimension array, got {x.shape}"
        assert len(y.shape) == 2, f"expect y to be 2 dimension array, got {y.shape}"

        num_sample = x.shape[0]
        num_feat = x.shape[1]
        batch_size = min(self._batch_size, num_sample)
        total_batch = (num_sample + batch_size - 1) // batch_size

        # always fit intercept
        weights = jnp.zeros((num_feat + 1, 1))
        dk_arr = None

        # do train
        for _ in range(self._epochs):
            weights, dk_arr = self._update_weights(
                x, y, weights, total_batch, batch_size, dk_arr
            )

        self._weights = weights
        self.dk_arr = dk_arr

        return

    def predict_proba(self, x):
        """Probability estimates.

        Parameters
        ----------
        X : {array-like}, shape (n_samples, n_features)
            Input data for prediction.

        Returns
        -------
        ndarray of shape (n_samples, n_classes)
            Returns the probability of the sample for each class in the model,
            where classes are ordered as they are in `self.classes_`.
        """
        num_feat = x.shape[1]
        w = self._weights
        assert w.shape[0] == num_feat + 1, f"w shape is mismatch to x={x.shape}"
        assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array"
        w.reshape((w.shape[0], 1))

        bias = w[-1, 0]
        w = jnp.resize(w, (num_feat, 1))

        pred = jnp.matmul(x, w) + bias
        pred = sigmoid(pred, method=self._sig_type)

        return pred

Now, let's try this algorithm in plaintext!

In [8]:
plain_model = SSLRSGDClassifier(
    epochs=3, learning_rate=0.1, batch_size=8, sig_type='t1', eps=1e-6, dk_method='norm'
)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [9]:
plain_model.fit(X.values, y.values.reshape(-1, 1))  # X, y should be two-dimension array

Things seem go well, try to predict the dataset and compute auc.

In [10]:
predict_prob = plain_model.predict_proba(X.values)

In [11]:
from sklearn.metrics import roc_auc_score

roc_auc_score(y.values, predict_prob)

0.9903083875059459

## Part 2: Run algorithm with simulator

Normally, you can just do something like [LR with spu](https://www.secretflow.org.cn/docs/secretflow/en/tutorial/lr_with_spu.html) to run your program within a secure context: move you dataset to PYU or SPU, run program with SPU you declare and reveal some information you need(`reveal` is a **very dangerous** op, and you should use it very carefully in real application).

However, we will see later that you may come across large **metric gap**(like auc in LR) between plaintext and secret. It will be a better choice that developer can run MPC program simpler with high flexibility to adjust hyper-parameters like the size of ring, fxp or underlying MPC protocol etc.

So in this part, we will show how to use simulator to run our algorithm just like running normal MPC program, and do minimum experiments to focus and verify the pitfall of the program.
To use simulator but not running program with SPU Device directly has two advantages:

1. **Fewer Code**: No need to deal with tons of `DeviceObject` and move data from PYU between SPU.
2. **Quicker Experiment**: No ray cluster connected, run experiments end-to-end.




In [12]:
import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu

Here, to simulate , we first define a simple simulator with CHEETAH protocol and 64 bits ring in 2pc settings. We will talk about 3pc later.

In [13]:
sim = spsim.Simulator.simple(2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM64)

In [14]:
def fit_and_predict(
    x,
    y,
    epochs=3,
    learning_rate=0.1,
    batch_size=8,
    sig_type='t1',
    eps=1e-6,
    dk_method='norm',
):
    model = SSLRSGDClassifier(
        epochs=epochs,
        learning_rate=learning_rate,
        batch_size=batch_size,
        sig_type=sig_type,
        eps=eps,
        dk_method=dk_method,
    )
    model.fit(x, y)
    return model.predict_proba(x)

In [15]:
result = spsim.sim_jax(sim, fit_and_predict)(
    X.values, y.values.reshape(-1, 1)
)  # X, y should be two-dimension array

[2024-02-01 09:24:44.108] [info] [cheetah_dot.cc:295] CheetahDot uses 3@2 modulus 8192 degree for 64 bit ring
[2024-02-01 09:24:44.114] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:24:44.129] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring
[2024-02-01 09:24:44.130] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring
[2024-02-01 09:24:44.131] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
[2024-02-01 09:24:44.683] [info] [cheetah_dot.cc:423] 1@8x31x1 => 8x31x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:24:44.692] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:24:44.707] [info] [cheetah_dot.cc:423] 1@8x31x1 => 8x31x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:24:44.717] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.176 MiB, Response

In [16]:
roc_auc_score(y, result)  # rather pool under cheetah protocol!

0.49056603773584906

Then, we try it in 3pc setting, i.e. use ABY3 protocol.

In [17]:
sim_aby = spsim.Simulator.simple(3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64)

In [18]:
result = spsim.sim_jax(sim_aby, fit_and_predict)(X.values, y.values.reshape(-1, 1))

In [19]:
roc_auc_score(y, result)

0.9707864277786586

In [20]:
# not very stable, if you run the fit procedure multiple times, you will sometimes get 0.97
result = spsim.sim_jax(sim_aby, fit_and_predict)(X.values, y.values.reshape(-1, 1))
roc_auc_score(y, result)

0.970865704772475

When the program runs in secret without any modification, the auc may drop dramatically after training 3 epochs(from 0.990 to 0.490 for cheetah)!

We will give some analysis and try to fix it first from application perspective and think deeper in MPC perspective.


### Application Perspective
Before we dive into this question, we can first summarize the differences of policy-sgd between naive-sgd are:

1. Using approximation function to compute sigmoid(default is t1).  

2. The scale of learning rate, which contains the computation of dk as defined in `compute_dk_func`.



Doing some simple math, we can notice that t1 approximation will force the pred to 0 when inner product is less than -2 and to 1 when inner product is large than 2. So when we compute gradient with:
$$ grad = \frac{1}{n} \sum_{i} (sigmoid(w^T x_i) - y_i) x_i $$
If coincidentally, we can get all elements of grad very near to 0(may have little error in MPC), then the `dk` computed in first epoch becomes very large, and may result in the failure of training. We can verify this by simply enlarge the `batch_size` to 64 which can decrease the probability of all-zero problem.

In [21]:
# use partial to fix batch_size=64
result = spsim.sim_jax(sim, partial(fit_and_predict, batch_size=64))(
    X.values, y.values.reshape(-1, 1)
)

[2024-02-01 09:24:54.880] [info] [cheetah_dot.cc:295] CheetahDot uses 3@2 modulus 8192 degree for 64 bit ring
[2024-02-01 09:24:54.887] [info] [cheetah_dot.cc:423] 1@31x64x1 => 31x64x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:24:54.897] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring
[2024-02-01 09:24:54.898] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring
[2024-02-01 09:24:54.922] [info] [cheetah_dot.cc:423] 1@64x31x1 => 64x31x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:24:54.934] [info] [cheetah_dot.cc:423] 1@31x64x1 => 31x64x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:24:54.952] [info] [cheetah_dot.cc:423] 1@64x31x1 => 64x31x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:24:54.964] [info] [cheetah_dot.cc:423] 1@31x64x1 => 31x64x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:24:54.980] [info] [cheetah_dot.cc:423] 1@64x31x

In [22]:
roc_auc_score(y, result)

0.9892315416732731

Restricting large batch_size is not an appropriate way, the key is to make the scale factor smaller, we can also fix the question by enlarging the `eps`, e.g. change `eps` from 1e-6 to 1e-2.


**Note**: `eps` in policy-sgd indeed has two affects, one is to prevent the zero-division error, the other is to restrict the maximum scale factor in warm-start phase(first epoch).

In [23]:
result = spsim.sim_jax(sim, partial(fit_and_predict, eps=1e-2))(
    X.values, y.values.reshape(-1, 1)
)

[2024-02-01 09:24:58.191] [info] [cheetah_dot.cc:295] CheetahDot uses 3@2 modulus 8192 degree for 64 bit ring
[2024-02-01 09:24:58.199] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:24:58.208] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring
[2024-02-01 09:24:58.208] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring
[2024-02-01 09:24:58.849] [info] [cheetah_dot.cc:423] 1@8x31x1 => 8x31x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:24:58.859] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:24:58.876] [info] [cheetah_dot.cc:423] 1@8x31x1 => 8x31x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:24:58.886] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:24:58.903] [info] [cheetah_dot.cc:423] 1@8x31x1 => 8x31x1

In [24]:
roc_auc_score(y, result)

0.9884255589028065

The above analyses are based on t1 sigmoid, which leads to 0 in grad. So we can switch the t1 approximation to other non-truncate but costly form(e.g. sr approximation).

In [25]:
result = spsim.sim_jax(sim, partial(fit_and_predict, sig_type='sr'))(
    X.values, y.values.reshape(-1, 1)
)

[2024-02-01 09:25:05.483] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring
[2024-02-01 09:25:05.483] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring
[2024-02-01 09:25:06.124] [info] [cheetah_dot.cc:295] CheetahDot uses 3@2 modulus 8192 degree for 64 bit ring
[2024-02-01 09:25:06.131] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:25:06.146] [info] [cheetah_dot.cc:423] 1@8x31x1 => 8x31x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:25:06.158] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:25:06.173] [info] [cheetah_dot.cc:423] 1@8x31x1 => 8x31x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:25:06.184] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:25:06.200] [info] [cheetah_dot.cc:423] 1@8x31x1 => 8x31x1

In [26]:
roc_auc_score(y, result)

0.992151577612177

So, if we just consider it on app layer, we can get three rules for fixing:

1. enlarge `batch_size`.

2. enlarge `eps`.

3. use non-truncate sigmoid approximation(e.g. sr approximation).


### MPC Perspective

In this part, we concentrate more on why huge error occurs. To achieve this goal, we will talk according to underlying protocol and use simulator to do some **experiments** to confirm our hypothesis. Readers can do the similar things when you develop your own secure application.

Before diving into the problem deeper, we highly recommend the reader to read: 

 1. [spu_inside](https://www.secretflow.org.cn/docs/spu/latest/en-US/tutorials/spu_inside#Tracing): gives some introductions how spu works inside for float-point operations.

 2. [pitfall](https://www.secretflow.org.cn/docs/spu/latest/en-US/development/fxp): spu implements math function(like `reciprocal`, `log` and so on) with approximation algorithm, so some precision issue will occur when inputs fall into some intervals. We list some known issue about this.

3. [protocols](https://www.secretflow.org.cn/docs/spu/latest/en-US/reference/mpc_status): list all protocols spu implements now. Generally speaking, for 2pc, it's safe to use cheetah, while for 3pc, ABY3 is the only choice.

First define a function just like `fit_and_predict` to get dk_arr.

In [27]:
def get_dk(
    x,
    y,
    epochs=3,
    learning_rate=0.1,
    batch_size=8,
    sig_type='t1',
    eps=1e-6,
    dk_method='norm',
):
    model = SSLRSGDClassifier(
        epochs=epochs,
        learning_rate=learning_rate,
        batch_size=batch_size,
        sig_type=sig_type,
        eps=eps,
        dk_method=dk_method,
    )
    model.fit(x, y)
    return model.dk_arr

#### 2PC: Cheetah Protocol

Recap:

1. [cheetah](https://eprint.iacr.org/2022/207) is a fast 2pc semi-honest protocol which uses FHE to accelerate the computation. But it will have 0-2 bits error when do `mul` or `dot`.

2. If 64-bits ring, about 18 bitwidth fixed-point number will be used. So the minimum positive float spu can represent is $\frac{1}{2^{18}}$.

We first check the output of dk_arr and try to find the caveat.

In [28]:
result = spsim.sim_jax(sim, get_dk)(X.values, y.values.reshape(-1, 1))

[2024-02-01 09:25:13.156] [info] [cheetah_dot.cc:295] CheetahDot uses 3@2 modulus 8192 degree for 64 bit ring
[2024-02-01 09:25:13.161] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:25:13.173] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring
[2024-02-01 09:25:13.173] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring
[2024-02-01 09:25:13.713] [info] [cheetah_dot.cc:423] 1@8x31x1 => 8x31x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:25:13.723] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:25:13.738] [info] [cheetah_dot.cc:423] 1@8x31x1 => 8x31x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:25:13.747] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.176 MiB, Response 0.122 MiB Pack 0 ms
[2024-02-01 09:25:13.761] [info] [cheetah_dot.cc:423] 1@8x31x1 => 8x31x1

Surprisingly, we get a very **small negative number** which makes that weight update wrong!(the opposite direction and large scale factor for sgd)

In [29]:
result[38]

0.0

However, if we use a bigger ring, then everything is ok.

In [30]:
# define a simulator with 128 rings
sim128 = spsim.Simulator.simple(
    2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM128
)

In [31]:
result = spsim.sim_jax(sim128, fit_and_predict)(X.values, y.values.reshape(-1, 1))

[2024-02-01 09:25:18.287] [info] [cheetah_dot.cc:295] CheetahDot uses 5@3 modulus 16384 degree for 128 bit ring
[2024-02-01 09:25:18.304] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.601 MiB, Response 0.369 MiB Pack 0 ms
[2024-02-01 09:25:18.327] [info] [cheetah_mul.cc:347] CheetahMul uses 7 modulus for 128 bit input over 128 bit ring
[2024-02-01 09:25:18.328] [info] [cheetah_mul.cc:347] CheetahMul uses 7 modulus for 128 bit input over 128 bit ring
[2024-02-01 09:25:18.641] [info] [cheetah_dot.cc:423] 1@8x31x1 => 8x31x1 Recv 0.601 MiB, Response 0.368 MiB Pack 0 ms
[2024-02-01 09:25:18.670] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.601 MiB, Response 0.369 MiB Pack 0 ms
[2024-02-01 09:25:18.704] [info] [cheetah_dot.cc:423] 1@8x31x1 => 8x31x1 Recv 0.601 MiB, Response 0.368 MiB Pack 0 ms
[2024-02-01 09:25:18.730] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.601 MiB, Response 0.369 MiB Pack 0 ms
[2024-02-01 09:25:18.764] [info] [cheetah_dot.cc:423] 1@8x31x1 => 

In [32]:
roc_auc_score(y, result)  # auc is just like plaintext

0.9903083875059457

In [33]:
result = spsim.sim_jax(sim128, get_dk)(X.values, y.values.reshape(-1, 1))[38]

[2024-02-01 09:25:34.239] [info] [cheetah_dot.cc:295] CheetahDot uses 5@3 modulus 16384 degree for 128 bit ring
[2024-02-01 09:25:34.255] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.601 MiB, Response 0.369 MiB Pack 0 ms
[2024-02-01 09:25:34.277] [info] [cheetah_mul.cc:347] CheetahMul uses 7 modulus for 128 bit input over 128 bit ring
[2024-02-01 09:25:34.277] [info] [cheetah_mul.cc:347] CheetahMul uses 7 modulus for 128 bit input over 128 bit ring
[2024-02-01 09:25:34.599] [info] [cheetah_dot.cc:423] 1@8x31x1 => 8x31x1 Recv 0.601 MiB, Response 0.368 MiB Pack 0 ms
[2024-02-01 09:25:34.626] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.601 MiB, Response 0.369 MiB Pack 0 ms
[2024-02-01 09:25:34.659] [info] [cheetah_dot.cc:423] 1@8x31x1 => 8x31x1 Recv 0.601 MiB, Response 0.368 MiB Pack 0 ms
[2024-02-01 09:25:34.685] [info] [cheetah_dot.cc:423] 1@31x8x1 => 31x8x1 Recv 0.601 MiB, Response 0.369 MiB Pack 0 ms
[2024-02-01 09:25:34.716] [info] [cheetah_dot.cc:423] 1@8x31x1 => 

In [34]:
result

136.70488

From the above outputs, we can guess if input is near $\frac{1}{2^{18}}$ and use cheetah protocol, when doing `square` and `sum`, the bit error may be significant and not negligible(`mul` and `dot` have 0-2 bit errors).

In [35]:
# Let's test this
def test_square_and_sum_when_x_small(x):
    return jnp.sum(jnp.square(x))

In [36]:
spsim.sim_jax(sim, test_square_and_sum_when_x_small)(np.array([1e-5] * 10))

[2024-02-01 09:25:39.236] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring
[2024-02-01 09:25:39.236] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring


array(-1.9073486e-05, dtype=float32)

In [37]:
def test_norm_when_x_small(x):
    return jnp.sqrt(jnp.sum(jnp.square(x)))

In [38]:
spsim.sim_jax(sim, test_norm_when_x_small)(
    np.array([1e-5] * 10)
)  # for small input, sqrt just output very small number!

[2024-02-01 09:25:39.330] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring
[2024-02-01 09:25:39.330] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring


array(0., dtype=float32)

So if `eps=1e-6`, the two norm of grad may be a negative number when each element of grad is near to $\frac{1}{2^{18}}$, then we get a very small negative dk.

This explains the claims we get from app layer:  
1. Enlarging `batch_size`: the probability of all elements of grad is zero becomes small.

2. Enlarging `eps`: force the denominator to be positive number.

#### 3PC: ABY3 Protocol

We still check the dk_arr first.

In [39]:
result = spsim.sim_jax(sim_aby, get_dk)(X.values, y.values.reshape(-1, 1))

emmmmmm, strange value occurs again, we find **0** in dk_arr!

In [40]:
result[66]

0.0

Comparing with small negative number, 0 is a mild error for our update procedure. It just does nothing in that iter, so the final auc may drop a little(from 0.99 to 0.97, users can test yourself that if you set eps to 1e-2, then the result will be very stable).

Likewise, We always check the computation of 2-norm.

In [41]:
spsim.sim_jax(sim_aby, test_norm_when_x_small)(
    np.array([1e-5] * 10)
)  # get 0 is acceptable

array(0., dtype=float32)

In [42]:
test_norm_when_x_small(np.array([1e-5] * 10))  # 2-norm in plaintext

Array(3.1622774e-05, dtype=float32)

Then, check the reciprocal op.

In [43]:
def test_reciprocal_when_x_small(x):
    return 1 / (x + 1e-6)

In [44]:
# get 0 when denominator very small, which is the caveat of reciprocal!
spsim.sim_jax(sim_aby, test_reciprocal_when_x_small)(np.array([0]))

array([0.], dtype=float32)

### Something More

Indeed, there are some other interesting things in SSLR. Here, due to length limitations, we just give some hints, and readers can do more simulations yourself!

#### Rsqrt v.s. Norm

We can recall that the `compute_dk_func` function defined in Part 1 contains a `method` arg, and we just ignore this arg before. Indeed, we can tell simulator to print more information like [spu_inside](https://www.secretflow.org.cn/docs/spu/latest/en-US/tutorials/spu_inside#Tracing) do: enable **hlo**(High Level Operations) trace and profile on. Then we can figure out which op has been invoked and its time cost.

Here, we list some advantages of using `jax.lax.rsqrt` rather than `jnp.linalg.norm`:

1. Fewer bytes and few send actions: which leads to smaller running time(See the following comments and notes for details).

2. More stable when given same `eps`: if we regard `f(x)` as `compute_dk_func` with `method=norm`, and `g(x)` with `method=rsqrt`, then the users can do simulation yourself, and find `f(x)` has higher relative error than `g(x)`.


In [45]:
# we define a cheetah config with pphlo trace and profile on
config_che = spu.RuntimeConfig(
    protocol=spu_pb2.ProtocolKind.CHEETAH,
    field=spu.FieldType.FM64,
    fxp_fraction_bits=18,
    enable_pphlo_trace=True,
    enable_pphlo_profile=True,
)
simulator_che = spsim.Simulator(2, config_che)

In [46]:
spsim.sim_jax(simulator_che, partial(compute_dk_func, method='norm'))(
    np.arange(1000) / 1000
)

[2024-02-01 09:25:41.508] [info] [pphlo_executor.cc:1324] PPHLO %0 = pphlo.constant dense<9.99999997E-7> : tensor<f32>
[2024-02-01 09:25:41.508] [info] [pphlo_executor.cc:1324] PPHLO %0 = pphlo.constant dense<9.99999997E-7> : tensor<f32>
[2024-02-01 09:25:41.508] [info] [pphlo_executor.cc:1324] PPHLO %1 = pphlo.constant dense<0.000000e+00> : tensor<f32>
[2024-02-01 09:25:41.508] [info] [pphlo_executor.cc:1324] PPHLO %1 = pphlo.constant dense<0.000000e+00> : tensor<f32>
[2024-02-01 09:25:41.508] [info] [pphlo_executor.cc:1324] PPHLO %2 = pphlo.constant dense<1.000000e+00> : tensor<f32>
[2024-02-01 09:25:41.508] [info] [pphlo_executor.cc:1324] PPHLO %2 = pphlo.constant dense<1.000000e+00> : tensor<f32>
[2024-02-01 09:25:41.508] [info] [pphlo_executor.cc:1324] PPHLO %3 = pphlo.multiply %arg0, %arg0 : tensor<1000x!pphlo.secret<f32>>
[2024-02-01 09:25:41.508] [info] [pphlo_executor.cc:1324] PPHLO %3 = pphlo.multiply %arg0, %arg0 : tensor<1000x!pphlo.secret<f32>>
[2024-02-01 09:25:41.516] [i

array(0.05480957, dtype=float32)

i.cc:194] - pphlo.free, executed 8 times, duration 3.407e-06s, send bytes 0
[2024-02-01 09:25:42.099] [info] [api.cc:194] - pphlo.multiply, executed 1 times, duration 0.05489884s, send bytes 792973
[2024-02-01 09:25:42.099] [info] [api.cc:194] - pphlo.reduce, executed 1 times, duration 0.000139253s, send bytes 0
[2024-02-01 09:25:42.099] [info] [api.cc:194] - pphlo.sqrt, executed 1 times, duration 0.262870205s, send bytes 267268
[2024-02-01 09:25:42.099] [info] [api.cc:204] Link details: total send bytes 1061558, send actions 2336
[2024-02-01 09:25:42.099] [info] [pphlo_executor.cc:1324] PPHLO pphlo.free %2 : tensor<f32>


If directly invoking rsqrt, you can find send actions have obvious drop!

In [47]:
# Note:
# 1. the total time cost by rsqrt may be even larger than norm, the reason of this is that CHEETAH use FHE, so the cost of multiply is very huge comparing to other ops.
# 2. time(rsqrt) = 0.005607 < time(sqrt+divide) = 0.00812 + 0.00283
spsim.sim_jax(simulator_che, partial(compute_dk_func, method='rsqrt'))(
    np.arange(1000) / 1000
)

[2024-02-01 09:25:42.179] [info] [pphlo_executor.cc:1324] PPHLO %0 = pphlo.constant dense<9.99999997E-7> : tensor<f32>
[2024-02-01 09:25:42.179] [info] [pphlo_executor.cc:1324] PPHLO %1 = pphlo.constant dense<0.000000e+00> : tensor<f32>
[2024-02-01 09:25:42.179] [info] [pphlo_executor.cc:1324] PPHLO %0 = pphlo.constant dense<9.99999997E-7> : tensor<f32>
[2024-02-01 09:25:42.179] [info] [pphlo_executor.cc:1324] PPHLO %2 = pphlo.multiply %arg0, %arg0 : tensor<1000x!pphlo.secret<f32>>
[2024-02-01 09:25:42.179] [info] [pphlo_executor.cc:1324] PPHLO %1 = pphlo.constant dense<0.000000e+00> : tensor<f32>
[2024-02-01 09:25:42.179] [info] [pphlo_executor.cc:1324] PPHLO %2 = pphlo.multiply %arg0, %arg0 : tensor<1000x!pphlo.secret<f32>>
[2024-02-01 09:25:42.187] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring
[2024-02-01 09:25:42.187] [info] [cheetah_mul.cc:347] CheetahMul uses 4 modulus for 64 bit input over 64 bit ring
[2024-02-01 09:25:42.238] [info] [pp

array(0.05480957, dtype=float32)

In [48]:
# Also, we can define an aby3 config with pphlo trace and profile on
config_aby = spu.RuntimeConfig(
    protocol=spu_pb2.ProtocolKind.ABY3,
    field=spu.FieldType.FM64,
    fxp_fraction_bits=18,
    enable_pphlo_trace=True,
    enable_pphlo_profile=True,
)
simulator_aby = spsim.Simulator(3, config_aby)

In [49]:
spsim.sim_jax(simulator_aby, partial(compute_dk_func, method='norm'))(
    np.arange(1000) / 1000
)

[2024-02-01 09:25:42.587] [info] [pphlo_executor.cc:1324] PPHLO %0 = pphlo.constant dense<9.99999997E-7> : tensor<f32>
[2024-02-01 09:25:42.587] [info] [pphlo_executor.cc:1324] PPHLO %0 = pphlo.constant dense<9.99999997E-7> : tensor<f32>
[2024-02-01 09:25:42.587] [info] [pphlo_executor.cc:1324] PPHLO %0 = pphlo.constant dense<9.99999997E-7> : tensor<f32>
[2024-02-01 09:25:42.587] [info] [pphlo_executor.cc:1324] PPHLO %1 = pphlo.constant dense<0.000000e+00> : tensor<f32>
[2024-02-01 09:25:42.587] [info] [pphlo_executor.cc:1324] PPHLO %1 = pphlo.constant dense<0.000000e+00> : tensor<f32>
[2024-02-01 09:25:42.587] [info] [pphlo_executor.cc:1324] PPHLO %1 = pphlo.constant dense<0.000000e+00> : tensor<f32>
[2024-02-01 09:25:42.587] [info] [pphlo_executor.cc:1324] PPHLO %2 = pphlo.constant dense<1.000000e+00> : tensor<f32>
[2024-02-01 09:25:42.587] [info] [pphlo_executor.cc:1324] PPHLO %2 = pphlo.constant dense<1.000000e+00> : tensor<f32>
[2024-02-01 09:25:42.587] [info] [pphlo_executor.cc:1

array(0.05480957, dtype=float32)

When using aby3, you can find both the send actions and send bytes drop large if using rsqrt!

In [50]:
# Note:
# 1. you can find total time of rsqrt will always smaller than norm
# 2. likewise, time(rsqrt) = 0.003096 < time(sqrt+divide) = 0.003601 + 0.002891

spsim.sim_jax(simulator_aby, partial(compute_dk_func, method='rsqrt'))(
    np.arange(1000) / 1000
)

[2024-02-01 09:25:42.611] [info] [pphlo_executor.cc:1324] PPHLO %0 = pphlo.constant dense<9.99999997E-7> : tensor<f32>
[2024-02-01 09:25:42.611] [info] [pphlo_executor.cc:1324] PPHLO %0 = pphlo.constant dense<9.99999997E-7> : tensor<f32>
[2024-02-01 09:25:42.611] [info] [pphlo_executor.cc:1324] PPHLO %1 = pphlo.constant dense<0.000000e+00> : tensor<f32>
[2024-02-01 09:25:42.611] [info] [pphlo_executor.cc:1324] PPHLO %1 = pphlo.constant dense<0.000000e+00> : tensor<f32>
[2024-02-01 09:25:42.611] [info] [pphlo_executor.cc:1324] PPHLO %2 = pphlo.multiply %arg0, %arg0 : tensor<1000x!pphlo.secret<f32>>
[2024-02-01 09:25:42.611] [info] [pphlo_executor.cc:1324] PPHLO %2 = pphlo.multiply %arg0, %arg0 : tensor<1000x!pphlo.secret<f32>>
[2024-02-01 09:25:42.611] [info] [pphlo_executor.cc:1324] PPHLO %0 = pphlo.constant dense<9.99999997E-7> : tensor<f32>
[2024-02-01 09:25:42.611] [info] [pphlo_executor.cc:1324] PPHLO %1 = pphlo.constant dense<0.000000e+00> : tensor<f32>
[2024-02-01 09:25:42.611] [

array(0.05480576, dtype=float32)

#### Computing Loss

Many ML frameworks will show validation loss during training procedure when using a validation dataset. It's straight to compute the loss in LR as follows:
$$ loss = -\frac{1}{N} \sum_{i=1}^N [y_i log(p_i) + (1-y_i) log(1-p_i)] \quad  (1) $$
But when you use t1 approximation for sigmoid, then you may come across $log(0)$ problem. Here, we list two potential recipes to alleviate it.

1. **Costly but accurate**: we plug in $p_i = \frac{1}{1+e^{-w^Tx_i}}$ to (1), then we can get:
$$ loss = -\frac{1}{N} \sum_{i=1}^N [y_i w^Tx_i - log(1+e^{w^Tx_i})] \quad  (2) $$
this formula solve the $log(0)$ problem, but if $w^Tx_i$ gets too large, as we already know in [pitfall](https://www.secretflow.org.cn/docs/spu/en/reference/fxp.html), this gets **huge errors**! To get stable and accurate formula to compute loss, we notice $log(1+e^{w^T x_i})$ is well-known *Softplus* function, so we can use the equation of Softplus: $log(1+e^{x}) = log(1 + e^{-|x|}) + max(0, x)$, then we can get:
$$ loss = -\frac{1}{N} \sum_{i=1}^N [y_i w^Tx_i - log(1+e^{-|w^Tx_i|}) - max(w^T x_i, 0)] \quad (3) $$

2. **Cheap but approximate** :Equation (3) can give accurate result, but it contains time-consuming ops($log$, $exp$), which cost a lot! If you just want to compute an approximation of loss(e.g. maybe you want to do early stop with loss), you can try Taylor expansion, which gives:
$$ loss = \frac{1}{N} \sum_{i=1}^N [log(2) - (y-0.5)w^T x_i + 0.125 * (w^T x_i)^2] $$

## Part 3: Run elaborated emulations

> Emulations is an **experimental** feature for now, and is under rapid development, so we do not package the code of sml into spu. Users can try this feature from **source code** and run with bazel .Till now, we only have support for LAN setting(`MULTIPROCESS` mode). `Docker` mode, which runs program like under WAN setting, will be posted in future version.

Finally, we talk about how to do emulations. Comparing to simulator, emulator runs with a simple scheduler like Secretflow does, and offers some facility(e.g. generate mock data) to make benchmark simpler. So spu provides an `Emulator` class and gives an easy-to-use interface.

Usually, the emulation will be done with larger dataset, so we won't run directly in this tutorial notebook. Instead, we will show a big picture on how to design and run emulations for MPC application step by step.

### Setup: define running function

Just like what we do in secretflow, we should first define a python function, which will be run in spu. Here, as an example, we just define a very simple function that accepts data from two parties and return the predicted probability after the model trained(you can also split data into training & validation parts, and return the probs of validation dataset.).

Taking `SSLRSGDClassifier` as an example, we mainly want to argue that policy-sgd is better than naive-sgd in MPC setting, so we can design the following experiments: 

1. Find best `dk_method` and `eps` for policy-sgd: for all datasets, compare the accuracy and efficiency.

2. Compare the accuracy and efficiency when switching `sig_type` for both policy-sgd and naive-sgd.

3. To compare policy-sgd and naive-sgd, we fix `epochs` and test the influence of `learning_rate` and `batch_size`.

In [51]:
# import library
import sml.utils.emulation as emulation

In [52]:
# import model impl which has been tested with simulator
# from Model import model


def run_model(x1, x2, y):
    # here, suppose we divide the dataset into two party
    x = jnp.concatenate((x1, x2), axis=1)
    y = y.reshape((y.shape[0], 1))

    # for sig_type in ['t1', 'sr']
    # for dk_method in ['norm', 'rsqrt']
    # for batch_size in [1024, 2048, 4096]
    # ...
    model = SSLRSGDClassifier(
        epochs=10,
        learning_rate=0.1,
        batch_size=2048,
        sig_type='t1',
        eps=1e-4,
        dk_method='rsqrt',
    )

    model.fit(x, y)
    return model.predict_proba(x)

### Define running config

After designing all the experiments, we can prepare our running config. Currently, we only support `MULTIPROCESS` mode, which uses multiprocess to emulate multi-party and just like running in LAN(`DOCKER` mode which can set `bandwidth` and `latency` to simulate the different WAN settings will be supported in future version).

For now, our goal is to compare the accuracy/efficiency diff when switching hyper-param, running program in LAN can be a good choice. Besides, in order to simulate diverse node deployment ways, users can flexibly configure the number of nodes and device situations yourself. You can get some examples of config in `examples/python/conf/`.

In [53]:
mode = emulation.Mode.MULTIPROCESS  # emulation.Mode.DOCKER for docker not support now
# take the mock config as sample, it deploys some nodes in outsourcing way and use ABY3 protocol
# Note: in MULTIPROCESS mode, bandwidth and latency are NOT working!
emulator = emulation.Emulator(
    emulation.CLUSTER_ABY3_3PC, mode, bandwidth=100, latency=10
)
# start up the running processes or containers
emulator.up()

[2024-02-01 09:25:42,688] Start multiprocess cluster...
[2024-02-01 09:25:42,729] [Process-1] Starting grpc server at 127.0.0.1:61920
[2024-02-01 09:25:42,756] [Process-2] Starting grpc server at 127.0.0.1:61921
[2024-02-01 09:25:42,783] [Process-3] Starting grpc server at 127.0.0.1:61922


[2024-02-01 09:25:42,810] [Process-4] Starting grpc server at 127.0.0.1:61923
[2024-02-01 09:25:42,837] [Process-5] Starting grpc server at 127.0.0.1:61924
[2024-02-01 09:25:43,840] [Process-1] Run : builtin_spu_init at node:0
[2024-02-01 09:25:43,840] [Process-2] Run : builtin_spu_init at node:1
[2024-02-01 09:25:43,840] [Process-3] Run : builtin_spu_init at node:2
[2024-02-01 09:25:43,849] [Process-2] spu-runtime (SPU) initialized
[2024-02-01 09:25:43,849] [Process-3] spu-runtime (SPU) initialized
[2024-02-01 09:25:43,849] [Process-1] spu-runtime (SPU) initialized
I0201 09:25:43.846532 1476252 external/com_github_brpc_brpc/src/brpc/server.cpp:1158] Server[yacl::link::transport::internal::ReceiverServiceImpl] is serving on port=61930.
W0201 09:25:43.846553 1476252 external/com_github_brpc_brpc/src/brpc/server.cpp:1164] Builtin services are disabled according to ServerOptions.has_builtin_services
I0201 09:25:43.846567 1476253 external/com_github_brpc_brpc/src/brpc/server.cpp:1158] Serv

In [54]:
# here, we just use the mock dataset.
# user can prepare your dataset: you need to always choose data that is close to reality!
# IMPORTANT Note: you should always "seal" your data before running in SPU, else spu will treats these data as PUBLIC!!!
# ref: https://www.secretflow.org.cn/docs/spu/en/getting_started/quick_start.html#Move-JAX-program-to-SPU

# step1: prepare your plaintext dataset
x1 = np.random.rand(1000, 40)
y = np.random.randint(0, 2, 1000)
x2 = np.random.rand(1000, 60)

# step2: "seal" the data, you can do it by calling emulator.seal().
x1, x2, y = emulator.seal(x1, x2, y)

[2024-02-01 09:25:43,864] [Process-4] Run : <lambda> at node:3
[2024-02-01 09:25:43,869] [Process-4] Run : <lambda> at node:3
[2024-02-01 09:25:43,871] [Process-4] Run : <lambda> at node:3
[2024-02-01 09:25:43,890] [Process-4] Run : make_shares at node:3
[2024-02-01 09:25:43,899] [Process-4] RunR: builtin_fetch_meta at node:3
[2024-02-01 09:25:43,902] [Process-4] Run : make_shares at node:3
[2024-02-01 09:25:43,911] [Process-4] RunR: builtin_fetch_meta at node:3
[2024-02-01 09:25:43,913] [Process-4] Run : make_shares at node:3
[2024-02-01 09:25:43,915] [Process-4] RunR: builtin_fetch_meta at node:3
[2024-02-01 09:25:44,071] [Process-1] Run : builtin_spu_run at node:0
[2024-02-01 09:25:44,071] [Process-3] Run : builtin_spu_run at node:2
[2024-02-01 09:25:44,071] [Process-2] Run : builtin_spu_run at node:1
[2024-02-01 09:25:44,074] [Process-4] RunR: builtin_fetch_object at node:3
[2024-02-01 09:25:44,074] [Process-4] RunR: builtin_fetch_object at node:3
[2024-02-01 09:25:44,074] [Process

In [55]:
# start running program and get results
result = emulator.run(run_model)(x1, x2, y)

# For safety, can put the program in a try-catch block
emulator.down()

[2024-02-01 09:25:44,171] Shutdown multiprocess cluster...


[2024-02-01 09:25:44.162] [info] [api.cc:158] [Profiling] SPU execution run_model completed, input processing took 1.002e-06s, execution took 0.075888372s, output processing took 2.825e-06s, total time 0.075892199s.
[2024-02-01 09:25:44.162] [info] [api.cc:191] HLO profiling: total time 0.07414353899999998
[2024-02-01 09:25:44.162] [info] [api.cc:194] - pphlo.add, executed 13 times, duration 0.000262654s, send bytes 0
[2024-02-01 09:25:44.162] [info] [api.cc:194] - pphlo.broadcast, executed 2 times, duration 9.728e-06s, send bytes 0
[2024-02-01 09:25:44.162] [info] [api.cc:194] - pphlo.concatenate, executed 2 times, duration 0.002212247s, send bytes 0
[2024-02-01 09:25:44.162] [info] [api.cc:194] - pphlo.constant, executed 9 times, duration 7.7076e-05s, send bytes 0
[2024-02-01 09:25:44.162] [info] [api.cc:194] - pphlo.convert, executed 3 times, duration 0.004569468s, send bytes 8008
[2024-02-01 09:25:44.162] [info] [api.cc:194] - pphlo.dot, executed 21 times, duration 0.009945714s, se

### Put them together

Now we put all these together, we can get a simple paradigm of emulation.

In [58]:
###########################################################################
# 0). Import library & define running function used by emulator
###########################################################################
import sml.utils.emulation as emulation


# normally, we will import a Model that we have tested in simulator, e.g:
# from model import Model


def run_model(x1, x2, y):
    x = jnp.concatenate((x1, x2), axis=1)
    y = y.reshape((y.shape[0], 1))

    model = SSLRSGDClassifier(
        epochs=5,
        learning_rate=0.1,
        batch_size=2048,
        sig_type='t1',
        eps=1e-4,
        dk_method='rsqrt',
    )

    model.fit(x, y)
    return model.predict_proba(x)


###########################################################################
# 2). Define running config and run emulations
###########################################################################
#  Set mode to MULTIPROCESS for LAN test.
mode = emulation.Mode.MULTIPROCESS  # emulation.Mode.DOCKER for docker not supported yet

# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(
    emulation.CLUSTER_ABY3_3PC, mode, bandwidth=100, latency=10
)

# For safety, it's a good practice that putting the running part in a try-catch block
try:
    # start up the running processes
    emulator.up()

    # prepare your dataset here.
    #   a. Normally, you should choose your dataset carefully. e.g. for lr,
    # we need to examine the performance on imbalanced, tail-heavy(or real dataset if possible) datasets.
    #   b. If you just want to get some efficiency number, we have some mock apis to produce dataset(check `examples.python.utils.dataset_utils`).
    #   c. IMPORTANT NOTE: MUST make sure your data has been "sealed" BEFORE running the program! (call emulator.seal())
    x1 = np.random.rand(1000, 40)
    y = np.random.randint(0, 2, 1000)
    x2 = np.random.rand(1000, 60)
    x1, x2, y = emulator.seal(x1, x2, y)

    # magic happens here! running the program in emulator like SPU
    result = emulator.run(run_model)(x1, x2, y)
except Exception as e:
    print(e)
finally:
    emulator.down()

[2024-02-01 09:27:18,572] Start multiprocess cluster...
[2024-02-01 09:27:18,610] [Process-16] Starting grpc server at 127.0.0.1:61920
[2024-02-01 09:27:18,639] [Process-17] Starting grpc server at 127.0.0.1:61921
[2024-02-01 09:27:18,678] [Process-18] Starting grpc server at 127.0.0.1:61922
[2024-02-01 09:27:18,707] [Process-19] Starting grpc server at 127.0.0.1:61923
[2024-02-01 09:27:18,736] [Process-20] Starting grpc server at 127.0.0.1:61924
[2024-02-01 09:27:19,734] [Process-16] Run : builtin_spu_init at node:0
[2024-02-01 09:27:19,735] [Process-17] Run : builtin_spu_init at node:1
[2024-02-01 09:27:19,735] [Process-18] Run : builtin_spu_init at node:2
[2024-02-01 09:27:19,744] [Process-18] spu-runtime (SPU) initialized
[2024-02-01 09:27:19,744] [Process-16] spu-runtime (SPU) initialized
[2024-02-01 09:27:19,744] [Process-17] spu-runtime (SPU) initialized
[2024-02-01 09:27:19,751] [Process-19] Run : <lambda> at node:3
[2024-02-01 09:27:19,756] [Process-19] Run : <lambda> at node:

[2024-02-01 09:27:19.949] [info] [api.cc:158] [Profiling] SPU execution run_model completed, input processing took 1.232e-06s, execution took 0.054217746s, output processing took 3.958e-06s, total time 0.054222936s.
[2024-02-01 09:27:19.949] [info] [api.cc:191] HLO profiling: total time 0.05285722200000001
[2024-02-01 09:27:19.950] [info] [api.cc:194] - pphlo.add, executed 8 times, duration 0.000199415s, send bytes 0
[2024-02-01 09:27:19.950] [info] [api.cc:194] - pphlo.broadcast, executed 2 times, duration 1.2654e-05s, send bytes 0
[2024-02-01 09:27:19.950] [info] [api.cc:194] - pphlo.concatenate, executed 2 times, duration 0.002178695s, send bytes 0
[2024-02-01 09:27:19.950] [info] [api.cc:194] - pphlo.constant, executed 9 times, duration 7.537e-05s, send bytes 0
[2024-02-01 09:27:19.950] [info] [api.cc:194] - pphlo.convert, executed 3 times, duration 0.001068422s, send bytes 8008
[2024-02-01 09:27:19.950] [info] [api.cc:194] - pphlo.dot, executed 11 times, duration 0.007380437s, sen