<a href="https://colab.research.google.com/github/whdid502/stt_model_project/blob/decoder/attention_decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
from torch import Tensor

import numpy as np
import time
from typing import Tuple, Optional, Any

## Encoder

In [2]:
class BaseRNN(nn.Module):
    supported_rnns = {
        'lstm': nn.LSTM,
        'gru': nn.GRU,
        'rnn': nn.RNN
    }

    def __init__(
            self,
            input_size: int,                       # size of input
            hidden_dim: int = 512,                 # dimension of RNN`s hidden state vector
            num_layers: int = 1,                   # number of recurrent layers
            rnn_type: str = 'lstm',                # number of RNN layers
            dropout_p: float = 0.3,                # dropout probability
            bidirectional: bool = True,            # if True, becomes a bidirectional rnn
            device: str = 'cuda'                   # device - 'cuda' or 'cpu'
    ) -> None:
        super(BaseRNN, self).__init__()
        rnn_cell = self.supported_rnns[rnn_type]
        self.rnn = rnn_cell(input_size, hidden_dim, num_layers, True, True, dropout_p, bidirectional)
        self.hidden_dim = hidden_dim
        self.device = device

    def forward(self, *args, **kwargs):
        raise NotImplementedError


In [3]:
class CNNExtractor(nn.Module):
    supported_activations = {
        'hardtanh': nn.Hardtanh(0, 20, inplace=True),
        'relu': nn.ReLU(inplace=True),
        'elu': nn.ELU(inplace=True),
        'leaky_relu': nn.LeakyReLU(inplace=True),
        'gelu': nn.GELU()
    }

    def __init__(self, activation: str = 'hardtanh') -> None:
        super(CNNExtractor, self).__init__()
        self.activation = CNNExtractor.supported_activations[activation]

    def forward(self, inputs: Tensor, input_lengths: Tensor) -> Optional[Any]:
        raise NotImplementedError

In [4]:
class VGGExtractor(CNNExtractor):
    def __init__(self, activation: str, mask_conv: bool):
        super(VGGExtractor, self).__init__(activation)
        self.mask_conv = mask_conv
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=64),
            self.activation,
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=64),
            self.activation,
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=128),
            self.activation,
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=128),
            self.activation,
            nn.MaxPool2d(2, stride=2)
        )

    def forward(self, inputs: Tensor, input_lengths: Tensor) -> Optional[Any]:
        conv_feat = self.conv(inputs)
        output = conv_feat

        return output

In [5]:
class Listener(BaseRNN):
  def __init__(
            self,
            input_size: int,                       # size of input
            hidden_dim: int = 512,                 # dimension of RNN`s hidden state
            device: str = 'cuda',                  # device - 'cuda' or 'cpu'
            dropout_p: float = 0.3,                # dropout probability
            num_layers: int = 3,                   # number of RNN layers
            bidirectional: bool = True,            # if True, becomes a bidirectional encoder
            rnn_type: str = 'lstm',                # type of RNN cell
            extractor: str = 'vgg',                # type of CNN extractor
            activation: str = 'hardtanh',          # type of activation function
            mask_conv: bool = False                # flag indication whether apply mask convolution or not
    ) -> None:
        self.mask_conv = mask_conv
        self.extractor = extractor.lower()
        self.device = device

        if self.extractor == 'vgg':
            input_size = (input_size - 1) << 5 if input_size % 2 else input_size << 5
            super(Listener, self).__init__(input_size, hidden_dim, num_layers, rnn_type, dropout_p, bidirectional, device)
            self.conv = VGGExtractor(activation, mask_conv)
        else:
            raise ValueError("Unsupported Extractor : {0}".format(extractor))

  def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
    conv_feat = self.conv(inputs.unsqueeze(1), input_lengths).to(self.device)
    conv_feat = conv_feat.transpose(1, 2)

    batch_size, seq_length, num_channels, hidden_dim = conv_feat.size()
    conv_feat = conv_feat.contiguous().view(batch_size, seq_length, num_channels * hidden_dim)

    if self.training:
        self.rnn.flatten_parameters()

    output, hidden = self.rnn(conv_feat)

    return output, hidden


## Multi-head Attention

In [6]:
def scaled_dot_product_attention(q, k, v, mask) :
    scaled_attention_logits = torch.bmm(q, k.transpose(1,2)) / np.sqrt(k.size(-1))

    if mask is not None :
        scaled_attention_logits.masked_fill_(mask, -1e9)

    attention_weights = torch.nn.functional.softmax(scaled_attention_logits, -1)
    output = torch.bmm(attention_weights, v)
    
    return output, attention_weights

