In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


### Dilated Inception

- https://github.com/nuist-cs/MS-STGNN/blob/main/layer.py
- https://github.com/KawaiiAsh/deepLearning-modules-toolbox/blob/main/Temporal_conv.py

In [3]:
class dilated_inception(nn.Module):
    """
    inception
    """
    def __init__(self, cin, cout, dilation_factor, seq_len):
        super().__init__()
        self.tconv = nn.ModuleList()
        self.padding = 0
        self.seq_len = seq_len
        self.kernel_set = [2,3,6,7]
        cout = int(cout / len(self.kernel_set))

        for kern in self.kernel_set: 
            self.tconv.append(nn.Conv2d(cin, cout, (1, kern), dilation=(1, dilation_factor)))

        self.out = nn.Sequential(
            nn.Linear(self.seq_len - dilation_factor * (self.kernel_set[-1] - 1) + self.padding * 2 - 1 + 1, cin),
            nn.ReLU(),
            nn.Linear(cin, self.seq_len)
        )

    def forward(self, input):
        x = []
        for i in range(len(self.kernel_set)):
            x.append(self.tconv[i](input))
        for i in range(len(self.kernel_set)):
            x[i] = x[i][..., -x[-1].size(3):] 

        x = torch.cat(x, dim=1) 
        x = self.out(x)
        return x

In [4]:
class temporal_conv(nn.Module):
    def __init__(self, cin, cout, dilation_factor, seq_len):
        super().__init__()

        self.filter_convs = dilated_inception(cin=cin, cout=cout, dilation_factor=dilation_factor, seq_len=seq_len)
        self.gated_convs = dilated_inception(cin=cin, cout=cout, dilation_factor=dilation_factor, seq_len=seq_len)

    def forward(self, X):
        filter = self.filter_convs(X) 
        filter = torch.tanh(filter)
        gate = self.gated_convs(X)
        gate = torch.sigmoid(gate)
        out = filter * gate
        return out

In [8]:
# (B,in_channels,height_or_nodes,time)
# height_or_nodes: 보통 1(단순 시계열)이거나, 여러 노드(그래프 시계열)인 경우 노드 차원일 수도 있음
X = torch.randn(1, 32, 6, 24) 

In [9]:
Model = temporal_conv(cin=32, cout=32, dilation_factor=1, seq_len=24)

In [10]:
out = Model(X) 
out.shape

torch.Size([1, 32, 6, 24])