In [1]:
import torch
import torch.nn as nn
import numpy as np
from numpy.random import default_rng

def set_seed(seed=0):
    seed = 42069  # set a random seed for reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)



In [2]:

def init_weight():
    # y0 dim: (1, 2)
    w1 = np.array([[0.2, 0.3],
                   [0.4, 0.2],
                   [0.3, 0.4]], dtype='f')
    # y1 dim: (1, 3)
    w2 = np.array([[0.2, 0.3, 0.4],
                   [0.4, 0.2, 0.3],
                   [0.3, 0.4, 0.2]], dtype='f')
    # y2 dim: (1, 3)
    w3 = np.array([[0.2, 0.3, 0.4],
                   [0.4, 0.2, 0.3]], dtype='f')
    # y3 dim: (1, 2)
    return w1, w2, w3


class CMlp(nn.Module):

    def __init__(self, encrypt=False):
        super(CMlp, self).__init__()
        w1, w2, w3 = init_weight()

        self.fc1 = nn.Linear(2, 3, False)
        self.fc1.weight.data = torch.from_numpy(w1)
        self.fc2 = nn.Linear(3, 3, False)
        self.fc2.weight.data = torch.from_numpy(w2)
        self.fc3 = nn.Linear(3, 2, False)
        self.fc3.weight.data = torch.from_numpy(w3)
        if encrypt:
            rng = default_rng(0)
            self.r1 = rng.standard_normal((3, 1), dtype='f')
            self.r2 = rng.standard_normal((3, 1), dtype='f')
            self.r3 = rng.standard_normal((2, 1), dtype='f')
            self.fc1.weight.data = torch.from_numpy(w1 * self.r1)
            self.fc2.weight.data = torch.from_numpy(w2 * self.r2 / self.r1.transpose())
            self.fc3.weight.data = torch.from_numpy(w3 / self.r2.transpose())
            self.fc3.weight.data.add_(self.r3)
        self.y2 = None
        self.y3 = None
        self.alpha = None
    def forward(self, x):
        y1 = self.fc1(x)
        self.y2 = self.fc2(y1)
        self.alpha = self.y2.sum()
        self.y3 = self.fc3(self.y2)
        self.y3.retain_grad()
        return self.y3

In [3]:
# setup gpu or cpu
device = 'cuda' if torch.cuda.is_available() else 'cpu'

r = (0.2, 0.4, 0.8)
x = torch.tensor([[0.2, 0.3]], device=device)
y_hat = torch.tensor([[0.5, 0.5]], device=device)

net = CMlp().to(device)
print('----------- plaintext weight ---------------')
for p in net.parameters():
    print(p.data)
y = net(x)
print('y: ', y)
criterion = nn.MSELoss()
loss = criterion(y, y_hat)
loss.backward(retain_graph=True)
print('----------- plaintext grad -----------------')
for p in net.parameters():
    print(p.grad)
w_gradlist = [p.grad for p in net.parameters()]
print('----------- ciphertext weight ---------------')
net_c = CMlp(encrypt=True).to(device)
for p in net_c.parameters():
    print(p.data)
y_c = net_c(x)
c_loss = criterion(y_c, y_hat)
c_loss.backward(retain_graph=True)
print('----------- ciphertext grad ---------------')
for p in net_c.parameters():
    print(p.grad)
c_w_gradlist = [p.grad for p in net_c.parameters()]

print('Get yc: ', y_c)
print('Get yc from y: ', y + net_c.alpha * 0.8)
print('Ly derivative')
print(y - y_hat)
print(y.grad)
print('Lhaty derivative')
print(y_c - y_hat)
print(y_c.grad)
y_c_grad = y_c.grad

----------- plaintext weight ---------------
tensor([[0.2000, 0.3000],
        [0.4000, 0.2000],
        [0.3000, 0.4000]])
tensor([[0.2000, 0.3000, 0.4000],
        [0.4000, 0.2000, 0.3000],
        [0.3000, 0.4000, 0.2000]])
tensor([[0.2000, 0.3000, 0.4000],
        [0.4000, 0.2000, 0.3000]])
y:  tensor([[0.1206, 0.1221]], grad_fn=<MmBackward>)
----------- plaintext grad -----------------
tensor([[-0.0401, -0.0602],
        [-0.0424, -0.0636],
        [-0.0401, -0.0602]])
tensor([[-0.0295, -0.0318, -0.0409],
        [-0.0246, -0.0265, -0.0341],
        [-0.0345, -0.0371, -0.0477]])
tensor([[-0.0531, -0.0508, -0.0497],
        [-0.0529, -0.0506, -0.0495]])
----------- ciphertext weight ---------------


TypeError: add_(): argument 'other' (position 1) must be Tensor, not numpy.ndarray

$ \frac{\partial \widehat{L}}{\partial \widehat{W}^{(l)}}  $

In [5]:
optim = torch.optim.Optimizer(net.parameters(), {})
optim_c = torch.optim.Optimizer(net_c.parameters(), {})

### set grad to zero

In [6]:
for p in net_c.parameters():
    print(p.grad)
