In [1]:
import torch
import torch.nn as nn
import numpy as np
from numpy.random import default_rng


def set_seed(seed=0):
    seed = 42069  # set a random seed for reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [2]:
def init_weight(random=False):
    if random:
        rng = default_rng(0)
         # y0 dim: (1, 2)
        w1 = rng.standard_normal((3, 2), dtype='f')
        # y1 dim: (1, 3)
        w2 = rng.standard_normal((3, 3), dtype='f')
        # y2 dim: (1, 3)
        w3 = rng.standard_normal((2, 3), dtype='f')
    else:
        # y0 dim: (1, 2)
        w1 = np.array([[0.2, 0.3],
                       [0.4, 0.2],
                       [0.3, 0.4]], dtype='f')
        # y1 dim: (1, 3)
        w2 = np.array([[0.2, 0.3, 0.4],
                       [0.4, 0.2, 0.3],
                       [0.3, 0.4, 0.2]], dtype='f')
        # y2 dim: (1, 3)
        w3 = np.array([[0.2, 0.3, 0.4],
                       [0.4, 0.2, 0.3]], dtype='f')
        # y3 dim: (1, 2)
    return w1, w2, w3


class CMlp(nn.Module):

    def __init__(self, r=(1., 1., 1.), weight_random=False):
        super(CMlp, self).__init__()
        w1, w2, w3 = init_weight(weight_random)
        self.noise = [r[0], r[1] / r[0], r[2] / r[1]]
        if weight_random:
            self.fc1 = nn.Linear(2, 3, False)
            self.fc1.weight.data = torch.from_numpy(w1)
            self.fc1.weight.data.mul_(self.noise[0])

            self.fc2 = nn.Linear(3, 3, False)
            self.fc2.weight.data = torch.from_numpy(w2)
            self.fc2.weight.data.mul_(self.noise[1])

            self.fc3 = nn.Linear(3, 2, False)
            self.fc3.weight.data = torch.from_numpy(w3)
            self.fc3.weight.data.mul_(self.noise[2])
        else:
            self.fc1 = nn.Linear(2, 3, False)
            self.fc1.weight.data = torch.from_numpy(w1)
            self.fc1.weight.data.mul_(self.noise[0])

            self.fc2 = nn.Linear(3, 3, False)
            self.fc2.weight.data = torch.from_numpy(w2)
            self.fc2.weight.data.mul_(self.noise[1])

            self.fc3 = nn.Linear(3, 2, False)
            self.fc3.weight.data = torch.from_numpy(w3)
            self.fc3.weight.data.mul_(self.noise[2])
        self.y2 = None
        self.y3 = None
        self.alpha = None
    def forward(self, x):
        y1 = self.fc1(x)
        self.y2 = self.fc2(y1)
        self.alpha = self.y2.sum()
        self.y3 = self.fc3(self.y2)
        self.y3.retain_grad()
        return self.y3

In [3]:
# setup gpu or cpu
device = 'cuda' if torch.cuda.is_available() else 'cpu'

r = (0.2, 0.4, 0.8)
x = torch.tensor([[0.2, 0.3]], device=device)
y_hat = torch.tensor([[0.5, 0.5]], device=device)

net = CMlp(weight_random=True).to(device)
print('----------- plaintext weight ---------------')
for p in net.parameters():
    print(p.data)
y = net(x)
print('y: ', y)
criterion = nn.MSELoss()
loss = criterion(y, y_hat)
loss.backward(retain_graph=True)
print('----------- plaintext grad -----------------')
for p in net.parameters():
    print(p.grad)
w_gradlist = [p.grad for p in net.parameters()]
print('----------- ciphertext weight ---------------')
net_c = CMlp(r, weight_random=True).to(device)
for p in net_c.parameters():
    print(p.data)
y_c = net_c(x)
c_loss = criterion(y_c, y_hat)
c_loss.backward(retain_graph=True)
print('----------- ciphertext grad ---------------')
for p in net_c.parameters():
    print(p.grad)
c_w_gradlist = [p.grad.detach().clone() for p in net_c.parameters()]

print('Get yc: ', y_c)
print('Get yc from y: ', y  * 0.8)
print('Ly derivative')
print(y - y_hat)
print(y.grad)
y_grad = y.grad.clone()
print('Lhaty derivative')
print(y_c - y_hat)
print(y_c.grad)
y_c_grad = y_c.grad

----------- plaintext weight ---------------
tensor([[ 1.1176, -1.3871],
        [-0.4266, -0.8036],
        [ 0.6014, -0.0750]])
