<a href="https://colab.research.google.com/github/whdid502/stt_model_project/blob/encoder/STT_Encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# encoder



In [14]:
import torch
import torch.nn as nn
from torch import Tensor
from typing import Tuple, Optional, Any

In [11]:
class Listener(BaseRNN):
  def __init__(
            self,
            input_size: int,                       # size of input
            hidden_dim: int = 512,                 # dimension of RNN`s hidden state
            device: str = 'cuda',                  # device - 'cuda' or 'cpu'
            dropout_p: float = 0.3,                # dropout probability
            num_layers: int = 3,                   # number of RNN layers
            bidirectional: bool = True,            # if True, becomes a bidirectional encoder
            rnn_type: str = 'lstm',                # type of RNN cell
            extractor: str = 'vgg',                # type of CNN extractor
            activation: str = 'hardtanh',          # type of activation function
            mask_conv: bool = False                # flag indication whether apply mask convolution or not
    ) -> None:
        self.mask_conv = mask_conv
        self.extractor = extractor.lower()
        if self.extractor == 'vgg':
            input_size = (input_size - 1) << 5 if input_size % 2 else input_size << 5
            super(Listener, self).__init__(input_size, hidden_dim, num_layers, rnn_type, dropout_p, bidirectional, device)
            self.conv = VGGExtractor(activation, mask_conv)
        else:
            raise ValueError("Unsupported Extractor : {0}".format(extractor))

  def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
            conv_feat = self.conv(inputs.unsqueeze(1), input_lengths).to(self.device)
            conv_feat = conv_feat.transpose(1, 2)

            batch_size, seq_length, num_channels, hidden_dim = conv_feat.size()
            conv_feat = conv_feat.contiguous().view(batch_size, seq_length, num_channels * hidden_dim)

            if self.training:
                self.rnn.flatten_parameters()

            output, hidden = self.rnn(conv_feat)

            return output, hidden


In [12]:
class BaseRNN(nn.Module):
    supported_rnns = {
        'lstm': nn.LSTM,
        'gru': nn.GRU,
        'rnn': nn.RNN
    }

    def __init__(
            self,
            input_size: int,                       # size of input
            hidden_dim: int = 512,                 # dimension of RNN`s hidden state vector
            num_layers: int = 1,                   # number of recurrent layers
            rnn_type: str = 'lstm',                # number of RNN layers
            dropout_p: float = 0.3,                # dropout probability
            bidirectional: bool = True,            # if True, becomes a bidirectional rnn
            device: str = 'cuda'                   # device - 'cuda' or 'cpu'
    ) -> None:
        super(BaseRNN, self).__init__()
        rnn_cell = self.supported_rnns[rnn_type]
        self.rnn = rnn_cell(input_size, hidden_dim, num_layers, True, True, dropout_p, bidirectional)
        self.hidden_dim = hidden_dim
        self.device = device

    def forward(self, *args, **kwargs):
        raise NotImplementedError


In [15]:
class CNNExtractor(nn.Module):
    supported_activations = {
        'hardtanh': nn.Hardtanh(0, 20, inplace=True),
        'relu': nn.ReLU(inplace=True),
        'elu': nn.ELU(inplace=True),
        'leaky_relu': nn.LeakyReLU(inplace=True),
        'gelu': nn.GELU()
    }

    def __init__(self, activation: str = 'hardtanh') -> None:
        super(CNNExtractor, self).__init__()
        self.activation = CNNExtractor.supported_activations[activation]

    def forward(self, inputs: Tensor, input_lengths: Tensor) -> Optional[Any]:
        raise NotImplementedError

In [16]:
class VGGExtractor(CNNExtractor):
    def __init__(self, activation: str, mask_conv: bool):
        super(VGGExtractor, self).__init__(activation)
        self.mask_conv = mask_conv
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=64),
            self.activation,
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=64),
            self.activation,
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=128),
            self.activation,
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=128),
            self.activation,
            nn.MaxPool2d(2, stride=2)
        )

    def forward(self, inputs: Tensor, input_lengths: Tensor) -> Optional[Any]:
        conv_feat = self.conv(inputs)
        output = conv_feat

        return output

In [18]:
x = torch.randn(20,40)
listener = Listener(20)
y = listener.forward(x, x.size())
print(y)

TypeError: ignored