In [1]:
import torch
import torch.nn as nn
from utils.model import GCN

In [155]:
OBSERVED_LEN = 10
x = torch.rand(8, 66, OBSERVED_LEN)

In [219]:
class Conv1Channel(nn.Module):
    def __init__(self, nb_filters=1, filter_size=1, stride=1, dilation=1):
        super().__init__()
        self.conv = nn.Conv1d(1, nb_filters, filter_size, stride=stride, padding=0, dilation=dilation, groups=1, bias=True, padding_mode='zeros')

    def forward(self, x):
        shape = x.shape
        x = x.reshape(-1, shape[-1])
        x = x[:, None, :]
        x = self.conv(x)
        x = x.reshape(shape[0], shape[1], -1)
        return x

class TimeInceptionModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.convolutions = nn.ModuleList([])
        self.convolutions.append(Conv1Channel(nb_filters=12, filter_size=2))
        self.convolutions.append(Conv1Channel(nb_filters=9, filter_size=3))
        self.convolutions.append(Conv1Channel(nb_filters=7, filter_size=5))
        self.convolutions.append(Conv1Channel(nb_filters=6, filter_size=7))
#         self.convolutions.append(Conv1Channel(nb_filters=4, filter_size=3, dilation=2))
    
    def forward(self, x):
        out = x
        for conv in self.convolutions:
            y = conv(x)
            out = torch.cat((out, y), 2)
            print(y.shape)
        return out

In [218]:
mod = TimeInceptionModule()
mod(x).shape

torch.Size([8, 66, 108])
torch.Size([8, 66, 72])
torch.Size([8, 66, 42])
torch.Size([8, 66, 24])


torch.Size([8, 66, 256])

In [182]:
for k, v in mod.state_dict().items():
    print(k, v.shape)
print(sum([p.numel() for p in mod.parameters()]))

convolutions.0.conv.weight torch.Size([10, 1, 2])
convolutions.0.conv.bias torch.Size([10])
convolutions.1.conv.weight torch.Size([8, 1, 3])
convolutions.1.conv.bias torch.Size([8])
convolutions.2.conv.weight torch.Size([5, 1, 5])
convolutions.2.conv.bias torch.Size([5])
convolutions.3.conv.weight torch.Size([4, 1, 7])
convolutions.3.conv.bias torch.Size([4])
convolutions.4.conv.weight torch.Size([4, 1, 3])
convolutions.4.conv.bias torch.Size([4])
140


In [140]:
gcn = GCN(input_feature=35, hidden_feature=256, p_dropout=0.5,
                        num_stage=2, node_n=66)

In [26]:
x.shape

torch.Size([8, 66, 35])