In [1]:
import torch
import torch.nn as nn
import numpy as np


def set_seed(seed=0):
    seed = 42069  # set a random seed for reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)



In [2]:

def init_weight():
    # y0 dim: (1, 2)
    w1 = np.array([[0.2, 0.3],
                   [0.4, 0.2],
                   [0.3, 0.4]], dtype='f')
    # y1 dim: (1, 3)
    w2 = np.array([[0.2, 0.3, 0.4],
                   [0.4, 0.2, 0.3],
                   [0.3, 0.4, 0.2]], dtype='f')
    # y2 dim: (1, 3)
    w3 = np.array([[0.2, 0.3, 0.4],
                   [0.4, 0.2, 0.3]], dtype='f')
    # y3 dim: (1, 2)
    return w1, w2, w3


class CMlp(nn.Module):

    def __init__(self, r=(1., 1., 1.)):
        super(CMlp, self).__init__()
        w1, w2, w3 = init_weight()
        self.noise = [r[0], r[1] / r[0], r[2] / r[1]]

        self.fc1 = nn.Linear(2, 3, False)
        self.fc1.weight.data = torch.from_numpy(w1)
        self.fc1.weight.data.mul_(self.noise[0])

        self.fc2 = nn.Linear(3, 3, False)
        self.fc2.weight.data = torch.from_numpy(w2)
        self.fc2.weight.data.mul_(self.noise[1])

        self.fc3 = nn.Linear(3, 2, False)
        self.fc3.weight.data = torch.from_numpy(w3)
        self.fc3.weight.data.mul_(self.noise[2])
        self.y2 = None
        self.y3 = None
        self.alpha = None
    def forward(self, x):
        y1 = self.fc1(x)
        self.y2 = self.fc2(y1)
        self.alpha = self.y2.sum()
        self.y3 = self.fc3(self.y2)
        self.y3.retain_grad()
        return self.y3

In [3]:
# setup gpu or cpu
device = 'cuda' if torch.cuda.is_available() else 'cpu'

r = (0.2, 0.4, 0.8)
x = torch.tensor([[0.2, 0.3]], device=device)
y_hat = torch.tensor([[0.5, 0.5]], device=device)

net = CMlp().to(device)
print('----------- plaintext weight ---------------')
for p in net.parameters():
    print(p.data)
y = net(x)
print('y: ', y)
criterion = nn.MSELoss()
loss = criterion(y, y_hat)
loss.backward(retain_graph=True)
print('----------- plaintext grad -----------------')
for p in net.parameters():
    print(p.grad)
w_gradlist = [p.grad for p in net.parameters()]
print('----------- ciphertext weight ---------------')
net_c = CMlp(r).to(device)
for p in net_c.parameters():
    print(p.data)
y_c = net_c(x)
c_loss = criterion(y_c, y_hat)
c_loss.backward(retain_graph=True)
print('----------- ciphertext grad ---------------')
for p in net_c.parameters():
    print(p.grad)
c_w_gradlist = [p.grad.detach().clone() for p in net_c.parameters()]

print('Get yc: ', y_c)
print('Get yc from y: ', y  * 0.8)
print('Ly derivative')
print(y - y_hat)
print(y.grad)
y_grad = y.grad.clone()
print('Lhaty derivative')
print(y_c - y_hat)
print(y_c.grad)
y_c_grad = y_c.grad

----------- plaintext weight ---------------
tensor([[0.2000, 0.3000],
        [0.4000, 0.2000],
        [0.3000, 0.4000]], device='cuda:0')
tensor([[0.2000, 0.3000, 0.4000],
        [0.4000, 0.2000, 0.3000],
        [0.3000, 0.4000, 0.2000]], device='cuda:0')
tensor([[0.2000, 0.3000, 0.4000],
        [0.4000, 0.2000, 0.3000]], device='cuda:0')
y:  tensor([[0.1206, 0.1221]], device='cuda:0', grad_fn=<MmBackward>)
----------- plaintext grad -----------------
tensor([[-0.0401, -0.0602],
        [-0.0424, -0.0636],
        [-0.0401, -0.0602]], device='cuda:0')
