In [140]:
import torch
from torch import nn
import numpy as np
from torch.nn import functional as F

## ReferenceEncoder

In [56]:
class ReferenceEncoder(nn.Module):
    def __init__(
        self,
        n_mel = 80,
        conv_channels1 = 32,
        conv_channels2 = 32,
        conv_channels3 = 64,
        conv_channels4 = 64,
        conv_channels5 = 128,
        conv_channels6 = 128,
        n_unit = 128
    ):
        super().__init__()
        conv_channels_list = [
            conv_channels1,
            conv_channels2,
            conv_channels3,
            conv_channels4,
            conv_channels5,
            conv_channels6,
        ]
        
        self.convs = nn.ModuleList()
        for layer in range(6):
            n_mel = (n_mel + 1) // 2
            
            in_channels = 1 if layer == 0 else conv_channels_list[layer-1]
            out_channels = conv_channels_list[layer]
            self.convs += [
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
            ]
        self.convs = nn.Sequential(*self.convs)
        
        n_mel *= conv_channels_list[-1]
        self.gru = nn.GRU(n_mel, n_unit, batch_first=True)
        
    def forward(self, seqs):
        out = self.convs(seqs) # (B, 1, seq_len, n_mel) -> (B, channels, new_seq_len, new_n_mel)
        out = torch.squeeze(out.reshape(out.shape[0], 1, out.shape[2], -1)) # (B, channels, new_seq_len, new_n_mel) -> (B, new_seq_len, channels * new_n_mel)
        out = self.gru(out) # (B, new_seq_len, channels * new_n_mel) -> (B, new_seq_len, n_unit)
        return out[0], out[1]

In [58]:
ref = ReferenceEncoder()
ref.cuda()
input = torch.zeros(10, 1, 16000, 80).to("cuda")
ref_output, hidden_state = ref(input)
print(ref_output.shape)

torch.Size([10, 250, 128])


## StyleTokenLayer

In [129]:
# https://qiita.com/m__k/items/646044788c5f94eadc8d

class StyleTokenLayer(nn.Module):
    def __init__(
        self,
        emb_size = 256,
        n_tokens = 10,
        device = "cuda"
    ):
        super().__init__()
        self.tanh = nn.Tanh()
        self.emb_size = emb_size,
        self.softmax = nn.Softmax(dim=2) # n_token方向にsoftmaxを取りたい
        self.tokens = torch.randn(n_tokens, emb_size).to(device)
    
    def forward(self, query=None, token_num=0):
        if query == None: #inference
            return self.tokens[token_num]
        
        query = torch.cat((query, query), 2) # この実装があってるかどうかわからない, ReferenceEncoderの出力次元が128でtokenのembeddingの次元が256なのでqueryを2つ重ねている
        tokens = self.tanh(self.tokens)
        key = tokens.repeat(query.shape[0], 1, 1) # (n_token, emb_size) -> (B, n_token, emb_size)
        value = key
        key = torch.transpose(key, 1, 2) # (B, n_token, emb_size) -> (B, emb_size, n_token)
        s = torch.bmm(query, key) / self.emb_size[0]**(1/2) # (B, seq_len, emb_size) @ (B, emb_size, n_token) -> (B, seq_len, n_token)  論文ではcosine similarityを用いていたが実装の簡単のためdot-productにした
        attention_weight = self.softmax(s) # (B, seq_len, n_token) -> (B, seq_len, n_token)
        out = torch.bmm(attention_weight, value) #(B, seq_len, n_token) @ (B, n_token, emb_size) -> (B, seq_len, emb_size)
        out = torch.mean(out, 1) # (B, seq_len, emb_size) -> (B, emb_size)
        
        return out, attention_weight

In [130]:
style = StyleTokenLayer()
style.cuda()
style_output, attention_weight = style(ref_output)
print(style_output.shape)

torch.Size([10, 256])


### GSTLayer

