In [None]:
from secretnote.compat.secretflow.device.driver import SFConfigSimulationFullyManaged

secretflow_config = SFConfigSimulationFullyManaged(parties=["alice", "bob"])

In [None]:
import secretflow

secretflow.shutdown()
secretflow.init(**secretflow_config.dict())

In [None]:
alice = secretflow.PYU("alice")
bob = secretflow.PYU("bob")

In [None]:
from secretnote.compat.spu import (
    SPUConfig,
    SPUClusterDef,
    SPUNode,
    SPUProtocolKind,
    SPUFieldType,
    SPURuntimeConfig,
)

spu_config = SPUConfig(
    cluster_def=SPUClusterDef(
        nodes=[
            SPUNode(party="alice", address="localhost:32767"),
            SPUNode(party="bob", address="localhost:32768"),
        ],
        runtime_config=SPURuntimeConfig(
            protocol=SPUProtocolKind.SEMI2K,
            field=SPUFieldType.FM128,
        ),
    ),
)

In [None]:
spu = secretflow.SPU(**spu_config.dict())

In [None]:
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor

from secretnote.instrumentation import ProfilingInstrumentor, MermaidExporter

mermaid = MermaidExporter()

resource = Resource(attributes={SERVICE_NAME: "simulation"})
provider = TracerProvider(resource=resource)
provider.add_span_processor(
    BatchSpanProcessor(OTLPSpanExporter(endpoint="localhost:4317", insecure=True)),
)
provider.add_span_processor(
    BatchSpanProcessor(mermaid),
)
trace.set_tracer_provider(provider)

In [None]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer


def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):
    x, y = load_breast_cancer(return_X_y=True)
    x = (x - np.min(x)) / (np.max(x) - np.min(x))
    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.2, random_state=42
    )

    if train:
        if party_id:
            if party_id == 1:
                return x_train[:, :15], _
            else:
                return x_train[:, 15:], y_train
        else:
            return x_train, y_train
    else:
        return x_test, y_test

In [None]:
import jax.numpy as jnp


def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))


# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)


# Training loss is the negative log-likelihood of the training examples.
def loss(W, b, inputs, targets):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.mean(jnp.log(label_probs))

In [None]:
from sklearn.metrics import roc_auc_score


def validate_model(W, b, X_test, y_test):
    y_pred = predict(W, b, X_test)
    return roc_auc_score(y_test, y_pred)

In [None]:
from jax import grad


def train_step(W, b, x1, x2, y, learning_rate):
    x = jnp.concatenate([x1, x2], axis=1)
    Wb_grad = grad(loss, (0, 1))(W, b, x, y)
    W -= learning_rate * Wb_grad[0]
    b -= learning_rate * Wb_grad[1]
    return W, b

In [None]:
def fit(W, b, x1, x2, y, epochs=1, learning_rate=1e-2):
    for _ in range(epochs):
        W, b = train_step(W, b, x1, x2, y, learning_rate=learning_rate)
    return W, b

In [None]:
instrumenter = ProfilingInstrumentor()
instrumenter.start()

In [None]:
x1, _ = alice(breast_cancer)(party_id=1)
x2, y = bob(breast_cancer)(party_id=2)

In [None]:
device = spu

W = jnp.zeros((30,))
b = 0.0

W_, b_, x1_, x2_, y_ = (
    secretflow.to(alice, W).to(device),
    secretflow.to(alice, b).to(device),
    x1.to(device),
    x2.to(device),
    y.to(device),
)

In [None]:
W_, b_ = device(
    fit,
    static_argnames=["epochs"],
    num_returns_policy=secretflow.device.SPUCompilerNumReturnsPolicy.FROM_USER,
    user_specified_num_returns=2,
)(W_, b_, x1_, x2_, y_, epochs=10, learning_rate=1e-2)

In [None]:
X_test, y_test = breast_cancer(train=False)
auc = validate_model(secretflow.reveal(W_), secretflow.reveal(b_), X_test, y_test)
print(f"auc={auc}")

In [None]:
instrumenter.stop()

In [None]:
print(mermaid.graph())