optim_c.zero_grad()
for p in net_c.parameters():
    print(p.grad)

tensor([[-0.2755, -0.4132],
        [-0.2829, -0.4244],
        [-0.2754, -0.4131]], device='cuda:0')
tensor([[-0.0201, -0.0216, -0.0278],
        [-0.0185, -0.0199, -0.0256],
        [-0.0217, -0.0234, -0.0300]], device='cuda:0')
tensor([[-0.0140, -0.0134, -0.0131],
        [-0.0139, -0.0133, -0.0130]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')
tensor([[0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')


## get $\frac{\partial \alpha}{\partial \widehat{w}^{(l)}}$

In [7]:
net_c.alpha.backward(retain_graph=True)
alpha_gradlist = [p.grad.detach().clone() for p in net_c.parameters()]
for p in net_c.parameters():
    print(p.grad)

tensor([[0.3600, 0.5400],
        [0.3600, 0.5400],
        [0.3600, 0.5400]], device='cuda:0')
tensor([[0.0260, 0.0280, 0.0360],
        [0.0260, 0.0280, 0.0360],
        [0.0260, 0.0280, 0.0360]], device='cuda:0')
tensor([[0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')


### set grad to zero and get $\mathbf{r}^t \frac{\partial \widehat{y}^{L}}{\partial \widehat{w}^{(l)}}$

In [8]:
optim_c.zero_grad()
r = torch.tensor([[0.8, 0.8]], device=device)
for p in net_c.parameters():
    print(p.grad)
y_c.backward(r, retain_graph=True)

tensor([[0., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')
tensor([[0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')


In [9]:
c_yw_gradlist = [p.grad for p in net_c.parameters()]

In [10]:
print(c_w_gradlist[2])

tensor([[0.0448, 0.0429, 0.0419],
        [0.0448, 0.0429, 0.0419]], device='cuda:0')


In [11]:
print(w_gradlist[2])
print(c_yw_gradlist[2])

tensor([[-0.0531, -0.0508, -0.0497],
        [-0.0529, -0.0506, -0.0495]], device='cuda:0')
tensor([[0.0448, 0.0429, 0.0419],
        [0.0448, 0.0429, 0.0419]], device='cuda:0')


Compute $\frac{\partial \widehat{L}}{\partial \widehat{W}}$

\begin{equation}
\frac{\partial \widehat{L}}{\partial \widehat{W}} = \frac{1}{R^{(l)}} \circ \frac{\partial L}{\partial W} + r^T \cdot \alpha \frac{\partial \widehat{y}^{(L)}}{\partial \widehat{W}} + r^T \cdot (\frac{ \partial \widehat{L}}{\partial \widehat{y}^{(L)}})^{T} \frac{\partial \alpha}{\partial \widehat{W}} - r^T r \alpha \frac{\partial \alpha}{\partial \widehat{W}}
\end{equation}

## Layer3 c_w3 grad with simple computing

\begin{equation}
\frac{\partial \widehat{L}}{\partial \widehat{W}} = \frac{1}{R^{(l)}} \circ \frac{\partial L}{\partial W} + r^T \cdot \alpha \frac{\partial \widehat{y}^{(L)}}{\partial \widehat{W}}
\end{equation}

In [12]:
print(w_gradlist[2] * 0.4 + c_yw_gradlist[2] * net_c.alpha)

tensor([[-0.0140, -0.0134, -0.0131],
        [-0.0139, -0.0133, -0.0130]], device='cuda:0', grad_fn=<AddBackward0>)


In [13]:
print(y_c_grad.reshape(1, -1))
t = r.matmul(y_c_grad.reshape(1, -1).t())

tensor([[-0.2498, -0.2483]], device='cuda:0')


In [14]:
print(w_gradlist[1])
print(alpha_gradlist[1])

tensor([[-0.0295, -0.0318, -0.0409],
        [-0.0246, -0.0265, -0.0341],
        [-0.0345, -0.0371, -0.0477]], device='cuda:0')
tensor([[0.0260, 0.0280, 0.0360],
        [0.0260, 0.0280, 0.0360],
        [0.0260, 0.0280, 0.0360]], device='cuda:0')


## Layer2 c_w2 grad

In [15]:
print(w_gradlist[1] * 0.5 + c_yw_gradlist[1] * net_c.alpha + alpha_gradlist[1] * t - r.matmul(r.t()) * net_c.alpha * alpha_gradlist[1])

tensor([[-0.0201, -0.0216, -0.0278],
        [-0.0185, -0.0199, -0.0256],
        [-0.0217, -0.0234, -0.0300]], device='cuda:0', grad_fn=<SubBackward0>)


## Layer1 c_w1 grad

In [16]:
print(w_gradlist[0] * 5 + c_yw_gradlist[0] * net_c.alpha + alpha_gradlist[0] * t - r.matmul(r.t()) * net_c.alpha * alpha_gradlist[0])

tensor([[-0.2755, -0.4132],
        [-0.2829, -0.4244],
        [-0.2754, -0.4131]], device='cuda:0', grad_fn=<SubBackward0>)