In [102]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class GSTLayer(nn.Module):
    def __init__(
        self,
        n_mel=80,
    ):
        super().__init__()
        self.ref = ReferenceEncoder(n_mel=n_mel).cuda()
        self.style = StyleTokenLayer().cuda()
        
    def forward(self, seqs=None, token_num=0): #inference時の挙動はreference audioを持ってくる形ではなく、tokenを指定してそれをembeddingとして用いるようにした
        if seqs == None: #inference
            token = self.style(token_num=token_num)
            return token
        else:
            seqs = pad_sequence(seqs, batch_first=True)
            seqs = torch.unsqueeze(seqs, 1)
            ref_out, ref_hidden = self.ref(seqs)
            style_out, style_weight = self.style(ref_out)
            style_out = torch.unsqueeze(style_out,1)
            return style_out, style_weight

In [103]:
gstLayer = GSTLayer()
gstLayer.cuda()
input = [torch.zeros(15000, 80).to("cuda"), torch.zeros(16000, 80).to("cuda")]
out, att = gstLayer(input)
print(out.shape)

torch.Size([2, 1, 256])


あとやること
StyleTokenLayerの出力expand asでexpandして、encoderの出力に足してdecoderに突っ込むようにいじる

## TextEncoder

In [109]:
class ConvEncoder(nn.Module):
    def __init__(
        self,
        num_vocab=40,
        embed_dim=256,
        conv_layers=3,
        conv_channels=256,
        conv_kernel_size=5,
    ):
        super().__init__()
        # 文字埋め込み
        self.embed = nn.Embedding(num_vocab, embed_dim, padding_idx=0)

        # 1次元畳み込みの重ね合わせ：局所的な依存関係のモデル化
        self.convs = nn.ModuleList()
        for layer in range(conv_layers):
            in_channels = embed_dim if layer == 0 else conv_channels
            self.convs += [
                nn.Conv1d(
                    in_channels,
                    conv_channels,
                    conv_kernel_size,
                    padding=(conv_kernel_size - 1) // 2,
                    bias=False,
                ),
                nn.BatchNorm1d(conv_channels),
                nn.ReLU(),
                nn.Dropout(0.5),
            ]
        self.convs = nn.Sequential(*self.convs)

    def forward(self, seqs):
        emb = self.embed(seqs)
        # 1 次元畳み込みと embedding では、入力のサイズが異なるので注意
        out = self.convs(emb.transpose(1, 2)).transpose(1, 2)
        return out

In [64]:
ConvEncoder()

ConvEncoder(
  (embed): Embedding(40, 256, padding_idx=0)
  (convs): Sequential(
    (0): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.5, inplace=False)
    (4): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.5, inplace=False)
    (8): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (9): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Dropout(p=0.5, inplace=False)
  )
)

In [133]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class Encoder(ConvEncoder):
    def __init__(
        self,
        num_vocab=40,
        embed_dim=512,
        hidden_dim=256,
        conv_layers=3,
        conv_channels=512,
        conv_kernel_size=5,
    ):
        super().__init__(
            num_vocab, embed_dim, conv_layers, conv_channels, conv_kernel_size
        )
        # 双方向 LSTM による長期依存関係のモデル化
        self.blstm = nn.LSTM(
            conv_channels, hidden_dim // 2, 1, batch_first=True, bidirectional=True
        )

    def forward(self, seqs, in_lens):
        emb = self.embed(seqs)
        # 1 次元畳み込みと embedding では、入力のサイズ が異なるので注意
        out = self.convs(emb.transpose(1, 2)).transpose(1, 2)

        # 双方向 LSTM の計算
        out = pack_padded_sequence(out, in_lens, batch_first=True)
        out, _ = self.blstm(out)
        out, _ = pad_packed_sequence(out, batch_first=True)
        return out

In [111]:
Encoder()

Encoder(
  (embed): Embedding(40, 512, padding_idx=0)
  (convs): Sequential(
    (0): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.5, inplace=False)
    (4): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.5, inplace=False)
    (8): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (9): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Dropout(p=0.5, inplace=False)
  )
  (blstm): LSTM(512, 128, batch_first=True, bidirectional=True)
)

## LocationSensitivaAttention

