In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
import numpy as np

In [None]:
import japanize_matplotlib

# warning表示off
import warnings
warnings.simplefilter('ignore')

# デフォルトフォントサイズ変更
plt.rcParams['font.size'] = 14

# デフォルトグラフサイズ変更
plt.rcParams['figure.figsize'] = (6,6)

# デフォルトで方眼表示ON
plt.rcParams['axes.grid'] = True

# numpyの表示桁数設定
np.set_printoptions(suppress=True, precision=5)

In [None]:
import os
from ttslearn.env import is_colab
from os.path import exists

# recipeのディレクトリに移動
cwd = os.getcwd()
if cwd.endswith("notebooks"):
    os.chdir("../recipes/transtron/")
elif is_colab():
    os.chdir("recipes/transtron/")   

In [None]:
class LayerNorm(nn.LayerNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)


class Linear(nn.Linear):
    def forward(self, x):
        return F.linear(
            x,
            self.weight.to(x.dtype),
            None if self.bias is None else self.bias.to(x.dtype),
        )

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

    def forward(
        self,
        x,
        xa = None,
        mask = None
    ):
        #print( " size of x:{}".format( x.size() ))
        q = self.query(x)

        if xa is None:
            # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
            # otherwise, perform key/value projections for self- or cross-attention as usual.
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            k = self.key( xa )
            v = self.value( xa )

        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk

    def qkv_attention(
        self, q, k, v, mask = None
    ):
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        qk = q @ k
        if mask is not None:
            qk = qk + mask[:n_ctx, :n_ctx]
        qk = qk.float()

        w = F.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()

In [None]:
tmp_self_attn = MultiHeadAttention( 512, 2 )

input_self_attn = torch.ones( (2,130,512))
input_self_attn_mask = torch.ones( (512,512))

a = tmp_self_attn( input_self_attn, mask = input_self_attn_mask )
print( a )

In [None]:
tmp_cross_attn = MultiHeadAttention( 512, 2)

input_corss_attn_a = torch.ones( (2,130,512)) 
input_cross_attn_b = torch.ones( (2,1300,512)) 
input_cross_attn_mask = None

b = tmp_cross_attn( input_corss_attn_a, input_cross_attn_b )

print( b )

In [None]:
class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        super().__init__()

        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)

        self.cross_attn = MultiHeadAttention(n_state, n_head ) if cross_attention else None
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(nn.Linear(n_state, n_mlp), nn.ReLU(), nn.Linear(n_mlp, n_state))
        self.mlp_ln = LayerNorm(n_state)

    def forward(
        self,
        x,
        xa,
        mask = None
    ):
        #x = x + self.attn(self.attn_ln(x), self.attn_ln(x), self.attn_ln(x), attn_mask=mask)[0]
        x = x + self.attn(self.attn_ln(x), mask=mask)[0]
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, mask = None)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x

In [None]:
encoder_layer = ResidualAttentionBlock( n_state = 512, n_head = 2, cross_attention = False)
#encoder_layer.eval()

x = torch.ones( (2, 130, 512 ))
mask = torch.ones( ( 130, 130 ))

a = encoder_layer( x, x, mask = None )

print( a )

In [None]:
class Encoder(nn.Module):
    def __init__(
        self,
        num_vocab=51,
        embed_dim=512,
        conv_layers=3,
        conv_channels=512,
        conv_kernel_size=5,
        num_enc_layers = 3,
        num_heads = 2,
        enc_dropout_rate = 0.1,
        conv_dropout_rate = 0.1,
        input_maxlen = 300,
        ffn_dim = 1024
    ):
        super(Encoder, self).__init__()
        # 文字の埋め込み表現
        self.embed = nn.Embedding(num_vocab, embed_dim, padding_idx=0)
        self.pos_emb = nn.Embedding(input_maxlen, embed_dim)
        # 1 次元畳み込みの重ね合わせ：局所的な時間依存関係のモデル化
        convs = nn.ModuleList()
        for layer in range(conv_layers):
            in_channels = embed_dim if layer == 0 else embed_dim
            convs += [
                nn.Conv1d(
                    in_channels,
                    conv_channels,
                    conv_kernel_size,
                    padding=(conv_kernel_size - 1) // 2,
                    bias=False,  # この bias は不要です
                ),
                nn.BatchNorm1d(conv_channels),
                nn.ReLU(),
                nn.Dropout(conv_dropout_rate),
            ]
        self.convs = nn.Sequential(*convs)
        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(embed_dim, num_heads) for _ in range(num_enc_layers)]
        )
        self.input_maxlen = input_maxlen
        hidden_dim = embed_dim
                
        #self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
        #    [ResidualAttentionBlock(hidden_dim, num_heads) for _ in range(n_layer)]
        #)
        self.dropout = nn.Dropout(p=enc_dropout_rate)
        self.num_enc_layers = num_enc_layers
        
    def forward(self, x, in_lens ):
        emb = self.embed(x)
        # 1 次元畳み込みと embedding では、入力のサイズ が異なるので注意
        out = self.convs(emb.transpose(1, 2)).transpose(1, 2)
        #print( "encoder out:{}".format( out ))
        maxlen = out.size()[1]
        #print( "size of out:{}".format( out.size()))
        positions = torch.range(start=0, end=self.input_maxlen - 1, step=1).to(torch.long)
        positions = self.pos_emb(positions)[:maxlen,:]
        #print( "size of positions:{}".format( positions.size()))
        x = out + positions
        #print( "0 encoder x:{}".format( x ))
        x = self.dropout( x )
        #print( "1 encoder x:{}".format( x ))
        #for i in range(self.num_enc_layers):
        #    x = self.enc_layers[i](x )
        #print( "2 x:{}".format( x ))
        for block in self.blocks:
            x = block(x, x, mask = None)
        
        return x  # (batch_size, input_seq_len, d_model)