In [7]:
class MultiHeadAttention(torch.nn.Module) :
    def __init__(self, d_model=512, num_heads=8) :
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads

        assert d_model % num_heads == 0

        self.depth = d_model // num_heads
        
        self.wq = torch.nn.Linear(d_model, d_model, bias=True)
        self.wk = torch.nn.Linear(d_model, d_model, bias=True)
        self.wv = torch.nn.Linear(d_model, d_model, bias=True)

        self.linear = torch.nn.Linear(d_model, d_model, bias=True) # ??

    def forward(self, q, k, v, mask=None) :        
        batch_size = v.size(0)

        q = self.wq(q).view(batch_size, -1, self.num_heads, self.depth)
        k = self.wk(k).view(batch_size, -1, self.num_heads, self.depth)
        v = self.wv(v).view(batch_size, -1, self.num_heads, self.depth)

        # split heads
        q = q.permute(2,0,1,3).contiguous().view(batch_size * self.num_heads, -1, self.depth)
        k = k.permute(2,0,1,3).contiguous().view(batch_size * self.num_heads, -1, self.depth)
        v = v.permute(2,0,1,3).contiguous().view(batch_size * self.num_heads, -1, self.depth)

        if mask is not None :
            mask = mask.repeat(self.num_heads, 1, 1)

        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)

        scaled_attention = scaled_attention.view(self.num_heads, batch_size, -1, self.depth)
        scaled_attention = scaled_attention.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.d_model)
        output = self.linear(scaled_attention) # TODO : check

        return output, attention_weights

In [8]:
# temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
# y = torch.rand((1, 60, 512))  # (batch_size, encoder_sequence, d_model)
# out, attn = temp_mha(y, y, y, mask=None)

# display(out.shape, attn.shape)
# display(y)
# out

## Decode

In [9]:
class DecoderStep(torch.nn.Module) :
    def __init__(self, num_classes, LSTM_num=1, d_model=1024, num_heads=4, dropout_p=0.3, device='cuda'):
        super(DecoderStep, self).__init__()
        self.d_model = d_model
        self.device = device

        self.embedding = torch.nn.Embedding(num_classes, d_model)
        self.input_dropout = torch.nn.Dropout(dropout_p)

        self.uniDirLSTM = torch.nn.LSTM(input_size=d_model, hidden_size=d_model, num_layers=LSTM_num, bias=True, batch_first=True, dropout=dropout_p, bidirectional=False)

        self.mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        
        self.layernorm1 = torch.nn.LayerNorm(d_model, eps=1e-6)
        self.layernorm2 = torch.nn.LayerNorm(d_model, eps=1e-6)

        self.linear1 = torch.nn.Linear(d_model, d_model, bias=True)
        self.linear2 = torch.nn.Linear(d_model, num_classes, bias=False)

    def forward(self, input_var, hidden, enc_output) :
        # enc_output.shape == (batch_size, input_seq_len, d_model)
        batch_size, output_lengths = input_var.size(0), input_var.size(1)

        embedded = self.embedding(input_var).to(self.device)
        embedded = self.input_dropout(embedded)

        if self.training :
            self.uniDirLSTM.flatten_parameters()

        out1, hidden = self.uniDirLSTM(embedded, hidden)
        
        context, attn_weights_block = self.mha(out1, enc_output, enc_output) # (batch_size, target_seq_len, d_model)
        out2 = self.layernorm1(context + out1).view(-1, self.d_model) # (batch_size, target_seq_len, d_model)

        out_proj = self.linear1(out2)
        output = self.layernorm2(out_proj + out2).view(batch_size, -1, self.d_model) # (batch_size, target_seq_len, d_model)

        output = self.linear2(torch.tanh(output).contiguous().view(-1, self.d_model))

        output = torch.nn.functional.log_softmax(output, dim=1)
        output = output.view(batch_size, output_lengths, -1).squeeze(1)

        return output, hidden, attn_weights_block

In [10]:
class Decoder(torch.nn.Module) :
    def __init__(self, num_classes, max_length=150, d_model=1024, num_heads=4, LSTM_num=2, dropout_p=0.3, device='cuda'):
        super(Decoder, self).__init__()

        self.d_model = d_model
        # self.num_layers = num_layers

        self.dec_layer = DecoderStep(num_classes=num_classes, LSTM_num=LSTM_num, d_model=d_model, num_heads=num_heads, dropout_p=dropout_p, device=device)

    def forward(self, inputs, enc_outputs) :
        assert enc_outputs is not None or inputs is not None

        hidden = None
        result, decode_dict = list(), dict()

        if not self.training :
            decode_dict[Speller.KEY_ATTENTION_SCORE] = list()
            decode_dict[Speller.KEY_SEQUENCE_SYMBOL] = list()

        max_lengths = inputs.size(1) - 1 # minus the start of sequence symbol

        input_var = inputs[:, 0].unsqueeze(1)
        
        # TODO : delete
        print("🎶 input_var size : ", input_var.size())

        for di in range(max_lengths) :
            step_output, hidden, attn_weights_block = self.dec_layer(input_var, hidden, enc_outputs)
            result.append(step_output)
            input_var = result[-1].topk(1)[1]

            if not self.training :
                decode_dict[Speller.KEY_ATTENTION_SCORE].append(attn)
                decode_dict[Speller.KEY_SEQUENCE_SYMBOL].append(input_var)
                eos_batches = input_var.data.eq(2) # eq(eos_id)

                if eos_batches.dim() > 0 :
                    eos_batches = eos_batches.cpu().view(-1).numpy()
                    update_idx = ((lengths > di) & eos_batches) != 0
                    lengths[update_idx] = len(decode_dict[Speller.KEY_SEQUENCE_SYMBOL])

        del decode_dict # TODO : check that this is necessary

        return result

