In [2]:
import re
import tenseal as ts
from torch import nn
from torch.nn import functional as F

In [3]:
def create_ctx():
    """Helper for creating the CKKS context.
    CKKS params:
        - Polynomial degree: 8192.
        - Coefficient modulus size: [40, 21, 21, 21, 21, 21, 21, 40].
        - Scale: 2 ** 21.
        - The setup requires the Galois keys for evaluating the convolutions.
    """
    poly_mod_degree = 8192
    coeff_mod_bit_sizes = [40, 21, 21, 21, 21, 21, 21, 40]
    ctx = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
    ctx.global_scale = 2 ** 21
    ctx.generate_galois_keys()
    ctx.generate_relin_keys()
    return ctx

In [4]:
class BaseModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(BaseModel, self).__init__()
        self.fc = nn.Linear(input_size, hidden_size, bias=False)
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = self.fc(x)
        x = F.sigmoid(x)
        x = self.fc2(x)
        x = F.sigmoid(x)
        return x

In [5]:
def sigmoid(enc_x):
    return enc_x.polyval([0.5, 0.197, 0, -0.004])

def encript_param(state_dict:nn.Module.state_dict, ctx:ts.context) -> dict:
    encripted_state_dict = dict()
    for name, param in state_dict.items():
        encripted_state_dict[name] = ts.ckks_tensor(ctx, param)
    return encripted_state_dict

In [6]:
ctx = create_ctx()

In [7]:
model = BaseModel(2, 2, 1)

In [8]:
model.state_dict()

OrderedDict([('fc.weight',
              tensor([[-0.1827, -0.4086],
                      [-0.2119, -0.4258]])),
             ('fc2.weight', tensor([[ 0.1569, -0.2480]])),
             ('fc2.bias', tensor([0.6575]))])

In [9]:
encript_param(model.state_dict(), ctx)

{'fc.weight': <tenseal.tensors.ckkstensor.CKKSTensor at 0x28369421b70>,
 'fc2.weight': <tenseal.tensors.ckkstensor.CKKSTensor at 0x2836a779b10>,
 'fc2.bias': <tenseal.tensors.ckkstensor.CKKSTensor at 0x2836a77b460>}

In [10]:
import inspect


# BaseModelクラスのforwardメソッドを取得
forward_method = getattr(model, 'forward')

# forwardメソッドのソースコードを取得
forward_source = inspect.getsource(forward_method)

# forwardメソッドのソースコードを表示
print(forward_source)

    def forward(self, x):
        x = self.fc(x)
        x = F.sigmoid(x)
        x = self.fc2(x)
        x = F.sigmoid(x)
        return x



In [11]:
def remove_oneindent(text):
    return re.sub(r'^    ', '', text, flags=re.MULTILINE)

def remove_type_hint(text):
    return re.sub(r':.*,', '', text)

def remove_return(text):
    return re.sub(r'^    return ', '', text, flags=re.MULTILINE)

def remove_brackets(text):
    pattern = r'\((.*?)\)'
    return re.sub(pattern, '', text)

def get_midparam(text):
    pattern = r'^.*='
    match = re.findall(pattern, text)[0]
    match = match.replace(" ", "")
    match = match.replace("=", "")
    match = match.replace("+", "")
    return match

def get_midprocess(text):
    pattern = r'=.*'
    match = re.findall(pattern, text)[0]
    match = match.replace(" ", "")
    match = match.replace("=", "")
    return match

def get_content_in_brackets(text):
    pattern = r'\((.*?)\)'
    matches = re.findall(pattern, text)
    return matches

def get_input_params(forward_sourcelist: list[str]) -> list[str]:
    in_brackets_first = get_content_in_brackets(forward_sourcelist[0])[0]
    input_params = in_brackets_first.replace(" ","").split(",")
    input_params = [remove_type_hint(param) for param in input_params]
    return input_params[1:] # exclude self

def get_output_params(forward_sourcelist: list[str]) -> str:
    return remove_return(forward_sourcelist[-1])

def get_midput_params(forward_sourcelist: list[str]) -> list[str]:
    process_sourcelist = forward_sourcelist[1:-1]
    process_sourcelist = [remove_oneindent(process_source) for process_source in process_sourcelist]
    midput_params = [get_midparam(process_source) for process_source in process_sourcelist]
    midput_params = list(set(midput_params))
    return midput_params