In [None]:
tmp_encoder = Encoder(
        num_vocab=51,
        embed_dim=512,
        conv_layers=3,
        conv_channels=512,
        conv_kernel_size=5,
        num_enc_layers = 3,
        num_heads = 2,
        enc_dropout_rate = 0.1,
        conv_dropout_rate = 0.1,
        input_maxlen = 300,
        ffn_dim = 1024
    )

In [None]:
#tmp_encoder.eval()
a = torch.ones( (2, 130), dtype=torch.long)

in_lens = []
for i in a:
    in_lens.append( len(a)) 


b = tmp_encoder( a, in_lens )
print( " size of b:{}".format( b.size()))
print( b )

In [None]:
class Prenet(nn.Module):
    def __init__(self, in_dim, layers=2, hidden_dim=256, dropout=0.5):
        super().__init__()
        self.dropout = dropout
        prenet = nn.ModuleList()
        for layer in range(layers):
            prenet += [
                nn.Linear(in_dim if layer == 0 else hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout) # added by Toshio Uchiyama
            ]
        self.prenet = nn.Sequential(*prenet)

    def forward(self, x):
        for layer in self.prenet:
            # 学習時、推論時の両方で Dropout を適用します
            #x = F.dropout(layer(x), self.dropout, training=True)
            x = layer(x) # original is above
        return x

In [None]:
decoder_layer = ResidualAttentionBlock( n_state = 512, n_head = 2, cross_attention = True)

In [None]:
#decoder_layer.eval()

x = torch.ones( (2, 1300, 512 ))
mask = torch.ones( ( 1300, 1300 ))

a = decoder_layer( x, x, mask = mask )

print( a )

