In [4]:
from torchsummaryX import summary
import torch
from utils.config import Config
from model.speller import Speller
from model.listener import Listener
from model.listenAttendSpell import ListenAttendSpell

config = Config(
    use_bidirectional=True,
    input_reverse=True,
    use_augment=True,
    use_pickle=False,
    use_cuda=True,
    augment_num=1,
    hidden_dim=256,
    dropout=0.4,
    num_head=4,
    attn_dim=128,
    label_smoothing=0.1,
    listener_layer_size=5,
    speller_layer_size=3,
    batch_size=32,
    worker_num=1,
    max_epochs=40,
    lr=0.001,
    teacher_forcing_ratio=1.0,
    sr=16000,
    window_size=20,
    stride=10,
    n_mels=80,
    save_result_every=1000,
    save_model_every=10000,
    print_every=10,
    seed=1,
    max_len=151,
    load_model=False,
    model_path=None
)

listener = Listener(
    in_features=80,
    hidden_dim=config.hidden_dim,
    dropout_p=config.dropout,
    num_layers=config.listener_layer_size,
    bidirectional=config.use_bidirectional,
    rnn_type='gru',
    device='cpu'
)
speller = Speller(
    num_class=2040,
    max_length=config.max_len,
    k=8,
    hidden_dim=config.hidden_dim << (1 if config.use_bidirectional else 0),
    sos_id=2037,
    eos_id=2038,
    num_head=config.num_head,
    num_layers=config.speller_layer_size,
    rnn_type='gru',
    dropout_p=config.dropout,
    attn_dim=config.attn_dim
)
model = ListenAttendSpell(listener, speller)

[2020-05-04 05:23:20,493 config.py:113 - print_log()] use_bidirectional : True
[2020-05-04 05:23:20,495 config.py:114 - print_log()] use_pickle : False
[2020-05-04 05:23:20,497 config.py:115 - print_log()] use_augment : True
[2020-05-04 05:23:20,499 config.py:116 - print_log()] augment_num : 1
[2020-05-04 05:23:20,500 config.py:117 - print_log()] input_reverse : True
[2020-05-04 05:23:20,501 config.py:118 - print_log()] hidden_dim : 256
[2020-05-04 05:23:20,503 config.py:119 - print_log()] listener_layer_size : 5
[2020-05-04 05:23:20,505 config.py:120 - print_log()] speller_layer_size : 3
[2020-05-04 05:23:20,506 config.py:121 - print_log()] rnn_type : gru
[2020-05-04 05:23:20,507 config.py:122 - print_log()] num_head : 4
[2020-05-04 05:23:20,509 config.py:123 - print_log()] attn_dim : 128
[2020-05-04 05:23:20,510 config.py:124 - print_log()] dropout : 0.40
[2020-05-04 05:23:20,512 config.py:125 - print_log()] batch_size : 32
[2020-05-04 05:23:20,513 config.py:126 - print_log()] worker

In [5]:
model

ListenAttendSpell(
  (listener): Listener(
    (conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Hardtanh(min_val=0, max_val=20, inplace=True)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): Hardtanh(min_val=0, max_val=20, inplace=True)
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): Hardtanh(min_val=0, max_val=20, inplace=True)
      (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): Hardtanh(min_val=0, max_val=20, inplace=True)
      (12): MaxPool2d(kernel_size

In [None]:
batch_size = 32
seq_length = 120
feature_size = 80
max_length = 150

inputs = torch.zeros((batch_size, seq_length, feature_size))
scripts = torch.zeros((batch_size, seq_length, max_length), dtype=torch.long)
summary(model, inputs, scripts)