In [6]:
import torch
import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torchvision
import torchvision.transforms as transforms

from models import *
from utils import *


# Model
print('==> Building model..')
original_model = ResNet18()
dict = torch.load('./checkpoint/resnet18_93.36.pth')

# original_model = VGG('VGG16')
# dict = torch.load('./checkpoint/vgg16_92.24.pth')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
original_model = original_model.to(device)
original_model = torch.nn.DataParallel(original_model)
original_model.load_state_dict(dict)

transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

==> Building model..
Files already downloaded and verified


In [10]:
encrypted_model, authorization_key = EncryptModel(original_model, enc_layers_num=-1)
encrypted_accuracy = easy_test(testloader, encrypted_model)
print('The inference accuracy of the encrypted model is  %.2f\n' % encrypted_accuracy + '%')
print('The authorization key:')
for item in authorization_key:
    print(item)


The inference accuracy of the encrypted model is  10.00
%
The authorization key:
('module.conv1.weight', [32, 32])
('module.layer1.0.conv1.weight', [16, 16, 4, 1, 1, 8, 4, 2, 2, 2, 2, 2, 4])
('module.layer1.0.conv2.weight', [8, 16, 8, 8, 4, 1, 4, 1, 1, 8, 2, 2, 1])
('module.layer1.1.conv1.weight', [8, 32, 1, 16, 1, 1, 1, 1, 2, 1])
('module.layer1.1.conv2.weight', [1, 16, 4, 16, 4, 1, 1, 1, 8, 4, 8])
('module.layer2.0.conv1.weight', [16, 1, 32, 2, 1, 8, 2, 1, 1])
('module.layer2.0.conv2.weight', [64, 8, 4, 4, 16, 4, 2, 16, 2, 2, 4, 2])
('module.layer2.1.conv1.weight', [16, 4, 16, 32, 4, 2, 8, 1, 1, 8, 1, 8, 16, 1, 4, 2, 1, 1, 2])
('module.layer2.1.conv2.weight', [4, 1, 1, 32, 64, 2, 2, 8, 8, 2, 1, 2, 1])
('module.layer3.0.conv1.weight', [8, 4, 64, 2, 32, 1, 16, 1])
('module.layer3.0.conv2.weight', [64, 2, 32, 128, 8, 1, 2, 8, 4, 4, 1, 1, 1])
('module.layer3.1.conv1.weight', [4, 8, 16, 1, 32, 128, 32, 4, 4, 2, 2, 4, 2, 8, 8, 1])
('module.layer3.1.conv2.weight', [64, 4, 1, 64, 4, 64, 2, 3

In [22]:
import pickle

len_encrypted_model = len(pickle.dumps(encrypted_model)) 
print(f"Encrypted model size: {len_encrypted_model / (10**6):.6f} MB")
len_authorization_key = len(pickle.dumps(authorization_key)) 
print(f"Authorization key size: {len_authorization_key / (10**6):.6f} MB")

hash_values = []
conv_names = prepare.search_conv(encrypted_model)
enc_dic = encrypted_model.state_dict()
for weight_name in conv_names:
    hash_values.append(hashn(enc_dic[weight_name]))

len_hash_values = len(pickle.dumps(hash_values)) 
print(f"Hash values size: {len_hash_values / (10**6):.6f} MB")

Encrypted model size: 44.787467 MB
Authorization key size: 0.001218 MB
Hash values size: 0.000671 MB