In [None]:
class Decoder(nn.Module):
    def __init__(
        self,
        decoder_hidden_dim=512,
        out_dim=80,
        layers=4,
        prenet_layers=2,
        prenet_hidden_dim=512,
        prenet_dropout=0.5,
        ffn_dim=1024,
        dropout_rate = 0.1,
        dec_input_maxlen=3000,
        num_heads = 2
    ):
        super().__init__()
        self.out_dim = out_dim
        self.num_heads = num_heads

        # Prenet
        self.prenet = Prenet(out_dim, prenet_layers, prenet_hidden_dim, prenet_dropout)
        #self.prenet = nn.Linear( out_dim, prenet_hidden_dim )
        #self.prenet.eval()

        #  DecoderLayer
        #self.dec_layers = [DecoderLayer(decoder_hidden_dim, num_heads, ffn_dim, dropout_rate) 
        #               for _ in range(layers)]
        
        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(decoder_hidden_dim, num_heads, cross_attention=True) for _ in range(layers)]
        )
        #self.blocks.eval()
        
        #self.dropout = nn.Dropout(dropout_rate)
        self.pos_emb = nn.Embedding(dec_input_maxlen, decoder_hidden_dim)
        #self.pos_emb.eval()
        
 
        # 出力への projection 層
        proj_in_dim = decoder_hidden_dim
        #print( "proj_in_dim:{}".format( proj_in_dim ))
        #print( "out_dim:{}".format( out_dim ))
        self.feat_out = nn.Linear(proj_in_dim, out_dim, bias=False)
        #self.feat_out.eval()
        self.prob_out = nn.Linear(proj_in_dim, 1)
        #self.prob_out.eval()
        
        self.dec_input_maxlen = dec_input_maxlen
        self.layers = layers
        hidden_dim = decoder_hidden_dim


    def forward(self, encoder_outs, in_lens, decoder_targets=None):

        # Pre-Net
        #prenet_out = self.prenet(prev_out)
        #print( " size of decoder_targets:{}".format( decoder_targets.size()))
        #print( "encoder_outs:{}".format(encoder_outs) )
        #print( "decoder_targets:{}".format( decoder_targets))
        prenet_out = self.prenet(decoder_targets)
        #print( "prenet_out:{}".format( prenet_out))
        maxlen = prenet_out.size()[1]
        #print( "size of prenet_out:{}".format( prenet_out.size()))
        positions = torch.range(start=0, end=self.dec_input_maxlen - 1, step=1).to(torch.long)
        positions = self.pos_emb(positions)[:maxlen,:]
        #print( "positions:{}".format( positions))
        #print( "size of positions:{}".format( positions.size()))
        x = prenet_out + positions
        #print( "0 x:{}".format( x ))
        
        attention_weights = {}
        
        # DecoderLayer
        #for i in range(self.layers):
        #    #print( "0 size of x:{}".format( x.size()))
        #    T = x.size()[1]
        #    #T = 1
        #    look_ahead_mask = torch.empty(T, T).fill_(-np.inf).triu_(1)
        #    #look_ahead_mask = torch.triu(torch.full((T, T), float('-inf')), diagonal=1).type(torch.bool)
        #    x, block1, block2  = self.dec_layers[i](encoder_outs, x, look_ahead_mask)
        #    attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
        #    attention_weights['decoder_layer{}_block2'.format(i+1)] = block2   
        #    #print( "1 size of x:{}".format( x.size()))

        for block in self.blocks:
            T = x.size()[1]
            #T = 1
            look_ahead_mask = torch.empty(T, T).fill_(-np.inf).triu_(1)
            x = block(x, encoder_outs, mask=look_ahead_mask)            
            
        #print( "size of x:{}".format( x.size()))
        outs = self.feat_out(x)
        #print( "outs:{}".format(outs))
        outs = torch.permute(outs, (0, 2, 1))
        logits = torch.squeeze( self.prob_out(x), axis=2 )            
        
        return outs, logits, attention_weights



In [None]:
tmp_decoder = Decoder(
        decoder_hidden_dim=512,
        out_dim=80,
        layers=2,
        prenet_layers=2,
        prenet_hidden_dim=512,
        prenet_dropout=0.5,
        ffn_dim=1024,
        dropout_rate = 0.1,
        dec_input_maxlen=3000,
        num_heads = 2
)


In [None]:
#tmp_decoder.eval()

a = torch.ones( (2, 130, 512), dtype=torch.float)

in_lens = []
for i in a:
    in_lens.append( len(a)) 

b = torch.ones( (2, 1300, 80 ), dtype=torch.float )

c, d, e = tmp_decoder( a, in_lens, b )

print( "size of c:{}".format( c.size()))
print( c )

In [None]:
class Postnet(nn.Module):
    def __init__(
        self,
        in_dim=80,
        layers=5,
        channels=512,
        kernel_size=5,
        dropout=0.5,
    ):
        super().__init__()
        postnet = nn.ModuleList()
        for layer in range(layers):
            in_channels = in_dim if layer == 0 else channels
            out_channels = in_dim if layer == layers - 1 else channels
            postnet += [
                nn.Conv1d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    stride=1,
                    padding=(kernel_size - 1) // 2,
                    bias=False,
                ),
                nn.BatchNorm1d(out_channels),
            ]
            if layer != layers - 1:
                postnet += [nn.Tanh()]
            postnet += [nn.Dropout(dropout)]
        self.postnet = nn.Sequential(*postnet)

    def forward(self, xs):
        return self.postnet(xs)

