In [None]:
import os
import sys
import json
from pathlib import Path

In [None]:
from spu import spu_pb2 as spu_constants

In [None]:
network_conf = {
    'parties': {
        'alice': {
            'address': 'alice:8080',
        },
        'bob': {
            'address': 'bob:8080',
        },
    },
}

spu_conf = {
    "nodes": [
        {
            "party": "alice",
            "address": "alice:8081",
        },
        {
            "party": "bob",
            "address": "bob:8081",
        },
    ],
    "runtime_config": {
        "protocol": spu_constants.SEMI2K,
        "field": spu_constants.FM128,
        "sigmoid_mode": spu_constants.RuntimeConfig.SIGMOID_REAL,
    },
}

In [None]:
import secretflow as sf

In [None]:
# In Docker container, set via environment
self_party = os.getenv('SELF_PARTY')

In [None]:
sf.init(
    address='127.0.0.1:6379',
    cluster_config={**network_conf, 'self_party': self_party},
    log_to_driver=True,
)

In [None]:
alice, bob = sf.PYU('alice'), sf.PYU('bob')

In [None]:
spu = sf.SPU(cluster_def=spu_conf)

In [None]:
import jax.numpy as jnp

In [None]:
def read_numbers(filename):
    with open(filename, 'r') as f:
        return jnp.array([int(item.strip()) for item in f.read().split(',')])

In [None]:
fibonacci = alice(read_numbers)(str(Path().joinpath('A000045.txt').resolve()))

In [None]:
pascal = bob(read_numbers)(str(Path().joinpath('A007318.txt').resolve()))

In [None]:
def dot(x, y):
    if x.shape[0] > y.shape[0]:
        x = x[:y.shape[0]]
    else:
        y = y[:x.shape[0]]
    return jnp.dot(x, y)

In [None]:
result = spu(dot)(fibonacci.to(spu), pascal.to(spu))

In [None]:
sf.reveal(result)