In [67]:
class LocationSensitiveAttention(nn.Module):
    def __init__(
        self,
        encoder_dim=512,
        decoder_dim=1024,
        hidden_dim=128,
        conv_channels=32,
        conv_kernel_size=31,
    ):
        super().__init__()
        self.V = nn.Linear(encoder_dim, hidden_dim)
        self.W = nn.Linear(decoder_dim, hidden_dim, bias=False)
        self.U = nn.Linear(conv_channels, hidden_dim, bias=False)
        self.F = nn.Conv1d(
            1,
            conv_channels,
            conv_kernel_size,
            padding=(conv_kernel_size - 1) // 2,
            bias=False,
        )
        # NOTE: 本書の数式通りに実装するなら bias=False ですが、実用上は bias=True としても問題ありません
        self.w = nn.Linear(hidden_dim, 1)

    def forward(self, encoder_outs, src_lens, decoder_state, att_prev, mask=None):
        # アテンション重みを一様分布で初期化
        if att_prev is None:
            att_prev = 1.0 - make_pad_mask(src_lens).to(
                device=decoder_state.device, dtype=decoder_state.dtype
            )
            att_prev = att_prev / src_lens.unsqueeze(-1).to(encoder_outs.device)

        # (B x T_enc) -> (B x 1 x T_enc) -> (B x conv_channels x T_enc) ->
        # (B x T_enc x conv_channels)
        f = self.F(att_prev.unsqueeze(1)).transpose(1, 2)

        # 式 (9.13) の計算
        erg = self.w(
            torch.tanh(
                self.W(decoder_state).unsqueeze(1) + self.V(encoder_outs) + self.U(f)
            )
        ).squeeze(-1)

        if mask is not None:
            erg.masked_fill_(mask, -float("inf"))

        attention_weights = F.softmax(erg, dim=1)

        # エンコーダ出力の長さ方向に対して重み付き和を取ります
        attention_context = torch.sum(
            encoder_outs * attention_weights.unsqueeze(-1), dim=1
        )

        return attention_context, attention_weights

In [68]:
LocationSensitiveAttention()

LocationSensitiveAttention(
  (V): Linear(in_features=512, out_features=128, bias=True)
  (W): Linear(in_features=1024, out_features=128, bias=False)
  (U): Linear(in_features=32, out_features=128, bias=False)
  (F): Conv1d(1, 32, kernel_size=(31,), stride=(1,), padding=(15,), bias=False)
  (w): Linear(in_features=128, out_features=1, bias=True)
)

## Pre-Net

In [69]:
class Prenet(nn.Module):
    def __init__(self, in_dim, layers=2, hidden_dim=256, dropout=0.5):
        super().__init__()
        self.dropout = dropout
        prenet = nn.ModuleList()
        for layer in range(layers):
            prenet += [
                nn.Linear(in_dim if layer == 0 else hidden_dim, hidden_dim),
                nn.ReLU(),
            ]
        self.prenet = nn.Sequential(*prenet)

    def forward(self, x):
        for layer in self.prenet:
            # 学習時、推論時の両方で Dropout を適用します
            x = F.dropout(layer(x), self.dropout, training=True)
        return x

In [70]:
Prenet(80)