In [None]:
class Transtron(nn.Module):
    def __init__(self,
            num_vocab=52,
            embed_dim=512,
            conv_layers=3,
            conv_channels=512,
            conv_kernel_size=5,
            num_enc_layers = 4,
            enc_num_heads = 2,
            enc_dropout_rate = 0.1,
            conv_dropout_rate = 0.1,
            enc_input_maxlen = 300,
            enc_ffn_dim = 1024,              
            decoder_hidden_dim=512,
            out_dim=80,
            num_dec_layers=4,
            prenet_layers=2,
            prenet_hidden_dim=512,
            prenet_dropout=0.5,
            dec_ffn_dim=1024,
            dec_dropout_rate = 0.1,
            dec_input_maxlen=3000,
            dec_num_heads = 2,                
            postnet_in_dim=80,
            postnet_layers=5,
            postnet_channels=512,
            postnet_kernel_size=5,
            postnet_dropout=0.5
        ):
        super().__init__()
        self.encoder = Encoder(
            num_vocab,
            embed_dim,
            conv_layers,
            conv_channels,
            conv_kernel_size,
            num_enc_layers,
            enc_num_heads,
            enc_dropout_rate,
            conv_dropout_rate,
            enc_input_maxlen,
            enc_ffn_dim 
        )
        self.decoder = Decoder(
            decoder_hidden_dim,
            out_dim,
            num_dec_layers,
            prenet_layers,
            prenet_hidden_dim,
            prenet_dropout,
            dec_ffn_dim,
            dec_dropout_rate,
            dec_input_maxlen,
            dec_num_heads       
        )
        self.postnet = Postnet(
            postnet_in_dim,
            postnet_layers,
            postnet_channels,
            postnet_kernel_size,
            postnet_dropout
        )

    def forward(self, seq, in_lens, decoder_targets):
        # エンコーダによるテキストに潜在する表現の獲得
        encoder_outs = self.encoder(seq, in_lens)

        # デコーダによるメルスペクトログラム、stop token の予測
        outs, logits, att_ws = self.decoder(encoder_outs, in_lens, decoder_targets)

        # Post-Net によるメルスペクトログラムの残差の予測
        outs_fine = outs + self.postnet(outs)

        # (B, C, T) -> (B, T, C)
        outs = outs.transpose(2, 1)
        outs_fine = outs_fine.transpose(2, 1)

        return outs, outs_fine, logits, att_ws
    '''
    def inference(self, seq):
        seq = seq.unsqueeze(0) if len(seq.shape) == 1 else seq
        in_lens = torch.tensor([seq.shape[-1]], dtype=torch.long, device=seq.device)

        return self.forward(seq, in_lens, None)
    def inference(self, in_feats ):
        """Performs inference over one batch of inputs using greedy decoding."""
        in_feats = torch.unsqueeze( in_feats, axis = 0 )
        bs = in_feats.size()[0]
        in_lens = []
        for feats in ( in_feats):
            in_lens.append( len( feats ))
        # エンコーダによるテキストに潜在する表現の獲得
        encoder_outs = model.encoder(in_feats, in_lens)
        decoder_targets_maxlen = in_lens[0] * 10
        #dec_input = tf.ones((bs, 1), dtype=tf.int32) * target_start_token_idx
        decoder_targets = encoder_outs.new_zeros((encoder_outs.size()[0], 1, 80))
        #decoder_targets = None
        #dec_logits = []
        for i in range(decoder_targets_maxlen ):
            print( "i:{}".format( i ))
            # デコーダによるメルスペクトログラム、stop token の予測
            outs, logits, att_ws = model.decoder(encoder_outs, in_lens, decoder_targets)
            print( "torch.sigmoid(logits[0, -1]):{}".format(torch.sigmoid(logits[0, -1])))
            if i > 40 and torch.sigmoid(logits[0, -1]) >= 0.5:
                break
            #print( "0 size of outs:{}".format( outs.size() ))
            outs = torch.permute(outs, (0, 2, 1))
            outs2 = torch.unsqueeze( outs[:,-1,:], axis = 1 )
            #print( "size of outs2:{}".format( outs2.size()))
            #print( "1 size of outs:{}".format( outs.size() ))
            #print( "1 size of decoder_targets:{}".format( decoder_targets.size()))
            decoder_targets = torch.cat( (decoder_targets, outs2), axis = 1 )
            #print( "2 size of decoder_targets:{}".format( decoder_targets.size()))
            #logits = self.classifier(dec_out)
            #logits = tf.argmax(logits, axis=-1, output_type=tf.int32)
            #last_logit = tf.expand_dims(logits[:, -1], axis=-1)
            #decoder_targets = torch.concat([decoder_targets, outs], axis=-1)
        # Post-Net によるメルスペクトログラムの残差の予測
        outs = torch.permute(outs, (0, 2, 1))
        outs_fine = outs + model.postnet(outs)

        # (B, C, T) -> (B, T, C)
        outs = outs.transpose(2, 1)
        outs_fine = outs_fine.transpose(2, 1)
    
        #print( "size of outs_fine:{}".format( outs_fine.size() ))
    
        return outs[0], outs_fine[0], logits[0], att_ws  
    '''

In [None]:
model = Transtron(
    num_vocab=52,
    embed_dim=512,
    conv_layers=3,
    conv_channels=512,
    conv_kernel_size=5,
    num_enc_layers = 4,
    enc_num_heads = 2,
    enc_dropout_rate = 0.1,
    conv_dropout_rate = 0.1,
    enc_input_maxlen = 300,
    enc_ffn_dim = 2048,          
    decoder_hidden_dim=512,
    out_dim=80,
    num_dec_layers=4,
    prenet_layers=2,
    prenet_hidden_dim=512,
    prenet_dropout=0.5,
    dec_ffn_dim=2048,
    dec_dropout_rate = 0.1,
    dec_input_maxlen=3000,
    dec_num_heads = 2,                
    postnet_in_dim=80,
    postnet_layers=5,
    postnet_channels=512,
    postnet_kernel_size=5,
    postnet_dropout=0.5
)