In [11]:
class LAS(torch.nn.Module) :
    def __init__(self, num_classes, input_size=80, hidden_dim=512, dropout_p=0.15, mask_conv=None, max_len=150, num_heads=4, 
                 dec_num_layers=2, enc_num_layers=3, device='cuda'):
        super(LAS, self).__init__()

        self.encoder = Listener(input_size=input_size, hidden_dim=hidden_dim, device=device, dropout_p=dropout_p, num_layers=enc_num_layers)
        self.decoder = Decoder(num_classes=num_classes, max_length=max_len, d_model=hidden_dim << 1, LSTM_num=dec_num_layers, dropout_p=dropout_p, device=device)

        # TODO : check
        # flatten parameter
        self.encoder.rnn.flatten_parameters()
        self.decoder.dec_layer.uniDirLSTM.flatten_parameters()

    def forward(self, inputs, input_lengths, targets=None):
        output, hidden = self.encoder(inputs, input_lengths)
        print("😎 encoding done -> output size : ", output.size())

        result = self.decoder(targets, output)
        print("🧐 decoder done")

        return result

## data 준비

In [12]:
# TODO
# 0. csv 파일 내에 있는 data path를 통해 audio->feature vector & label 을 tuple로 묶어서 list에 넣어두기
# 1. max length로 padding

## Train

In [13]:
# def train_step(model, epoch, inputs, targets, device='cuda') :
#     cer = 1.0
#     epoch_loss_total = 0.
#     total_num = 0
#     # timestep = 0

#     train_start_time = time.time()

#     model.train() # model을 train mode로 변경

#     inputs = inputs.to(device)
#     targets = targets.to(device)
#     model = model.to(device)

#     if isinstance(model, nn.DataParallel):
#         model.module.flatten_parameters()
#     else:
#         model.flatten_parameters()


#     result = model(inputs, input_lengths, targets)
#     result = torch.stack(result, dim=1).to(device)
#     targets = targets[:, 1:] # TODO : check that this is necessary -> 맨 앞에 하나를 빼고 target을 넣음????

#     top_1 = result.max(-1)[1]


In [14]:
# output = net(input)
# target = torch.randn(10)  # a dummy target, for example
# target = target.view(1, -1)  # make it the same shape as output
# criterion = nn.MSELoss()

# loss = criterion(output, target)
# print(loss)

In [15]:
# def train(model, batch_size, num_epochs) :
#     print("[INFO] train start")

#     for epoch in range(num_epochs) :
#         # train
#         epoch_loss, epoch_cer = train_step(model, epoch)

#         # checkpoint 저장
#         if (epoch+1) % 5 == 0 :
#             pass

#         # print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
#         #                                         train_loss.result(), 
#         #                                         train_accuracy.result()))

#         # print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

#     optimizer = torch.optim.Adam(model.module.parameter(), lr=learning_rate, weight_decay=weight_decay)


## 실행

In [16]:
# hyper-parameter
INPUT_SIZE = 80
HIDDEN_DIM = 256
DROPOUT_P = 0.15
MAX_LEN = 150
NUM_HEADS = 4
ENC_NUM_LAYERS = 3
DEC_NUM_LAYERS = 2
DEVICE = 'cuda'

NUM_CLASSES = 50 # TODO : dataset으로 label의 개수 넣어주기
LEARNING_RATE = 1e-06
WEIGHT_DECAY = 1e-05
BATCH_SIZE = 8
NUM_EPOCHS = 20

In [17]:
tmp_input = torch.rand((BATCH_SIZE,951,INPUT_SIZE), dtype=torch.float64).uniform_(0,200)
tmp_input_length = torch.randint(0, INPUT_SIZE, size=(BATCH_SIZE,))
tmp_target = torch.rand((BATCH_SIZE,59), dtype=torch.float64).uniform_(0,49)

tmp_input = tmp_input.float()
tmp_target = tmp_target.long()

model = torch.nn.DataParallel(LAS(num_classes=NUM_CLASSES, input_size=INPUT_SIZE, hidden_dim=HIDDEN_DIM, dropout_p=DROPOUT_P, max_len=12, 
                                  num_heads=NUM_HEADS, dec_num_layers=DEC_NUM_LAYERS, enc_num_layers=ENC_NUM_LAYERS, device=DEVICE)).to('cuda')

print("model 초기화 성공")

tmp_input.to('cuda')
tmp_target.to('cuda')
model.to('cuda')

fn_out = model(tmp_input, tmp_input_length, tmp_target)

len(fn_out)

model 초기화 성공
😎 encoding done -> output size :  torch.Size([8, 237, 512])
🎶 input_var size :  torch.Size([8, 1])
🧐 decoder done


58