Prenet(
  (prenet): Sequential(
    (0): Linear(in_features=80, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
  )
)

## ZoneOutCell

In [71]:
class ZoneOutCell(nn.Module):
    def __init__(self, cell, zoneout=0.1):
        super().__init__()
        self.cell = cell
        self.hidden_size = cell.hidden_size
        self.zoneout = zoneout

    def forward(self, inputs, hidden):
        next_hidden = self.cell(inputs, hidden)
        next_hidden = self._zoneout(hidden, next_hidden, self.zoneout)
        return next_hidden

    def _zoneout(self, h, next_h, prob):
        h_0, c_0 = h
        h_1, c_1 = next_h
        h_1 = self._apply_zoneout(h_0, h_1, prob)
        c_1 = self._apply_zoneout(c_0, c_1, prob)
        return h_1, c_1

    def _apply_zoneout(self, h, next_h, prob):
        if self.training:
            mask = h.new(*h.size()).bernoulli_(prob)
            return mask * h + (1 - mask) * next_h
        else:
            return prob * h + (1 - prob) * next_h

## Decoder

In [72]:
class Decoder(nn.Module):
    def __init__(
        self,
        encoder_hidden_dim=512,
        out_dim=80,
        layers=2,
        hidden_dim=1024,
        prenet_layers=2,
        prenet_hidden_dim=256,
        prenet_dropout=0.5,
        zoneout=0.1,
        reduction_factor=1,
        attention_hidden_dim=128,
        attention_conv_channels=32,
        attention_conv_kernel_size=31,
    ):
        super().__init__()
        self.out_dim = out_dim

        # 注意機構
        self.attention = LocationSensitiveAttention(
            encoder_hidden_dim,
            hidden_dim,
            attention_hidden_dim,
            attention_conv_channels,
            attention_conv_kernel_size,
        )
        self.reduction_factor = reduction_factor

        # Prenet
        self.prenet = Prenet(out_dim, prenet_layers, prenet_hidden_dim, prenet_dropout)

        # 片方向LSTM
        self.lstm = nn.ModuleList()
        for layer in range(layers):
            lstm = nn.LSTMCell(
                encoder_hidden_dim + prenet_hidden_dim if layer == 0 else hidden_dim,
                hidden_dim,
            )
            lstm = ZoneOutCell(lstm, zoneout)
            self.lstm += [lstm]

        # 出力への projection 層
        proj_in_dim = encoder_hidden_dim + hidden_dim
        self.feat_out = nn.Linear(proj_in_dim, out_dim * reduction_factor, bias=False)
        self.prob_out = nn.Linear(proj_in_dim, reduction_factor)

    def _zero_state(self, hs):
        init_hs = hs.new_zeros(hs.size(0), self.lstm[0].hidden_size)
        return init_hs

    def forward(self, encoder_outs, in_lens, decoder_targets=None):
        is_inference = decoder_targets is None

        # Reduction factor に基づくフレーム数の調整
        # (B, Lmax, out_dim) ->  (B, Lmax/r, out_dim)
        if self.reduction_factor > 1 and not is_inference:
            decoder_targets = decoder_targets[
                :, self.reduction_factor - 1 :: self.reduction_factor
            ]

        # デコーダの系列長を保持
        # 推論時は、エンコーダの系列長から経験的に上限を定める
        if is_inference:
            max_decoder_time_steps = int(encoder_outs.shape[1] * 10.0)
        else:
            max_decoder_time_steps = decoder_targets.shape[1]

        # ゼロパディングされた部分に対するマスク
        mask = make_pad_mask(in_lens).to(encoder_outs.device)

        # LSTM の状態をゼロで初期化
        h_list, c_list = [], []
        for _ in range(len(self.lstm)):
            h_list.append(self._zero_state(encoder_outs))
            c_list.append(self._zero_state(encoder_outs))

        # デコーダの最初の入力
        go_frame = encoder_outs.new_zeros(encoder_outs.size(0), self.out_dim)
        prev_out = go_frame

        # 1つ前の時刻のアテンション重み
        prev_att_w = None

        # メインループ
        outs, logits, att_ws = [], [], []
        t = 0
        while True:
            # コンテキストベクトル、アテンション重みの計算
            att_c, att_w = self.attention(
                encoder_outs, in_lens, h_list[0], prev_att_w, mask
            )

            # Pre-Net
            prenet_out = self.prenet(prev_out)

            # LSTM
            xs = torch.cat([att_c, prenet_out], dim=1)
            h_list[0], c_list[0] = self.lstm[0](xs, (h_list[0], c_list[0]))
            for i in range(1, len(self.lstm)):
                h_list[i], c_list[i] = self.lstm[i](
                    h_list[i - 1], (h_list[i], c_list[i])
                )
            # 出力の計算
            hcs = torch.cat([h_list[-1], att_c], dim=1)
            outs.append(self.feat_out(hcs).view(encoder_outs.size(0), self.out_dim, -1))
            logits.append(self.prob_out(hcs))
            att_ws.append(att_w)

            # 次の時刻のデコーダの入力を更新
            if is_inference:
                prev_out = outs[-1][:, :, -1]  # (1, out_dim)
            else:
                # Teacher forcing
                prev_out = decoder_targets[:, t, :]

            # 累積アテンション重み
            prev_att_w = att_w if prev_att_w is None else prev_att_w + att_w

            t += 1
            # 停止条件のチェック
            if t >= max_decoder_time_steps:
                break
            if is_inference and (torch.sigmoid(logits[-1]) >= 0.5).any():
                break
                
        # 各時刻の出力を結合
        logits = torch.cat(logits, dim=1)  # (B, Lmax)
        outs = torch.cat(outs, dim=2)  # (B, out_dim, Lmax)
        att_ws = torch.stack(att_ws, dim=1)  # (B, Lmax, Tmax)

        if self.reduction_factor > 1:
            outs = outs.view(outs.size(0), self.out_dim, -1)  # (B, out_dim, Lmax)

        return outs, logits, att_ws

## Post-Net

In [73]:
class Postnet(nn.Module):
    def __init__(
        self,
        in_dim=80,
        layers=5,
        channels=512,
        kernel_size=5,
        dropout=0.5,
    ):
        super().__init__()
        postnet = nn.ModuleList()
        for layer in range(layers):
            in_channels = in_dim if layer == 0 else channels
            out_channels = in_dim if layer == layers - 1 else channels
            postnet += [
                nn.Conv1d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    stride=1,
                    padding=(kernel_size - 1) // 2,
                    bias=False,
                ),
                nn.BatchNorm1d(out_channels),
            ]
            if layer != layers - 1:
                postnet += [nn.Tanh()]
            postnet += [nn.Dropout(dropout)]
        self.postnet = nn.Sequential(*postnet)

    def forward(self, xs):
        return self.postnet(xs)

