In [2]:
import torch
from torch.autograd import Variable, Function
import inspect
import random
import copy

# import functools
# TODO: Use functools.wrap to get original function/method dir attributes

In [3]:
BASE = 10
KAPPA = 3 # ~29 bits

PRECISION_INTEGRAL = 2
PRECISION_FRACTIONAL = 5
PRECISION = PRECISION_INTEGRAL + PRECISION_FRACTIONAL
BOUND = BASE**PRECISION

# Q field
Q = 2**41 # < 64 bits
#Q = 2147483648
Q_MAXDEGREE = 1
#assert Q > BASE**(PRECISION * Q_MAXDEGREE) # supported multiplication degree (without truncation)
#assert Q > 2*BOUND * BASE**KAPPA # supported kappa when in positive range 

# P field
P = 1802216888453791673313287943102424579859887305661122324585863735744776691801009887 # < 270 bits
P_MAXDEGREE = 9
#assert P > Q
#assert P > BASE**(PRECISION * P_MAXDEGREE)


In [16]:
def encode(rational,field=Q,precision_fractional=PRECISION_FRACTIONAL):
    upscaled = (rational * BASE**precision_fractional).long()
    upscaled.remainder_(field)
    return upscaled
def decode(field_element,field=Q,precision_fractional=PRECISION_FRACTIONAL):
    field_element = field_element.data
    neg_values = field_element.gt(field)
    #pos_values = field_element.le(field)
    #upscaled = field_element*(neg_valuese+pos_values)
    field_element[neg_values] = field-field_element[neg_values]
    rational = field_element.float()/ BASE**precision_fractional
    return rational
def share(secret,field=Q):
    first = torch.LongTensor(secret.shape).random_(field)
    second = (secret - first)% field
    return [first,second]
def reconstruct(shares ,field=Q):
    return sum(shares)%field
    

In [22]:
def send_shrare(value):
    raise NotImplementedError()
def receive_share():
    raise NotImplementedError()
def swap_shares(share,party):
    if (party == 0):
        send_share(share)
        share_other = receive_share()
    elif (party == 1):
        share_other = receive_share()
        send_share(share)
    return share_other

In [35]:
def public_add(x,y,party):
    if (party ==0):
        return x+y
    elif(party == 1):
        return x

In [34]:
def spdz_mul(x,y,party,field=Q):
    if a.shape != b.shape:
        raise ValueError()
    m,n = a.shape
    triple = generate_mul_triple_communication(m,n,party,field)
    a,b,c = triple
    d = x - a
    e = y - b
    
    d_other = swap_shares(d,party)
    e_other = swap_shares(e,party)
    delta = reconstruct([d,d_other],field)
    epsilon = reconstruct([e,e_other],field)
    r = delta * epsilon
    s = a * epsilon
    t = b * delta
    share = s + t + c
    share = public_add(share,r,party)
    share = truncate(share)
    return share

In [36]:
def spdz_matmul(x,y,party,field=Q):
    x_height = x.shape[0]
    x_width = x.shape[1]
    
    y_height = y.shape[0]
    y_width = y.shape[1]
    
    assert x_width == y_height
    
    r, s, t = generate_matmul_triple_communication(x_height,y_width,x_width, party,field)

    rho_local = x - r
    sigma_local = y - s
    
    # Communication
    rho_other = swap_shares(rho_local, party, socket)
    sigma_other = swap_shares(sigma_local, party, socket)
    
    # They both add up the shares locally
    rho = reconstruct([rho_local, rho_other],field)
    sigma = reconstruct([sigma_local, sigma_other],field)

    r_sigma = r @ sigma    
    rho_s = rho @ s 

    share =  r_sigma + rho_s + t 
    
    rs = rho @ sigma

    share = add_public(share, rs)
    share = truncate(share)   
    return share

In [30]:
def generate_mul_triple(m,n,field=Q):
    r = torch.LongTensor(m,n).random_(field)
    s = torch.LongTensor(m,m).random_(field)
    t = r * s 
    return r,s,t

In [33]:
def generate_mul_triple_communication(m,n,party,field=Q):
    if (party==0):
        r,s,t = generate_mul_triple(m,n,field)
        
        r_alice, r_bob = share(r)
        s_alice, s_bob = share(s)
        t_alice, t_bob = share(t)
        
        reponse_r = swap_shares(r_bob,party)
        reponse_s = swap_shares(s_bob,party)
        reponse_t = swap_shares(t_bob,party)
        
        triple_alice = [r_alice,s_alice,t_alice]
        return triple_alice
    elif (party == 1):
        r_bob = swap_shares(torch.LongTensor(m,n).zeros(),party)
        s_bob = swap_shares(torch.LongTensor(m,n).zeros(),party)
        t_bob = swap_shares(torch.LongTensor(m,n).zeros(),party)
        triple_bob = [r_bob,s_bob,t_bob]
        return triple_bob

