In [None]:
import os
from pathlib import Path

In [1]:
from secretnote.utils.version import assert_version
from secretnote.typing.spu import SPUClusterDef, SPUNode, spu_proto
from secretnote.typing.secretflow.device.driver import SFClusterConfig, SFClusterParty

In [2]:
import pydantic
import ray

import secretflow
import spu

assert_version(ray, ray.__version__, '~=2.2.0')
assert_version(pydantic, pydantic.__version__, '>=1.10, <2')
assert_version(secretflow, secretflow.__version__, '~=1.0')
assert_version(spu, spu.__version__, '>=0.4.1')

In [None]:
self_party = os.getenv("SELF_PARTY", "alice")

In [None]:
secretflow.init(
    address="127.0.0.1:6379",
    cluster_config=SFClusterConfig(
        parties={
            "alice": SFClusterParty(address="alice:8080"),
            "bob": SFClusterParty(address="bob:8080"),
        },
        self_party=self_party,
    ).dict(),
    log_to_driver=True,
)

spu = secretflow.SPU(
    cluster_def=SPUClusterDef(
        nodes=[
            SPUNode(party="alice", address="alice:8081"),
            SPUNode(party="bob", address="bob:8081"),
        ],
        runtime_config=spu_proto.RuntimeConfig(
            protocol=spu_proto.SEMI2K,
            field=spu_proto.FM128,
            sigmoid_mode=spu_proto.RuntimeConfig.SIGMOID_REAL,
        ),
    ).dict()
)

In [None]:
alice, bob = secretflow.PYU('alice'), secretflow.PYU('bob')

In [None]:
import jax.numpy as jnp

In [None]:
def read_numbers(filename):
    with open(filename, 'r') as f:
        return jnp.array([int(item.strip()) for item in f.read().split(',')])

In [None]:
fibonacci = alice(read_numbers)(str(Path().joinpath('A000045.txt').resolve()))

In [None]:
pascal = bob(read_numbers)(str(Path().joinpath('A007318.txt').resolve()))

In [None]:
def dot(x, y):
    if x.shape[0] > y.shape[0]:
        x = x[:y.shape[0]]
    else:
        y = y[:x.shape[0]]
    return jnp.dot(x, y)

In [None]:
result = spu(dot)(fibonacci.to(spu), pascal.to(spu))

In [None]:
secretflow.reveal(result)