TenSEALの使い方(秘匿計算)
===

In [1]:
%pip install tenseal

Note: you may need to restart the kernel to use updated packages.


In [2]:
import tenseal as ts
import joblib
from torch import nn
from torch import functional as F

In [3]:
def tenseal_save(enclist, filename="enc-data/main"):
    metaenclist = list()
    with open(f"{filename}.tso", "wb") as f:
        for data in enclist:
            data = data.serialize()
            f.write(data)
            metaenclist.append(len(data))
    joblib.dump(metaenclist, f"{filename}-meta.joblib")

def tenseal_load(filename="enc-data/main"):
    metaenclist = joblib.load(f"{filename}-meta.joblib")
    enclist = list()
    with open(f"{filename}.tso", "rb") as f:
        for size in metaenclist:
            data = f.read(size)
            enclist.append(data)
    return enclist

In [4]:
class BaseModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(BaseModel, self).__init__()
        self.fc = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = F.relu(self.fc(x))
        x = self.fc2(x)
        return x
    

class EncNet:
    def __init__(self, torch_nn):
        self.fc_weight = torch_nn.fc.weight.T.data.tolist()
        self.fc_bias = torch_nn.fc.bias.data.tolist()
        self.fc2_weight = torch_nn.fc2.weight.T.data.tolist()
        self.fc2_bias = torch_nn.fc2.bias.data.tolist()
    
    def sigmoid(self, enc_x):
        return enc_x.polyval([0.5, 0.197, 0, -0.004])
    
    def forward(self, enc_x):
        # fc1 layer
        enc_x = enc_x.mm(self.fc_weight) + self.fc_bias
        # sigmoid
        enc_x = self.sigmoid(enc_x)
        # fc2 layer
        enc_x = enc_x.mm(self.fc2_weight) + self.fc2_bias
        return self.sigmoid(enc_x)

In [5]:
embedding_size = 512
input_size = 4 + embedding_size
hidden_size = 256
output_size = 1

basemodel = BaseModel(input_size, hidden_size, output_size)
model = EncNet(basemodel)

In [6]:
ctxbin = joblib.load("enc-data/public_context.joblib")
ctx = ts.context_from(ctxbin)
ctx.is_private()

enclist = tenseal_load(filename="enc-data/main")

In [7]:
preds = list()
for row in enclist:
    encdata = ts.ckks_vector_from(ctx, row)
    output = model.forward(encdata)
    preds.append(output)

In [None]:
tenseal_save(preds, "enc-data/preds")