In [1]:
import spu.utils.distributed as ppd

# initialized the distributed environment.
ppd.init(ppd.SAMPLE_NODES_DEF, ppd.SAMPLE_DEVICES_DEF)

In [2]:
ppd.current().nodes_def

{'node:0': '127.0.0.1:9327',
 'node:1': '127.0.0.1:9328',
 'node:2': '127.0.0.1:9329'}

In [3]:
ppd.current().devices

{'SPU': SPU(SPU) hosted by: ['127.0.0.1:9327', '127.0.0.1:9328', '127.0.0.1:9329'],
 'P1': PYU(P1) hosted by: 127.0.0.1:9327,
 'P2': PYU(P2) hosted by: 127.0.0.1:9328}

In [4]:
print(ppd.device('SPU').details())

name: SPU
hosted by: ['127.0.0.1:9327', '127.0.0.1:9328', '127.0.0.1:9329']
internal addrs: ['127.0.0.1:9437', '127.0.0.1:9438', '127.0.0.1:9439']
protocol: ABY3
field: FM128
enable_pphlo_profile: true



In [5]:
import numpy as np
import jax.numpy as jnp

def make_rand():
    np.random.seed()
    return np.random.randint(100, size=(1, ))

def greater(x, y):
    return jnp.greater(x, y)

x = make_rand()
y = make_rand()
ans = greater(x, y)

print(f"x = {x}")
print(f"y = {y}")
print(f"x>y = {ans}")

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


x = [8]
y = [1]
x>y = [ True]


In [6]:
# run make_rand on P1, the value is visible for P1 only.
x = ppd.device("P1")(make_rand)()

# run make_rand on P2, the value is visible for P2 only.
y = ppd.device("P2")(make_rand)()

# run greater on SPU, it automatically fetches x/y from P1/P2 (as ciphertext), and compute the result securely.
ans = ppd.device("SPU")(greater)(x, y)

In [7]:
x, y, ans

(DeviceObject(140665754876848 at P1),
 DeviceObject(140665754877328 at P2),
 DeviceObject(140665755360128 at SPU))

In [8]:
"x>y = ", ppd.get(ans)

('x>y = ', array([False]))

In [9]:
x_revealed = ppd.get(x)
y_revealed = ppd.get(y)
x_revealed, y_revealed, np.greater(x_revealed, y_revealed)

(array([31]), array([86]), array([False]))