In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tenseal as ts
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np
from time import time


## FashionMNIST dataset

In [2]:
torch.manual_seed(22)

train_data = datasets.FashionMNIST('data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.FashionMNIST('data', train=False, download=True, transform=transforms.ToTensor())

batch_size = 64

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:02<00:00, 11255510.58it/s]


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 191490.66it/s]


Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 3523271.00it/s]


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 5707712.66it/s]

Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw






## CryptoNets for training

In [3]:
class CryptoNet(nn.Module):
    def __init__(self):
        super(CryptoNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 2, 1)
        self.pool1 = nn.AvgPool2d(3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(20, 50, 5, (2, 2), 0)
        self.fc1 = nn.Linear(1250, 100)
        self.fc2 = nn.Linear(100, 10)
        nn.init.kaiming_normal_(self.conv1.weight)
        nn.init.kaiming_normal_(self.conv2.weight)

    def forward(self, x):
        x = self.conv1(x) # (64,20,13,13) 3380
        x = x * x  # 使用平方激活函数
        x = self.pool1(x) # (64,20,13,13) 3380
        x = self.conv2(x) # (64,50,5,5) 1250
        x = self.pool1(x) # (64,50,5,5) 1250
        x = torch.flatten(x, 1) # 1250
        x = self.fc1(x)
        x = x * x  # 使用平方激活函数
        x = self.fc2(x)
        return x


In [4]:
# cpu
def train(model, train_loader, criterion, optimizer, n_epochs=30):
    # 将模型设为训练模式
    model.train()
    for epoch in range(1, n_epochs+1):
        train_loss = 0.0
        correct_new = 0
        total_new = 0
        
        t1 = time()
        for data, target in train_loader:
            # 将梯度置零
            optimizer.zero_grad()
            # 前向传播
            output = model(data)
            # 计算损失
            loss = criterion(output, target)
            # 反向传播
            loss.backward()
            # 更新权重
            optimizer.step()
            # 计算损失
            train_loss += loss.item()
            # 比较预测值和真实值
            _, predicted = torch.max(output.data, 1)
            total_new += target.size(0)
            correct_new += (predicted == target).sum().item()
        t2 = time()

        # 计算平均损失
        train_loss = train_loss / len(train_loader)
        acc_new = 100 * correct_new / total_new

        print('Epoch: {} \tTraining Loss: {:.6f} \tCorrect: {} \tAccuracy: {:.3f} \ttime: {:.3f}'.format(epoch, train_loss, correct_new, acc_new, t2 - t1))

    print('Finished Training')

    return model

# 创建模型实例
model = CryptoNet()
# 定义损失函数
criterion = torch.nn.CrossEntropyLoss()
# 定义优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练模型
model = train(model, train_loader, criterion, optimizer, 15)

Epoch: 1 	Training Loss: 0.715167 	Correct: 45414 	Accuracy: 75.690 	time: 10.508
Epoch: 2 	Training Loss: 0.459788 	Correct: 50210 	Accuracy: 83.683 	time: 10.573
Epoch: 3 	Training Loss: 0.399890 	Correct: 51357 	Accuracy: 85.595 	time: 10.314
Epoch: 4 	Training Loss: 0.366117 	Correct: 52084 	Accuracy: 86.807 	time: 10.475
Epoch: 5 	Training Loss: 0.348206 	Correct: 52432 	Accuracy: 87.387 	time: 10.230
Epoch: 6 	Training Loss: 0.331562 	Correct: 52703 	Accuracy: 87.838 	time: 10.422
Epoch: 7 	Training Loss: 0.316892 	Correct: 52961 	Accuracy: 88.268 	time: 10.276
Epoch: 8 	Training Loss: 0.305246 	Correct: 53232 	Accuracy: 88.720 	time: 10.359
Epoch: 9 	Training Loss: 0.296608 	Correct: 53390 	Accuracy: 88.983 	time: 10.353
Epoch: 10 	Training Loss: 0.290693 	Correct: 53532 	Accuracy: 89.220 	time: 10.445
Epoch: 11 	Training Loss: 0.277052 	Correct: 53860 	Accuracy: 89.767 	time: 10.401
Epoch: 12 	Training Loss: 0.269662 	Correct: 53956 	Accuracy: 89.927 	time: 10.567
Epoch: 13 	Tr

In [5]:
# 保存整个模型
PATH = './model/MNIST_CryptoNet.pth'
torch.save(model, PATH)

# 只保存模型权重
PATH = './model/MNIST_WeightCryptoNet.pth'
torch.save(model.state_dict(), PATH)

In [6]:
# 测试模型
def test(model, test_loader, criterion):
    # 初始化测试损失
    test_loss = 0.0
    # 初始化每个类别的正确预测数
    class_correct = list(0. for i in range(10))
    # 初始化每个类别的总数
    class_total = list(0. for i in range(10))

    # 将模型设为评估模式
    model.eval()
    
    for data, target in test_loader:
        
        # 前向传播，得到模型输出
        output = model(data)
        # 计算损失
        loss = criterion(output, target)
        # 累积测试损失
        test_loss += loss.item()
        # 将输出概率转换为预测类别
        _, pred = torch.max(output, 1)
        # 将预测与真实标签比较，得到每个样本是否预测正确的结果
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        # 计算每个对象类别的测试准确率
        for i in range(len(target)):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

    # 计算平均测试损失
    test_loss = test_loss/len(test_loader)
    print(f'Test Loss: {test_loss:.6f}\n')

    for label in range(10):
        print(
            f'Test Accuracy of {label}: {int(100 * class_correct[label] / class_total[label])}% '
            f'({int(np.sum(class_correct[label]))}/{int(np.sum(class_total[label]))})'
        )

    print(
        f'\nTest Accuracy (Overall): {int(100 * np.sum(class_correct) / np.sum(class_total))}% '
        f'({int(np.sum(class_correct))}/{int(np.sum(class_total))})'
    )

# 加载模型
PATH = './model/MNIST_CryptoNet.pth'
model = torch.load(PATH)

t1 = time()
test(model.to('cpu'), test_loader, criterion)
t2 = time()
print("test time=", t2 - t1)

Test Loss: 0.334308

Test Accuracy of 0: 84% (847/1000)
Test Accuracy of 1: 97% (973/1000)
Test Accuracy of 2: 77% (771/1000)
Test Accuracy of 3: 87% (875/1000)
Test Accuracy of 4: 89% (893/1000)
Test Accuracy of 5: 97% (971/1000)
Test Accuracy of 6: 65% (655/1000)
Test Accuracy of 7: 94% (948/1000)
Test Accuracy of 8: 96% (965/1000)
Test Accuracy of 9: 95% (956/1000)

Test Accuracy (Overall): 88% (8854/10000)
test time= 1.324462652206421


## Collapsing linear layers

In [7]:
'''
引入本层的作用主要是将卷积层和线性层合二为一，在增加网络层数的情况下，减小同态加密运算的开销，折叠层的原理是所有线性操作的组合仍然是线性的
'''

model1 = torch.load('./model/MNIST_CryptoNet.pth')
# 输入图片大小
in_channels = 20
h = 13
w = h

#卷积层参数
out_channels = 50
kernel = 5
stride = 2

#线性层参数
in_features=int (out_channels * (((h - kernel) / 2) + 1)**2)
out_features = 100

#conv2 = torch.nn.Conv2d (in_channels, out_channels, kernel, stride)
#fc1 = torch.nn.Linear (in_features, out_features)
pool1 = nn.AvgPool2d(3, stride=1, padding=1)

conv2 = model1.conv2
fc1 = model1.fc1

# 创建折叠层的偏置
bias = fc1 (torch.flatten (pool1(conv2 (pool1(torch.zeros (1, in_channels, h, w))))))

# 创建折叠层的权重
n_pixels = in_channels * h * w 
pixel_batch = torch.eye (n_pixels).reshape (n_pixels, in_channels, h, w)
weight = (fc1 (torch.flatten (pool1(conv2 (pool1(pixel_batch))), 1)) - bias).T

# 创建折叠层
fcnew = torch.nn.Linear (n_pixels, out_features)  

# 复制权重和偏置
with torch.no_grad():
  _ = fcnew.weight.copy_ (weight)
  _ = fcnew.bias.copy_ (bias)


In [8]:
# 保存折叠模型
PATH = './model/MNIST_WeightCollapsed.pth'
torch.save(model1.state_dict(), PATH)

In [9]:
# CryptoNets和 fastCryptoNets压缩过后模型结构一样！
class CollapsedCryptoNets(nn.Module):
    def __init__(self):
        super(CollapsedCryptoNets, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=2, padding=1)
        self.pool1 = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
        self.fcnew = nn.Linear(3380, 100)
        self.fc2 = nn.Linear(100, 10)
        
        # 
        with torch.no_grad():
          _ = self.fcnew.weight.copy_ (weight)
          _ = self.fcnew.bias.copy_ (bias)

    def forward(self, x):
        x = self.conv1(x)
        x = x * x  # 使用平方激活函数
        x = torch.flatten(x, 1)
        x = self.fcnew(x)
        x = x * x  # 使用平方激活函数
        x = self.fc2(x)
        return x

#创建折叠模型
model2 = CollapsedCryptoNets()

# 折叠模型读取训练权重
PATH = './model/MNIST_WeightCollapsed.pth'
model2.load_state_dict(torch.load(PATH), strict=False)
torch.save(model2, './model/MNIST_CollapsedNet.pth')



## Collapsed CryptoNets for testing

In [10]:
# cpu
# 此处直接用collapsed CryptoNets的训练网络进行测试
def test(model, test_loader, criterion):
    test_loss = 0.0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))

    # 将模型设为评估模式
    model.eval()
    
    for data, target in test_loader:
        
        # 前向传播，得到模型输出
        output = model(data)
        # 计算损失
        loss = criterion(output, target)
        # 累积测试损失
        test_loss += loss.item()
        # 将输出概率转换为预测类别
        _, pred = torch.max(output, 1)
        # 将预测与真实标签比较，得到每个样本是否预测正确的结果
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        # 计算每个对象类别的测试准确率
        for i in range(len(target)):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

    # 计算平均测试损失
    test_loss = test_loss/len(test_loader)
    print(f'Test Loss: {test_loss:.6f}\n')

    for label in range(10):
        print(
            f'Test Accuracy of {label}: {int(100 * class_correct[label] / class_total[label])}% '
            f'({int(np.sum(class_correct[label]))}/{int(np.sum(class_total[label]))})'
        )

    print(
        f'\nTest Accuracy (Overall): {int(100 * np.sum(class_correct) / np.sum(class_total))}% '
        f'({int(np.sum(class_correct))}/{int(np.sum(class_total))})'
    )

