In [None]:
import torch
from torch import nn
from torch.nn import functional as F

In [None]:
from torch import Tensor

# Nornal TCN

In [None]:
class CausalConv1d(nn.Conv1d):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stride = 1,
                 dilation = 1,
                 groups = 1,
                 bias = True) -> None:
        self.__padding = (kernel_size - 1)*dilation

        super(CausalConv1d, self).__init__(in_channels,
                                  out_channels,
                                  kernel_size = kernel_size,
                                  stride = stride,
                                  padding = self.__padding,
                                  dilation = dilation,
                                  groups = groups,
                                  bias = bias
                                  )

    def forward(self, input: Tensor) -> Tensor:
        result = super(CausalConv1d, self).forward(input)
        if self.__padding != 0:
            return result[:, :, :-self.__padding]
        return result

In [None]:
a = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0).unsqueeze(0)
tcn = CausalConv1d(1, 1, 3, 1, 2)
tcn(a)

tensor([[[-0.0629, -0.5868, -0.7813, -0.9759, -1.5539, -2.1320, -2.7101,
          -3.2882, -3.8663]]], grad_fn=<SliceBackward0>)

In [None]:
tcn = CausalConv1d(1, 1, 3, 1, 2, bias = False)
with torch.no_grad():
    tcn.weight = nn.Parameter(torch.Tensor([[[1, 1, 1]]]))
tcn(a)

tensor([[[ 1.,  2.,  4.,  6.,  9., 12., 15., 18., 21.]]],
       grad_fn=<SliceBackward0>)

# DTC Block

In [None]:
dd_tcn_1 = CausalConv1d(3, 3, 3, 1, 2, 3)
dd_tcn_1.weight.shape

torch.Size([3, 1, 3])

In [None]:
class DTCBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 dilation: int) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.dilation = dilation

        self.dilated_depth_tcn = CausalConv1d(in_channels,
                                              in_channels,
                                              kernel_size = kernel_size,
                                              dilation = dilation,
                                              groups = in_channels)
        self.bn1 = nn.BatchNorm1d(in_channels)
        self.point_conv1 = nn.Conv1d(in_channels,
                                     out_channels,
                                     kernel_size = 1)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.relu1 = nn.ReLU()
        self.point_conv2 = nn.Conv1d(out_channels,
                                     out_channels,
                                     kernel_size = 1)
        self.bn3 = nn.BatchNorm1d(out_channels)
        self.relu2 = nn.ReLU()

    def forward(self, input: Tensor) -> Tensor:
        """
        Args:
            input: torch.Tensor: Input tensor (N, C, T)
        Returns
            torch.Tensor: Output tensor (N, C, T)
        """
        output = self.dilated_depth_tcn(input)
        output = self.bn1(output)
        output = self.point_conv1(output)
        output = self.bn2(self.relu1(output))
        output = self.point_conv2(output)
        output = self.bn3(output)
        if self.in_channels == self.out_channels:
            output = input + output
        return self.relu2(output)

In [None]:
x = torch.Tensor(128, 64, 81)
dtcblock = DTCBlock(64, 64, 5, 4)
dtcblock(x).shape

torch.Size([128, 64, 81])

# DTC Stack

In [None]:
class DTCStack(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stack_size: int) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stack_size = stack_size

        dilations = [2**i for i in range(stack_size)]
        stack = []
        stack.append(DTCBlock(in_channels,
                              out_channels,
                              kernel_size = kernel_size,
                              dilation = dilations[0]))
        for i in range(1, stack_size):
            stack.append(DTCBlock(out_channels,
                                  out_channels,
                                  kernel_size = kernel_size,
                                  dilation = dilations[i]))
        self.stack = nn.Sequential(*stack)

    def forward(self, input: Tensor) -> Tensor:
        """
        Args:
            input: torch.Tensor: Input tensor (N, C, T)
        Returns
            torch.Tensor: Output tensor (N, C, T)
        """
        return self.stack(input)

In [None]:
x = torch.Tensor(128, 64, 81)
dtcstack = DTCStack(64, 64, 5, 4)
dtcstack(x).shape

torch.Size([128, 64, 81])

# MDTC

In [None]:
class MDTC(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stack_num: int,
                 stack_size: int,
                 classification: bool = None,
                 hidden_size: int = None,
                 num_classes: int = None) -> None:
        super().__init__()
        self.stack_num = stack_num
        self.classification = classification
        self.preprocessing_tdc = DTCBlock(in_channels,
                                          out_channels,
                                          kernel_size = kernel_size,
                                          dilation = 1)
        self.stack = nn.ModuleList()
        for i in range(stack_num):
            self.stack.append(DTCStack(out_channels,
                                  out_channels,
                                  kernel_size = kernel_size,
                                  stack_size = stack_size))

        if classification:
            assert hidden_size and num_classes, \
            "In classification mode you should give the model hidden_size and num_classes"
            self.avgpool = nn.AdaptiveAvgPool1d(1)
            self.fc1 = nn.Linear(out_channels, hidden_size)
            self.relu1 = nn.ReLU()
            self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, input):
        input = self.preprocessing_tdc(input)
        output = 0

        for i in range(self.stack_num):
            input = self.stack[i](input)
            output += input

        if self.classification:
            output = self.avgpool(output).squeeze()
            output = self.relu1(self.fc1(output))
            output = self.fc2(output)

        return output

In [None]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [None]:
x = torch.Tensor(128, 40, 81)
mdtc = MDTC(40, 64, 5, 4, 4, True, 64, 11)
mdtc(x).shape

torch.Size([128, 11])

In [None]:
from torchinfo import summary

summary(mdtc, input_size = (128, 40, 81), device = 'cpu')

Layer (type:depth-idx)                        Output Shape              Param #
MDTC                                          [128, 11]                 --
├─DTCBlock: 1-1                               [128, 64, 81]             --
│    └─CausalConv1d: 2-1                      [128, 40, 81]             240
│    └─BatchNorm1d: 2-2                       [128, 40, 81]             80
│    └─Conv1d: 2-3                            [128, 64, 81]             2,624
│    └─ReLU: 2-4                              [128, 64, 81]             --
│    └─BatchNorm1d: 2-5                       [128, 64, 81]             128
│    └─Conv1d: 2-6                            [128, 64, 81]             4,160
│    └─BatchNorm1d: 2-7                       [128, 64, 81]             128
│    └─ReLU: 2-8                              [128, 64, 81]             --
├─ModuleList: 1-2                             --                        --
│    └─DTCStack: 2-9                          [128, 64, 81]             --
│    │    └

In [None]:
avgpool = nn.AdaptiveAvgPool1d(1)
out = torch.Tensor(128, 64, 81)
avgpool(out).shape

torch.Size([128, 64, 1])

In [None]:
a = torch.Tensor([[[1, 2, 3],
                   [2, 3, 4]]])
a.shape

torch.Size([1, 2, 3])

In [None]:
avgpool(a)

tensor([[[2.],
         [3.]]])

# Test

In [None]:
linear = nn.Linear(40, 64)
x = torch.Tensor(128, 40, 81)
x = x.transpose(1, 2)
linear(x).shape

torch.Size([128, 81, 64])