In [None]:
print( model )

In [None]:
#学習で必要な関数
def ensure_divisible_by(feats, N):
    if N == 1:
        return feats
    mod = len(feats) % N
    if mod != 0:
        feats = feats[: len(feats) - mod]
    return feats

In [None]:
#学習で必要な関数
from ttslearn.util import pad_1d, pad_2d

def collate_fn_transtron(batch):
    xs = [x[0] for x in batch]
    ys = [ensure_divisible_by(x[1], 1) for x in batch]
    in_lens = [len(x) for x in xs]
    out_lens = [len(y) for y in ys]
    in_max_len = max(in_lens)
    out_max_len = max(out_lens)
    x_batch = torch.stack([torch.from_numpy(pad_1d(x, in_max_len)) for x in xs])
    y_batch = torch.stack([torch.from_numpy(pad_2d(y, out_max_len)) for y in ys])
    in_lens = torch.tensor(in_lens, dtype=torch.long)
    out_lens = torch.tensor(out_lens, dtype=torch.long)
    stop_flags = torch.zeros(y_batch.shape[0], y_batch.shape[1])
    for idx, out_len in enumerate(out_lens):
        stop_flags[idx, out_len - 1 :] = 1.0
    return x_batch, in_lens, y_batch, out_lens, stop_flags

In [None]:
#学習で必要なミニバッチデータ
from pathlib import Path
from ttslearn.train_util import Dataset
from functools import partial

in_paths_dev = sorted(Path("./dump/jsut_sr16000/norm/dev/in_tacotron/").glob("*.npy"))
in_paths = sorted(Path("./dump/jsut_sr16000/norm/train/in_tacotron/").glob("*.npy"))
#in_paths = sorted(Path("./dump/jsut_sr16000/norm/dev/in_tacotron/").glob("*.npy"))
#print( "in_paths:{}".format( in_paths ))
out_paths_dev = sorted(Path("./dump/jsut_sr16000/norm/dev/out_tacotron/").glob("*.npy"))
out_paths = sorted(Path("./dump/jsut_sr16000/norm/train/out_tacotron/").glob("*.npy"))
#out_paths = sorted(Path("./dump/jsut_sr16000/norm/dev/out_tacotron/").glob("*.npy"))


dataset = Dataset(in_paths, out_paths)
dataset_dev = Dataset(in_paths_dev, out_paths_dev)
#print( " len of dataset:{}".format( len( dataset )))
collate_fn = partial(collate_fn_transtron)
#data_loader = torch.utils.data.DataLoader(dataset, batch_size=8, collate_fn=collate_fn, num_workers=0)
#data_loader_dev = torch.utils.data.DataLoader(dataset_dev, batch_size=8, collate_fn=collate_fn, num_workers=0)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=16, collate_fn=collate_fn, num_workers=0)
data_loader_dev = torch.utils.data.DataLoader(dataset_dev, batch_size=16, collate_fn=collate_fn, num_workers=0)


in_feats, in_lens, out_feats, out_lens, stop_flags = next(iter(data_loader))
print("入力特徴量のサイズ:", tuple(in_feats.shape))
print("出力特徴量のサイズ:", tuple(out_feats.shape))
print("stop flags のサイズ:", tuple(stop_flags.shape))

In [None]:
#学習前にミニバチデータの可視化（教師データ,out_feats)

import librosa.display
import matplotlib.pyplot as plt
import numpy as np
from ttslearn.notebook import get_cmap, init_plot_style, savefig
cmap = get_cmap()
init_plot_style()

sr = 16000

fig, ax = plt.subplots(len(out_feats), 1, figsize=(8,10), sharex=True, sharey=True)
for n in range(len(in_feats)):
    x = out_feats[n].data.numpy()
    hop_length = int(sr * 0.0125)
    mesh = librosa.display.specshow(x.T, sr=sr, x_axis="time", y_axis="frames", hop_length=hop_length, cmap=cmap, ax=ax[n])
    fig.colorbar(mesh, ax=ax[n])
    mesh.set_clim(-4, 4)
    # あとで付け直すので、ここではラベルを削除します
    ax[n].set_xlabel("")
    
ax[-1].set_xlabel("Time [sec]")
for a in ax:
    a.set_ylabel("Mel channel")

plt.tight_layout()
savefig("fig/e2etts_impl_minibatch")

In [None]:
#学習の前準備

from torch import optim

# lr は学習率を表します
#optimizer = optim.Adam(model.parameters(), lr=0.001)
optimizer = optim.Adam(model.parameters(), lr=0.0001, eps=1e-9, amsgrad=True)
#optimizer = optim.Adam(model.parameters(), lr=0.001, eps=1e-9, amsgrad=True)
#optimizer = optim.RMSprop(model.parameters(), lr=0.0001, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)


