In [None]:
from secretnote.formal.locations import (
    SPUFieldType,
    SPUProtocolKind,
    SymbolicPYU,
    SymbolicSPU,
    SymbolicWorld,
)
from secretflow.device import SPUCompilerNumReturnsPolicy

sym_world = SymbolicWorld(world=frozenset(("alice", "bob")))
sym_alice = SymbolicPYU("alice")
sym_bob = SymbolicPYU("bob")
sym_spu = SymbolicSPU(
    world=frozenset(("alice", "bob")),
    protocol=SPUProtocolKind.SEMI2K,
    field=SPUFieldType.FM128,
    fxp_fraction_bits=0,
)

In [None]:
import ray
from secretnote.formal.locations import SFConfigSimulation, PortBinding

ray.shutdown()

sym_world.reify(SFConfigSimulation())

alice = sym_alice.reify()
bob = sym_bob.reify()
spu = sym_spu.reify(
    alice=PortBinding(announced_as="127.0.0.1:32767"),
    bob=PortBinding(announced_as="127.0.0.1:32768"),
)

In [None]:
from secretnote.instrumentation.sdk import create_profiler, setup_tracing

setup_tracing()

profiler = create_profiler()
profiler.start()

In [None]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split


def load_dataset_for_training(rand_key: int, party_id: int):
    x, y = load_breast_cancer(return_X_y=True)
    x = (x - np.min(x)) / (np.max(x) - np.min(x))
    x_train, x_test, y_train, y_test = train_test_split(
        x,
        y,
        test_size=0.2,
        random_state=rand_key,
    )
    if party_id == 0:
        return x_train[:, :15], None
    else:
        return x_train[:, 15:], y_train


def load_dataset_for_testing(rand_key: int):
    x, y = load_breast_cancer(return_X_y=True)
    x = (x - np.min(x)) / (np.max(x) - np.min(x))
    x_train, x_test, y_train, y_test = train_test_split(
        x,
        y,
        test_size=0.2,
        random_state=rand_key,
    )
    return x_test, y_test

In [None]:
import jax
from sklearn.metrics import roc_auc_score


def sigmoid(x):
    return 1 / (1 + jax.numpy.exp(-x))


def predict(W, b, inputs):
    """Outputs probability of a label being true."""
    return sigmoid(jax.numpy.dot(inputs, W) + b)


def loss(W, b, inputs, targets):
    """Training loss is the negative log-likelihood of the training examples."""
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jax.numpy.mean(jax.numpy.log(label_probs))


def train_step(W, b, x1, x2, y, learning_rate):
    x = jax.numpy.concatenate([x1, x2], axis=1)
    Wb_grad = grad(loss, (0, 1))(W, b, x, y)
    W -= learning_rate * Wb_grad[0]
    b -= learning_rate * Wb_grad[1]
    return W, b


def fit(W, b, x1, x2, y, epochs=1, learning_rate=1e-2):
    for _ in range(epochs):
        W, b = train_step(W, b, x1, x2, y, learning_rate=learning_rate)
    return W, b


def concatenate_samples(x1, x2):
    return jax.numpy.concatenate([x1, x2], axis=1)


def grad(weights, bias, x, y, *, learning_rate):
    gradients = jax.grad(loss, (0, 1))(weights, bias, x, y)
    weights -= learning_rate * gradients[0]
    bias -= learning_rate * gradients[1]
    return weights, bias


def validate_model(weights, bias, X_test, y_test):
    y_pred = predict(weights, bias, X_test)
    return roc_auc_score(y_test, y_pred)

In [None]:
x1, _ = alice(load_dataset_for_training, num_returns=2)(rand_key=42, party_id=0)
x2, y = bob(load_dataset_for_training, num_returns=2)(rand_key=42, party_id=1)

x1 = x1.to(spu)
x2 = x2.to(spu)
y = y.to(spu)

weights = jax.numpy.zeros((30,))
bias = 0.0

weights = alice(lambda x: x)(weights)
weights = weights.to(spu)

bias = alice(lambda x: x)(bias)
bias = bias.to(spu)

In [None]:
epochs = 10

x = spu(concatenate_samples)(x1, x2)

for _ in range(10):
    weights, bias = spu(
        grad,
        num_returns_policy=SPUCompilerNumReturnsPolicy.FROM_USER,
        user_specified_num_returns=2,
    )(
        weights,
        bias,
        x,
        y,
        learning_rate=1e-2,
    )

In [None]:
from secretflow import reveal

weights = weights.to(bob)
bias = bias.to(bob)
x_test, y_test = bob(load_dataset_for_testing, num_returns=2)(rand_key=42)
auc = reveal(bob(validate_model)(weights, bias, x_test, y_test))

print(f"{weights=}")
print(f"{bias=}")
print(f"{auc=}")

In [None]:
from secretnote.display import visualize_run

profiler.stop()
visualize_run(profiler)