diff --git a/docs/source/models/video_mvit.rst b/docs/source/models/video_mvit.rst index d5be1245ac9..cd23754b7bb 100644 --- a/docs/source/models/video_mvit.rst +++ b/docs/source/models/video_mvit.rst @@ -12,7 +12,7 @@ The MViT model is based on the Model builders -------------- -The following model builders can be used to instantiate a MViT model, with or +The following model builders can be used to instantiate a MViT v1 or v2 model, with or without pre-trained weights. All the model builders internally rely on the ``torchvision.models.video.MViT`` base class. Please refer to the `source code @@ -24,3 +24,4 @@ more details about this class. :template: function.rst mvit_v1_b + mvit_v2_s diff --git a/test/expect/ModelTester.test_mvit_v2_s_expect.pkl b/test/expect/ModelTester.test_mvit_v2_s_expect.pkl new file mode 100644 index 00000000000..5ae3e4a0d76 Binary files /dev/null and b/test/expect/ModelTester.test_mvit_v2_s_expect.pkl differ diff --git a/test/test_models.py b/test/test_models.py index bc83874ee4f..5061888d71d 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -309,6 +309,9 @@ def _check_input_backprop(model, inputs): "mvit_v1_b": { "input_shape": (1, 3, 16, 224, 224), }, + "mvit_v2_s": { + "input_shape": (1, 3, 16, 224, 224), + }, } # speeding up slow models: slow_models = [ diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index cfa82a4b851..7283a21bb0d 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -19,12 +19,11 @@ "MViT", "MViT_V1_B_Weights", "mvit_v1_b", + "MViT_V2_S_Weights", + "mvit_v2_s", ] -# TODO: Consider handle 2d input if Temporal is 1 - - @dataclass class MSBlockConfig: num_heads: int @@ -106,28 +105,121 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten return x, (T, H, W) +def _interpolate(embedding: torch.Tensor, d: int) -> torch.Tensor: + if embedding.shape[0] == d: + return embedding + + return ( + nn.functional.interpolate( + embedding.permute(1, 0).unsqueeze(0), + size=d, + mode="linear", + ) + .squeeze(0) + .permute(1, 0) + ) + + +def _add_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + q_thw: Tuple[int, int, int], + k_thw: Tuple[int, int, int], + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + rel_pos_t: torch.Tensor, +) -> torch.Tensor: + # Modified code from: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932 + q_t, q_h, q_w = q_thw + k_t, k_h, k_w = k_thw + dh = int(2 * max(q_h, k_h) - 1) + dw = int(2 * max(q_w, k_w) - 1) + dt = int(2 * max(q_t, k_t) - 1) + + # Scale up rel pos if shapes for q and k are different. + q_h_ratio = max(k_h / q_h, 1.0) + k_h_ratio = max(q_h / k_h, 1.0) + dist_h = torch.arange(q_h)[:, None] * q_h_ratio - (torch.arange(k_h)[None, :] + (1.0 - k_h)) * k_h_ratio + q_w_ratio = max(k_w / q_w, 1.0) + k_w_ratio = max(q_w / k_w, 1.0) + dist_w = torch.arange(q_w)[:, None] * q_w_ratio - (torch.arange(k_w)[None, :] + (1.0 - k_w)) * k_w_ratio + q_t_ratio = max(k_t / q_t, 1.0) + k_t_ratio = max(q_t / k_t, 1.0) + dist_t = torch.arange(q_t)[:, None] * q_t_ratio - (torch.arange(k_t)[None, :] + (1.0 - k_t)) * k_t_ratio + + # Intepolate rel pos if needed. + rel_pos_h = _interpolate(rel_pos_h, dh) + rel_pos_w = _interpolate(rel_pos_w, dw) + rel_pos_t = _interpolate(rel_pos_t, dt) + Rh = rel_pos_h[dist_h.long()] + Rw = rel_pos_w[dist_w.long()] + Rt = rel_pos_t[dist_t.long()] + + B, n_head, _, dim = q.shape + + r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim) + rel_h_q = torch.einsum("bythwc,hkc->bythwk", r_q, Rh) # [B, H, q_t, qh, qw, k_h] + rel_w_q = torch.einsum("bythwc,wkc->bythwk", r_q, Rw) # [B, H, q_t, qh, qw, k_w] + # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim] + r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w, dim) + # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t] + rel_q_t = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1) + # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t] + rel_q_t = rel_q_t.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5) + + # Combine rel pos. + rel_pos = ( + rel_h_q[:, :, :, :, :, None, :, None] + + rel_w_q[:, :, :, :, :, None, None, :] + + rel_q_t[:, :, :, :, :, :, None, None] + ).reshape(B, n_head, q_t * q_h * q_w, k_t * k_h * k_w) + + # Add it to attention + attn[:, :, 1:, 1:] += rel_pos + + return attn + + +def _add_shortcut(x: torch.Tensor, shortcut: torch.Tensor, residual_with_cls_embed: bool): + if residual_with_cls_embed: + x.add_(shortcut) + else: + x[:, :, 1:, :] += shortcut[:, :, 1:, :] + return x + + +torch.fx.wrap("_add_rel_pos") +torch.fx.wrap("_add_shortcut") + + class MultiscaleAttention(nn.Module): def __init__( self, + input_size: List[int], embed_dim: int, + output_dim: int, num_heads: int, kernel_q: List[int], kernel_kv: List[int], stride_q: List[int], stride_kv: List[int], residual_pool: bool, + residual_with_cls_embed: bool, + rel_pos_embed: bool, dropout: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ) -> None: super().__init__() self.embed_dim = embed_dim + self.output_dim = output_dim self.num_heads = num_heads - self.head_dim = embed_dim // num_heads + self.head_dim = output_dim // num_heads self.scaler = 1.0 / math.sqrt(self.head_dim) self.residual_pool = residual_pool + self.residual_with_cls_embed = residual_with_cls_embed - self.qkv = nn.Linear(embed_dim, 3 * embed_dim) - layers: List[nn.Module] = [nn.Linear(embed_dim, embed_dim)] + self.qkv = nn.Linear(embed_dim, 3 * output_dim) + layers: List[nn.Module] = [nn.Linear(output_dim, output_dim)] if dropout > 0.0: layers.append(nn.Dropout(dropout, inplace=True)) self.project = nn.Sequential(*layers) @@ -177,24 +269,52 @@ def __init__( norm_layer(self.head_dim), ) + self.rel_pos_h: Optional[nn.Parameter] = None + self.rel_pos_w: Optional[nn.Parameter] = None + self.rel_pos_t: Optional[nn.Parameter] = None + if rel_pos_embed: + size = max(input_size[1:]) + q_size = size // stride_q[1] if len(stride_q) > 0 else size + kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size + spatial_dim = 2 * max(q_size, kv_size) - 1 + temporal_dim = 2 * input_size[0] - 1 + self.rel_pos_h = nn.Parameter(torch.zeros(spatial_dim, self.head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(spatial_dim, self.head_dim)) + self.rel_pos_t = nn.Parameter(torch.zeros(temporal_dim, self.head_dim)) + nn.init.trunc_normal_(self.rel_pos_h, std=0.02) + nn.init.trunc_normal_(self.rel_pos_w, std=0.02) + nn.init.trunc_normal_(self.rel_pos_t, std=0.02) + def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: B, N, C = x.shape q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(dim=2) if self.pool_k is not None: - k = self.pool_k(k, thw)[0] + k, k_thw = self.pool_k(k, thw) + else: + k_thw = thw if self.pool_v is not None: v = self.pool_v(v, thw)[0] if self.pool_q is not None: q, thw = self.pool_q(q, thw) attn = torch.matmul(self.scaler * q, k.transpose(2, 3)) + if self.rel_pos_h is not None and self.rel_pos_w is not None and self.rel_pos_t is not None: + attn = _add_rel_pos( + attn, + q, + thw, + k_thw, + self.rel_pos_h, + self.rel_pos_w, + self.rel_pos_t, + ) attn = attn.softmax(dim=-1) x = torch.matmul(attn, v) if self.residual_pool: - x.add_(q) - x = x.transpose(1, 2).reshape(B, -1, C) + _add_shortcut(x, q, self.residual_with_cls_embed) + x = x.transpose(1, 2).reshape(B, -1, self.output_dim) x = self.project(x) return x, thw @@ -203,13 +323,18 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten class MultiscaleBlock(nn.Module): def __init__( self, + input_size: List[int], cnf: MSBlockConfig, residual_pool: bool, + residual_with_cls_embed: bool, + rel_pos_embed: bool, + proj_after_attn: bool, dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ) -> None: super().__init__() + self.proj_after_attn = proj_after_attn self.pool_skip: Optional[nn.Module] = None if _prod(cnf.stride_q) > 1: @@ -219,24 +344,30 @@ def __init__( nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip), None # type: ignore[arg-type] ) + attn_dim = cnf.output_channels if proj_after_attn else cnf.input_channels + self.norm1 = norm_layer(cnf.input_channels) - self.norm2 = norm_layer(cnf.input_channels) + self.norm2 = norm_layer(attn_dim) self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d) self.attn = MultiscaleAttention( + input_size, cnf.input_channels, + attn_dim, cnf.num_heads, kernel_q=cnf.kernel_q, kernel_kv=cnf.kernel_kv, stride_q=cnf.stride_q, stride_kv=cnf.stride_kv, + rel_pos_embed=rel_pos_embed, residual_pool=residual_pool, + residual_with_cls_embed=residual_with_cls_embed, dropout=dropout, norm_layer=norm_layer, ) self.mlp = MLP( - cnf.input_channels, - [4 * cnf.input_channels, cnf.output_channels], + attn_dim, + [4 * attn_dim, cnf.output_channels], activation_layer=nn.GELU, dropout=dropout, inplace=None, @@ -249,36 +380,45 @@ def __init__( self.project = nn.Linear(cnf.input_channels, cnf.output_channels) def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + x_norm1 = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x) + x_attn, thw_new = self.attn(x_norm1, thw) + x = x if self.project is None or not self.proj_after_attn else self.project(x_norm1) x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0] + x = x_skip + self.stochastic_depth(x_attn) - x = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x) - x, thw = self.attn(x, thw) - x = x_skip + self.stochastic_depth(x) + x_norm2 = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x) + x_proj = x if self.project is None or self.proj_after_attn else self.project(x_norm2) - x_norm = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x) - x_proj = x if self.project is None else self.project(x_norm) - - return x_proj + self.stochastic_depth(self.mlp(x_norm)), thw + return x_proj + self.stochastic_depth(self.mlp(x_norm2)), thw_new class PositionalEncoding(nn.Module): - def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int) -> None: + def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int, rel_pos_embed: bool) -> None: super().__init__() self.spatial_size = spatial_size self.temporal_size = temporal_size self.class_token = nn.Parameter(torch.zeros(embed_size)) - self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size)) - self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size)) - self.class_pos = nn.Parameter(torch.zeros(embed_size)) + self.spatial_pos: Optional[nn.Parameter] = None + self.temporal_pos: Optional[nn.Parameter] = None + self.class_pos: Optional[nn.Parameter] = None + if not rel_pos_embed: + self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size)) + self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size)) + self.class_pos = nn.Parameter(torch.zeros(embed_size)) def forward(self, x: torch.Tensor) -> torch.Tensor: - hw_size, embed_size = self.spatial_pos.shape - pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0) - pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size)) - pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0) class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1) - return torch.cat((class_token, x), dim=1).add_(pos_embedding) + x = torch.cat((class_token, x), dim=1) + + if self.spatial_pos is not None and self.temporal_pos is not None and self.class_pos is not None: + hw_size, embed_size = self.spatial_pos.shape + pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0) + pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size)) + pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0) + x.add_(pos_embedding) + + return x class MViT(nn.Module): @@ -288,12 +428,18 @@ def __init__( temporal_size: int, block_setting: Sequence[MSBlockConfig], residual_pool: bool, + residual_with_cls_embed: bool, + rel_pos_embed: bool, + proj_after_attn: bool, dropout: float = 0.5, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, num_classes: int = 400, block: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, + patch_embed_kernel: Tuple[int, int, int] = (3, 7, 7), + patch_embed_stride: Tuple[int, int, int] = (2, 4, 4), + patch_embed_padding: Tuple[int, int, int] = (1, 3, 3), ) -> None: """ MViT main class. @@ -303,12 +449,19 @@ def __init__( temporal_size (int): The temporal size ``T`` of the input. block_setting (sequence of MSBlockConfig): The Network structure. residual_pool (bool): If True, use MViTv2 pooling residual connection. + residual_with_cls_embed (bool): If True, the addition on the residual connection will include + the class embedding. + rel_pos_embed (bool): If True, use MViTv2's relative positional embeddings. + proj_after_attn (bool): If True, apply the projection after the attention. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. num_classes (int): The number of classes. block (callable, optional): Module specifying the layer which consists of the attention and mlp. norm_layer (callable, optional): Module specifying the normalization layer to use. + patch_embed_kernel (tuple of ints): The kernel of the convolution that patchifies the input. + patch_embed_stride (tuple of ints): The stride of the convolution that patchifies the input. + patch_embed_padding (tuple of ints): The padding of the convolution that patchifies the input. """ super().__init__() # This implementation employs a different parameterization scheme than the one used at PyTorch Video: @@ -330,16 +483,19 @@ def __init__( self.conv_proj = nn.Conv3d( in_channels=3, out_channels=block_setting[0].input_channels, - kernel_size=(3, 7, 7), - stride=(2, 4, 4), - padding=(1, 3, 3), + kernel_size=patch_embed_kernel, + stride=patch_embed_stride, + padding=patch_embed_padding, ) + input_size = [size // stride for size, stride in zip((temporal_size,) + spatial_size, self.conv_proj.stride)] + # Spatio-Temporal Class Positional Encoding self.pos_encoding = PositionalEncoding( embed_size=block_setting[0].input_channels, - spatial_size=(spatial_size[0] // self.conv_proj.stride[1], spatial_size[1] // self.conv_proj.stride[2]), - temporal_size=temporal_size // self.conv_proj.stride[0], + spatial_size=(input_size[1], input_size[2]), + temporal_size=input_size[0], + rel_pos_embed=rel_pos_embed, ) # Encoder module @@ -350,13 +506,20 @@ def __init__( self.blocks.append( block( + input_size=input_size, cnf=cnf, residual_pool=residual_pool, + residual_with_cls_embed=residual_with_cls_embed, + rel_pos_embed=rel_pos_embed, + proj_after_attn=proj_after_attn, dropout=attention_dropout, stochastic_depth_prob=sd_prob, norm_layer=norm_layer, ) ) + + if len(cnf.stride_q) > 0: + input_size = [size // stride for size, stride in zip(input_size, cnf.stride_q)] self.norm = norm_layer(block_setting[-1].output_channels) # Classifier module @@ -380,6 +543,8 @@ def __init__( nn.init.trunc_normal_(weights, std=0.02) def forward(self, x: torch.Tensor) -> torch.Tensor: + # Convert if necessary (B, C, H, W) -> (B, C, 1, H, W) + x = _unsqueeze(x, 5, 2)[0] # patchify and reshape: (B, C, T, H, W) -> (B, embed_channels[0], T', H', W') -> (B, THW', embed_channels[0]) x = self.conv_proj(x) x = x.flatten(2).transpose(1, 2) @@ -420,6 +585,9 @@ def _mvit( temporal_size=temporal_size, block_setting=block_setting, residual_pool=kwargs.pop("residual_pool", False), + residual_with_cls_embed=kwargs.pop("residual_with_cls_embed", True), + rel_pos_embed=kwargs.pop("rel_pos_embed", False), + proj_after_attn=kwargs.pop("proj_after_attn", False), stochastic_depth_prob=stochastic_depth_prob, **kwargs, ) @@ -461,6 +629,37 @@ class MViT_V1_B_Weights(WeightsEnum): DEFAULT = KINETICS400_V1 +class MViT_V2_S_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/mvit_v2_s-ae3be167.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.45, 0.45, 0.45), + std=(0.225, 0.225, 0.225), + ), + meta={ + "min_size": (224, 224), + "min_temporal_size": 16, + "categories": _KINETICS400_CATEGORIES, + "recipe": "https://github.com/facebookresearch/SlowFast/blob/main/MODEL_ZOO.md", + "_docs": ( + "The weights were ported from the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`" + ), + "num_params": 34537744, + "_metrics": { + "Kinetics-400": { + "acc@1": 80.757, + "acc@5": 94.665, + } + }, + }, + ) + DEFAULT = KINETICS400_V1 + + @register_model() def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT: """ @@ -548,6 +747,138 @@ def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = T temporal_size=16, block_setting=block_setting, residual_pool=False, + residual_with_cls_embed=False, + stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), + weights=weights, + progress=progress, + **kwargs, + ) + + +@register_model() +def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT: + """ + Constructs a small MViTV2 architecture from + `Multiscale Vision Transformers `__. + + Args: + weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.MViT_V2_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.MViT`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.MViT_V2_S_Weights + :members: + """ + weights = MViT_V2_S_Weights.verify(weights) + + config: Dict[str, List] = { + "num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8], + "input_channels": [96, 96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768], + "output_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768], + "kernel_q": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + ], + "kernel_kv": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + ], + "stride_q": [ + [1, 1, 1], + [1, 2, 2], + [1, 1, 1], + [1, 2, 2], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 2, 2], + [1, 1, 1], + ], + "stride_kv": [ + [1, 8, 8], + [1, 4, 4], + [1, 4, 4], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 1, 1], + [1, 1, 1], + ], + } + + block_setting = [] + for i in range(len(config["num_heads"])): + block_setting.append( + MSBlockConfig( + num_heads=config["num_heads"][i], + input_channels=config["input_channels"][i], + output_channels=config["output_channels"][i], + kernel_q=config["kernel_q"][i], + kernel_kv=config["kernel_kv"][i], + stride_q=config["stride_q"][i], + stride_kv=config["stride_kv"][i], + ) + ) + + return _mvit( + spatial_size=(224, 224), + temporal_size=16, + block_setting=block_setting, + residual_pool=True, + residual_with_cls_embed=False, + rel_pos_embed=True, + proj_after_attn=True, stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), weights=weights, progress=progress,