In [155]:
import re
import tenseal as ts
from torch import nn
from torch.nn import functional as F

In [156]:
def create_ctx():
    """Helper for creating the CKKS context.
    CKKS params:
        - Polynomial degree: 8192.
        - Coefficient modulus size: [40, 21, 21, 21, 21, 21, 21, 40].
        - Scale: 2 ** 21.
        - The setup requires the Galois keys for evaluating the convolutions.
    """
    poly_mod_degree = 8192
    coeff_mod_bit_sizes = [40, 21, 21, 21, 21, 21, 21, 40]
    ctx = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
    ctx.global_scale = 2 ** 21
    ctx.generate_galois_keys()
    ctx.generate_relin_keys()
    return ctx

In [157]:
class BaseModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(BaseModel, self).__init__()
        self.fc = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = self.fc(x)
        x = F.sigmoid(x)
        x = self.fc2(x)
        x = F.sigmoid(x)
        return x

In [158]:
def sigmoid(enc_x):
    return enc_x.polyval([0.5, 0.197, 0, -0.004])

def encript_param(state_dict:nn.Module.state_dict, ctx:ts.context) -> dict:
    encripted_state_dict = dict()
    for name, param in state_dict.items():
        encripted_state_dict[name] = ts.ckks_tensor(ctx, param)
    return encripted_state_dict

In [159]:
ctx = create_ctx()

In [160]:
model = BaseModel(2, 2, 1)

In [161]:
model.state_dict()

OrderedDict([('fc.weight',
              tensor([[-0.0408,  0.5956],
                      [-0.0229, -0.2333]])),
             ('fc.bias', tensor([-0.3405, -0.6082])),
             ('fc2.weight', tensor([[-0.0193, -0.5482]])),
             ('fc2.bias', tensor([0.2462]))])

In [162]:
encript_param(model.state_dict(), ctx)

{'fc.weight': <tenseal.tensors.ckkstensor.CKKSTensor at 0x1be2f15fd60>,
 'fc.bias': <tenseal.tensors.ckkstensor.CKKSTensor at 0x1be27f02170>,
 'fc2.weight': <tenseal.tensors.ckkstensor.CKKSTensor at 0x1be27f019c0>,
 'fc2.bias': <tenseal.tensors.ckkstensor.CKKSTensor at 0x1be27f006a0>}

In [163]:
import inspect


# BaseModelクラスのforwardメソッドを取得
forward_method = getattr(model, 'forward')

# forwardメソッドのソースコードを取得
forward_source = inspect.getsource(forward_method)

# forwardメソッドのソースコードを表示
print(forward_source)

    def forward(self, x):
        x = self.fc(x)
        x = F.sigmoid(x)
        x = self.fc2(x)
        x = F.sigmoid(x)
        return x



In [166]:
def remove_oneindent(text):
    return re.sub(r'^    ', '', text, flags=re.MULTILINE)

def remove_type_hint(text):
    return re.sub(r':.*,', '', text)

def remove_return(text):
    return re.sub(r'^    return ', '', text, flags=re.MULTILINE)

def get_midparam(text):
    pattern = r'^.*='
    match = re.findall(pattern, text)[0]
    match = match.replace(" ", "")
    match = match.replace("=", "")
    match = match.replace("+", "")
    return match

def get_content_in_brackets(text):
    pattern = r'\((.*?)\)'
    matches = re.findall(pattern, text)
    return matches

def get_input_params(forward_sourcelist: list[str]) -> list[str]:
    in_brackets_first = get_content_in_brackets(forward_sourcelist[0])[0]
    input_params = in_brackets_first.replace(" ","").split(",")
    input_params = [remove_type_hint(param) for param in input_params]
    return input_params[1:] # exclude self

def get_output_params(forward_sourcelist: list[str]) -> str:
    return remove_return(forward_sourcelist[-1])

def get_midput_params(forward_sourcelist: list[str]) -> list[str]:
    process_sourcelist = forward_sourcelist[1:-1]
    process_sourcelist = [remove_oneindent(process_source) for process_source in process_sourcelist]
    midput_params = [get_midparam(process_source) for process_source in process_sourcelist]
    midput_params = list(set(midput_params))
    return midput_params

In [167]:
forward_source = remove_oneindent(forward_source)
print(forward_source)

def forward(self, x):
    x = self.fc(x)
    x = F.sigmoid(x)
    x = self.fc2(x)
    x = F.sigmoid(x)
    return x



In [168]:
forward_sourcelist = list(filter(None, forward_source.split('\n')))
forward_sourcelist

['def forward(self, x):',
 '    x = self.fc(x)',
 '    x = F.sigmoid(x)',
 '    x = self.fc2(x)',
 '    x = F.sigmoid(x)',
 '    return x']

In [169]:
input_param = get_input_params(forward_sourcelist)
print(input_param)

['x']


In [170]:
midput_param = get_midput_params(forward_sourcelist)
print(midput_param)

['x']


In [175]:
attrs = set(dir(model)) - set(dir(nn.Module()))
instance_params_dict = {f"self.{attr}": getattr(model, attr) for attr in attrs if (callable(getattr(model, attr)))&(type(getattr(model, attr))!="method")}
instance_params_dict

{'self.fc': Linear(in_features=2, out_features=2, bias=True),
 'self.fc2': Linear(in_features=2, out_features=1, bias=True)}

In [171]:
output_param = get_output_params(forward_sourcelist)
print(output_param)

x


In [172]:
# show process
process = forward_sourcelist[1:-1]
process = [remove_oneindent(remove_type_hint(prc)) for prc in process]
for prc in process:
    print(prc)

x = self.fc(x)
x = F.sigmoid(x)
x = self.fc2(x)
x = F.sigmoid(x)
