-
Notifications
You must be signed in to change notification settings - Fork 128
/
Copy pathserialization.py
107 lines (81 loc) · 2.99 KB
/
serialization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from functools import wraps
from copy import deepcopy
import inspect
import torch.nn as nn
def serialize(init):
parameters = list(inspect.signature(init).parameters)
@wraps(init)
def new_init(self, *args, **kwargs):
params = deepcopy(kwargs)
for pname, value in zip(parameters[1:], args):
params[pname] = value
config = {
'class': get_classname(self.__class__),
'params': dict()
}
specified_params = set(params.keys())
for pname, param in get_default_params(self.__class__).items():
if pname not in params:
params[pname] = param.default
for name, value in list(params.items()):
param_type = 'builtin'
if inspect.isclass(value):
param_type = 'class'
value = get_classname(value)
config['params'][name] = {
'type': param_type,
'value': value,
'specified': name in specified_params
}
setattr(self, '_config', config)
init(self, *args, **kwargs)
return new_init
def load_model(config, **kwargs):
model_class = get_class_from_str(config['class'])
model_default_params = get_default_params(model_class)
model_args = dict()
for pname, param in config['params'].items():
value = param['value']
if param['type'] == 'class':
value = get_class_from_str(value)
if pname not in model_default_params and not param['specified']:
continue
assert pname in model_default_params
if not param['specified'] and model_default_params[pname].default == value:
continue
model_args[pname] = value
model_args.update(kwargs)
return model_class(**model_args)
def get_config_repr(config):
config_str = f'Model: {config["class"]}\n'
for pname, param in config['params'].items():
value = param["value"]
if param['type'] == 'class':
value = value.split('.')[-1]
param_str = f'{pname:<22} = {str(value):<12}'
if not param['specified']:
param_str += ' (default)'
config_str += param_str + '\n'
return config_str
def get_default_params(some_class):
params = dict()
for mclass in some_class.mro():
if mclass is nn.Module or mclass is object:
continue
mclass_params = inspect.signature(mclass.__init__).parameters
for pname, param in mclass_params.items():
if param.default != param.empty and pname not in params:
params[pname] = param
return params
def get_classname(cls):
module = cls.__module__
name = cls.__qualname__
if module is not None and module != "__builtin__":
name = module + "." + name
return name
def get_class_from_str(class_str):
components = class_str.split('.')
mod = __import__('.'.join(components[:-1]))
for comp in components[1:]:
mod = getattr(mod, comp)
return mod