tensor([[-0.0295, -0.0318, -0.0409],
        [-0.0246, -0.0265, -0.0341],
        [-0.0345, -0.0371, -0.0477]], device='cuda:0')
tensor([[-0.0531, -0.0508, -0.0497],
        [-0.0529, -0.0506, -0.0495]], device='cuda:0')
----------- ciphertext weight ---------------
tensor([[0.0400, 0.0600],
        [0.0800, 0.0400],
        [0.0600, 0.0800]], device='cuda:0')
tensor([[0.4000, 0.6000, 0.8000],
        [0.8000, 0.4000, 0.6000],
     

$ \frac{\partial \widehat{L}}{\partial \widehat{W}^{(l)}} = ( \frac{\partial L}{\partial y^{(L)}} + (r^L - 1) \cdot y^{L}) \frac{\partial \widehat{y}^{(L)}}{\partial \widehat{W}^{(l)}} $

In [4]:
optim = torch.optim.Optimizer(net.parameters(), {})
optim_c = torch.optim.Optimizer(net_c.parameters(), {})
optim_c.zero_grad()

In [5]:
n3 =  torch.tensor([[.8, .8]], device=device)
r_L = (n3 - 1) * y
print(r_L)
t = y_grad + r_L
y_c.backward(t, retain_graph=True)

tensor([[-0.0241, -0.0244]], device='cuda:0', grad_fn=<MulBackward0>)


### get grad $\frac{\partial \widehat{L}}{\partial \widehat{W}^{(l)}}$

In [6]:
for p in net_c.parameters():
    print(p.grad)

tensor([[-0.1709, -0.2563],
        [-0.1805, -0.2708],
        [-0.1708, -0.2562]], device='cuda:0')
tensor([[-0.0126, -0.0135, -0.0174],
        [-0.0105, -0.0113, -0.0145],
        [-0.0147, -0.0158, -0.0203]], device='cuda:0')
tensor([[-0.0226, -0.0216, -0.0211],
        [-0.0225, -0.0216, -0.0211]], device='cuda:0')


In [7]:
print(*c_w_gradlist, sep='\n')

tensor([[-0.1709, -0.2563],
        [-0.1805, -0.2708],
        [-0.1708, -0.2562]], device='cuda:0')
tensor([[-0.0126, -0.0135, -0.0174],
        [-0.0105, -0.0113, -0.0145],
        [-0.0147, -0.0158, -0.0203]], device='cuda:0')
tensor([[-0.0226, -0.0216, -0.0211],
        [-0.0225, -0.0216, -0.0211]], device='cuda:0')


$ \frac{\partial \widehat{L}}{\partial \widehat{W}^{(l)}} = \frac{r^L}{R_l} \frac{\partial L}{\partial y^{(L)}} \frac{\partial \widehat{y}^{(L)}}{\partial \widehat{W}^{(l)}}+ (r^L - 1) \cdot y^{L} \frac{\partial \widehat{y}^{(L)}}{\partial \widehat{W}^{(l)}} $

In [8]:
print(net_c.noise) # noise for w_ij in each layer

# left of addition
l_tensor_list =[]
for x, n in zip(w_gradlist, net_c.noise):
    l_tensor_list.append(x.mul(1/n).mul(0.8))
optim_c.zero_grad()
y_c.backward(r_L, retain_graph=True)

# right of addition
r_tensor_list = [p.grad.detach().clone() for p in net_c.parameters()]

[0.2, 2.0, 2.0]


### restore ciphertext grad in left of the equation

In [9]:
for l, r in zip(l_tensor_list, r_tensor_list):
    print(l + r)

tensor([[-0.1709, -0.2563],
        [-0.1805, -0.2708],
        [-0.1708, -0.2562]], device='cuda:0')
tensor([[-0.0126, -0.0135, -0.0174],
        [-0.0105, -0.0113, -0.0145],
        [-0.0147, -0.0158, -0.0203]], device='cuda:0')
tensor([[-0.0226, -0.0216, -0.0211],
        [-0.0225, -0.0216, -0.0211]], device='cuda:0')
