-
Notifications
You must be signed in to change notification settings - Fork 46
/
__init__.py
165 lines (147 loc) · 7.63 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from .graph import Graph
from .dependency_graph import build_pruning_dependency_graph
from .subnet_construction import automated_pruning_compression
import os
from .flops.flops import compute_flops
class OTO:
def __init__(self, model=None, dummy_input=None, compress_mode='prune', skip_patterns=None, strict_out_nodes=False):
self._graph = None
self._model = model
self._dummy_input = dummy_input
self._skip_patterns = skip_patterns
self._strict_out_nodes = strict_out_nodes
self._mode = compress_mode
if self._model is not None and self._dummy_input is not None:
self.initialize(model=self._model, dummy_input=self._dummy_input, skip_patterns=self._skip_patterns, strict_out_nodes=self._strict_out_nodes)
if self._mode == 'prune':
self.partition_pzigs()
self.set_trainable()
self._graph.cluster_node_groups()
elif self._mode == 'erase':
# Will be released
raise NotImplementedError
self.compressed_model_path = None
self.full_group_sparse_model_path = None
def cluster_node_groups(self, num_clusters=1):
self._graph.cluster_node_groups(num_clusters=num_clusters)
def initialize(self, model=None, dummy_input=None, skip_patterns=None, strict_out_nodes=False):
model = model.eval()
self._model = model
self._dummy_input = dummy_input
self._graph = Graph(model, dummy_input, skip_patterns=skip_patterns, strict_out_nodes=strict_out_nodes)
def partition_pzigs(self):
build_pruning_dependency_graph(self._graph)
def visualize(self, out_dir=None, view=False, vertical=True, by_node_groups=True):
self._graph.build_dot(vertical=vertical, by_node_groups=by_node_groups).render(\
os.path.join(out_dir if out_dir is not None else './', \
self._model.name if hasattr(self._model, 'name') else type(self._model).__name__ + '_zig.gv'), \
view=view)
def hesso(self, lr=0.1, weight_decay=None, first_momentum=None, second_momentum=None, \
variant='sgd', target_group_sparsity=0.5, start_pruning_step=0, \
pruning_steps=1, pruning_periods=1, \
dampening=None, group_divisible=1, fixed_zero_groups=True, importance_score_criteria='default'):
from .optimizer import HESSO
self._optimizer = HESSO(
params=self._graph.get_param_groups(),
lr=lr,
weight_decay=weight_decay,
first_momentum=first_momentum,
second_momentum=second_momentum,
dampening=dampening,
variant=variant,
target_group_sparsity=target_group_sparsity,
start_pruning_step=start_pruning_step,
pruning_periods=pruning_periods,
pruning_steps=pruning_steps,
group_divisible=group_divisible,
importance_score_criteria=importance_score_criteria
)
return self._optimizer
def dhspg(self, lr=0.1, weight_decay=None, first_momentum=None, second_momentum=None, \
variant='sgd', target_group_sparsity=0.5, tolerance_group_sparsity=0.01, start_pruning_step=0, \
pruning_steps=1, pruning_periods=1, device='cuda', \
dampening=None, group_divisible=1, fixed_zero_groups=True, importance_score_criteria='default'):
from .optimizer import DHSPG
self._optimizer = DHSPG(
params=self._graph.get_param_groups(),
lr=lr,
weight_decay=weight_decay,
first_momentum=first_momentum,
second_momentum=second_momentum,
dampening=dampening,
variant=variant,
target_group_sparsity=target_group_sparsity,
tolerance_group_sparsity=tolerance_group_sparsity,
start_pruning_step=start_pruning_step,
pruning_periods=pruning_periods,
pruning_steps=pruning_steps,
group_divisible=group_divisible,
fixed_zero_groups=fixed_zero_groups,
importance_score_criteria=importance_score_criteria,
device=device
)
return self._optimizer
def lhspg(self, lr=0.1, epsilon=0.0, weight_decay=None, first_momentum=None, second_momentum=None, \
variant='sgd', target_group_sparsity=0.5, tolerance_group_sparsity=0.01, start_pruning_step=0, \
pruning_steps=1, pruning_periods=1, device='cuda', \
dampening=None, group_divisible=1, fixed_zero_groups=True, lora_update_freq=4, importance_score_criteria=None):
from .optimizer import LHSPG
self._optimizer = LHSPG(
params=self._graph.get_param_groups(),
lr=lr,
weight_decay=weight_decay,
first_momentum=first_momentum,
second_momentum=second_momentum,
dampening=dampening,
variant=variant,
target_group_sparsity=target_group_sparsity,
tolerance_group_sparsity=tolerance_group_sparsity,
start_pruning_step=start_pruning_step,
pruning_periods=pruning_periods,
pruning_steps=pruning_steps,
group_divisible=group_divisible,
fixed_zero_groups=fixed_zero_groups,
importance_score_criteria=importance_score_criteria,
device=device,
lora_update_freq=lora_update_freq
)
return self._optimizer
def h2spg(self, **kwargs):
# Will be released
raise NotImplementedError
def skip_operators(self, operator_list=list()):
self._graph.skip_operators(operator_list)
def set_trainable(self):
self._graph.set_trainable()
def construct_subnet(self, merge_lora_to_base=False, unmerge_lora_to_base=False, export_huggingface_format=False, export_float16=False, out_dir='./', \
full_group_sparse_model_dir=None, compressed_model_dir=None, save_full_group_sparse_model=True, ckpt_format='torch'):
full_group_sparse_model_dir = out_dir if full_group_sparse_model_dir is None else full_group_sparse_model_dir
compressed_model_dir = out_dir if compressed_model_dir is None else compressed_model_dir
if self._mode == 'prune':
self.compressed_model_path, self.full_group_sparse_model_path = automated_pruning_compression(
oto_graph=self._graph,
model=self._model,
merge_lora_to_base=merge_lora_to_base,
unmerge_lora_to_base=unmerge_lora_to_base,
export_huggingface_format=export_huggingface_format,
export_float16=export_float16,
full_group_sparse_model_dir=full_group_sparse_model_dir,
compressed_model_dir=compressed_model_dir,
save_full_group_sparse_model=save_full_group_sparse_model,
ckpt_format=ckpt_format)
elif self._mode == 'erase':
# Will be released
raise NotImplementedError
def random_set_zero_groups(self, target_group_sparsity=None):
self._graph.random_set_zero_groups(target_group_sparsity=target_group_sparsity)
def mark_unprunable_by_node_ids(self, node_ids=list()):
for node_group in self._graph.node_groups.values():
for node_id in node_ids:
if node_id in node_group.nodes:
node_group.is_prunable = False
def compute_flops(self, compressed=False, verbose=False):
# Will be released
raise NotImplementedError
def compute_num_params(self, compressed=False):
# Will be released
raise NotImplementedError