/
module.py
114 lines (85 loc) · 3.74 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import copy
from collections import deque
from collections.abc import Generator
import ujson
# NOTE: Note: It's important (temporary decision) to maintain named_parameters that's different in behavior from
# named_sub_modules for the time being.
class BaseModule:
def __init__(self):
pass
def named_parameters(self):
"""
Unlike PyTorch, handles (non-recursive) lists of parameters too.
"""
import dspy
from dspy.predict.parameter import Parameter
visited = set()
named_parameters = []
def add_parameter(param_name, param_value):
if isinstance(param_value, Parameter) and id(param_value) not in visited:
visited.add(id(param_value))
named_parameters.append((param_name, param_value))
for name, value in self.__dict__.items():
if isinstance(value, Parameter):
add_parameter(name, value)
elif isinstance(value, dspy.Module):
# When a sub-module is pre-compiled, keep it frozen.
if not getattr(value, "_compiled", False):
for sub_name, param in value.named_parameters():
add_parameter(f"{name}.{sub_name}", param)
elif isinstance(value, (list, tuple)):
for idx, item in enumerate(value):
add_parameter(f"{name}[{idx}]", item)
elif isinstance(value, dict):
for key, item in value.items():
add_parameter(f"{name}['{key}']", item)
return named_parameters
def named_sub_modules(self, type_=None, skip_compiled=False) -> Generator[tuple[str, "BaseModule"], None, None]:
"""Find all sub-modules in the module, as well as their names.
Say self.children[4]['key'].sub_module is a sub-module. Then the name will be
'children[4][key].sub_module'. But if the sub-module is accessible at different
paths, only one of the paths will be returned.
"""
if type_ is None:
type_ = BaseModule
queue = deque([("self", self)])
seen = {id(self)}
def add_to_queue(name, item):
if id(item) not in seen:
seen.add(id(item))
queue.append((name, item))
while queue:
name, item = queue.popleft()
if isinstance(item, type_):
yield name, item
if isinstance(item, BaseModule):
if skip_compiled and getattr(item, "_compiled", False):
continue
for sub_name, sub_item in item.__dict__.items():
add_to_queue(f"{name}.{sub_name}", sub_item)
elif isinstance(item, (list, tuple)):
for i, sub_item in enumerate(item):
add_to_queue(f"{name}[{i}]", sub_item)
elif isinstance(item, dict):
for key, sub_item in item.items():
add_to_queue(f"{name}[{key}]", sub_item)
def parameters(self):
return [param for _, param in self.named_parameters()]
def deepcopy(self):
return copy.deepcopy(self)
def reset_copy(self):
obj = copy.deepcopy(self)
for param in obj.parameters():
param.reset()
return obj
def dump_state(self):
return {name: param.dump_state() for name, param in self.named_parameters()}
def load_state(self, state):
for name, param in self.named_parameters():
param.load_state(state[name])
def save(self, path):
with open(path, "w") as f:
f.write(ujson.dumps(self.dump_state(), indent=2))
def load(self, path):
with open(path) as f:
self.load_state(ujson.loads(f.read()))