-
Notifications
You must be signed in to change notification settings - Fork 0
/
global_pruning.py
48 lines (41 loc) · 2.04 KB
/
global_pruning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn as nn
def global_prune(model, p_prune = 0.3, p_bern = 1.):
### Step 1: consider weights in cnn & l1 norm
l1_norm_list = []
for module in model.modules():
if isinstance(module, nn.Conv2d):
weight = module.weight.data
l1_norm = torch.norm(weight, p=1, dim=(2, 3))
assert l1_norm.shape[0] == weight.shape[0]
l1_norm_list.append(l1_norm.view(-1))
### Step 2: find the smallest 'p_prune' weights
l1_norm_list_tensor = torch.cat(l1_norm_list)
sorted_l1_norm_list_tensor, _ = torch.sort(l1_norm_list_tensor)
index = int(p_prune * len(sorted_l1_norm_list_tensor))
threshold = sorted_l1_norm_list_tensor[index]
### Step 3: get prune mask & prune
for module in model.modules():
if isinstance(module, nn.Conv2d):
out_,in_,_,_ = module.weight.shape
l1_norm = torch.norm(module.weight, p=1, dim=(2, 3))
assert l1_norm.shape == (out_, in_), f"{l1_norm.shape}, ({out_}, {in_})"
prune_candidate_mask = l1_norm < threshold # True: prune candidate
p_bern_matrix = torch.ones_like(prune_candidate_mask) * float(p_bern)
prunt_bernoulli_mask = torch.bernoulli(p_bern_matrix)
prune_mask = prune_candidate_mask * prunt_bernoulli_mask
new_mask = 1 - prune_mask # 1: remain, 0: prune
new_mask = new_mask.view(out_, in_, 1, 1).expand(module.weight.shape)
module = prune_from_mask(module, new_mask)
return model
# add new attribute to module, named 'weight_orig', and save the original weight
def prune_from_mask(module, mask):
module.weight_orig = module.weight.clone() # must use .clone()
module.weight = nn.Parameter(module.weight * mask)
return module
def restore_weight(model):
for module in model.modules():
if isinstance(module, nn.Conv2d) and hasattr(module, 'weight_orig'):
module.weight = nn.Parameter(module.weight_orig.clone()) # must use .clone()
del module.weight_orig
return model