In [1]:
from torchsummaryX import summary
import torch
from package.config import Config
from model.speller import Speller
from model.listener import Listener
from model.listenAttendSpell import ListenAttendSpell

config = Config(
    use_bidirectional=True,
    use_label_smooth=True,
    input_reverse=True,
    use_augment=False,
    use_pickle=False,
    use_cuda=True,
    augment_ratio=1.0,
    hidden_dim=256,
    dropout=0.5,
    listener_layer_size=5,
    speller_layer_size=3,
    batch_size=32,
    worker_num=1,
    max_epochs=40,
    use_multistep_lr=False,
    init_lr=0.0001,
    high_plateau_lr=0.0003,
    low_plateau_lr=0.00001,
    teacher_forcing=1.0,
    seed=1,
    n_head=12,
    max_len=151,
    load_model=False,
    model_path=None
)

listener = Listener(
    in_features=80,
    hidden_dim=config.hidden_dim,
    dropout_p=config.dropout,
    n_layers=config.listener_layer_size,
    bidirectional=config.use_bidirectional,
    rnn_type='gru',
    device='cpu'
)
speller = Speller(
    n_class=2040,
    max_length=config.max_len,
    k=8,
    hidden_dim=config.hidden_dim << (1 if config.use_bidirectional else 0),
    sos_id=2037,
    eos_id=2038,
    n_head=config.n_head,
    n_layers=config.speller_layer_size,
    rnn_type='gru',
    dropout_p=config.dropout,
    device='cpu'
)
model = ListenAttendSpell(listener, speller)

[2020-04-30 03:09:07,099 config.py:93 - print_log()] use_bidirectional : True
[2020-04-30 03:09:07,104 config.py:94 - print_log()] use_pickle : False
[2020-04-30 03:09:07,106 config.py:95 - print_log()] use_augment : False
[2020-04-30 03:09:07,108 config.py:96 - print_log()] augment_ratio : 1.00
[2020-04-30 03:09:07,110 config.py:97 - print_log()] input_reverse : True
[2020-04-30 03:09:07,112 config.py:98 - print_log()] hidden_dim : 256
[2020-04-30 03:09:07,115 config.py:99 - print_log()] listener_layer_size : 5
[2020-04-30 03:09:07,117 config.py:100 - print_log()] speller_layer_size : 3
[2020-04-30 03:09:07,119 config.py:101 - print_log()] n_head : 12
[2020-04-30 03:09:07,122 config.py:102 - print_log()] dropout : 0.50
[2020-04-30 03:09:07,125 config.py:103 - print_log()] batch_size : 32
[2020-04-30 03:09:07,127 config.py:104 - print_log()] worker_num : 1
[2020-04-30 03:09:07,130 config.py:105 - print_log()] max_epochs : 40
[2020-04-30 03:09:07,132 config.py:106 - print_log()] initial

In [2]:
model

ListenAttendSpell(
  (listener): Listener(
    (conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Hardtanh(min_val=0, max_val=20, inplace=True)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): Hardtanh(min_val=0, max_val=20, inplace=True)
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): Hardtanh(min_val=0, max_val=20, inplace=True)
      (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): Hardtanh(min_val=0, max_val=20, inplace=True)
      (12): MaxPool2d(kernel_size

In [None]:
batch_size = 32
seq_length = 120
feature_size = 80
max_length = 150

inputs = torch.zeros((batch_size, seq_length, feature_size))
scripts = torch.zeros((batch_size, seq_length, max_length), dtype=torch.long)
summary(model, inputs, scripts)