def get_midput_processes(forward_sourcelist: list[str]) -> list[str]:
    process_sourcelist = forward_sourcelist[1:-1]
    process_sourcelist = [remove_oneindent(process_source) for process_source in process_sourcelist]
    midput_processes = [get_midprocess(process_source) for process_source in process_sourcelist]
    return midput_processes

In [12]:
forward_source = remove_oneindent(forward_source)
print(forward_source)

def forward(self, x):
    x = self.fc(x)
    x = F.sigmoid(x)
    x = self.fc2(x)
    x = F.sigmoid(x)
    return x



In [13]:
forward_sourcelist = list(filter(None, forward_source.split('\n')))
forward_sourcelist

['def forward(self, x):',
 '    x = self.fc(x)',
 '    x = F.sigmoid(x)',
 '    x = self.fc2(x)',
 '    x = F.sigmoid(x)',
 '    return x']

In [14]:
input_params = get_input_params(forward_sourcelist)
print(input_params)

['x']


In [15]:
midput_params = get_midput_params(forward_sourcelist)
print(midput_params)

['x']


In [16]:
midput_processes_base = get_midput_processes(forward_sourcelist)
midput_processes_input = [get_content_in_brackets(process_source) for process_source in midput_processes_base]
midput_processes = [remove_brackets(process_source) for process_source in midput_processes_base]
print(midput_processes)
print(midput_processes_input)

['self.fc', 'F.sigmoid', 'self.fc2', 'F.sigmoid']
[['x'], ['x'], ['x'], ['x']]


In [17]:
attrs = set(dir(model)) - set(dir(nn.Module()))
instance_params_dict = {attr: getattr(model, attr) for attr in attrs if (callable(getattr(model, attr)))&(type(getattr(model, attr))!="method")}
instance_params_dict

{'fc2': Linear(in_features=2, out_features=1, bias=True),
 'fc': Linear(in_features=2, out_features=2, bias=False)}

In [18]:
encript_param_dict = encript_param(model.state_dict(), ctx)
encript_param_dict

{'fc.weight': <tenseal.tensors.ckkstensor.CKKSTensor at 0x28312c42d40>,
 'fc2.weight': <tenseal.tensors.ckkstensor.CKKSTensor at 0x28312c41b40>,
 'fc2.bias': <tenseal.tensors.ckkstensor.CKKSTensor at 0x28312c42ce0>}

In [19]:
def get_model_encript_params(paramname:str, encript_param_dict:dict) -> tuple[ts.ckks_tensor, ts.ckks_tensor]:
    if type(instance_params_dict.get(paramname)) == nn.Linear:
        # return instance_params_dict.get(paramname).weight, instance_params_dict.get(paramname).bias
        return encript_param_dict.get(f"{paramname}.weight"), encript_param_dict.get(f"{paramname}.bias")

In [20]:
get_model_encript_params("fc", encript_param_dict)

[2, 2]

In [21]:
get_model_encript_params("fc2", encript_param_dict)

(<tenseal.tensors.ckkstensor.CKKSTensor at 0x28312c41b40>,
 <tenseal.tensors.ckkstensor.CKKSTensor at 0x28312c42ce0>)

In [22]:
class encript_foward_linear:
    def __init__(self, paramname, encript_param_dict:dict):
        self.weight = encript_param_dict.get(f"{paramname}.weight")
        self.bias = encript_param_dict.get(f"{paramname}.bias")
    def forward(self, enc_x):
        return enc_x.mm(self.weight)#.add(self.bias)
    
fc = encript_foward_linear("fc", encript_param_dict)

In [None]:
fc.forward(ts.ckks_tensor(ctx, [1, 2]))

: 

In [None]:
output_params = get_output_params(forward_sourcelist)
print(output_params)

x


In [None]:
# show process
process = forward_sourcelist[1:-1]
process = [remove_oneindent(remove_type_hint(prc)) for prc in process]
for prc in process:
    print(prc)

x = self.fc(x)
x = F.sigmoid(x)
x = self.fc2(x)
x = F.sigmoid(x)
