In [3]:
import torch
import torch.nn as nn
from torch.nn import Conv2d
from torch import Tensor
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union

In [2]:
class Conv2Plus1D(nn.Sequential):
    def __init__(self, in_planes: int, out_planes: int, midplanes: int, stride: int = 1, padding: int = 1) -> None:
        super().__init__(
            nn.Conv3d(
                in_planes,
                midplanes,
                kernel_size=(1, 3, 3),
                stride=(1, stride, stride),
                padding=(0, padding, padding),
                bias=False,
            ),
            nn.BatchNorm3d(midplanes),
            nn.ReLU(inplace=True),
            nn.Conv3d(
                midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False
            ),
        )

    @staticmethod
    def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
        return stride, stride, stride

In [32]:
class R2Plus1dStem(nn.Sequential):
    """R(2+1)D stem is different than the default one as it uses separated 3D convolution"""

    def __init__(self) -> None:
        super().__init__(
            nn.Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False),
            nn.BatchNorm3d(45),
            nn.ReLU(inplace=True),
            nn.Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
        )

In [33]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(
        self,
        inplanes: int,
        out_planes: int,
        conv_builder: Callable[..., nn.Module],
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
    ) -> None:

        super().__init__()
        midplanes = (inplanes * out_planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * out_planes)

        # 1x1x1
        self.conv1 = nn.Sequential(
            nn.Conv3d(inplanes, out_planes, kernel_size=1, bias=False), 
            nn.BatchNorm3d(out_planes), 
            nn.ReLU(inplace=True)
        )
        # Second kernel
        self.conv2 = nn.Sequential(
            conv_builder(out_planes, out_planes, midplanes, stride), 
            nn.BatchNorm3d(out_planes), 
            nn.ReLU(inplace=True)
        )

        # 1x1x1
        self.conv3 = nn.Sequential(
            nn.Conv3d(out_planes, out_planes * self.expansion, kernel_size=1, bias=False),
            nn.BatchNorm3d(out_planes * self.expansion),
        )
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        residual = x

        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

In [34]:
class BasicBlock(nn.Module):

    expansion = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        conv_builder: Callable[..., nn.Module],
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
    ) -> None:
        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)

        super().__init__()
        self.conv1 = nn.Sequential(
            conv_builder(inplanes, planes, midplanes, stride), 
            nn.BatchNorm3d(planes), 
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes))
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        residual = x

        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