# gamma は学習率の減衰係数を表します
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, gamma=0.5, step_size=500000)

In [None]:
import gc 

gc.collect()

for param in model.parameters():
    param.requires_grad = True

In [None]:
#学習

from ttslearn.util import make_non_pad_mask
from ttslearn.tacotron import Tacotron2TTS
from tqdm.notebook import tqdm
from IPython.display import Audio

history = np.zeros((0, 7))
history_dev = np.zeros((0, 6))

num_epochs = 1000
#num_epochs = 30
it_train = 0
it_dev = 0
for epoch in range( num_epochs ):
    
    model.train()
    total_decoder_out_loss = 0
    total_postnet_out_loss = 0
    total_stop_token_loss = 0
    total_loss = 0
    count = 0
    # DataLoader を用いたミニバッチの作成: ミニバッチ毎に処理する
    phar = tqdm( range( len(data_loader) ), desc='train' )
    Iter_train = iter(data_loader)
    for i in phar:
    #for in_feats, in_lens, out_feats, out_lens, stop_flags in tqdm(data_loader):
        in_feats, in_lens, out_feats, out_lens, stop_flags = next(Iter_train)
        in_lens, indices = torch.sort(in_lens, dim=0, descending=True)
        in_feats, out_feats, out_lens = in_feats[indices], out_feats[indices], out_lens[indices]
        out_feats0 = torch.zeros_like( out_feats )
        out_feats0[:,1:,:] = out_feats[:,:-1,:]
    
        #count += len( in_feats )
        count += 1
    
        # 順伝搬の計算
        #print( "size of in_feats:{}".format( in_feats.size()))
        #print( "size of in_lens:{}".format( in_lens.size()))
        #print( "in_lens:{}".format( in_lens ))
        #print( "size of out_feats:{}".format( out_feats.size ))
        #out_feats2 = out_feats[:,:-1,:]
    
        #outs, outs_fine, logits, _ = model(in_feats, in_lens, out_feats)
        outs, outs_fine, logits, _ = model(in_feats, in_lens, out_feats0)
        #print( "size of out_feats:{}".format( out_feats.size()))
        #out_feats2 = torch.zeros_like( out_feats )
        #out_feats2[:,:-1,:] = out_feats[:,1:,:]
        #print( "size of out_feats2:{}".format( out_feats2.size()))
        #stop_flags2 = torch.ones_like( stop_flags )
        #stop_flags2[:,:-1] = stop_flags[:,1:] 
        
        # ゼロパディグした部分を損失関数のの計算から除外するためにマスクを適用します
        # Mask (B x T x 1)
        mask = make_non_pad_mask(out_lens).unsqueeze(-1)
        #print( out_feats2.size())
        out_feats = out_feats.masked_select(mask)
        #out_feats2 = out_feats2.masked_select(mask)
        outs = outs.masked_select(mask)
        outs_fine = outs_fine.masked_select(mask)
        #print( "size of stop_flags:{}".format( stop_flags.size()))
        #print( "stop_flags[0][-1]:{}".format( stop_flags[0][-1]))
        stop_flags = stop_flags.masked_select(mask.squeeze(-1))
        #stop_flags2 = stop_flags2.masked_select(mask.squeeze(-1))
        logits = logits.masked_select(mask.squeeze(-1))
        #print( out_feats.size())
        
        # 損失の計算
        #decoder_out_loss = nn.MSELoss(reduction='mean')(outs, out_feats2)
        #decoder_out_loss = nn.MSELoss()(outs, out_feats2)
        #postnet_out_loss = nn.MSELoss(reduction='mean')(outs_fine, out_feats2)
        #postnet_out_loss = nn.MSELoss()(outs_fine, out_feats2)
        decoder_out_loss = nn.MSELoss()(outs, out_feats)
        postnet_out_loss = nn.MSELoss()(outs_fine, out_feats) 
        #print( "logits", logits )
        #print( "stop_flags", stop_flags)
        stop_token_loss = nn.BCEWithLogitsLoss()(logits, stop_flags)
        #stop_token_loss = nn.BCEWithLogitsLoss(reduction='mean')(logits, stop_flags2)
        #stop_token_loss = nn.BCEWithLogitsLoss()(logits, stop_flags2)
        
        # 損失の合計
        loss = decoder_out_loss + postnet_out_loss + stop_token_loss
        
        total_loss += loss.item()
        #print( "loss:{}".format(total_loss))
        total_decoder_out_loss += decoder_out_loss.item()
        #print( "decoder_out_loss:{}".format(total_decoder_out_loss))
        total_postnet_out_loss += postnet_out_loss.item()
        #print( "postnet_out_loss:{}".format(total_postnet_out_loss))
        total_stop_token_loss += stop_token_loss.item()
        #print( "stop_token_loss:{}".format(total_stop_token_loss))

        
        # 損失の値を出力
        it_train += 1
        # optimizer に蓄積された勾配をリセット
        optimizer.zero_grad()
        # 誤差の逆伝播
        loss.backward()
        # gradient clipping
        #a = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2)
        #a = nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0, norm_type=2)
        #print( "a:{}".format(a))
        #nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)
        #nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5, norm_type=2)
        #nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
        # パラメータの更新
        optimizer.step()
        # 学習率スケジューラの更新
        current_lr = optimizer.param_groups[0]["lr"]
        lr_scheduler.step()
        
        avg_loss = total_loss / count
        
        #プログラスバーに cer 表示
        phar.set_postfix( loss = avg_loss )   
        
    avg_loss = total_loss / count
    avg_decoder_out_loss = total_decoder_out_loss / count
    avg_postnet_out_loss = total_postnet_out_loss / count
    avg_stop_token_loss = total_stop_token_loss / count    

    print(f"epoch: {epoch+1:3d}, train it: {it_train:6d}, decoder_out: {avg_decoder_out_loss :.5f}, postnet_out: {avg_postnet_out_loss :.5f}, stop_token: {avg_stop_token_loss :.5f}, loss: {avg_loss :.5f}")
    item = np.array([epoch+1, it_train, avg_decoder_out_loss , avg_postnet_out_loss , avg_stop_token_loss , avg_loss ,  current_lr ])
    history = np.vstack((history, item))
    
    model.eval()
    total_dev_decoder_out_loss = 0
    total_dev_postnet_out_loss = 0
    total_dev_stop_token_loss = 0
    total_dev_loss = 0
    count = 0
    # DataLoader を用いたミニバッチの作成: ミニバッチ毎に処理する
    phar = tqdm( range( len(data_loader_dev) ), desc='dev' )
    Iter_dev = iter(data_loader_dev)
    for i in phar:
    #for in_feats, in_lens, out_feats, out_lens, stop_flags in tqdm(data_loader_dev):
        in_feats, in_lens, out_feats, out_lens, stop_flags = next(Iter_dev)
        in_lens, indices = torch.sort(in_lens, dim=0, descending=True)
        in_feats, out_feats, out_lens = in_feats[indices], out_feats[indices], out_lens[indices]
        out_feats0 = torch.zeros_like( out_feats )
        out_feats0[:,1:,:] = out_feats[:,:-1,:]
        
        #count += len( in_feats )
        count += 1
   
        #outs, outs_fine, logits, _ = model(in_feats, in_lens, out_feats)
        outs, outs_fine, logits, _ = model(in_feats, in_lens, out_feats0)
        #out_feats2 = torch.zeros_like( out_feats )
        #out_feats2[:,:-1,:] = out_feats[:,1:,:]    
        #stop_flags2 = torch.ones_like( stop_flags )
        #stop_flags2[:,:-1] = stop_flags[:,1:] 
        
        # ゼロパディグした部分を損失関数のの計算から除外するためにマスクを適用します
        # Mask (B x T x 1)
        mask = make_non_pad_mask(out_lens).unsqueeze(-1)
        #out_feats2 = out_feats2.masked_select(mask)
        out_feats = out_feats.masked_select(mask)
        outs = outs.masked_select(mask)
        outs_fine = outs_fine.masked_select(mask)
        stop_flags = stop_flags.masked_select(mask.squeeze(-1))
        #stop_flags2 = stop_flags2.masked_select(mask.squeeze(-1))
        logits = logits.masked_select(mask.squeeze(-1))
        
        # 損失の計算
        #print( " size of outs:{}".format( outs.size()))
        #print( " size of out_feats2:{}".format( out_feats2.size()))
        #dev_decoder_out_loss = nn.MSELoss(reduction='mean')(outs, out_feats2)
        #dev_decoder_out_loss = nn.MSELoss()(outs, out_feats2)
        #dev_postnet_out_loss = nn.MSELoss(reduction='mean')(outs_fine, out_feats2)
        #dev_postnet_out_loss = nn.MSELoss()(outs_fine, out_feats2)
        dev_decoder_out_loss = nn.MSELoss()(outs, out_feats)
        dev_postnet_out_loss = nn.MSELoss()(outs_fine, out_feats) 
        dev_stop_token_loss = nn.BCEWithLogitsLoss()(logits, stop_flags)
        #dev_stop_token_loss = nn.BCEWithLogitsLoss(reduction='mean')(logits, stop_flags2)
        #dev_stop_token_loss = nn.BCEWithLogitsLoss()(logits, stop_flags2)
        
        # 損失の合計
        dev_loss = dev_decoder_out_loss + dev_postnet_out_loss + dev_stop_token_loss
        
        total_dev_loss += dev_loss.item()
        total_dev_decoder_out_loss += dev_decoder_out_loss.item()
        total_dev_postnet_out_loss += dev_postnet_out_loss.item()
        total_dev_stop_token_loss += dev_stop_token_loss.item()
        
        avg_dev_loss = total_dev_loss / count
        
        #プログラスバーに cer 表示
        phar.set_postfix( dev_loss = avg_dev_loss ) 

        # 損失の値を出力
        it_dev += 1
        
    avg_dev_loss = total_dev_loss / count
    avg_dev_decoder_out_loss = total_dev_decoder_out_loss / count
    avg_dev_postnet_out_loss = total_dev_postnet_out_loss / count
    avg_dev_stop_token_loss = total_dev_stop_token_loss / count    

    print(f"epoch: {epoch+1:3d}, dev it: {it_dev:6d}, decoder_out: {avg_dev_decoder_out_loss:.5f}, postnet_out: {avg_dev_postnet_out_loss:.5f}, stop_token: {avg_dev_stop_token_loss:.5f}, loss: {avg_dev_loss:.5f}")
    item = np.array([epoch+1, it_dev, avg_dev_decoder_out_loss , avg_dev_postnet_out_loss , avg_dev_stop_token_loss , avg_dev_loss ])
    history_dev = np.vstack((history_dev, item))

    