In [40]:
def generate_matmul_triple(m,n,k,field=Q):
    r = torch.LongTensor(m,k).random_(field)
    s = torch.LongTensor(k,n).random_(field)
    t = (r @ s) % field
    return r, s, t

In [5]:
class EncryptedAdd(Function):
    
    @staticmethod
    def forward(ctx, a, b,field=Q):
        return a+b % field
        # compute a + b on encrypted data - they are regular PyTorch tensors
        
    @staticmethod
    def backward(ctx, grad_out):
        grad_out = VariableProxy(grad_out.data)
        return grad_out.var,grad_out.var
        # not grad_out operators are overloaded
    

In [6]:
class EncryptedMult(Function):
    
    @staticmethod
    def forward(ctx, a, b):
        ctx.save_for_backward(a,b)       
        return spdz_mul(a,b)
        # compute a * b on encrypted data - they are regular PyTorch tensors
        
    @staticmethod
    def backward(ctx, grad_out):
        a,b = ctx.saved_tensors
        grad_out = grad_out
        return Variable(spdz_mul(grad_out.data,b)),Variable(spdz_mul(grad_out.data,a))
        # not grad_out operators are overloaded

In [24]:
def spdz_matmul(a,b):
    return a @ b

In [26]:
class EncryptedMatmul(Function):
    
    @staticmethod
    def forward(ctx,a,b):
        ctx.save_for_backward(a,b)
        return spdz_matmul(a,b)
    def backward(ctx,grad_out):
        a,b = ctx.saved_tensors
        raise NotImplementedError()

In [28]:
class VariableProxy(object):
    
    def __init__(self, var, field=Q, requires_grad=True):
        self.var = Variable(var,requires_grad=requires_grad)
        self.var = self.var

    def __add__(self, other):
        return (EncryptedAdd.apply(self.var, other.var))
    
    def __mul__(self,other):
        return (EncryptedMult.apply(self.var, other.var))
    def __matmul__(self,other):
        retunn (EncryptedMatmul.apply(self.var,other.var))
    
    def grad(self):
        return self.var.grad
    

In [17]:
x = VariableProxy(encode(torch.FloatTensor([1,1,1])),requires_grad=True)
y = VariableProxy(encode(torch.FloatTensor([2,3,4])),requires_grad=True)

In [18]:
z = x + y

In [19]:
z

Variable containing:
 3.0000e+05
 4.0000e+05
 5.0000e+05
[torch.LongTensor of size 3]

In [20]:
decode(z)


 3
 4
 5
[torch.FloatTensor of size 3]

In [223]:
z.backward(torch.FloatTensor([1]))

In [224]:
x.grad()

Variable containing:
 1
[torch.FloatTensor of size 1]

In [42]:
generate_matmul_triple(3,4,5,10)

(
  4  6  6  1  4
  6  9  1  9  1
  1  7  0  8  1
 [torch.IntTensor of size 3x5], 
  3  1  6  3
  4  0  4  4
  5  3  3  1
  0  8  1  8
  3  0  7  4
 [torch.IntTensor of size 5x4], 
  8  0  5  6
  2  1  1  1
  4  5  9  9
 [torch.IntTensor of size 3x4])

In [215]:
ds =  torch.FloatTensor([1,2.5,3])
print(ds)
dsa = encode(ds)
print(dsa)
dsq = decode(dsa)
print(dsq)
alice,bob = share(dsa)
print(alice)
print(bob)
dsr = reconstruct([alice,bob])
print(dsr)


 1.0000
 2.5000
 3.0000
[torch.FloatTensor of size 3]


 1.0000e+05
 2.5000e+05
 3.0000e+05
[torch.LongTensor of size 3]


 1.0000
 2.5000
 3.0000
[torch.FloatTensor of size 3]


 1.7385e+12
 1.3528e+12
 1.1383e+12
[torch.LongTensor of size 3]


-1.7385e+12
-1.3528e+12
-1.1382e+12
[torch.LongTensor of size 3]


 1.0000e+05
 2.5000e+05
 3.0000e+05
[torch.LongTensor of size 3]

