-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathsave_load.py
87 lines (63 loc) · 2.65 KB
/
save_load.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import pickle
from functools import wraps
from pathlib import Path
from packaging import version as packaging_version
import torch
from torch.nn import Module
from beartype import beartype
from beartype.typing import Optional
# helpers
def exists(v):
return v is not None
@beartype
def save_load(
save_method_name = 'save',
load_method_name = 'load',
config_instance_var_name = '_config',
init_and_load_classmethod_name = 'init_and_load',
version: Optional[str] = None
):
def _save_load(klass):
assert issubclass(klass, Module), 'save_load should decorate a subclass of torch.nn.Module'
_orig_init = klass.__init__
@wraps(_orig_init)
def __init__(self, *args, **kwargs):
_config = pickle.dumps((args, kwargs))
setattr(self, config_instance_var_name, _config)
_orig_init(self, *args, **kwargs)
def _save(self, path, overwrite = True):
path = Path(path)
assert overwrite or not path.exists()
pkg = dict(
model = self.state_dict(),
config = getattr(self, config_instance_var_name),
version = version,
)
torch.save(pkg, str(path))
def _load(self, path, strict = True):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path), map_location = 'cpu')
if exists(version) and exists(pkg['version']) and packaging_version.parse(version) != packaging_version.parse(pkg['version']):
print(f'loading saved model at version {pkg["version"]}, but current package version is {version}')
self.load_state_dict(pkg['model'], strict = strict)
# init and load from
# looks for a `config` key in the stored checkpoint, instantiating the model as well as loading the state dict
@classmethod
def _init_and_load_from(cls, path, strict = True):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path), map_location = 'cpu')
assert 'config' in pkg, 'model configs were not found in this saved checkpoint'
config = pickle.loads(pkg['config'])
args, kwargs = config
model = cls(*args, **kwargs)
_load(model, path, strict = strict)
return model
# set decorated init as well as save, load, and init_and_load
klass.__init__ = __init__
setattr(klass, save_method_name, _save)
setattr(klass, load_method_name, _load)
setattr(klass, init_and_load_classmethod_name, _init_and_load_from)
return klass
return _save_load