-
Notifications
You must be signed in to change notification settings - Fork 5
/
vit_ladder.py
470 lines (409 loc) · 16.9 KB
/
vit_ladder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import math
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch import Tensor, nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.registry import MODELS
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import constant_init, trunc_normal_init
from mmaction.utils import ConfigType, OptConfigType
from mmaction.models.backbones.vit_mae import get_sinusoid_encoding
class Adapter(BaseModule):
def __init__(
self,
embed_dims: int,
mlp_ratio: float = 0.25,
kernel_size: int = 3,
dilation: int = 1,
temporal_size: int = 384,
with_cp: bool = False,
) -> None:
super().__init__()
self.with_cp = with_cp
hidden_dims = int(embed_dims * mlp_ratio)
# temporal depth-wise convolution
self.temporal_size = temporal_size
self.dwconv = nn.Conv1d(
hidden_dims,
hidden_dims,
kernel_size=kernel_size,
stride=1,
padding=(kernel_size // 2) * dilation,
dilation=dilation,
groups=hidden_dims,
)
self.conv = nn.Conv1d(hidden_dims, hidden_dims, 1)
self.dwconv.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / kernel_size))
self.dwconv.bias.data.zero_()
self.conv.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / hidden_dims))
self.conv.bias.data.zero_()
# adapter projection
self.down_proj = nn.Linear(embed_dims, hidden_dims)
self.act = nn.GELU()
self.up_proj = nn.Linear(hidden_dims, embed_dims)
self.gamma = nn.Parameter(torch.ones(1))
trunc_normal_init(self.down_proj, std=0.02, bias=0)
constant_init(self.up_proj, 0) # the last projection layer is initialized to 0
def forward(self, x: Tensor, h: int, w: int) -> Tensor:
def _inner_forward(x):
# down and up projection
x = self.down_proj(x)
x = self.act(x)
# temporal depth-wise convolution
B, N, C = x.shape # 48, 8*10*10, 384
attn = x.reshape(-1, self.temporal_size, h, w, x.shape[-1]) # [b,t,h,w,c] [1,384,10,10,384]
attn = attn.permute(0, 2, 3, 4, 1).flatten(0, 2) # [b*h*w,c,t] [1*10*10,384,384]
attn = self.dwconv(attn) # [b*h*w,c,t] [1*10*10,384,384]
attn = self.conv(attn) # [b*h*w,c,t] [1*10*10,384,384]
attn = attn.unflatten(0, (-1, h, w)).permute(0, 4, 1, 2, 3) # [b,t,h,w,c] [1,384,10,10,384]
attn = attn.reshape(B, N, C)
x = x + attn
x = self.up_proj(x) * self.gamma
return x
if self.with_cp:
x = cp.checkpoint(_inner_forward, x, use_reentrant=False)
else:
x = _inner_forward(x)
return x
class LadderNetwork(BaseModule):
def __init__(
self,
depth,
embed_dims,
temporal_size: int = 384,
with_cp: bool = False,
mlp_ratio: float = 0.25,
):
super().__init__()
self.depth = depth
self.adapters = ModuleList([])
for i in range(depth):
self.adapters.append(
Adapter(
embed_dims=embed_dims,
mlp_ratio=mlp_ratio,
kernel_size=3,
temporal_size=temporal_size,
with_cp=with_cp,
)
)
def forward(self, feat_list: List[Tensor], h, w) -> Tensor:
# the last feature map is the original output of the backbone
output = feat_list[-1]
for i in range(self.depth):
output = output + self.adapters[i](feat_list[i], h, w)
return output
class Attention(BaseModule):
"""Multi-head Self-attention.
Args:
embed_dims (int): Dimensions of embedding.
num_heads (int): Number of parallel attention heads.
qkv_bias (bool): If True, add a learnable bias to q and v.
Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
attn_drop_rate (float): Dropout ratio of attention weight.
Defaults to 0.
drop_rate (float): Dropout ratio of output. Defaults to 0.
init_cfg (dict or ConfigDict, optional): The Config
for initialization. Defaults to None.
"""
def __init__(
self,
embed_dims: int,
num_heads: int = 8,
qkv_bias: bool = True,
qk_scale: Optional[float] = None,
attn_drop_rate: float = 0.0,
drop_rate: float = 0.0,
init_cfg: OptConfigType = None,
**kwargs,
) -> None:
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.scale = qk_scale or head_embed_dims**-0.5
if qkv_bias:
self._init_qv_bias()
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=False)
self.attn_drop = nn.Dropout(attn_drop_rate)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(drop_rate)
def _init_qv_bias(self) -> None:
self.q_bias = nn.Parameter(torch.zeros(self.embed_dims))
self.v_bias = nn.Parameter(torch.zeros(self.embed_dims))
def forward(self, x: Tensor) -> Tensor:
"""Defines the computation performed at every call.
Args:
x (Tensor): The input data with size of (B, N, C).
Returns:
Tensor: The output of the attention block, same size as inputs.
"""
B, N, C = x.shape
if hasattr(self, "q_bias"):
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
else:
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# standard self-attention
# q = q * self.scale
# attn = q @ k.transpose(-2, -1)
# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)
# x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
# fast attention
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p)
x = x.transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(BaseModule):
"""The basic block in the Vision Transformer.
Args:
embed_dims (int): Dimensions of embedding.
num_heads (int): Number of parallel attention heads.
mlp_ratio (int): The ratio between the hidden layer and the
input layer in the FFN. Defaults to 4.
qkv_bias (bool): If True, add a learnable bias to q and v.
Defaults to True.
qk_scale (float): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
drop_rate (float): Dropout ratio of output. Defaults to 0.
attn_drop_rate (float): Dropout ratio of attention weight.
Defaults to 0.
drop_path_rate (float): Dropout ratio of the residual branch.
Defaults to 0.
act_cfg (dict or ConfigDict): Config for activation layer in FFN.
Defaults to `dict(type='GELU')`.
norm_cfg (dict or ConfigDict): Config for norm layers.
Defaults to `dict(type='LN', eps=1e-6)`.
init_cfg (dict or ConfigDict, optional): The Config
for initialization. Defaults to None.
"""
def __init__(
self,
embed_dims: int,
num_heads: int,
mlp_ratio: int = 4.0,
qkv_bias: bool = True,
qk_scale: Optional[float] = None,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
act_cfg: ConfigType = dict(type="GELU"),
norm_cfg: ConfigType = dict(type="LN", eps=1e-6),
init_cfg: OptConfigType = None,
**kwargs,
) -> None:
super().__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = Attention(
embed_dims,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
)
self.drop_path = nn.Identity()
if drop_path_rate > 0.0:
self.drop_path = DropPath(drop_path_rate)
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
mlp_hidden_dim = int(embed_dims * mlp_ratio)
self.mlp = FFN(
embed_dims=embed_dims,
feedforward_channels=mlp_hidden_dim,
act_cfg=act_cfg,
ffn_drop=drop_rate,
add_identity=False,
)
def forward(self, x: Tensor) -> Tensor:
"""Defines the computation performed at every call.
Args:
x (Tensor): The input data with size of (B, N, C).
Returns:
Tensor: The output of the transformer block, same size as inputs.
"""
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@MODELS.register_module()
class VisionTransformerLadder(BaseModule):
"""Vision Transformer with support for patch or hybrid CNN input stage. An
impl of `VideoMAE: Masked Autoencoders are Data-Efficient Learners for
Self-Supervised Video Pre-Training <https://arxiv.org/pdf/2203.12602.pdf>`_
We add the checkpointing and frozen stage to the original VisionTransformer.
Args:
img_size (int or tuple): Size of input image.
Defaults to 224.
patch_size (int): Spatial size of one patch. Defaults to 16.
in_channels (int): The number of channels of he input.
Defaults to 3.
embed_dims (int): Dimensions of embedding. Defaults to 768.
depth (int): number of blocks in the transformer.
Defaults to 12.
num_heads (int): Number of parallel attention heads in
TransformerCoder. Defaults to 12.
mlp_ratio (int): The ratio between the hidden layer and the
input layer in the FFN. Defaults to 4.
qkv_bias (bool): If True, add a learnable bias to q and v.
Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
drop_rate (float): Dropout ratio of output. Defaults to 0.
attn_drop_rate (float): Dropout ratio of attention weight.
Defaults to 0.
drop_path_rate (float): Dropout ratio of the residual branch.
Defaults to 0.
norm_cfg (dict or Configdict): Config for norm layers.
Defaults to `dict(type='LN', eps=1e-6)`.
num_frames (int): Number of frames in the video. Defaults to 16.
tubelet_size (int): Temporal size of one patch. Defaults to 2.
use_mean_pooling (bool): If True, take the mean pooling over all
positions. Defaults to True.
pretrained (str, optional): Name of pretrained model. Default: None.
return_feat_map (bool): If True, return the feature in the shape of
`[B, C, T, H, W]`. Defaults to False.
init_cfg (dict or list[dict]): Initialization config dict. Defaults to
``[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
]``.
"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
embed_dims: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: int = 4.0,
qkv_bias: bool = True,
qk_scale: int = None,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_cfg: ConfigType = dict(type="LN", eps=1e-6),
num_frames: int = 16, # frames per attention
tubelet_size: int = 2,
pretrained: Optional[str] = None,
with_cp: bool = False,
adapter_mlp_ratio: float = 0.25,
total_frames: int = 768,
adapter_index: Optional[List[int]] = list(range(6, 12)),
init_cfg: Optional[Union[Dict, List[Dict]]] = [
dict(type="TruncNormal", layer="Linear", std=0.02, bias=0.0),
dict(type="Constant", layer="LayerNorm", val=1.0, bias=0.0),
],
**kwargs,
) -> None:
if pretrained:
self.init_cfg = dict(type="Pretrained", checkpoint=pretrained)
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.patch_size = patch_size
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
conv_type="Conv3d",
kernel_size=(tubelet_size, patch_size, patch_size),
stride=(tubelet_size, patch_size, patch_size),
padding=(0, 0, 0),
dilation=(1, 1, 1),
)
grid_size = img_size // patch_size
num_patches = grid_size**2 * (num_frames // tubelet_size)
self.grid_size = (grid_size, grid_size)
# sine-cosine positional embeddings
pos_embed = get_sinusoid_encoding(num_patches, embed_dims)
self.register_buffer("pos_embed", pos_embed)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = ModuleList(
[
Block(
embed_dims=embed_dims,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[i],
norm_cfg=norm_cfg,
init_cfg=init_cfg,
)
for i in range(depth)
]
)
# build the ladder network
self.adapter_index = adapter_index
self.ladders = LadderNetwork(
depth=len(adapter_index),
embed_dims=embed_dims,
temporal_size=total_frames // tubelet_size,
with_cp=with_cp,
mlp_ratio=adapter_mlp_ratio,
)
# count the number of parameters in the backbone
num_vit_param = sum(p.numel() for name, p in self.named_parameters() if "adapter" not in name)
num_adapter_param = sum(p.numel() for name, p in self.named_parameters() if "adapter" in name)
ratio = num_adapter_param / num_vit_param * 100
print("ViT's param: {}, Adapter's params: {}, ratio: {:2.1f}%".format(num_vit_param, num_adapter_param, ratio))
def forward(self, x: Tensor) -> Tensor:
b, _, _, h, w = x.shape
h //= self.patch_size
w //= self.patch_size
feat_list = self.forward_vit(x)
# ladder network
out = self.ladders(feat_list, h, w)
out = out.reshape(b, -1, h, w, self.embed_dims)
out = out.permute(0, 4, 1, 2, 3)
return out
@torch.no_grad()
def forward_vit(self, x: Tensor) -> Tensor:
self._freeze_layers()
_, _, _, h, w = x.shape
h //= self.patch_size
w //= self.patch_size
x = self.patch_embed(x)[0]
if (h, w) != self.grid_size:
pos_embed = self.pos_embed.reshape(-1, *self.grid_size, self.embed_dims)
pos_embed = pos_embed.permute(0, 3, 1, 2)
pos_embed = F.interpolate(pos_embed, size=(h, w), mode="bicubic", align_corners=False)
pos_embed = pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
pos_embed = pos_embed.reshape(1, -1, self.embed_dims)
else:
pos_embed = self.pos_embed
x = x + pos_embed
x = self.pos_drop(x)
feat_list = [] # append the patch embedding features
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in self.adapter_index:
feat_list.append(x.detach())
return feat_list
def _freeze_layers(self):
"""Prevent all the parameters in the original ViT"""
# freeze patch_embed
self.patch_embed.eval()
for m in self.patch_embed.modules():
for param in m.parameters():
param.requires_grad = False
# freeze blocks except the adapter's parameters
for block in self.blocks:
for m, n in block.named_children():
n.eval()
for param in n.parameters():
param.requires_grad = False