model2 = torch.load('./model/MNIST_CollapsedNet.pth')

t1 = time()
test(model2, test_loader, criterion)
t2 = time()
print("test time=", t2 - t1)

Test Loss: 0.334714

Test Accuracy of 0: 84% (847/1000)
Test Accuracy of 1: 97% (973/1000)
Test Accuracy of 2: 77% (771/1000)
Test Accuracy of 3: 87% (875/1000)
Test Accuracy of 4: 89% (893/1000)
Test Accuracy of 5: 97% (971/1000)
Test Accuracy of 6: 65% (655/1000)
Test Accuracy of 7: 94% (948/1000)
Test Accuracy of 8: 96% (965/1000)
Test Accuracy of 9: 95% (956/1000)

Test Accuracy (Overall): 88% (8854/10000)
test time= 1.2042276859283447


## Encrypted Collapsed CryptoNets

In [17]:

#创建加密模型
class EncCryptoNet:
    def __init__(self, torch_nn):
        self.conv1_weight = torch_nn.conv1.weight.data.view(
            torch_nn.conv1.out_channels, torch_nn.conv1.kernel_size[0],
            torch_nn.conv1.kernel_size[1]
        ).tolist()
        self.conv1_bias = torch_nn.conv1.bias.data.tolist()

        self.fcnew_weight = torch_nn.fcnew.weight.T.data.tolist()
        self.fcnew_bias = torch_nn.fcnew.bias.data.tolist()

        self.fc2_weight = torch_nn.fc2.weight.T.data.tolist()
        self.fc2_bias = torch_nn.fc2.bias.data.tolist()

    def forward(self, enc_x, windows_nb):
        # conv layer
        enc_channels = []  # 存储每个通道的加密结果
        for kernel, bias in zip(self.conv1_weight, self.conv1_bias):
            # 执行加密的卷积操作并添加偏差
            y = enc_x.conv2d_im2col(kernel, windows_nb) + bias
            enc_channels.append(y)
        # print("1.encx.size: " + enc_x.size().__str__())
        
        # pack all channels into a single flattened vector
        enc_x = ts.CKKSVector.pack_vectors(enc_channels)
        # print("2.encx.size: " + enc_x.size().__str__())
        
        enc_x.square_()  # 对加密向量进行平方操作（激活函数）
        # print("3.encx.size: " + enc_x.size().__str__())
        
        # fcnew layer  
        # print("windows_nb.size: " + str(windows_nb))
        # print("fcnew.size: " + len(self.fcnew_weight).__str__())
        enc_x = enc_x.mm(self.fcnew_weight) + self.fcnew_bias
        enc_x.square_()

        # fc2 layer
        enc_x = enc_x.mm(self.fc2_weight) + self.fc2_bias

        return enc_x

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

