forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 4
/
moe_checkpoint_utils.py
183 lines (157 loc) · 7.53 KB
/
moe_checkpoint_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import re
import torch
import numpy as np
from collections import defaultdict, OrderedDict
from glob import glob
from fairseq import distributed_utils
from fairseq.file_io import torch_load_cpu
from typing import List, Dict
from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor
OPT_KEY = 'last_optimizer_state'
logger = logging.getLogger(__name__)
def merge_expert_and_shared_state(expert_state, shared_state):
state = {}
for key in ['cfg', 'args', 'extra_state', 'optimizer_history']:
state[key] = expert_state[key]
state['model'] = {**expert_state['model'], **shared_state['model']}
if OPT_KEY in expert_state:
state[OPT_KEY] = {}
for key in ['loss_scale', 'param_groups']:
if key in expert_state[OPT_KEY]:
state[OPT_KEY][key] = expert_state[OPT_KEY][key]
if 'param_id_map' in shared_state[OPT_KEY]: # FSDP
unflat_expert_state = _unflat_expert_tensor_state(expert_state[OPT_KEY], shared_state[OPT_KEY])
state[OPT_KEY]['state'] = {
**shared_state[OPT_KEY]['state'],
**unflat_expert_state
}
state[OPT_KEY].update({k: v for k, v in shared_state[OPT_KEY].items()
if k not in state[OPT_KEY]})
else:
state[OPT_KEY]['state'] = {
**expert_state[OPT_KEY]['state'],
**shared_state[OPT_KEY]['state'],
}
return state
def split_shared_and_expert_states(model, optimizer):
model_state_dict = model.state_dict()
shared_model_state_dict = OrderedDict()
expert_model_state_dict = OrderedDict()
for name, value in model_state_dict.items():
# TODO: this is a bit hacky - find a better way determine expert params
if 'expert' in name and 'expert_centroids' not in name:
expert_model_state_dict[name] = value
else:
shared_model_state_dict[name] = value
shared_optimizer_state_dict = {}
expert_optimizer_state_dict = {}
optimizer_state_dict = optimizer.state_dict()
for key in ['param_groups', 'loss_scale']:
if key in optimizer_state_dict:
expert_optimizer_state_dict[key] = optimizer_state_dict[key]
shared_optimizer_state_dict[key] = optimizer_state_dict[key]
param_mappings = {}
param_id_to_is_expert = {}
start_index = 0
for group in optimizer.param_groups:
# nonlocal start_index
packed = {k: v for k, v in group.items() if k != 'params'}
for i, p in enumerate(group['params'], start_index):
if id(p) not in param_mappings:
param_mappings.update({id(p): i})
param_id_to_is_expert[i] = hasattr(p, 'expert') or hasattr(p, 'base_expert')
packed['params'] = [param_mappings[id(p)] for p in group['params']]
start_index += len(packed['params'])
# return packed
# param_groups = [pack_group(g) ]
expert_optimizer_state_dict['state'] = {
k: v for k, v in optimizer_state_dict['state'].items()
if param_id_to_is_expert[k]
}
shared_optimizer_state_dict['state'] = {
k: v for k, v in optimizer_state_dict['state'].items()
if not param_id_to_is_expert[k]
}
return (
(shared_model_state_dict, shared_optimizer_state_dict),
(expert_model_state_dict, expert_optimizer_state_dict),
)
def merge_multi_local_expert_states(expert_states: List[Dict]) -> Dict:
merged_expert_state = {}
for key in ['cfg', 'args', 'extra_state', 'optimizer_history']:
merged_expert_state[key] = expert_states[0][key]
if OPT_KEY in expert_states[0]:
logger.warning(
"Not stitching last optimizer state while merging experts. "
"This is okay for inference but not for continued training. "
)
model_state_dict = {}
for expert_group_id, expert_state in enumerate(expert_states):
num_local_experts_in_chkpt = 1
for key in expert_state['model']:
match = re.search(r"experts.([1-9][0-9]*)", key)
if match and int(match.groups()[0]) + 1 > num_local_experts_in_chkpt:
num_local_experts_in_chkpt = int(match.groups()[0]) + 1
logger.info(f"found {num_local_experts_in_chkpt} local experts in expert_group_id={expert_group_id}")
for key, val in expert_state['model'].items():
match = re.search(r"experts.([0-9][0-9]*)", key)
assert match is not None, "\"experts.([0-9][0-9]*)\" pattern expected in key {key}"
local_chkpt_expert_id = int(match.groups()[0])
target_expert_id = expert_group_id * num_local_experts_in_chkpt + local_chkpt_expert_id
key = key.replace(f"experts.{local_chkpt_expert_id}", 'experts.{}'.format(target_expert_id))
model_state_dict[key] = val
merged_expert_state['model'] = model_state_dict
return merged_expert_state
def load_expert_state(local_path):
checkpoint_files_count = len(glob(re.sub('rank-[0-9]+', 'rank-*', local_path)))
world_size = distributed_utils.get_data_parallel_world_size()
rank = distributed_utils.get_data_parallel_rank()
if world_size < checkpoint_files_count:
assert checkpoint_files_count % world_size == 0
logger.info(
f"Found total {checkpoint_files_count} expert files and"
f" current distributed world size: {world_size},"
" Stitching experts to able to load on current world size."
)
local_expert_count = int(checkpoint_files_count / world_size)
start_rank = local_expert_count * rank
expert_states = []
for expert_rank in range(start_rank, start_rank + local_expert_count):
fname = re.sub(
'rank-[0-9]+',
'rank-{0}'.format(expert_rank),
local_path,
)
expert_states.append(torch_load_cpu(fname))
expert_state = merge_multi_local_expert_states(expert_states)
else:
expert_state = torch_load_cpu(local_path)
return expert_state
def assert_equal(a, b, msg=''):
assert a == b, f"{msg}{a} != {b}"
def _unflat_expert_tensor_state(expert, shared) -> Dict:
"""called from merge_expert_and_shared_state, for FSDP only."""
local_to_globals = defaultdict(list)
for global_id, local_id in shared['param_id_map'].items():
if local_id in shared['uncollected_local_ids']:
local_to_globals[local_id].append(global_id)
flat_expert_state = expert['state']
unflat_state = {}
for local_id, global_ids in local_to_globals.items():
global_ids = sorted(global_ids)
unflat_state.update({g: {} for g in global_ids})
already_unflat = {k: v for k, v in flat_expert_state[local_id].items() if not torch.is_tensor(v) or is_singleton_tensor(v)}
for buffer_name, flat_param in flat_expert_state[local_id].items():
if torch.is_tensor(flat_param) and not is_singleton_tensor(flat_param):
unflat_shapes = [shared['state'][g][buffer_name].shape for g in global_ids]
numels = [np.prod(s) for s in unflat_shapes]
unflat = zip(global_ids, (t.view(s) for (t, s) in zip(flat_param.split(numels), unflat_shapes)))
for gid, t in unflat:
unflat_state[gid][buffer_name] = t
unflat_state[gid].update(already_unflat)
return unflat_state