tensor([[ 0.0597, -0.0320, -0.1855],
        [ 1.2048,  0.7775, -1.3583],
        [ 0.7698, -0.8702,  1.0998]])
tensor([[-0.9585, -1.2749, -1.3653],
        [-1.4743,  0.4335, -0.3281]])
y:  tensor([[ 0.4749, -0.3197]], grad_fn=<MmBackward>)
----------- plaintext grad -----------------
tensor([[-0.0165, -0.0248],
        [-0.1109, -0.1664],
        [ 0.1088,  0.1632]])
tensor([[-0.2374, -0.4023,  0.1205],
        [ 0.0623,  0.1056, -0.0316],
        [-0.0584, -0.0990,  0.0297]])
tensor([[ 4.8124e-04,  1.5515e-02, -6.1016e-03],
        [ 1.5730e-02,  5.0711e-01, -1.9944e-01]])
----------- ciphertext weight ---------------
tensor([[ 0.2235, -0.2774],
        [-0.0853, -0.1607],
        [ 0.1203, -0.0150]])
tensor([[ 0.1194, -0.0640, -0.3710],
        [ 2.4095,  1.5549, -2.7166],
        [ 1.5397, -1.7405,  2.1995]])
tensor([[-1.9169, -2.5499, -2.7307],
       

$ \frac{\partial \widehat{L}}{\partial \widehat{W}^{(l)}} = ( \frac{\partial L}{\partial y^{(L)}} + (r^L - 1) \cdot y^{L}) \frac{\partial \widehat{y}^{(L)}}{\partial \widehat{W}^{(l)}} $

In [4]:
optim = torch.optim.Optimizer(net.parameters(), {})
optim_c = torch.optim.Optimizer(net_c.parameters(), {})
optim_c.zero_grad()

In [5]:
n3 =  torch.tensor([[.8, .8]], device=device)
r_L = (n3 - 1) * y
print(r_L)
t = y_grad + r_L
y_c.backward(t, retain_graph=True)

tensor([[-0.0950,  0.0639]], grad_fn=<MulBackward0>)


### get grad $\frac{\partial \widehat{L}}{\partial \widehat{W}^{(l)}}$

In [6]:
for p in net_c.parameters():
    print(p.grad)

tensor([[ 0.1441,  0.2162],
        [-0.4268, -0.6402],
        [ 0.3697,  0.5545]])
tensor([[-0.0947, -0.1605,  0.0481],
        [ 0.0135,  0.0228, -0.0068],
        [-0.0317, -0.0538,  0.0161]])
tensor([[ 0.0009,  0.0297, -0.0117],
        [ 0.0058,  0.1870, -0.0736]])


In [7]:
print(*c_w_gradlist, sep='\n')

tensor([[ 0.1441,  0.2162],
        [-0.4268, -0.6402],
        [ 0.3697,  0.5545]])
tensor([[-0.0947, -0.1605,  0.0481],
        [ 0.0135,  0.0228, -0.0068],
        [-0.0317, -0.0538,  0.0161]])
tensor([[ 0.0009,  0.0297, -0.0117],
        [ 0.0058,  0.1870, -0.0736]])


$ \frac{\partial \widehat{L}}{\partial \widehat{W}^{(l)}} = \frac{r^L}{R_l} \frac{\partial L}{\partial W^{(l)}}+ (r^L - 1) \cdot y^{L} \frac{\partial \widehat{y}^{(L)}}{\partial \widehat{W}^{(l)}} $

In [8]:
print(net_c.noise) # noise for w_ij in each layer

# left of addition
l_tensor_list =[]
for x, n in zip(w_gradlist, net_c.noise):
    l_tensor_list.append(x.mul(1/n).mul(0.8))
optim_c.zero_grad()
y_c.backward(r_L, retain_graph=True)

# right of addition
r_tensor_list = [p.grad.detach().clone() for p in net_c.parameters()]

[0.2, 2.0, 2.0]


### restore ciphertext grad in left of the equation

In [9]:
for l, r in zip(l_tensor_list, r_tensor_list):
    print(l + r)

tensor([[ 0.1441,  0.2162],
        [-0.4268, -0.6402],
        [ 0.3697,  0.5545]])
tensor([[-0.0947, -0.1605,  0.0481],
        [ 0.0135,  0.0228, -0.0068],
        [-0.0317, -0.0538,  0.0161]])
tensor([[ 0.0009,  0.0297, -0.0117],
        [ 0.0058,  0.1870, -0.0736]])
