In [1]:
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
from torchtext import data
from torchtext import datasets
from torchtext.vocab import GloVe

TEXT = data.Field(lower=True, fix_length=50, batch_first=True)
LABEL = data.Field(sequential=False,)

train, dev, test = data.TabularDataset.splits(
    path='SST-2', train='train.tsv', validation='dev.tsv',
    test='test.tsv', format='tsv', skip_header=True,
    fields=[('text', TEXT), ('label', LABEL)])
print("the size of train: {}, dev:{}, test:{}".format(len(train.examples), len(dev.examples), len(test.examples)))

TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=50), max_size=25000)
LABEL.build_vocab(train,)

print("train.fields:", train.fields, TEXT.vocab.vectors.shape)

train_iter, dev_iter, test_iter = data.BucketIterator.splits(
        (train, dev, test), batch_sizes=(32, 32, 32), sort_key=lambda x: len(x.text), sort_within_batch=True, repeat=False
    )
train_iter.repeat = False
test_iter.repeat = False

the size of train: 65328, dev:872, test:2021
train.fields: {'text': <torchtext.data.field.Field object at 0x000001EC22D89FD0>, 'label': <torchtext.data.field.Field object at 0x000001EC22D89F70>} torch.Size([14795, 50])


In [2]:
# Text-CNN Parameter
sequence_length = 50
vocab_size = TEXT.vocab.vectors.shape[0]
embedding_size = TEXT.vocab.vectors.shape[1]
num_classes = 2  # 0 or 1
filter_sizes = [2, 3, 5] # n-gram window
num_filters = 2
lr = 1e-3

In [3]:
import tenseal as ts

# Create TenSEAL context
context_client = ts.context(
    ts.SCHEME_TYPE.CKKS, 16384, coeff_mod_bit_sizes=[ 58, 40, 40, 40, 40, 40, 40, 40, 40, 58 ]
)
# set the scale
context_client.global_scale = pow(2, 40)
# generated galois keys in order to do rotation on ciphertext vectors
context_client.generate_galois_keys()

In [4]:
import time

class Sigmoid(nn.Module):
    def __init__(self):
        super(Sigmoid, self).__init__()
 
    def forward(self, x):
        x = 0.5 + 0.197*x - 0.004*torch.pow(x, 3)
        return x

class Softmax(nn.Module):
    def __init__(self):
        super(Softmax, self).__init__()
 
    def forward(self, x):
        x = 1 + x + 0.5*torch.pow(x, 2)
        return x

class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()
 
    def forward(self, x):
        x = 0.1198 + 0.5*x + 0.1473*torch.pow(x, 2) - 0.002012*torch.pow(x, 4)
        return x

def sigmoid(ckks_vec):
    return ckks_vec.polyval([0.5, 0.197, 0, -0.004])
    
def softmax(ckks_vec):
    return ckks_vec.polyval([1, 1, 0.5])

def swish(ckks_vec):
    return ckks_vec.polyval([0.1198, 0.5, 0.1473, 0, -0.002012])

class HETextCNN(nn.Module):
    def __init__(self):
        super(HETextCNN, self).__init__()

        self.num_filters_total = num_filters * len(filter_sizes)
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(1, num_filters, (kernel, embedding_size), bias=False),
                Swish(),
                nn.AvgPool2d((sequence_length - kernel + 1,1))
            ) for kernel in filter_sizes])
        
        self.fc = nn.Linear(self.num_filters_total,num_classes)
        self.sm = Softmax()
                           
    def forward(self, X):
        embedded_chars = self.embedding(X)# [batch_size, sequence_length, embedding_size]
        embedded_chars = embedded_chars.unsqueeze(1)

        # Plain forward
        plain_out = [conv(embedded_chars) for conv in self.convs]
        plain_out = torch.cat(plain_out, dim=1)
        plain_out = plain_out.view(embedded_chars.size(0), -1)
        plain_out = self.fc(plain_out)
        print(plain_out)
        plain_logit = self.sm(plain_out)
        
        # Cipher forward
        cipher_logits = []
        for single in embedded_chars:
            # Encrypt and encode
            x_enc1, windows_nb1 = ts.im2col_encoding(context_client, single[0][:26].tolist(), 2, embedding_size, 1) #25
            x_enc2, windows_nb2 = ts.im2col_encoding(context_client, single[0][25:].tolist(), 2, embedding_size, 1) #24

            x_enc3, windows_nb3 = ts.im2col_encoding(context_client, single[0][:27].tolist(), 3, embedding_size, 1) #25
            x_enc4, windows_nb4 = ts.im2col_encoding(context_client, single[0][25:].tolist(), 3, embedding_size, 1) #23

            x_enc5, windows_nb5 = ts.im2col_encoding(context_client, single[0][:27].tolist(), 5, embedding_size, 1) #23
            x_enc6, windows_nb6 = ts.im2col_encoding(context_client, single[0][23:].tolist(), 5, embedding_size, 1) #23

            cipher_out = []
            fc_weight = self.fc._parameters['weight'].clone().T

            for idx in range(len(self.convs)):
                conv_weights = self.convs[idx][0]._parameters['weight'].tolist()
                for channel in range(num_filters):
                    kernel = conv_weights[channel][0]

                    if len(kernel) == 2: # 2-gram, 25+24
                        c_conv1 = x_enc1.conv2d_im2col(kernel, windows_nb1)
                        c_conv2 = x_enc2.conv2d_im2col(kernel, windows_nb2)
                        
                        h1 = swish(c_conv1)
                        h2 = swish(c_conv2)
                        ap = h1.sum() + h2.sum() # [1, sequence_length - filter_size + 1]

                        cipher_out.append(ap)
                        print("Conv out: ", ap.decrypt())
                    elif len(kernel) == 3: # 3-gram, 25+23
                        c_conv1 = x_enc3.conv2d_im2col(kernel, windows_nb3)
                        c_conv2 = x_enc4.conv2d_im2col(kernel, windows_nb4)

                        h1 = swish(c_conv1)
                        h2 = swish(c_conv2)
                        ap = h1.sum() + h2.sum() # [1, sequence_length - filter_size + 1]

                        cipher_out.append(ap)
                        print("Conv out: ", ap.decrypt())
                    elif len(kernel) == 5: # 5-gram, 23+23
                        c_conv1 = x_enc5.conv2d_im2col(kernel, windows_nb5)
                        c_conv2 = x_enc6.conv2d_im2col(kernel, windows_nb6)

                        h1 = swish(c_conv1)
                        h2 = swish(c_conv2)
                        ap = h1.sum() + h2.sum() # [1, sequence_length - filter_size + 1]

                        cipher_out.append(ap)
                        print("Conv out: ", ap.decrypt())

                    fc_weight[idx * num_filters + channel] /= sequence_length - len(kernel) + 1

            cipher_out = ts.pack_vectors(cipher_out)

            fc_bias = self.fc._parameters['bias']
            cipher_out = cipher_out.mm_(fc_weight.tolist()) + fc_bias.tolist() # [1, num_classes]

            print("FC out: ", cipher_out.decrypt())
            cipher_logit = softmax(cipher_out)

            cipher_logits.append(cipher_logit.decrypt())
            print("Softmax out: ", cipher_logit.decrypt())
        
        acc_loss = (np.abs(np.array(cipher_logits) - np.array(plain_logit.tolist())) / np.array(plain_logit.tolist())).sum()
        
        print("Batch acc loss: ", acc_loss)
        
        return plain_logit
    
    def forward_plain(self, X):
        embedded_chars = self.embedding(X)# [batch_size, sequence_length, sequence_length]
        embedded_chars = embedded_chars.unsqueeze(1)

        out = [conv(embedded_chars) for conv in self.convs]
        out = torch.cat(out, dim=1)
        out = out.view(embedded_chars.size(0), -1)
        out = self.fc(out)
        logit = self.sm(out)
        return out

