-
Notifications
You must be signed in to change notification settings - Fork 1
/
i3d.py
138 lines (121 loc) · 5.34 KB
/
i3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torchvision
from torch.nn.init import normal, constant
from transforms import *
import models.i3dnon
import models.s3d
class I3DModel(torch.nn.Module):
def __init__(self, num_class, sample_frames, modality,
base_model='resnet101',
dropout=0.8):
super(I3DModel, self).__init__()
self.modality = modality
self.sample_frames = sample_frames
self.reshape = True
self.dropout = dropout
self.num_class = num_class
self.base_model_name = base_model
print(("""
Initializing I3D with base model: {}.
I3D Configurations:
input_modality: {}
sample_frames: {}
dropout_ratio: {}
""".format(base_model, self.modality, self.sample_frames, self.dropout)))
self._prepare_base_model(base_model)
if 'resnet101' in base_model or 'resnet152' in base_model:
self._prepare_i3d(num_class)
def _prepare_i3d(self, num_class):
feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features
if self.dropout == 0:
setattr(self.base_model, self.base_model.last_layer_name, torch.nn.Linear(feature_dim, num_class))
self.new_fc = None
else:
setattr(self.base_model, self.base_model.last_layer_name, torch.nn.Dropout(p=self.dropout))
self.new_fc = torch.nn.Linear(feature_dim, num_class)
std = 0.001
if self.new_fc is None:
normal(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std)
constant(getattr(self.base_model, self.base_model.last_layer_name).bias, 0)
else:
normal(self.new_fc.weight, 0, std)
constant(self.new_fc.bias, 0)
def _prepare_base_model(self, base_model):
if 'resnet101' in base_model or 'resnet152' in base_model:
self.base_model = getattr(models, base_model)()
#self.base_model = getattr(models, base_model)(pretrained=True)
self.base_model.last_layer_name = 'fc'
self.input_size = 224
self.input_mean = [0.485, 0.456, 0.406]
self.input_std = [0.229, 0.224, 0.225]
elif 'S3DG' in base_model:
self.base_model = getattr(models,base_model)(num_classes=self.num_class)
# self.base_model.last_layer_name = 'softmax'
self.input_size = 224
self.input_mean = [0.485, 0.456, 0.406]
self.input_std = [0.229, 0.224, 0.225]
else:
raise ValueError('Unknown base model: {}'.format(base_model))
def get_optim_policies(self):
first_conv_weight = []
first_conv_bias = []
normal_weight = []
normal_bias = []
bn = []
conv_cnt = 0
for m in self.modules():
if isinstance(m, torch.nn.Conv3d):
ps = list(m.parameters())
conv_cnt += 1
if conv_cnt == 1:
first_conv_weight.append(ps[0])
if len(ps) == 2:
first_conv_bias.append(ps[1])
else:
normal_weight.append(ps[0])
if len(ps) == 2:
normal_bias.append(ps[1])
elif isinstance(m, torch.nn.Linear):
ps = list(m.parameters())
normal_weight.append(ps[0])
if len(ps) == 2:
normal_bias.append(ps[1])
elif isinstance(m, torch.nn.BatchNorm3d): # enable BN
bn.extend(list(m.parameters()))
elif len(m._modules) == 0:
if len(list(m.parameters())) > 0:
raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m)))
return [
{'params': first_conv_weight, 'lr_mult': 1, 'decay_mult': 1,
'name': "first_conv_weight"},
{'params': first_conv_bias, 'lr_mult': 2, 'decay_mult': 0,
'name': "first_conv_bias"},
{'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1,
'name': "normal_weight"},
{'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0,
'name': "normal_bias"},
{'params': bn, 'lr_mult': 1, 'decay_mult': 0,
'name': "BN scale/shift"},
]
def forward(self, input):
out = self.base_model(input)
if self.dropout > 0 and ('resnet101' in self.base_model_name or 'resnet152' in self.base_model_name):
out = self.new_fc(out)
return out
@property
def crop_size(self):
return self.input_size
@property
def scale_size(self):
return self.input_size * 256 // 224
def get_augmentation(self,mode='train'):
resize_range_min = self.scale_size
if mode == 'train':
resize_range_max = self.input_size * 320 // 224
return torchvision.transforms.Compose(
[GroupRandomResizeCrop([resize_range_min, resize_range_max], self.input_size),
GroupRandomHorizontalFlip(is_flow=False),
GroupColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05)])
elif mode == 'val':
return torchvision.transforms.Compose([GroupScale(resize_range_min),
GroupCenterCrop(self.input_size)])