-
Notifications
You must be signed in to change notification settings - Fork 2
/
convert_vgg.py
27 lines (19 loc) · 848 Bytes
/
convert_vgg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
'''Convert pretrained VGG model to SSD.
VGG model download from PyTorch model zoo: https://download.pytorch.org/models/vgg16-397923af.pth
'''
import torch
from model import SSD300
vgg = torch.load("/home/pzl/pytorch-hed/model/vgg16.pth")
ssd = SSD300()
layer_indices = [0,2,5,7,10,12,14,17,19,21]
for layer_idx in layer_indices:
ssd.base[layer_idx].weight.data = vgg['features.%d.weight' % layer_idx]
ssd.base[layer_idx].bias.data = vgg['features.%d.bias' % layer_idx]
# [24,26,28]
ssd.conv5_1.weight.data = vgg['features.24.weight']
ssd.conv5_1.bias.data = vgg['features.24.bias']
ssd.conv5_2.weight.data = vgg['features.26.weight']
ssd.conv5_2.bias.data = vgg['features.26.bias']
ssd.conv5_3.weight.data = vgg['features.28.weight']
ssd.conv5_3.bias.data = vgg['features.28.bias']
torch.save(ssd.state_dict(), 'pretained/ssd.pth')