def enc_test(context, enc_model, test_loader, criterion, kernel_shape, stride):
    # 初始化用于监视测试损失和准确率的列表
    test_loss = 0.0
    class_correct = list(0. for i in range(10))  # 用于每个类别的正确预测数目
    class_total = list(0. for i in range(10))    # 用于每个类别的样本总数

    n = 0
    for data, target in test_loader:
        n += 1
        if n > 50:
            break
        t1 = time()
        
        # 编码和加密  //windows_nb = 144
        # x_enc, windows_nb = ts.im2col_encoding(
        #     context, data.view(28, 28).tolist(), kernel_shape[0],
        #     kernel_shape[1], stride
        # )

        # new coding
        dat = F.pad(data, (1, 1, 1, 1))
        x_enc, windows_nb = ts.im2col_encoding(
            context, dat.view(30, 30).tolist(), kernel_shape[0],
            kernel_shape[1], stride
        )

        # 加密评估
        enc_output = enc_model(x_enc, windows_nb)
        # 解密结果
        output = enc_output.decrypt()
        output = torch.tensor(output).view(1, -1)

        # 计算损失
        loss = criterion(output, target)
        test_loss += loss.item()

        # 将输出概率转换为预测类别
        _, pred = torch.max(output, 1)
        # 将预测与真实标签比较
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        # 计算每个对象类别的测试准确率
        label = target.data[0]
        class_correct[label] += correct.item()  
        class_total[label] += 1
        
        t2 = time()
        print("{} round time:{}s [{}/{}] loss:{}".format(n, t2 - t1, n, len(test_loader),loss.item()))

    # 计算并打印平均测试损失
    test_loss = test_loss / sum(class_total)
    print(f'Test Loss: {test_loss:.6f}\n')
    print(class_total)
    for label in range(10):
        print(
            f'Test Accuracy of {label}: {int(100 * class_correct[label] / class_total[label])}% '
            f'({int(np.sum(class_correct[label]))}/{int(np.sum(class_total[label]))})'
        )

    print(
        f'\nTest Accuracy (Overall): {int(100 * np.sum(class_correct) / np.sum(class_total))}% '
        f'({int(np.sum(class_correct))}/{int(np.sum(class_total))})'
    )

