In [1]:
from secretnote.compat.secretflow.device.driver import SFConfigSimulationFullyManaged

secretflow_config = SFConfigSimulationFullyManaged(parties=["alice", "bob"])


In [2]:
import secretflow

secretflow.shutdown()
secretflow.init(**secretflow_config.dict())


2023-10-09 12:55:38,057	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2023-10-09 12:55:40,393	INFO worker.py:1612 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


In [3]:
alice = secretflow.PYU("alice")
bob = secretflow.PYU("bob")


In [4]:
from secretnote.compat.spu import (
    SPUConfig,
    SPUClusterDef,
    SPUNode,
    SPUProtocolKind,
    SPUFieldType,
    SPURuntimeConfig,
)

spu_config = SPUConfig(
    cluster_def=SPUClusterDef(
        nodes=[
            SPUNode(party="alice", address="localhost:32767"),
            SPUNode(party="bob", address="localhost:32768"),
        ],
        runtime_config=SPURuntimeConfig(
            protocol=SPUProtocolKind.SEMI2K,
            field=SPUFieldType.FM128,
        ),
    ),
)


In [5]:
spu = secretflow.SPU(**spu_config.dict())


In [6]:
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor

from secretnote.instrumentation import Profiler
from secretnote.instrumentation.models import ProfilerRule


resource = Resource(attributes={SERVICE_NAME: "simulation"})
provider = TracerProvider(resource=resource)
provider.add_span_processor(
    SimpleSpanProcessor(OTLPSpanExporter(endpoint="localhost:4317", insecure=True)),
)
trace.set_tracer_provider(provider)


In [7]:
import jax.numpy as jnp
import re


def dot(x, y):
    return jnp.dot(x, y)


rules = [
    ProfilerRule(
        file=re.compile(r"^secretflow/.*"),
        func_name=re.compile(r".*"),
    ),
]

with Profiler(rules):
    x = secretflow.to(alice, jnp.asarray([1, 2, 3]))
    y = secretflow.to(bob, jnp.asarray([1, 2, 3]))
    z = x.to(spu)
    w = y.to(spu)
    r = spu(dot)(z, w)
    s = r.to(alice)
    secretflow.reveal(s)


INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
INFO:jax._src.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
[2m[36m(_run pid=45181)[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
[2m[36m(_run pid=45181)[0m INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
[2m[36m(_run pid=45181)[0m INFO:jax._s

[2m[36m(SPURuntime pid=45199)[0m 2023-10-09 12:55:42.340 [info] [default_brpc_retry_policy.cc:DoRetry:52] socket error, sleep=1000000us and retry
[2m[36m(SPURuntime pid=45199)[0m 2023-10-09 12:55:43.341 [info] [default_brpc_retry_policy.cc:LogHttpDetail:33] cntl ErrorCode '64', http status code '200', response header '', error msg '[E61]Fail to connect Socket{id=0 addr=127.0.0.1:32768} (0x0x138118000): Connection refused [R1][E64]Not connected to 127.0.0.1:32768 yet, server_id=0'
[2m[36m(SPURuntime pid=45199)[0m 2023-10-09 12:55:43.341 [info] [default_brpc_retry_policy.cc:DoRetry:75] aggressive retry, sleep=1000000us and retry
[2m[36m(SPURuntime pid=45201)[0m 2023-10-09 12:55:44.345 [info] [default_brpc_retry_policy.cc:DoRetry:71] not retry for reached rcp timeout, ErrorCode '1008', error msg '[E1008]Reached timeout=2000ms @127.0.0.1:32767'
[2m[36m(SPURuntime pid=45199)[0m 2023-10-09 12:55:44.341 [info] [default_brpc_retry_policy.cc:LogHttpDetail:33] cntl ErrorCode '64',