In [None]:
#モデルのセーブ

import torch
import pandas as pd

hist_df = pd.DataFrame(history)
hist_df.to_csv('history_ch28.csv', header=False, index=False)
hist_dev_df = pd.DataFrame(history_dev)
hist_dev_df.to_csv('history_dev_ch28.csv', header=False, index=False)

torch.save(model, 'transtron_weight28.pth')
torch.save(model.state_dict(), 'transtron_weight28_state_dict.pth')

In [None]:
model.eval()

def inference( in_feats ):
    
    """Performs inference over one batch of inputs using greedy decoding."""
    #print( in_feats )
    in_feats = torch.unsqueeze( in_feats, axis = 0 )
    bs = in_feats.size()[0]
    in_lens = []
    for feats in ( in_feats):
        in_lens.append( len( feats ))
    # エンコーダによるテキストに潜在する表現の獲得
    #print( "in_feats:{}".format(in_feats) )
    #print( "in_lens:{}".format( in_lens ))
    encoder_outs = model.encoder(in_feats, in_lens)
    #print( "encoder_outs:{}".format( encoder_outs ))
    decoder_targets_maxlen = in_lens[0] * 10
    #dec_input = tf.ones((bs, 1), dtype=tf.int32) * target_start_token_idx
    decoder_targets = encoder_outs.new_zeros((encoder_outs.size()[0], 1, 80))
    #decoder_targets = None
    #dec_logits = []
    for i in range(decoder_targets_maxlen ):
        print( "i:{}".format( i ))
        # デコーダによるメルスペクトログラム、stop token の予測
        outs, logits, att_ws = model.decoder(encoder_outs, in_lens, decoder_targets)
        #print( "torch.sigmoid(logits[0, -1]):{}".format(torch.sigmoid(logits[0, -1])))
        if i > 40 and torch.sigmoid(logits[0, -1]) >= 0.5:
            break
        #print( "0 size of outs:{}".format( outs.size() ))
        outs = torch.permute(outs, (0, 2, 1))
        outs2 = torch.unsqueeze( outs[:,-1,:], axis = 1 )
        #print( "size of outs2:{}".format( outs2.size()))
        #print( "1 size of outs:{}".format( outs.size() ))
        #print( "1 size of decoder_targets:{}".format( decoder_targets.size()))
        decoder_targets = torch.cat( (decoder_targets, outs2), axis = 1 )
        #print( "2 size of decoder_targets:{}".format( decoder_targets.size()))
        #logits = self.classifier(dec_out)
        #logits = tf.argmax(logits, axis=-1, output_type=tf.int32)
        #last_logit = tf.expand_dims(logits[:, -1], axis=-1)
        #decoder_targets = torch.concat([decoder_targets, outs], axis=-1)
    # Post-Net によるメルスペクトログラムの残差の予測
    outs = torch.permute(decoder_targets, (0, 2, 1))
    outs_fine = outs + model.postnet(outs)

    # (B, C, T) -> (B, T, C)
    outs = outs.transpose(2, 1)
    outs_fine = outs_fine.transpose(2, 1)
    
    #print( "size of outs_fine:{}".format( outs_fine.size() ))
    
    return outs[0], outs_fine[0], logits[0], att_ws  