In [1]:
import torch
from torch import nn
from torch.nn import functional as F
import torchaudio

import os
import sys

In [2]:
LIB_PATH = '/content/drive/MyDrive/GSC/GSC_helper'

sys.path.append(LIB_PATH)
from MDTC import MDTC
from GSC_zip import zipzip, unzipzip

In [4]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [5]:
from torchinfo import summary

## Model

In [13]:
class MDTC_training(nn.Module):
    def __init__(self, lr, in_channels, num_classes, *args, **kwargs):
        super().__init__(*args, **kwargs)
        #self.automatic_optimization = False
        self.lr = lr
        self.linear = nn.Linear(in_channels, 64)
        self.net1 = MDTC(in_channels = 64,
             out_channels = 64,
             kernel_size = 7,
             stack_num = 4,
             stack_size = 5)
        self.net2 = MDTC(in_channels = 64,
             out_channels = 64,
             kernel_size = 5,
             stack_num = 3,
             stack_size = 4)
        self.net3 = MDTC(in_channels = 64,
             out_channels = 64,
             kernel_size = 4,
             stack_num = 3,
             stack_size = 3)
        self.net4 = MDTC(in_channels = 64,
             out_channels = 64,
             kernel_size = 3,
             stack_num = 3,
             stack_size = 4,
             classification = True,
             hidden_size = 64,
             num_classes = 12)

    def forward(self, input):
        input = self.linear(input.squeeze().transpose(1, 2))
        input = self.net1(input.transpose(1, 2))
        input = self.net2(input)
        input = self.net3(input)
        input = self.net4(input)
        return input

In [14]:
net = MDTC_training(lr = 0.001, in_channels = 80, num_classes = 12)
summary(net, input_size = (128, 1, 80, 101))

Layer (type:depth-idx)                             Output Shape              Param #
MDTC_training                                      [128, 12]                 --
├─Linear: 1-1                                      [128, 101, 64]            5,184
├─MDTC: 1-2                                        [128, 64, 101]            --
│    └─DTCBlock: 2-1                               [128, 64, 101]            --
│    │    └─CausalConv1d: 3-1                      [128, 64, 101]            512
│    │    └─BatchNorm1d: 3-2                       [128, 64, 101]            128
│    │    └─Conv1d: 3-3                            [128, 64, 101]            4,160
│    │    └─BatchNorm1d: 3-4                       [128, 64, 101]            128
│    │    └─ReLU: 3-5                              [128, 64, 101]            --
│    │    └─Conv1d: 3-6                            [128, 64, 101]            4,160
│    │    └─BatchNorm1d: 3-7                       [128, 64, 101]            128
│    │    └─ReLU: 3-8 