In [74]:
Postnet()

Postnet(
  (postnet): Sequential(
    (0): Conv1d(80, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Tanh()
    (3): Dropout(p=0.5, inplace=False)
    (4): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Tanh()
    (7): Dropout(p=0.5, inplace=False)
    (8): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (9): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): Tanh()
    (11): Dropout(p=0.5, inplace=False)
    (12): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (13): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Tanh()
    (15): Dropout(p=0.5, inplace=False)
    (16): Conv1d(512, 80, kernel_size=(5,), strid

あとはつなげる


### GST-Tacotronモデルの定義

In [138]:
class GST_Tacotron(nn.Module):
    def __init__(self
    ):
        super().__init__()
        self.encoder = Encoder()
        self.gst = GSTLayer()
        self.decoder = Decoder()
        self.postnet = Postnet()

    def forward(self, seq, in_lens, decoder_targets):
        seq = seq.to("cuda")
        # エンコーダによるテキストに潜在する表現の獲得
        encoder_outs = self.encoder(seq, in_lens)
        
        # GSTによる音声に潜在する表現の獲得
        gst_outs, gst_att_ws = self.gst(decoder_targets)
        
        # エンコーダの出力とGSTの出力を足す
        encoder_outs += gst_outs.expand_as(encoder_outs) 
        encoder_outs = encoder_outs.repeat(1, 1, 2)
        print(encoder_outs.size())

        # デコーダによるメルスペクトログラム、stop token の予測
        outs, logits, att_ws = self.decoder(encoder_outs, in_lens, torch.stack(decoder_targets))

        # Post-Net によるメルスペクトログラムの残差の予測
        outs_fine = outs + self.postnet(outs)

        # (B, C, T) -> (B, T, C)
        outs = outs.transpose(2, 1)
        outs_fine = outs_fine.transpose(2, 1)

        return outs, outs_fine, encoder_outs, logits, att_ws, gst_att_ws 
    
    def inference(self, seq):
        seq = seq.unsqueeze(0) if len(seq.shape) == 1 else seq
        in_lens = torch.tensor([seq.shape[-1]], dtype=torch.long, device=seq.device)

        return self.forward(seq, in_lens, None)

### utils

In [136]:
def pad_1d(x, max_len, constant_values=0):
    """Pad a 1d-tensor.
    Args:
        x (torch.Tensor): tensor to pad
        max_len (int): maximum length of the tensor
        constant_values (int, optional): value to pad with. Default: 0
    Returns:
        torch.Tensor: padded tensor
    """
    x = np.pad(
        x,
        (0, max_len - len(x)),
        mode="constant",
        constant_values=constant_values,
    )
    return x


In [77]:
def get_dummy_input():
    # バッチサイズに 2 を想定して、適当な文字列を作成
    seqs = [
        text_to_sequence("What is your favorite language?"),
        text_to_sequence("Hello world."),
    ]
    in_lens = torch.tensor([len(x) for x in seqs], dtype=torch.long)
    max_len = max(len(x) for x in seqs)
    seqs = torch.stack([torch.from_numpy(pad_1d(seq, max_len)) for seq in seqs])
    
    return seqs, in_lens

In [92]:
def get_dummy_inout():
    seqs, in_lens = get_dummy_input()
   
    # デコーダの出力（メルスペクトログラム）の教師データ
    decoder_targets = [torch.ones(120, 80).to("cuda"), torch.ones(120, 80).to("cuda")]
    
    # stop token の教師データ
    # stop token の予測値は確率ですが、教師データは 二値のラベルです
    # 1 は、デコーダの出力が完了したことを表します
    stop_tokens = torch.zeros(2, 120)
    stop_tokens[:, -1:] = 1.0
    
    return seqs, in_lens, decoder_targets, stop_tokens

In [79]:
# 語彙の定義
characters = "abcdefghijklmnopqrstuvwxyz!'(),-.:;? "
# その他特殊記号
extra_symbols = [
    "^",  # 文の先頭を表す特殊記号 <SOS>
    "$",  # 文の末尾を表す特殊記号 <EOS>
]
_pad = "~"

# NOTE: パディングを 0 番目に配置
symbols = [_pad] + extra_symbols + list(characters)

# 文字列⇔数値の相互変換のための辞書
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}

In [80]:
def text_to_sequence(text):
    # 簡易のため、大文字と小文字を区別せず、全ての大文字を小文字に変換
    text = text.lower()

    # <SOS>
    seq = [_symbol_to_id["^"]]

    # 本文
    seq += [_symbol_to_id[s] for s in text]

    # <EOS>
    seq.append(_symbol_to_id["$"])

    return seq


def sequence_to_text(seq):
    return [_id_to_symbol[s] for s in seq]

In [113]:
def make_pad_mask(lengths, maxlen=None):
    """Make mask for padding frames
    Args:
        lengths (list): list of lengths
        maxlen (int, optional): maximum length. If None, use max value of lengths.
    Returns:
        torch.ByteTensor: mask
    """
    if not isinstance(lengths, list):
        lengths = lengths.tolist()
    bs = int(len(lengths))
    if maxlen is None:
        maxlen = int(max(lengths))

    seq_range = torch.arange(0, maxlen, dtype=torch.int64)
    seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
    seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
    mask = seq_range_expand >= seq_length_expand

    return mask

### トイモデル

In [141]:
seqs, in_lens, decoder_targets, stop_tokens = get_dummy_inout()
model = GST_Tacotron()
model.cuda()
outs, outs_fine, encoder_outs, logits, att_ws, gst_att_ws = model(seqs, in_lens, decoder_targets)

print("入力のサイズ:", tuple(seqs.shape))
print("エンコーダの出力のサイズ:", tuple(encoder_outs.shape))
print("デコーダの出力のサイズ:", tuple(outs.shape))
print("Stop token のサイズ:", tuple(logits.shape))
print("アテンション重みのサイズ:", tuple(att_ws.shape))

torch.Size([2, 33, 512])
入力のサイズ: (2, 33)
エンコーダの出力のサイズ: (2, 33, 512)
デコーダの出力のサイズ: (2, 120, 80)
Stop token のサイズ: (2, 120)
アテンション重みのサイズ: (2, 120, 33)