In [35]:
class VideoResNet(nn.Module):
    def __init__(
            self,
            block: Type[Union[BasicBlock, Bottleneck]],
            conv_makers: Sequence[Type[Union[Conv2Plus1D]]],
            nums_of_layers: List[int], 
            stem: Callable[..., nn.Module],
            num_classes: int = 2,
            zero_init_residual: bool = False
    ) -> None:
        super().__init__()
        self.inplanes = 64

        self.stem = stem()

        self.layer1 = self._make_layer(block, conv_makers[0], 64, nums_of_layers[0], stride=1)
        self.layer2 = self._make_layer(block, conv_makers[1], 128, nums_of_layers[1], stride=2)
        self.layer3 = self._make_layer(block, conv_makers[2], 256, nums_of_layers[2], stride=2)
        self.layer4 = self._make_layer(block, conv_makers[3], 512, nums_of_layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        #init weights 
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
        
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
        

    def forward(self, x: Tensor) -> Tensor:
        x = self.stem(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        #flatten layer to fc
        x = x.flatten(1)
        x = self.fc(x)

        return x




    def _make_layer(
            self,
            block: Type[Union[BasicBlock, Bottleneck]],
            conv_builder: Type[Union[Conv2Plus1D]],
            out_planes: int,
            num_of_blocks: int,
            stride: int = 1,
            ) -> nn.Sequential:
        downsample = None

        if stride != 1 or self.inplanes != out_planes *  block.expansion:
            ds_stride = conv_builder.get_downsample_stride(stride)
            downsample = nn.Sequential(
                nn.Conv3d(self.inplanes, out_planes*block.expansion, kernel_size=1, stride=ds_stride, bias=False),
                nn.BatchNorm3d(out_planes * block.expansion)
            )
        layers = []
        layers.append(block(self.inplanes, out_planes, conv_builder, stride, downsample))

        self.inplanes = out_planes * block.expansion
        for i in range(1, num_of_blocks):
            layers.append(block(self.inplanes, out_planes, conv_builder))
        
        return nn.Sequential(*layers)
    


In [36]:
def make_video_resnet(
    block: Type[Union[BasicBlock, Bottleneck]],
    conv_makers: Sequence[Type[Union[Conv2Plus1D]]],
    layers: List[int],
    stem: Callable[..., nn.Module],
) -> VideoResNet:

    model = VideoResNet(block, conv_makers, layers, stem)

    return model

In [37]:
make_video_resnet(
        BasicBlock,
        [Conv2Plus1D] * 4,
        [2, 2, 2, 2],
        R2Plus1dStem,
    )

VideoResNet(
  (stem): R2Plus1dStem(
    (0): Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False)
    (1): BatchNorm3d(45, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
    (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Sequential(
        (0): Conv2Plus1D(
          (0): Conv3d(64, 144, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
          (1): BatchNorm3d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv3d(144, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
        )
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru

## Deformable Conv Layer

In [None]:
class DeformableConvLayer(Conv2d):

    def __init__(
            self,
            filters,
            kernel_size,
            strides=(1, 1),
            padding='same',
            data_format=None,
            dilation_rate=(1,1),
            num_deformable_group = None,
            activation = None,
            use_bias = None,
            kernel_initializer='glorot_uniform',
            bias_initializer='zeros',
            kernel_regularizer=None,
            bias_regularizer=None,
            activity_regularizer=None,
            kernel_constraint=None,
            bias_constraint=None,
            **kwargs
    ):
        """`kernel_size`, `strides` and `dilation_rate` must have the same value in both axis.
        
        :param num_deformable_group: split output channels into groups, offset shared in each group. If 
        this parameter is None, then  set num_deformable_group=filters.
        """
        
        super().__init__(
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            data_format=data_format,
            dilation_rate=dilation_rate,
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            **kwargs)
        self.kernel = None
        self.bias = None
        self.offset_layer_kernel = None
        self.offsetlayer_bias = None
        if num_deformable_group is None:
            num_deformable_group = filters
        if filters % num_deformable_group != 0:
            raise ValueError('"filters" mod "num_deformable_group must be zero')
        self.num_deformable_group = num_deformable_group
    
    def build(self, input_shape):
        input_dim = int(input_shape[-1])
        #kernel_shape = self.kernel_size + (input_dim, self.filters)
        # we want to use depth-wise conv
        kernel_shape = self.kernel_size + (self.filters * input_dim, 1)
        self.kernel = nn.Parameter(
            torch.zeros(kernel_shape, dtype=torch.float, requires_grad=True)
        )
        nn.init.xavier_uniform_(self.kernel)

        if self.use_bias:
            self.bias = nn.Parameter(
                torch.zeros(kernel_shape, dtype=torch.float),
                requires_grad=True
            )
        nn.init.zeros_(self.bias)
        
        # create offset conv layer 
        offset_num = self.kernel_size[0] * self.kernel_size[1] * self.num_deformable_group
        self.offset_layer_kernel = nn.Parameter(
            torch.zeros(self.kernel_size + (input_dim, offset_num * 2), dtype=torch.float), # 2 mean x and y axis 
            requires_grad=True
                    )
        nn.init.zeros_(self.offset_layer_kernel)

        self.offset_layer_bias = nn.Parameter(
            torch.zeros(offset_num * 2,),
            requires_grad=True
        )
        nn.init.zeros_(self.offset_layer_bias)

    def forward(self, inputs, training=None, **kwargs):
        #get offset shape [batch_size, out_h, out_w, filter_h * filter_w * chanel_out * 2]
        offset = nn.Conv2d(inputs, 
                  self.offset_layer_kernel, 
                  bias=self.offset_layer_bias, 
                  stride=self.strides, 
                  padding=self.padding, 
                  dilation=[1, self.dilation_rate, 1])
        print(offset.shape, "offset.shape")
        offset += self.offset_layer_bias
        ... #Continue 

    
    def _pad_input(self, inputs):
        """Check if input feature map needs padding, because we don't use the standart Conv() function.

        :param inputs:
        :return: padded input feature map 
        """

        #When paddin is 'same', we should pad the feature map.
        # if padding == 'same', output size should be `ceil(input / stride)`
        if self.padding == 'same':
            in_shape = inputs.shape[1:3]
            padding_list = []
            for i in range(2):
                filter_size = self.kernel_size[i]
                dilation = self.dilation_rate[i]
                dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
                same_output = (in_shape[i] + self.strides[i] - 1) // self.strides[i]
                valid_output = (in_shape[i] - dilated_filter_size + self.strides[i]) // self.strides[i]
                if same_output == valid_output:
                    padding_list += [0, 0]
                else:
                    p = dilated_filter_size - 1
                    p_0 = p // 2
                    padding_list += [p_0, p - p_0] 

SyntaxError: invalid syntax (95916157.py, line 67)

In [None]:
Conv2d()