In [None]:
import torch
from torch import nn, Tensor
from torch.nn import functional as F

# Base Block

In [None]:
class BaseBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stride: int = 1,
                 padding: int = 1,
                 dilation: int = 1,
                 bias: bool = True) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        out_ahalf = out_channels//2
        in_ahalf = in_channels//2

        self.pconv1 = nn.Conv2d(in_ahalf,
                                out_ahalf,
                                kernel_size = 1,
                                bias = bias)
        self.bn1 = nn.BatchNorm2d(out_ahalf)
        self.relu1 = nn.ReLU(inplace = True)

        self.dwconv = nn.Conv2d(out_ahalf,
                                out_ahalf,
                                kernel_size = kernel_size,
                                stride = stride,
                                padding = padding,
                                dilation = dilation,
                                groups = out_ahalf,
                                bias = bias)
        self.bn2 = nn.BatchNorm2d(out_ahalf)

        self.pconv2 = nn.Conv2d(out_ahalf,
                                out_ahalf,
                                kernel_size = 1,
                                bias = bias)
        self.bn3 = nn.BatchNorm2d(out_ahalf)
        self.relu2 = nn.ReLU(inplace = True)

        self.relu3 = nn.ReLU(inplace = True)

    def forward(self, input: Tensor) -> Tensor:
        """
        Args:
        input: Input Tensor: (N, C, H, W)
        """
        output1 = torch.chunk(input, 2, 1)[0]
        output2 = torch.chunk(input, 2, 1)[1]
        output2 = self.relu1(self.bn1(self.pconv1(output2)))
        output2 = self.bn2(self.dwconv(output2))
        output2 = self.relu2(self.bn3(self.pconv2(output2)))
        #if self.in_channels == self.out_channels:
        #    output += input
        output = torch.concat([output1, output2], dim = 1)
        return self.relu3(output)

In [None]:
baseblock = BaseBlock(24, 72, 3)
x = torch.Tensor(128, 24, 20, 51)
baseblock(x).shape

torch.Size([128, 48, 20, 51])

# EdgeCRNN Block

In [None]:
class EdgeCRNNBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stride: int = 2,
                 padding: int = 0,
                 dilation: int = 1,
                 bias: bool = True) -> None:
        super().__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size = kernel_size,
                      stride = stride,
                      padding = padding,
                      dilation = dilation,
                      groups = in_channels,
                      bias = bias),
            nn.BatchNorm2d(in_channels),
            nn.Conv2d(in_channels,
                      out_channels,
                      kernel_size = 1,
                      bias = bias),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True)
        )
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels,
                      out_channels,
                      kernel_size = 1,
                      bias = bias),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels,
                      out_channels,
                      kernel_size = kernel_size,
                      stride = stride,
                      padding = padding,
                      dilation = dilation,
                      groups = out_channels,
                      bias = bias
                      ),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels,
                      out_channels,
                      kernel_size = 1,
                      bias = bias),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True)
        )
        self.relu = nn.ReLU(inplace = True)

    def forward(self, input: Tensor) -> Tensor:
        """
        Args:
        input: Input Tensor: (N, C, H, W)
        """
        output1 = self.branch1(input)
        output2 = self.branch2(input)
        output = torch.concat([output1, output2], dim = 1)
        return self.relu(output)

In [None]:
edgecrnnblock = EdgeCRNNBlock(24, 72, 3)
x = torch.Tensor(128, 24, 20, 51)
edgecrnnblock(x).shape

torch.Size([128, 144, 9, 25])

# Stage Block

In [None]:
class StageBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 padding: int = 1,
                 dilation: int = 1,
                 bias: bool = True,
                 num_base_blks: int = 1) -> None:
        super().__init__()
        stack = []

        out_ahalf = out_channels//2
        stack.append(EdgeCRNNBlock(in_channels,
                                      out_ahalf,
                                      kernel_size = kernel_size,
                                      stride = 2,
                                      padding = padding,
                                      bias = bias))
        for _ in range(num_base_blks):
            stack.append(BaseBlock(out_channels,
                                      out_channels,
                                      kernel_size = kernel_size,
                                      stride = 1,
                                      padding = padding,
                                      bias = bias))
        self.stack = nn.Sequential(*stack)
    def forward(self, input):
        return self.stack(input)

In [None]:
stageblk = StageBlock(72, 144, 3)
x = torch.Tensor(128, 72, 20, 51)
stageblk(x).shape

torch.Size([128, 144, 10, 26])