# 逐个加载元素
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True)
# 编码所需参数
kernel_shape = model2.conv1.kernel_size
stride = model2.conv1.stride[0]


In [18]:
from torch.utils.data import DataLoader, Subset

# 选择缩小后的测试集大小
subset_size = 50

# 生成一个指定大小的子集
indices = torch.arange(subset_size)
test_data_subset = Subset(test_loader, indices)

# 创建DataLoader
test_loader_subset = DataLoader(test_data_subset, batch_size=1, shuffle=True)


# 加密参数

# 控制CKKS方案精度
bits_scale = 26

# 创建上下文
context = ts.context(
    ts.SCHEME_TYPE.CKKS,
    poly_modulus_degree=16384,
    coeff_mod_bit_sizes=[40, bits_scale, bits_scale, bits_scale, bits_scale, bits_scale, bits_scale, bits_scale, bits_scale, 40]
    # 乘法深度：8次
)
# 设置全局缩放
context.global_scale = pow(2, bits_scale)

# 生成伽罗瓦密钥
context.generate_galois_keys()



# 加密评估
model2 = torch.load('./model/MNIST_CollapsedNet.pth')
model2.eval()

t3 = time()
enc_model = EncCryptoNet(model2)
enc_test(context, enc_model, test_loader, criterion, kernel_shape, stride)
t4 = time()
print("time=", t4 - t3)


1 round time:84.09301710128784s [1/10000] loss:0.0021356174256652594
2 round time:87.19123959541321s [2/10000] loss:0.02821114845573902
3 round time:93.79601526260376s [3/10000] loss:0.0
4 round time:94.88916230201721s [4/10000] loss:0.0009921634336933494
5 round time:94.4293065071106s [5/10000] loss:0.3301824927330017
6 round time:92.89857339859009s [6/10000] loss:0.0
7 round time:93.25124335289001s [7/10000] loss:4.8993817472364753e-05
8 round time:92.71236777305603s [8/10000] loss:0.020519010722637177
9 round time:91.60464882850647s [9/10000] loss:0.005595141556113958
10 round time:91.39043068885803s [10/10000] loss:0.033226463943719864
11 round time:92.60525512695312s [11/10000] loss:0.0
12 round time:93.21839165687561s [12/10000] loss:0.012353217229247093
13 round time:92.95763182640076s [13/10000] loss:0.037195853888988495
14 round time:93.9071455001831s [14/10000] loss:0.0
15 round time:93.00287556648254s [15/10000] loss:0.2759433686733246
16 round time:94.07708239555359s [16/10