In [5]:
def binary_acc(preds, y):
    correct = torch.eq(preds, y).float()
    acc = correct.sum() / len(correct)
    return acc

train_set_size = 128

def train(model, optimizer, criterion):
    avg_acc = []
    avg_loss = []
    model.train()
    for batch_idx , batch in enumerate(train_iter):
        if batch_idx >= train_set_size:
            continue
        text, labels = batch.text , batch.label - 1
        predicted = model.forward_plain(text)

        acc = binary_acc(torch.max(predicted, dim=1)[1], labels)
        avg_acc.append(acc)
        loss = criterion(predicted, labels)
        avg_loss.append(loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    return np.array(avg_acc).mean()

def evaluate(model, criterion):
    avg_acc = []
    model.eval()
    for batch_idx , batch in enumerate(dev_iter):
        text, labels = batch.text , batch.label - 1
        predicted = model(text)

        acc = binary_acc(torch.max(predicted, dim=1)[1], labels)
        avg_acc.append(acc)

    return np.array(avg_acc).mean()

def evaluate_plain(model, criterion):
    avg_acc = []
    model.eval()
    for batch_idx , batch in enumerate(dev_iter):
        text, labels = batch.text , batch.label - 1
        predicted = model.forward_plain(text)

        acc = binary_acc(torch.max(predicted, dim=1)[1], labels)
        avg_acc.append(acc)

    return np.array(avg_acc).mean()

In [6]:
model = HETextCNN()
print(model)

pretrained_embedding = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embedding)

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

train_accs, test_accs = [], []

for epoch in range(50):

    train_acc = train(model, optimizer, criterion)
    print('epoch={},训练准确率={}'.format(epoch, train_acc))
    test_acc = evaluate_plain(model, criterion)
    print("epoch={},测试准确率={}".format(epoch, test_acc))
test_acc = evaluate(model, criterion)

HETextCNN(
  (embedding): Embedding(14795, 50)
  (convs): ModuleList(
    (0): Sequential(
      (0): Conv2d(1, 2, kernel_size=(2, 50), stride=(1, 1), bias=False)
      (1): Swish()
      (2): AvgPool2d(kernel_size=(49, 1), stride=(49, 1), padding=0)
    )
    (1): Sequential(
      (0): Conv2d(1, 2, kernel_size=(3, 50), stride=(1, 1), bias=False)
      (1): Swish()
      (2): AvgPool2d(kernel_size=(48, 1), stride=(48, 1), padding=0)
    )
    (2): Sequential(
      (0): Conv2d(1, 2, kernel_size=(5, 50), stride=(1, 1), bias=False)
      (1): Swish()
      (2): AvgPool2d(kernel_size=(46, 1), stride=(46, 1), padding=0)
    )
  )
  (fc): Linear(in_features=6, out_features=2, bias=True)
  (sm): Softmax()
)
epoch=0,训练准确率=0.54296875
epoch=0,测试准确率=0.5122767686843872
epoch=1,训练准确率=0.56103515625
epoch=1,测试准确率=0.5982142686843872
epoch=2,训练准确率=0.6494140625
epoch=2,测试准确率=0.6875
epoch=3,训练准确率=0.705322265625
epoch=3,测试准确率=0.7399553656578064
epoch=4,训练准确率=0.770263671875
epoch=4,测试准确率=0.74776786565780

KeyboardInterrupt: 