In [None]:
class EdgeCRNN(nn.Module):
    def __init__(self,
                 in_channels: int,
                 hidden_size: int,
                 num_classes: int,
                 dropout: float = 0.1,
                 width_multiplier: int = 1) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(in_channels,
                               int(24*width_multiplier),
                               kernel_size = 3,
                               stride = 1,
                               padding = 1,
                               bias = False
                               )
        self.bn1 = nn.BatchNorm2d(int(24*width_multiplier))
        self.relu1 = nn.ReLU(inplace = True)

        self.maxpool = nn.MaxPool2d(kernel_size = 3,
                                    stride = 2,
                                    padding = 1
                                    )
        self.stage2 = StageBlock(int(24*width_multiplier),
                                 int(72*width_multiplier),
                                 kernel_size = 3,
                                 padding = 1,
                                 bias = False)
        self.stage3 = StageBlock(int(72*width_multiplier),
                                 int(144*width_multiplier),
                                 kernel_size = 3,
                                 padding = 1,
                                 num_base_blks = 2,
                                 bias = False)
        self.stage4 = StageBlock(int(144*width_multiplier),
                                 int(288*width_multiplier),
                                 kernel_size = 3,
                                 padding = 1,
                                 bias = False)

        self.conv5 = nn.Conv2d(int(288*width_multiplier),
                               int(512*width_multiplier),
                               kernel_size = 1,
                               bias = False)
        self.bn5 = nn.BatchNorm2d(int(512*width_multiplier))
        self.relu5 = nn.ReLU(inplace = True)

        self.globalpool = nn.AvgPool2d((3, 1), stride = (1, 1))

        self.lstm = nn.LSTM(input_size = int(512*width_multiplier),
                            hidden_size = hidden_size,
                            batch_first = True)

        self.fc = nn.Linear(hidden_size,
                            num_classes)

    def forward(self, input: Tensor) -> Tensor:
        output = self.relu1(self.bn1(self.conv1(input)))
        output = self.maxpool(output)
        output = self.stage2(output)
        output = self.stage3(output)
        output = self.stage4(output)
        output = self.relu5(self.bn5(self.conv5(output)))
        output = self.globalpool(output).squeeze(2)
        output, _ = self.lstm(output.transpose(1, 2)) # N, T, H
        output = output.transpose(1, 2).mean(dim = 2)
        return self.fc(output)

In [None]:
globalpool = nn.AvgPool2d((3, 1), stride = (1, 1))
x = torch.Tensor(128, 512, 3, 7)
globalpool(x).shape

torch.Size([128, 512, 1, 7])

In [None]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [None]:
from torchinfo import summary

In [None]:
model = EdgeCRNN(1, 64, 12, width_multiplier=1)
x = torch.Tensor(1, 1, 39, 101)
model(x).shape

torch.Size([1, 12])

In [None]:
summary(model, input_size = (1, 1, 39, 101))

Layer (type:depth-idx)                        Output Shape              Param #
EdgeCRNN                                      [1, 12]                   --
├─Conv2d: 1-1                                 [1, 24, 39, 101]          216
├─BatchNorm2d: 1-2                            [1, 24, 39, 101]          48
├─ReLU: 1-3                                   [1, 24, 39, 101]          --
├─MaxPool2d: 1-4                              [1, 24, 20, 51]           --
├─StageBlock: 1-5                             [1, 72, 10, 26]           --
│    └─Sequential: 2-1                        [1, 72, 10, 26]           --
│    │    └─EdgeCRNNBlock: 3-1                [1, 72, 10, 26]           3,900
│    │    └─BaseBlock: 3-2                    [1, 72, 10, 26]           3,132
├─StageBlock: 1-6                             [1, 144, 5, 13]           --
│    └─Sequential: 2-2                        [1, 144, 5, 13]           --
│    │    └─EdgeCRNNBlock: 3-3                [1, 144, 5, 13]           17,568
│    │   

In [None]:
!pip install thop

In [None]:
from thop import profile

In [None]:
flop, para = profile(model, inputs=(x,))
print("FLOPs:%.2fM" % (flop / 1e6), "Parameters:%.2fM" % (para / 1e6))

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_avgpool() for <class 'torch.nn.modules.pooling.AvgPool2d'>.
[INFO] Register count_lstm() for <class 'torch.nn.modules.rnn.LSTM'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
FLOPs:1951.20M Parameters:0.45M


# Code github

In [None]:
import EdgeCRNN as eCRNN

In [None]:
model2 = eCRNN.EdgeCRNN(width_mult = 1)

In [None]:
summary(model2, input_size = (128, 1, 39, 101))

Layer (type:depth-idx)                   Output Shape              Param #
EdgeCRNN                                 [128, 12]                 --
├─Sequential: 1-1                        [128, 24, 39, 101]        --
│    └─Conv2d: 2-1                       [128, 24, 39, 101]        216
│    └─BatchNorm2d: 2-2                  [128, 24, 39, 101]        48
│    └─ReLU: 2-3                         [128, 24, 39, 101]        --
├─MaxPool2d: 1-2                         [128, 24, 20, 51]         --
├─Sequential: 1-3                        [128, 288, 3, 7]          --
│    └─EdgeCRNN_Residual: 2-4            [128, 72, 10, 26]         --
│    │    └─Sequential: 3-1              [128, 36, 10, 26]         1,200
│    │    └─Sequential: 3-2              [128, 36, 10, 26]         2,700
│    └─EdgeCRNN_Residual: 2-5            [128, 72, 10, 26]         --
│    │    └─Sequential: 3-3              [128, 36, 10, 26]         3,132
│    └─EdgeCRNN_Residual: 2-6            [128, 144, 5, 13]         --
│    