In [1]:
import torch
import torch.nn as nn
import numpy as np
from numpy.random import default_rng


def set_seed(seed=0):
    seed = 42069  # set a random seed for reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [2]:
def init_weight(random=False):
    if random:
        rng = default_rng(0)
         # y0 dim: (1, 2)
        w1 = rng.standard_normal((3, 2), dtype='f')
        # y1 dim: (1, 3)
        w2 = rng.standard_normal((3, 3), dtype='f')
        # y2 dim: (1, 3)
        w3 = rng.standard_normal((2, 3), dtype='f')
    else:
        # y0 dim: (1, 2)
        w1 = np.array([[0.2, 0.3],
                       [0.4, 0.2],
                       [0.3, 0.4]], dtype='f')
        # y1 dim: (1, 3)
        w2 = np.array([[0.2, 0.3, 0.4],
                       [0.4, 0.2, 0.3],
                       [0.3, 0.4, 0.2]], dtype='f')
        # y2 dim: (1, 3)
        w3 = np.array([[0.2, 0.3, 0.4],
                       [0.4, 0.2, 0.3]], dtype='f')
        # y3 dim: (1, 2)
    return w1, w2, w3


class CMlp(nn.Module):

    def __init__(self, encrypt=False, weight_random=False):
        super(CMlp, self).__init__()
        w1, w2, w3 = init_weight(random=weight_random)
        self.fc1 = nn.Linear(2, 3, False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(3, 3, False)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(3, 2, False)

        if encrypt:
            noise_rng = default_rng(1234)
            self.r1 = np.absolute(noise_rng.standard_normal((3, 1), dtype='f'))
            self.r2 = np.absolute(noise_rng.standard_normal((3, 1), dtype='f'))
            self.r3 = np.absolute(noise_rng.standard_normal((2, 1), dtype='f'))
            self.fc1.weight.data = torch.from_numpy(w1 * self.r1)
            self.fc2.weight.data = torch.from_numpy(w2 * self.r2 / self.r1.transpose())
            self.fc3.weight.data = torch.from_numpy(w3 * self.r3 / self.r2.transpose())
        else:
            self.fc1.weight.data = torch.from_numpy(w1)
            self.fc2.weight.data = torch.from_numpy(w2)
            self.fc3.weight.data = torch.from_numpy(w3)

        self.y2 = None
        self.y3 = None
        self.alpha = None

    def forward(self, x):
        y1 = self.fc1(x)
        self.y2 = self.fc2(self.relu1(y1))
        self.alpha = self.y2.sum()
        self.y3 = self.fc3(self.relu2(self.y2))
        self.y3.retain_grad()
        return self.y3

In [3]:
# setup gpu or cpu
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'

x = torch.tensor([[0.2, 0.3]], device=device)
y_hat = torch.tensor([[0.5, 0.5]], device=device)

net = CMlp(weight_random=True).to(device)
print('----------- plaintext weight ---------------')
for p in net.parameters():
    print(p.data)
y = net(x)
print('y: ', y)
criterion = nn.MSELoss()
loss = criterion(y, y_hat)
loss.backward(retain_graph=True)
print('----------- plaintext grad -----------------')
for p in net.parameters():
    print(p.grad)
w_gradlist = [p.grad for p in net.parameters()]
print('----------- ciphertext weight ---------------')
net_c = CMlp(encrypt=True, weight_random=True).to(device)
for p in net_c.parameters():
    print(p.data)
y_c = net_c(x)
c_loss = criterion(y_c, y_hat)
c_loss.backward(retain_graph=True)
print('----------- ciphertext grad ---------------')
for p in net_c.parameters():
    print(p.grad)
c_w_gradlist = [p.grad.detach().clone() for p in net_c.parameters()]

print('Get yc: ', y_c)
r3 = torch.from_numpy(net_c.r3.transpose()).to(device)
print('Get yc from y: ', y  * r3)
print('Ly derivative')
print(y - y_hat)
print(y.grad)
y_grad = y.grad.clone()
print('Lhaty derivative')
print(y_c - y_hat)
print(y_c.grad)
y_c_grad = y_c.grad

----------- plaintext weight ---------------
tensor([[ 1.1176, -1.3871],
        [-0.4266, -0.8036],
        [ 0.6014, -0.0750]])
tensor([[ 0.0597, -0.0320, -0.1855],
        [ 1.2048,  0.7775, -1.3583],
        [ 0.7698, -0.8702,  1.0998]])
tensor([[-0.9585, -1.2749, -1.3653],
        [-1.4743,  0.4335, -0.3281]])
y:  tensor([[-0.1468, -0.0353]], grad_fn=<MmBackward>)
----------- plaintext grad -----------------
tensor([[0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.2329, 0.3493]])
tensor([[0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.1035]])
tensor([[-0.0000, -0.0000, -0.0696],
        [-0.0000, -0.0000, -0.0576]])
----------- ciphertext weight ---------------
tensor([[ 2.1577, -2.6780],
        [-1.1628, -2.1905],
        [ 0.9163, -0.1143]])
tensor([[ 0.0345, -0.0131, -0.1358],
        [ 0.2053,  0.0938, -0.2933],
        [ 0.5553, -0.4446,  1.0052]])
tensor([[-0.1470, -0.6633, -0.1678],
        [-0.4237,  0.4225, -0.0755]])
--------

$ \frac{\partial \widehat{L}}{\partial \widehat{W}^{(l)}} = ( \frac{\partial L}{\partial y^{(L)}} + (r^L - 1) \cdot y^{L}) \frac{\partial \widehat{y}^{(L)}}{\partial \widehat{W}^{(l)}} $

In [4]:
optim = torch.optim.Optimizer(net.parameters(), {})
optim_c = torch.optim.Optimizer(net_c.parameters(), {})
optim_c.zero_grad()

In [5]:
r_L = r3.clone()
r_L_1y = (r_L - 1) * y
print(r_L)
t = y_grad + r_L_1y
y_c.backward(t, retain_graph=True)

tensor([[0.1711, 0.3206]])


### get grad $\frac{\partial \widehat{L}}{\partial \widehat{W}^{(l)}}$

In [6]:
for p in net_c.parameters():
    print(p.grad)

print('Assert cipher grad') 
print(*c_w_gradlist, sep='\n')

tensor([[0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0255, 0.0382]])
tensor([[0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0189]])
tensor([[ 0.0000,  0.0000, -0.0786],
        [ 0.0000,  0.0000, -0.0766]])
Assert cipher grad
tensor([[0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0255, 0.0382]])
tensor([[0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0189]])
tensor([[-0.0000, -0.0000, -0.0786],
        [-0.0000, -0.0000, -0.0766]])


# Equation

$ \frac{\partial \widehat{L}}{\partial \widehat{W}^{(l)}} = \frac{r^L}{R_l} \frac{\partial L}{\partial {W}^{(l)}}+ (r^L - 1) \cdot y^{L} \frac{\partial \widehat{y}^{(L)}}{\partial \widehat{W}^{(l)}} $

$ R_l \cdot (\frac{1}{r^L} \cdot \frac{\partial \widehat{L}}{\partial \widehat{y}^{(L)}} \cdot \frac{\partial \widehat{y}^{(L)}}{\partial \widehat{W}^{(l)}} - \frac{1}{r^L} \cdot (r^L - 1) \cdot y^{L} \frac{\partial \widehat{y}^{(L)}}{\partial \widehat{W}^{(l)}}) = \frac{\partial L}{\partial {W}^{(l)}} $

In [7]:
# Denoise vars
denoise = [1 / net_c.r1, net_c.r1.transpose() / net_c.r2, net_c.r2.transpose() / net_c.r3]
print(*denoise, sep='\n') # noise for w_ij in each layer

[[0.51797825]
 [0.3668524 ]
 [0.65637   ]]
[[1.7304813 2.443358  1.3656197]
 [5.8691945 8.287025  4.6317096]
 [1.3863815 1.9575051 1.0940711]]
[[6.5187154 1.921987  8.136661 ]
 [3.4799373 1.0260295 4.3436575]]


# Compute left of equation

lvar: $ \frac{1}{r^L} \cdot \frac{\partial \widehat{L}}{\partial \widehat{y}^{(L)}} $, and backward

rvar: $ \frac{1}{r^L} \cdot (r^L - 1) \cdot y^{L} $, and backward

In [8]:
frac_r_L = 1 / r_L
lvar = frac_r_L * y_c_grad
print(f'lvar: {lvar}')
optim_c.zero_grad()
y_c.backward(lvar, retain_graph=True)
left_result_list = [p.grad.clone() for p in net_c.parameters()]
print(*left_result_list, sep='\n')

lvar: tensor([[-3.0684, -1.5949]])
tensor([[0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.1277, 0.1916]])
tensor([[0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0947]])
tensor([[ 0.0000,  0.0000, -0.4595],
        [ 0.0000,  0.0000, -0.2388]])


In [9]:
rvar = frac_r_L * r_L_1y
print(f'rvar: {rvar}')
optim_c.zero_grad()
y_c.backward(rvar, retain_graph=True)
right_result_list = [p.grad.clone() for p in net_c.parameters()]
print(*right_result_list, sep='\n')

rvar: tensor([[0.7111, 0.0748]], grad_fn=<MulBackward0>)
tensor([[ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [-0.0251, -0.0377]])
tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0186]])
tensor([[0.0000, 0.0000, 0.1065],
        [0.0000, 0.0000, 0.0112]])


# Get plaintext grad

In [10]:
for term1, term2, den in zip(left_result_list, right_result_list, denoise):
    print((term1 - term2).div(torch.from_numpy(den).to(device)))

tensor([[0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.2329, 0.3493]])
tensor([[0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.1035]])
tensor([[ 0.0000,  0.0000, -0.0696],
        [ 0.0000,  0.0000, -0.0576]])
