<a href="https://colab.research.google.com/github/wasimmadha/dubbing-project-research/blob/main/translatoron_training_understanding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
## Imports

import torch
import numpy as np
from scipy.signal import get_window
import librosa.util as librosa_util
import random
import torch.utils.data
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
from math import sqrt
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F

import numpy as np
from scipy.io.wavfile import read
import torch

import matplotlib
matplotlib.use("Agg")
import matplotlib.pylab as plt
import numpy as np

import random
import torch
from torch.utils.tensorboard import SummaryWriter


import torch
import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable


import torch
import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable
from scipy.signal import get_window
from librosa.util import pad_center, tiny

import torch
from librosa.filters import mel as librosa_mel_fn

%matplotlib inline

In [None]:
### hparams.py

class mapDict(dict):
  __getattr__ = dict.get
  __setattr__ = dict.__setitem__
  __delattr__ = dict.__delitem__



def create_hparams(hparams_string=None,verbose=False):
  hparams = {
    ################################
    # Experiment Parameters        #
    ################################
    "epochs":2,
    "iters_per_checkpoint":10,
    "seed":1234,
    "dynamic_loss_scaling":True,
    "fp16_run":False,
    "distributed_run":False,
    "dist_backend":"nccl",
    "dist_url":"tcp://localhost:54321",
    "cudnn_enabled":True,
    "cudnn_benchmark":False,
    "ignore_layers":['embedding.weight'],

    ################################
    # Data Parameters             #
    ################################
    "load_mel_from_disk":False,
    "training_files":'data/train',
    "validation_files":'data/val',
    "text_cleaners":['english_cleaners'],

    ################################
    # Audio Parameters             #
    ################################
    "max_wav_value":32768.0,
    "sampling_rate":22050,
    "filter_length":1024,
    "hop_length":256,
    "win_length":1024,
    "n_mel_channels":80,
    "mel_fmin":0.0,
    "mel_fmax":8000.0,

    #Data parameters
    "input_data_root": r'C:\Users\Wasim\DubbingProject\Speech2Speech\google_research\translatotron\data\prepared_data\source\train',
    "output_data_root": r'C:\Users\Wasim\DubbingProject\Speech2Speech\google_research\translatotron\data\prepared_data\target\train',
    "train_size": 0.75,
    #Output Audio Parameters
    "out_channels":1025,
    ################################
    # Model Parameters             #
    ################################
    "symbols_embedding_dim":512,

    # Encoder parameters
    "encoder_kernel_size":5,
    "encoder_n_convolutions":3,
    "encoder_embedding_dim":128,

    # Decoder parameters
    "n_frames_per_step":1,  # currently only 1 is supported
    "decoder_rnn_dim":256,
    "prenet_dim":32,
    "max_decoder_steps":1000,
    "gate_threshold":0.5,
    "p_attention_dropout":0.1,
    "p_decoder_dropout":0.1,

    # Attention parameters
    "attention_rnn_dim":256,
    "attention_dim":128,
    "attention_heads": 4,

    # Location Layer parameters
    "attention_location_n_filters":32,
    "attention_location_kernel_size":31,

    # Mel-post processing network parameters
    "postnet_embedding_dim":128,
    "postnet_kernel_size":5,
    "postnet_n_convolutions":2,

    ################################
    # Optimization Hyperparameters #
    ################################
    "use_saved_learning_rate":False,
    "learning_rate":1e-3,
    "weight_decay":1e-6,
    "grad_clip_thresh":1.0,
    "batch_size":1,
    "mask_padding":True
    # set model's padded outputs to padded values
  }

  hparams = mapDict(hparams)

  return hparams


In [None]:
## utils.py

def get_mask_from_lengths(lengths):
    max_len = torch.max(lengths).item()
    ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
    mask = (ids < lengths.unsqueeze(1)).bool()
    return mask


def load_wav_to_torch(full_path):
    # print(full_path)
    sampling_rate, data = read(full_path)
    return torch.FloatTensor(data.astype(np.float32)), sampling_rate


def load_filepaths_and_text(filename, split="|"):
    with open(filename, encoding='utf-8') as f:
        filepaths_and_text = [line.strip().split(split) for line in f]
    return filepaths_and_text


def to_gpu(x):
    x = x.contiguous()

    if torch.cuda.is_available():
        x = x.cuda(non_blocking=True)
    return torch.autograd.Variable(x)


In [None]:
## plotting_utils.py
def save_figure_to_numpy(fig):
    # save it to a numpy array.
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    return data


def plot_alignment_to_numpy(alignment, info=None):
    fig, ax = plt.subplots(figsize=(6, 4))
    im = ax.imshow(alignment, aspect='auto', origin='lower',
                   interpolation='none')
    fig.colorbar(im, ax=ax)
    xlabel = 'Decoder timestep'
    if info is not None:
        xlabel += '\n\n' + info
    plt.xlabel(xlabel)
    plt.ylabel('Encoder timestep')
    plt.tight_layout()

    fig.canvas.draw()
    data = save_figure_to_numpy(fig)
    plt.close()
    return data


def plot_spectrogram_to_numpy(spectrogram):
    fig, ax = plt.subplots(figsize=(12, 3))
    im = ax.imshow(spectrogram, aspect="auto", origin="lower",
                   interpolation='none')
    plt.colorbar(im, ax=ax)
    plt.xlabel("Frames")
    plt.ylabel("Channels")
    plt.tight_layout()

    fig.canvas.draw()
    data = save_figure_to_numpy(fig)
    plt.close()
    return data


def plot_gate_outputs_to_numpy(gate_targets, gate_outputs):
    fig, ax = plt.subplots(figsize=(12, 3))
    ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5,
               color='green', marker='+', s=1, label='target')
    ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5,
               color='red', marker='.', s=1, label='predicted')

    plt.xlabel("Frames (Green target, Red predicted)")
    plt.ylabel("Gate State")
    plt.tight_layout()

    fig.canvas.draw()
    data = save_figure_to_numpy(fig)
    plt.close()
    return data


In [None]:
class Tacotron2Logger(SummaryWriter):
    def __init__(self, logdir):
        super(Tacotron2Logger, self).__init__(logdir)

    def log_training(self, reduced_loss, grad_norm, learning_rate, duration,
                     iteration):
            self.add_scalar("training.loss", reduced_loss, iteration)
            self.add_scalar("grad.norm", grad_norm, iteration)
            self.add_scalar("learning.rate", learning_rate, iteration)
            self.add_scalar("duration", duration, iteration)

    def log_validation(self, reduced_loss, model, y, y_pred, iteration):
        self.add_scalar("validation.loss", reduced_loss, iteration)
        _, mel_outputs, gate_outputs, alignments = y_pred
        mel_targets, gate_targets = y

        # plot distribution of parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), iteration)

        # plot alignment, mel target and predicted, gate target and predicted
        idx = random.randint(0, alignments.size(0) - 1)
        self.add_image(
            "alignment",
            plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
            iteration, dataformats='HWC')
        self.add_image(
            "mel_target",
            plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
            iteration, dataformats='HWC')
        self.add_image(
            "mel_predicted",
            plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
            iteration, dataformats='HWC')
        self.add_image(
            "gate",
            plot_gate_outputs_to_numpy(
                gate_targets[idx].data.cpu().numpy(),
                torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()),
            iteration, dataformats='HWC')

In [None]:
### distributed.py
def _flatten_dense_tensors(tensors):
    """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
    same dense type.
    Since inputs are dense, the resulting tensor will be a concatenated 1D
    buffer. Element-wise operation on this buffer will be equivalent to
    operating individually.
    Arguments:
        tensors (Iterable[Tensor]): dense tensors to flatten.
    Returns:
        A contiguous 1D buffer containing input tensors.
    """
    if len(tensors) == 1:
        return tensors[0].contiguous().view(-1)
    flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
    return flat

def _unflatten_dense_tensors(flat, tensors):
    """View a flat buffer using the sizes of tensors. Assume that tensors are of
    same dense type, and that flat is given by _flatten_dense_tensors.
    Arguments:
        flat (Tensor): flattened dense tensors to unflatten.
        tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
          unflatten flat.
    Returns:
        Unflattened dense tensors with sizes same as tensors and values from
        flat.
    """
    outputs = []
    offset = 0
    for tensor in tensors:
        numel = tensor.numel()
        outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
        offset += numel
    return tuple(outputs)


'''
This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py
launcher included with this example. It assumes that your run is using multiprocess with 1
GPU/process, that the model is on the correct device, and that torch.set_device has been
used to set the device.

Parameters are broadcasted to the other processes on initialization of DistributedDataParallel,
and will be allreduced at the finish of the backward pass.
'''
class DistributedDataParallel(Module):

    def __init__(self, module):
        super(DistributedDataParallel, self).__init__()
        #fallback for PyTorch 0.3
        if not hasattr(dist, '_backend'):
            self.warn_on_half = True
        else:
            self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False

        self.module = module

        for p in self.module.state_dict().values():
            if not torch.is_tensor(p):
                continue
            dist.broadcast(p, 0)

        def allreduce_params():
            if(self.needs_reduction):
                self.needs_reduction = False
                buckets = {}
                for param in self.module.parameters():
                    if param.requires_grad and param.grad is not None:
                        tp = type(param.data)
                        if tp not in buckets:
                            buckets[tp] = []
                        buckets[tp].append(param)
                if self.warn_on_half:
                    if torch.cuda.HalfTensor in buckets:
                        print("WARNING: gloo dist backend for half parameters may be extremely slow." +
                              " It is recommended to use the NCCL backend in this case. This currently requires" +
                              "PyTorch built from top of tree master.")
                        self.warn_on_half = False

                for tp in buckets:
                    bucket = buckets[tp]
                    grads = [param.grad.data for param in bucket]
                    coalesced = _flatten_dense_tensors(grads)
                    dist.all_reduce(coalesced)
                    coalesced /= dist.get_world_size()
                    for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
                        buf.copy_(synced)

        for param in list(self.module.parameters()):
            def allreduce_hook(*unused):
                param._execution_engine.queue_callback(allreduce_params)
            if param.requires_grad:
                param.register_hook(allreduce_hook)

    def forward(self, *inputs, **kwargs):
        self.needs_reduction = True
        return self.module(*inputs, **kwargs)

    '''
    def _sync_buffers(self):
        buffers = list(self.module._all_buffers())
        if len(buffers) > 0:
            # cross-node buffer sync
            flat_buffers = _flatten_dense_tensors(buffers)
            dist.broadcast(flat_buffers, 0)
            for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
                buf.copy_(synced)
     def train(self, mode=True):
        # Clear NCCL communicator and CUDA event cache of the default group ID,
        # These cache will be recreated at the later call. This is currently a
        # work-around for a potential NCCL deadlock.
        if dist._backend == dist.dist_backend.NCCL:
            dist._clear_group_cache()
        super(DistributedDataParallel, self).train(mode)
        self.module.train(mode)
    '''
'''
Modifies existing model to do gradient allreduce, but doesn't change class
so you don't need "module"
'''
def apply_gradient_allreduce(module):
        if not hasattr(dist, '_backend'):
            module.warn_on_half = True
        else:
            module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False

        for p in module.state_dict().values():
            if not torch.is_tensor(p):
                continue
            dist.broadcast(p, 0)

        def allreduce_params():
            if(module.needs_reduction):
                module.needs_reduction = False
                buckets = {}
                for param in module.parameters():
                    if param.requires_grad and param.grad is not None:
                        tp = param.data.dtype
                        if tp not in buckets:
                            buckets[tp] = []
                        buckets[tp].append(param)
                if module.warn_on_half:
                    if torch.cuda.HalfTensor in buckets:
                        print("WARNING: gloo dist backend for half parameters may be extremely slow." +
                              " It is recommended to use the NCCL backend in this case. This currently requires" +
                              "PyTorch built from top of tree master.")
                        module.warn_on_half = False

                for tp in buckets:
                    bucket = buckets[tp]
                    grads = [param.grad.data for param in bucket]
                    coalesced = _flatten_dense_tensors(grads)
                    dist.all_reduce(coalesced)
                    coalesced /= dist.get_world_size()
                    for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
                        buf.copy_(synced)

        for param in list(module.parameters()):
            def allreduce_hook(*unused):
                Variable._execution_engine.queue_callback(allreduce_params)
            if param.requires_grad:
                param.register_hook(allreduce_hook)

        def set_needs_reduction(self, input, output):
            self.needs_reduction = True

        module.register_forward_hook(set_needs_reduction)
        return module

In [None]:
### audio_preprocessing.py

def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
                     n_fft=800, dtype=np.float32, norm=None):
    """
    # from librosa 0.6
    Compute the sum-square envelope of a window function at a given hop length.

    This is used to estimate modulation effects induced by windowing
    observations in short-time fourier transforms.

    Parameters
    ----------
    window : string, tuple, number, callable, or list-like
        Window specification, as in `get_window`

    n_frames : int > 0
        The number of analysis frames

    hop_length : int > 0
        The number of samples to advance between frames

    win_length : [optional]
        The length of the window function.  By default, this matches `n_fft`.

    n_fft : int > 0
        The length of each analysis frame.

    dtype : np.dtype
        The data type of the output

    Returns
    -------
    wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
        The sum-squared envelope of the window function
    """
    if win_length is None:
        win_length = n_fft

    n = n_fft + hop_length * (n_frames - 1)
    x = np.zeros(n, dtype=dtype)

    # Compute the squared window at the desired length
    win_sq = get_window(window, win_length, fftbins=True)
    win_sq = librosa_util.normalize(win_sq, norm=norm)**2
    win_sq = librosa_util.pad_center(win_sq, n_fft)

    # Fill the envelope
    for i in range(n_frames):
        sample = i * hop_length
        x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
    return x


def griffin_lim(magnitudes, stft_fn, n_iters=30):
    """
    PARAMS
    ------
    magnitudes: spectrogram magnitudes
    stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
    """

    angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
    angles = angles.astype(np.float32)
    angles = torch.autograd.Variable(torch.from_numpy(angles))
    signal = stft_fn.inverse(magnitudes, angles).squeeze(1)

    for i in range(n_iters):
        _, angles = stft_fn.transform(signal)
        signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
    return signal


def dynamic_range_compression(x, C=1, clip_val=1e-5):
    """
    PARAMS
    ------
    C: compression factor
    """
    return torch.log(torch.clamp(x, min=clip_val) * C)


def dynamic_range_decompression(x, C=1):
    """
    PARAMS
    ------
    C: compression factor used to compress
    """
    return torch.exp(x) / C


In [None]:
### data_utils.py
class TextMelLoader(torch.utils.data.Dataset):
    """
        1) loads audio,text pairs
        2) normalizes text and converts them to sequences of one-hot vectors
        3) computes mel-spectrograms from audio files.
    """
    def __init__(self, audiopaths, hparams):
        self.inputs = audiopaths[0]
        self.outputs = audiopaths[1]
        self.text_cleaners = hparams.text_cleaners
        self.max_wav_value = hparams.max_wav_value
        self.sampling_rate = hparams.sampling_rate
        self.load_mel_from_disk = hparams.load_mel_from_disk
        self.stft = TacotronSTFT(
            hparams.filter_length, hparams.hop_length, hparams.win_length,
            hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
            hparams.mel_fmax)
        # random.seed(hparams.seed)
        # random.shuffle(self.audiopaths_and_text)

        for i, input_file in enumerate(self.inputs):
            melspec = self.get_mel(input_file)
            os.makedirs('melspec_input', exist_ok=True)

            # # Plot and save the mel spectrogram as an image
            # plt.figure(figsize=(8, 6))
            # plt.imshow(melspec, cmap='viridis', origin='lower', aspect='auto')
            # plt.title('Mel Spectrogram')
            # plt.xlabel('Time')
            # plt.ylabel('Frequency')
            # plt.colorbar(format='%+2.0f dB')
            # plt.savefig(os.path.join('melspec_input', f'mel_spectrogram_{i}.png'))

        for i, input_file in enumerate(self.outputs):
            melspec = self.get_mel(input_file)
            os.makedirs('melspec_output', exist_ok=True)

            # # Plot and save the mel spectrogram as an image
            # plt.figure(figsize=(8, 6))
            # plt.imshow(melspec, cmap='viridis', origin='lower', aspect='auto')
            # plt.title('Mel Spectrogram')
            # plt.xlabel('Time')
            # plt.ylabel('Frequency')
            # plt.colorbar(format='%+2.0f dB')
            # plt.savefig(os.path.join('melspec_output', f'mel_spectrogram_{i}.png'))

    def get_mel_spec_pair(self, index):
        # separate filename and text
        # lin = self.get_spec(self.outputs[index])
        # mel = self.get_mel(self.inputs[index])
        inputs = self.get_mel(self.inputs[index])
        outputs = self.get_mel(self.outputs[index])

        return (inputs,outputs)

    def get_mel(self, filename):
        if not self.load_mel_from_disk:
            audio, sampling_rate = load_wav_to_torch(filename)

            audio = audio[:, 0]
            # if sampling_rate != self.stft.sampling_rate:
            #     raise ValueError("{} {} SR doesn't match target {} SR".format(
            #         sampling_rate, self.stft.sampling_rate))
            audio_norm = audio / self.max_wav_value
            audio_norm = audio_norm.unsqueeze(0)
            audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
            melspec = self.stft.mel_spectrogram(audio_norm)
            melspec = torch.squeeze(melspec, 0)
            melspec_arr = melspec[0].numpy()

            # # Plot and save the mel spectrogram as an image
            # plt.figure(figsize=(8, 6))
            # plt.imshow(melspec, cmap='viridis', origin='lower', aspect='auto')
            # plt.title('Mel Spectrogram')
            # plt.xlabel('Time')
            # plt.ylabel('Frequency')
            # plt.colorbar(format='%+2.0f dB')
            # plt.savefig('mel_spectrogram.png')

        else:
            melspec = torch.from_numpy(np.load(filename))
            assert melspec.size(0) == self.stft.n_mel_channels, (
                'Mel dimension mismatch: given {}, expected {}'.format(
                    melspec.size(0), self.stft.n_mel_channels))

        return melspec

    def get_spec(self, filename):
        if not self.load_mel_from_disk:
            audio, sampling_rate = load_wav_to_torch(filename)
            if sampling_rate != self.stft.sampling_rate:
                raise ValueError("{} {} SR doesn't match target {} SR".format(
                    sampling_rate, self.stft.sampling_rate))
            audio_norm = audio / self.max_wav_value
            audio_norm = audio_norm.unsqueeze(0)
            audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
            spec = self.stft.spectrogram(audio_norm)
            spec = torch.squeeze(spec, 0)
        else:
            spec = torch.from_numpy(np.load(filename))
            # assert melspec.size(0) == self.stft.n_mel_channels, (
            #     'Mel dimension mismatch: given {}, expected {}'.format(
            #         melspec.size(0), self.stft.n_mel_channels))

        return spec

    # def get_text(self, text):
    #     text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
    #     return text_norm

    def __getitem__(self, index):
        return self.get_mel_spec_pair(index)

    def __len__(self):
        return len(self.inputs)


class TextMelCollate():
    """ Zero-pads model inputs and targets based on number of frames per setep
    """
    def __init__(self, n_frames_per_step):
        self.n_frames_per_step = n_frames_per_step

    def __call__(self, batch):
        """Collate's training batch from normalized text and mel-spectrogram
        PARAMS
        ------
        batch: [text_normalized, mel_normalized]
        """

        # Right zero-pad mel-spec
        # num_mels = batch[0][1].size(0)
        # max_input_len = max([x[0].size(1) for x in batch])
        # if max_input_len % self.n_frames_per_step != 0:
        #     max_input_len += self.n_frames_per_step - max_input_len % self.n_frames_per_step
        #     assert max_input_len % self.n_frames_per_step == 0

        # # include mel padded and gate padded
        # mel_padded = torch.FloatTensor(len(batch), num_mels, max_input_len)
        # mel_padded.zero_()
        # gate_padded = torch.FloatTensor(len(batch), max_input_len)
        # gate_padded.zero_()
        # input_lengths = torch.LongTensor(len(batch))
        # for i in range(len(batch)):
        #     mel = batch[i][0]
        #     mel_padded[i, :, :mel.size(1)] = mel
        #     gate_padded[i, mel.size(1)-1:] = 1
        #     input_lengths[i] = mel.size(1)

        # # input_lengths, ids_sorted_decreasing = torch.sort(
        # #     torch.LongTensor([len(x[0]) for x in batch]),
        # #     dim=0, descending=True)
        # num_dims = batch[0][1].size(0)
        # max_target_len = max([x[1].size(1) for x in batch])
        # spec_padded = torch.FloatTensor(len(batch), num_dims, max_target_len)
        # spec_padded.zero_()
        # # gate_padded = torch.FloatTensor(len(batch), max_target_len)
        # # gate_padded.zero_()
        # output_lengths = torch.LongTensor(len(batch))
        # for i in range(len(batch)):
        #     spec = batch[i][1]
        #     spec_padded[i, :, :spec.size(1)] = spec
        #     # gate_padded[i, mel.size(1)-1:] = 1
        #     output_lengths[i] = spec.size(1)

        # return mel_padded, gate_padded,input_lengths, spec_padded, \
        #     output_lengths
        num_mels = batch[0][0].size(0)
        # max_input_len = max([x[0].size(1) for x in batch])
        input_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([x[0].size(1) for x in batch]),
            dim=0, descending=True)
        max_input_len = input_lengths[0]
        if max_input_len % self.n_frames_per_step != 0:
            max_input_len += self.n_frames_per_step - max_input_len % self.n_frames_per_step
            assert max_input_len % self.n_frames_per_step == 0

        # include mel padded and gate padded
        input_padded = torch.FloatTensor(len(batch), num_mels, max_input_len)
        input_padded.zero_()
        # gate_padded = torch.FloatTensor(len(batch), max_target_len)
        # gate_padded.zero_()
        for i in ids_sorted_decreasing:
            mel = batch[i][0]
            input_padded[i, :, :mel.size(1)] = mel
            # gate_padded[i, mel.size(1)-1:] = 1


        num_mels = batch[0][1].size(0)
        max_target_len = max([x[1].size(1) for x in batch])
        if max_target_len % self.n_frames_per_step != 0:
            max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
            assert max_target_len % self.n_frames_per_step == 0

        # include mel padded and gate padded
        mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
        mel_padded.zero_()
        gate_padded = torch.FloatTensor(len(batch), max_target_len)
        gate_padded.zero_()
        output_lengths = torch.LongTensor(len(batch))
        for i in ids_sorted_decreasing:
            mel = batch[i][1]
            mel_padded[i, :, :mel.size(1)] = mel
            gate_padded[i, mel.size(1)-1:] = 1
            output_lengths[i] = mel.size(1)

        return input_padded, input_lengths, mel_padded, gate_padded, \
            output_lengths


In [None]:
## stft.py

"""
BSD 3-Clause License

Copyright (c) 2017, Prem Seetharaman
All rights reserved.

* Redistribution and use in source and binary forms, with or without
  modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice,
  this list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice, this
  list of conditions and the following disclaimer in the
  documentation and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from this
  software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""

class STFT(torch.nn.Module):
    """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
    def __init__(self, filter_length=800, hop_length=200, win_length=800,
                 window='hann'):
        super(STFT, self).__init__()
        self.filter_length = filter_length
        self.hop_length = hop_length
        self.win_length = win_length
        self.window = window
        self.forward_transform = None
        scale = self.filter_length / self.hop_length
        fourier_basis = np.fft.fft(np.eye(self.filter_length))

        cutoff = int((self.filter_length / 2 + 1))
        fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
                                   np.imag(fourier_basis[:cutoff, :])])

        forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
        inverse_basis = torch.FloatTensor(
            np.linalg.pinv(scale * fourier_basis).T[:, None, :])

        if window is not None:
            assert(filter_length >= win_length)
            # get window and zero center pad it to filter_length
            fft_window = get_window(window, win_length, fftbins=True)
            fft_window = pad_center(data=fft_window, size=filter_length)
            fft_window = torch.from_numpy(fft_window).float()

            # window the bases
            forward_basis *= fft_window
            inverse_basis *= fft_window

        self.register_buffer('forward_basis', forward_basis.float())
        self.register_buffer('inverse_basis', inverse_basis.float())

    def transform(self, input_data):
        num_batches = input_data.size(0)
        num_samples = input_data.size(1)

        self.num_samples = num_samples

        # similar to librosa, reflect-pad the input
        input_data = input_data.view(num_batches, 1, num_samples)
        input_data = F.pad(
            input_data.unsqueeze(1),
            (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
            mode='reflect')
        input_data = input_data.squeeze(1)

        forward_transform = F.conv1d(
            input_data,
            Variable(self.forward_basis, requires_grad=False),
            stride=self.hop_length,
            padding=0)

        cutoff = int((self.filter_length / 2) + 1)
        real_part = forward_transform[:, :cutoff, :]
        imag_part = forward_transform[:, cutoff:, :]

        magnitude = torch.sqrt(real_part**2 + imag_part**2)
        phase = torch.autograd.Variable(
            torch.atan2(imag_part.data, real_part.data))

        return magnitude, phase

    def inverse(self, magnitude, phase):
        recombine_magnitude_phase = torch.cat(
            [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)

        inverse_transform = F.conv_transpose1d(
            recombine_magnitude_phase,
            Variable(self.inverse_basis, requires_grad=False),
            stride=self.hop_length,
            padding=0)

        if self.window is not None:
            window_sum = window_sumsquare(
                self.window, magnitude.size(-1), hop_length=self.hop_length,
                win_length=self.win_length, n_fft=self.filter_length,
                dtype=np.float32)
            # remove modulation effects
            approx_nonzero_indices = torch.from_numpy(
                np.where(window_sum > tiny(window_sum))[0])
            window_sum = torch.autograd.Variable(
                torch.from_numpy(window_sum), requires_grad=False)
            window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
            inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]

            # scale by hop ratio
            inverse_transform *= float(self.filter_length) / self.hop_length

        inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
        inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]

        return inverse_transform

    def forward(self, input_data):
        self.magnitude, self.phase = self.transform(input_data)
        reconstruction = self.inverse(self.magnitude, self.phase)
        return reconstruction


In [None]:
## layers.py

class LinearNorm(torch.nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
        super(LinearNorm, self).__init__()
        self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)

        torch.nn.init.xavier_uniform_(
            self.linear_layer.weight,
            gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, x):
        return self.linear_layer(x)


class ConvNorm(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, dilation=1, bias=True, w_init_gain='linear'):
        super(ConvNorm, self).__init__()
        if padding is None:
            assert(kernel_size % 2 == 1)
            padding = int(dilation * (kernel_size - 1) / 2)

        self.conv = torch.nn.Conv1d(in_channels, out_channels,
                                    kernel_size=kernel_size, stride=stride,
                                    padding=padding, dilation=dilation,
                                    bias=bias)

        torch.nn.init.xavier_uniform_(
            self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, signal):
        conv_signal = self.conv(signal)
        return conv_signal


class TacotronSTFT(torch.nn.Module):
    def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
                 n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
                 mel_fmax=8000.0):
        super(TacotronSTFT, self).__init__()
        self.n_mel_channels = n_mel_channels
        self.sampling_rate = sampling_rate
        self.stft_fn = STFT(filter_length, hop_length, win_length)
        mel_basis = librosa_mel_fn(
            sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax)
        mel_basis = torch.from_numpy(mel_basis).float()
        self.register_buffer('mel_basis', mel_basis)

    def spectral_normalize(self, magnitudes):
        output = dynamic_range_compression(magnitudes)
        return output

    def spectral_de_normalize(self, magnitudes):
        output = dynamic_range_decompression(magnitudes)
        return output

    def spectrogram(self,y):
        assert(torch.min(y.data) >= -1)
        assert(torch.max(y.data) <= 1)

        magnitudes, phases = self.stft_fn.transform(y)
        magnitudes = magnitudes.data

        return magnitudes

    def mel_spectrogram(self, y):
        """Computes mel-spectrograms from a batch of waves
        PARAMS
        ------
        y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]

        RETURNS
        -------
        mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
        """
        # assert(torch.min(y.data) >= -1)
        # assert(torch.max(y.data) <= 1)
        magnitudes, phases = self.stft_fn.transform(y)
        magnitudes = magnitudes.data
        mel_output = torch.matmul(self.mel_basis, magnitudes)
        mel_output = self.spectral_normalize(mel_output)
        return mel_output


In [None]:
## model.py

class LocationLayer(nn.Module):
    def __init__(self, attention_n_filters, attention_kernel_size,
                 attention_dim):
        super(LocationLayer, self).__init__()
        padding = int((attention_kernel_size - 1) / 2)
        self.location_conv = ConvNorm(2, attention_n_filters,
                                      kernel_size=attention_kernel_size,
                                      padding=padding, bias=False, stride=1,
                                      dilation=1)
        self.location_dense = LinearNorm(attention_n_filters, attention_dim,
                                         bias=False, w_init_gain='tanh')

    def forward(self, attention_weights_cat):
        processed_attention = self.location_conv(attention_weights_cat)
        processed_attention = processed_attention.transpose(1, 2)
        processed_attention = self.location_dense(processed_attention)
        return processed_attention


class Attention(nn.Module):
    def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
                 attention_location_n_filters, attention_location_kernel_size,num_heads=4):
        super(Attention, self).__init__()
        self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
                                      bias=False, w_init_gain='tanh')
        self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
                                       w_init_gain='tanh')
        # self.v = LinearNorm(attention_dim, 1, bias=False)
        self.attention = nn.MultiheadAttention(attention_dim,num_heads,batch_first=True)
        self.location_layer = LocationLayer(attention_location_n_filters,
                                            attention_location_kernel_size,
                                            attention_dim)
        self.score_mask_value = -float("inf")

    def get_alignment_energies(self, query, processed_memory,
                               attention_weights_cat):
        """
        PARAMS
        ------
        query: decoder output (batch, n_mel_channels * n_frames_per_step)
        processed_memory: processed encoder outputs (B, T_in, attention_dim)
        attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)

        RETURNS
        -------
        alignment (batch, max_time)
        """

        processed_query = self.query_layer(query.unsqueeze(1))
        processed_attention_weights = self.location_layer(attention_weights_cat)
        # energies = self.v(torch.tanh(
        #     processed_query + processed_attention_weights + processed_memory))
        # energies = energies.squeeze(-1)
        return processed_query,processed_attention_weights

    def forward(self, attention_hidden_state, memory, processed_memory,
                attention_weights_cat, mask):
        """
        PARAMS
        ------
        attention_hidden_state: attention rnn last output
        memory: encoder outputs
        processed_memory: processed encoder outputs
        attention_weights_cat: previous and cummulative attention weights
        mask: binary mask for padded data
        """
        # alignment = self.get_alignment_energies(
            # attention_hidden_state, processed_memory, attention_weights_cat)

        processed_query,processed_attention_weights = self.get_alignment_energies(
        attention_hidden_state, processed_memory, attention_weights_cat)

        # if mask is not None:
        #     alignment.data.masked_fill_(mask, self.score_mask_value)

        # attention_weights = F.softmax(alignment, dim=1)
        # attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
        # attention_context = attention_context.squeeze(1)

        attention_context,attention_weights = self.attention(processed_query,processed_attention_weights,processed_memory)
        return attention_context.squeeze(1), attention_weights.squeeze(1)


# class Attention(nn.Module):
#     def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
#                  attention_location_n_filters, attention_location_kernel_size):
#         super(Attention, self).__init__()
#         self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
#                                       bias=False, w_init_gain='tanh')
#         self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
#                                        w_init_gain='tanh')
#         self.v = LinearNorm(attention_dim, 1, bias=False)
#         self.location_layer = LocationLayer(attention_location_n_filters,
#                                             attention_location_kernel_size,
#                                             attention_dim)
#         self.score_mask_value = -float("inf")

#     def get_alignment_energies(self, query, processed_memory,
#                                attention_weights_cat):
#         """
#         PARAMS
#         ------
#         query: decoder output (batch, n_mel_channels * n_frames_per_step)
#         processed_memory: processed encoder outputs (B, T_in, attention_dim)
#         attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)

#         RETURNS
#         -------
#         alignment (batch, max_time)
#         """

#         processed_query = self.query_layer(query.unsqueeze(1))
#         processed_attention_weights = self.location_layer(attention_weights_cat)
#         energies = self.v(torch.tanh(
#             processed_query + processed_attention_weights + processed_memory))

#         energies = energies.squeeze(-1)
#         return energies

#     def forward(self, attention_hidden_state, memory, processed_memory,
#                 attention_weights_cat, mask):
#         """
#         PARAMS
#         ------
#         attention_hidden_state: attention rnn last output
#         memory: encoder outputs
#         processed_memory: processed encoder outputs
#         attention_weights_cat: previous and cummulative attention weights
#         mask: binary mask for padded data
#         """
#         alignment = self.get_alignment_energies(
#             attention_hidden_state, processed_memory, attention_weights_cat)

#         if mask is not None:
#             alignment.data.masked_fill_(mask, self.score_mask_value)

#         attention_weights = F.softmax(alignment, dim=1)
#         attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
#         attention_context = attention_context.squeeze(1)

#         return attention_context, attention_weights


class Prenet(nn.Module):
    def __init__(self, in_dim, sizes):
        super(Prenet, self).__init__()
        in_sizes = [in_dim] + sizes[:-1]
        self.layers = nn.ModuleList(
            [LinearNorm(in_size, out_size, bias=False)
             for (in_size, out_size) in zip(in_sizes, sizes)])

    def forward(self, x):
        for linear in self.layers:
            x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
        return x


class Postnet(nn.Module):
    """Postnet
        - Five 1-d convolution with 512 channels and kernel size 5
    """

    def __init__(self, hparams):
        super(Postnet, self).__init__()
        self.convolutions = nn.ModuleList()

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
                         kernel_size=hparams.postnet_kernel_size, stride=1,
                         padding=int((hparams.postnet_kernel_size - 1) / 2),
                         dilation=1, w_init_gain='tanh'),
                nn.BatchNorm1d(hparams.postnet_embedding_dim))
        )

        for i in range(1, hparams.postnet_n_convolutions - 1):
            self.convolutions.append(
                nn.Sequential(
                    ConvNorm(hparams.postnet_embedding_dim,
                             hparams.postnet_embedding_dim,
                             kernel_size=hparams.postnet_kernel_size, stride=1,
                             padding=int((hparams.postnet_kernel_size - 1) / 2),
                             dilation=1, w_init_gain='tanh'),
                    nn.BatchNorm1d(hparams.postnet_embedding_dim))
            )

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
                         kernel_size=hparams.postnet_kernel_size, stride=1,
                         padding=int((hparams.postnet_kernel_size - 1) / 2),
                         dilation=1, w_init_gain='linear'),
                nn.BatchNorm1d(hparams.n_mel_channels))
            )

    def forward(self, x):
        for i in range(len(self.convolutions) - 1):
            x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
        x = F.dropout(self.convolutions[-1](x), 0.5, self.training)

        return x


class Encoder(nn.Module):
    """Encoder module:
        - Three 1-d convolution banks
        - Bidirectional LSTM
    """
    def __init__(self, hparams):
        super(Encoder, self).__init__()

        # convolutions = []
        # for _ in range(hparams.encoder_n_convolutions):
        #     conv_layer = nn.Sequential(
        #         ConvNorm(hparams.encoder_embedding_dim,
        #                  hparams.encoder_embedding_dim,
        #                  kernel_size=hparams.encoder_kernel_size, stride=1,
        #                  padding=int((hparams.encoder_kernel_size - 1) / 2),
        #                  dilation=1, w_init_gain='relu'),
        #         nn.BatchNorm1d(hparams.encoder_embedding_dim))
        #     convolutions.append(conv_layer)
        # self.convolutions = nn.ModuleList(convolutions)

        self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
                            int(hparams.encoder_embedding_dim/2), num_layers= 1, # According to paper it should be 8
                            batch_first=True, bidirectional=True)

    def forward(self, x, input_lengths):
        # for conv in self.convolutions:
        #     x = F.dropout(F.relu(conv(x)), 0.5, self.training)

        x = x.transpose(1, 2)

        # pytorch tensor are not reversible, hence the conversion
        input_lengths = input_lengths.cpu().numpy()
        x = nn.utils.rnn.pack_padded_sequence(
            x, input_lengths, batch_first=True)

        self.lstm.flatten_parameters()
        outputs, _ = self.lstm(x)

        outputs, _ = nn.utils.rnn.pad_packed_sequence(
            outputs, batch_first=True)

        return outputs

    def inference(self, x):
        # for conv in self.convolutions:
        #     x = F.dropout(F.relu(conv(x)), 0.5, self.training)

        x = x.transpose(1, 2)

        self.lstm.flatten_parameters()
        outputs, _ = self.lstm(x)

        return outputs


class Decoder(nn.Module):
    def __init__(self, hparams):
        super(Decoder, self).__init__()
        self.n_mel_channels = hparams.n_mel_channels
        self.out_channels = hparams.out_channels
        self.n_frames_per_step = hparams.n_frames_per_step
        self.encoder_embedding_dim = hparams.encoder_embedding_dim
        self.attention_rnn_dim = hparams.attention_rnn_dim
        self.decoder_rnn_dim = hparams.decoder_rnn_dim
        self.prenet_dim = hparams.prenet_dim
        self.max_decoder_steps = hparams.max_decoder_steps
        self.gate_threshold = hparams.gate_threshold
        self.p_attention_dropout = hparams.p_attention_dropout
        self.p_decoder_dropout = hparams.p_decoder_dropout
        self.attention_dim = hparams.attention_dim

        self.prenet = Prenet(
            hparams.n_mel_channels * hparams.n_frames_per_step,
            [hparams.prenet_dim, hparams.prenet_dim])

        self.attention_rnn = nn.LSTMCell(
            hparams.prenet_dim + hparams.encoder_embedding_dim,
            hparams.attention_rnn_dim)

        self.attention_layer = Attention(
            hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
            hparams.attention_dim, hparams.attention_location_n_filters,
            hparams.attention_location_kernel_size)

        self.decoder_rnn = nn.LSTMCell(
            hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
            hparams.decoder_rnn_dim, 1)

        self.linear_projection = LinearNorm(
            hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
            hparams.n_mel_channels * hparams.n_frames_per_step)

        self.gate_layer = LinearNorm(
            hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
            bias=True, w_init_gain='sigmoid')

    def get_go_frame(self, memory):
        """ Gets all zeros frames to use as first decoder input
        PARAMS
        ------
        memory: decoder outputs

        RETURNS
        -------
        decoder_input: all zeros frames
        """
        B = memory.size(0)
        decoder_input = Variable(memory.data.new(
            B, self.n_mel_channels * self.n_frames_per_step).zero_())
        return decoder_input

    def initialize_decoder_states(self, memory, mask):
        """ Initializes attention rnn states, decoder rnn states, attention
        weights, attention cumulative weights, attention context, stores memory
        and stores processed memory
        PARAMS
        ------
        memory: Encoder outputs
        mask: Mask for padded data if training, expects None for inference
        """
        B = memory.size(0)
        MAX_TIME = memory.size(1)

        self.attention_hidden = Variable(memory.data.new(
            B, self.attention_rnn_dim).zero_())
        self.attention_cell = Variable(memory.data.new(
            B, self.attention_rnn_dim).zero_())

        self.decoder_hidden = Variable(memory.data.new(
            B, self.decoder_rnn_dim).zero_())
        self.decoder_cell = Variable(memory.data.new(
            B, self.decoder_rnn_dim).zero_())

        self.attention_weights = Variable(memory.data.new(
            B, MAX_TIME).zero_())
        self.attention_weights_cum = Variable(memory.data.new(
            B, MAX_TIME).zero_())
        self.attention_context = Variable(memory.data.new(
            B, self.encoder_embedding_dim).zero_())

        self.memory = memory
        self.processed_memory = self.attention_layer.memory_layer(memory)
        self.mask = mask

    def parse_decoder_inputs(self, decoder_inputs):
        """ Prepares decoder inputs, i.e. mel outputs
        PARAMS
        ------
        decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs

        RETURNS
        -------
        inputs: processed decoder inputs

        """
        # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
        decoder_inputs = decoder_inputs.transpose(1, 2)
        decoder_inputs = decoder_inputs.view(
            decoder_inputs.size(0),
            int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
        # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
        decoder_inputs = decoder_inputs.transpose(0, 1)
        return decoder_inputs

    def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
        """ Prepares decoder outputs for output
        PARAMS
        ------
        mel_outputs:
        gate_outputs: gate output energies
        alignments:

        RETURNS
        -------
        mel_outputs:
        gate_outpust: gate output energies
        alignments:
        """
        # (T_out, B) -> (B, T_out)
        alignments = torch.stack(alignments).transpose(0, 1)
        # (T_out, B) -> (B, T_out)
        gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
        gate_outputs = gate_outputs.contiguous()
        # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
        mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
        # decouple frames per step
        mel_outputs = mel_outputs.view(
            mel_outputs.size(0), -1, self.n_mel_channels)
        # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
        mel_outputs = mel_outputs.transpose(1, 2)

        return mel_outputs, gate_outputs, alignments

    def decode(self, decoder_input):
        """ Decoder step using stored states, attention and memory
        PARAMS
        ------
        decoder_input: previous mel output

        RETURNS
        -------
        mel_output:
        gate_output: gate output energies
        attention_weights:
        """
        cell_input = torch.cat((decoder_input, self.attention_context), -1)
        self.attention_hidden, self.attention_cell = self.attention_rnn(
            cell_input, (self.attention_hidden, self.attention_cell))
        self.attention_hidden = F.dropout(
            self.attention_hidden, self.p_attention_dropout, self.training)

        attention_weights_cat = torch.cat(
            (self.attention_weights.unsqueeze(1),
              self.attention_weights_cum.unsqueeze(1)), dim=1)
        self.attention_context, self.attention_weights = self.attention_layer(
            self.attention_hidden, self.memory, self.processed_memory,
            attention_weights_cat, self.mask)

        # print(self.attention_context.shape,self.attention_weights.shape)
        self.attention_weights_cum += self.attention_weights
        decoder_input = torch.cat(
            (self.attention_hidden, self.attention_context), -1)
        self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
            decoder_input, (self.decoder_hidden, self.decoder_cell))
        self.decoder_hidden = F.dropout(
            self.decoder_hidden, self.p_decoder_dropout, self.training)

        decoder_hidden_attention_context = torch.cat(
            (self.decoder_hidden, self.attention_context), dim=1)
        decoder_output = self.linear_projection(
            decoder_hidden_attention_context)

        gate_prediction = self.gate_layer(decoder_hidden_attention_context)
        return decoder_output, gate_prediction, self.attention_weights

    def forward(self, memory,decoder_inputs,memory_lengths,output_lengths):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """
        # print(memory.shape)
        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(decoder_inputs)
        self.initialize_decoder_states(
            memory, mask=~get_mask_from_lengths(memory_lengths))
        # print(decoder_inputs.shape)
        mel_outputs, gate_outputs, alignments = [], [], []
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            decoder_input = decoder_inputs[len(mel_outputs)]
            mel_output, gate_output, attention_weights = self.decode(
                decoder_input)
            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze(1)]
            alignments += [attention_weights]

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        # print(mel_outputs.shape)
        return mel_outputs, gate_outputs, alignments

    def inference(self, memory):
        """ Decoder inference
        PARAMS
        ------
        memory: Encoder outputs

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """
        decoder_input = self.get_go_frame(memory)

        self.initialize_decoder_states(memory, mask=None)

        mel_outputs, gate_outputs, alignments = [], [], []
        while True:
            decoder_input = self.prenet(decoder_input)
            mel_output, gate_output, alignment = self.decode(decoder_input)

            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output]
            alignments += [alignment]

            if torch.sigmoid(gate_output.data) > self.gate_threshold:
                break
            elif len(mel_outputs) == self.max_decoder_steps:
                print("Warning! Reached max decoder steps")
                break

            decoder_input = mel_output

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        return mel_outputs, gate_outputs, alignments


class Tacotron2(nn.Module):
    def __init__(self, hparams):
        super(Tacotron2, self).__init__()
        self.mask_padding = hparams.mask_padding
        self.fp16_run = hparams.fp16_run
        self.n_mel_channels = hparams.n_mel_channels
        self.n_frames_per_step = hparams.n_frames_per_step
        # self.embedding = nn.Embedding(
        #     hparams.n_symbols, hparams.symbols_embedding_dim)
        self.linear = nn.Linear(hparams.n_mel_channels,hparams.encoder_embedding_dim)
        # std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
        # val = sqrt(3.0) * std  # uniform bounds for std
        # self.embedding.weight.data.uniform_(-val, val)
        self.encoder = Encoder(hparams)
        self.decoder = Decoder(hparams)
        self.postnet = Postnet(hparams)

    def parse_batch(self, batch):
        # # text_padded, input_lengths, mel_padded, gate_padded, \
        # #     output_lengths = batch
        # mels, gate,input_lengths, spec, output_lengths = batch
        # mels = to_gpu(mels).float()
        # gate = to_gpu(gate).float()
        # input_lengths = to_gpu(input_lengths).long()
        # inp_len = torch.max(input_lengths.data).item()
        # spec = to_gpu(spec).float()
        # out_lengths = to_gpu(out_lenghts).long()
        # out_len = torch.max(out_lengths).item()
        # # gate_padded = to_gpu(gate_padded).float()
        # # output_lengths = to_gpu(output_lengths).long()

        # return (
        #     (mels,gate,input_lengths,max_len,spec,output_lengths,out_len),
        #     (spec,output_lengths,out_len))
        input_padded, input_lengths, mel_padded, gate_padded, \
            output_lengths = batch
        input_padded = to_gpu(input_padded).float()
        input_lengths = to_gpu(input_lengths).long()
        max_len = torch.max(input_lengths.data).item()
        mel_padded = to_gpu(mel_padded).float()
        gate_padded = to_gpu(gate_padded).float()
        output_lengths = to_gpu(output_lengths).long()

        return (
            (input_padded, input_lengths, mel_padded, max_len, output_lengths),
            (mel_padded, gate_padded))

    def parse_output(self, outputs, output_lengths=None):
        if self.mask_padding and output_lengths is not None:
            mask = ~get_mask_from_lengths(output_lengths)
            mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            outputs[0].data.masked_fill_(mask, 0.0)
            outputs[1].data.masked_fill_(mask, 0.0)
            outputs[2].data.masked_fill_(mask[:, 0, :], 1e3)  # gate energies

        return outputs

    def forward(self, inputs):
        # mels,gate,input_lengths,max_len,spec,output_lengths,out_len = inputs
        # mels, input_lengths = mels.data, input_lengths.data

        # # embedded_inputs = self.embedding(text_inputs).transpose(1, 2)

        # encoder_outputs = self.encoder(mels, input_lengths)

        # mel_outputs, gate_outputs, alignments = self.decoder(
        #     encoder_outputs, spec,input_lengths,output_lengths)

        # mel_outputs_postnet = self.postnet(mel_outputs)
        # mel_outputs_postnet = mel_outputs + mel_outputs_postnet

        # return self.parse_output(
        #     [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
        #     output_lengths)
        inputs, input_lengths, mels, max_len, output_lengths = inputs
        input_lengths, output_lengths = input_lengths.data, output_lengths.data

        embedded_inputs = self.linear(inputs.transpose(1,2)).transpose(1, 2)

        encoder_outputs = self.encoder(embedded_inputs, input_lengths)
        # print(encoder_outputs.shape)
        mel_outputs, gate_outputs, alignments = self.decoder(
            encoder_outputs, mels, input_lengths,output_lengths)

        mel_outputs_postnet = self.postnet(mel_outputs)
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet

        return self.parse_output(
            [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
            output_lengths)

    def inference(self, inputs):
        # embedded_inputs = self.embedding(inputs).transpose(1, 2)
        encoder_outputs = self.encoder.inference(inputs)
        mel_outputs, gate_outputs, alignments = self.decoder.inference(
            encoder_outputs)

        mel_outputs_postnet = self.postnet(mel_outputs)
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet

        outputs = self.parse_output(
            [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])

        return outputs


In [None]:
## loss_function.py
class Tacotron2Loss(nn.Module):
    def __init__(self):
        super(Tacotron2Loss, self).__init__()

    def forward(self, model_output, targets):
        mel_target, gate_target = targets[0], targets[1]
        mel_target.requires_grad = False
        gate_target.requires_grad = False
        gate_target = gate_target.view(-1, 1)

        mel_out, mel_out_postnet, gate_out, _ = model_output
        gate_out = gate_out.view(-1, 1)
        mel_loss = nn.MSELoss()(mel_out, mel_target) + \
            nn.MSELoss()(mel_out_postnet, mel_target)
        gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
        return mel_loss + gate_loss

In [None]:
## loss_scaler.py

import torch

class LossScaler:

    def __init__(self, scale=1):
        self.cur_scale = scale

    # `params` is a list / generator of torch.Variable
    def has_overflow(self, params):
        return False

    # `x` is a torch.Tensor
    def _has_inf_or_nan(x):
        return False

    # `overflow` is boolean indicating whether we overflowed in gradient
    def update_scale(self, overflow):
        pass

    @property
    def loss_scale(self):
        return self.cur_scale

    def scale_gradient(self, module, grad_in, grad_out):
        return tuple(self.loss_scale * g for g in grad_in)

    def backward(self, loss):
        scaled_loss = loss*self.loss_scale
        scaled_loss.backward()

class DynamicLossScaler:

    def __init__(self,
                 init_scale=2**32,
                 scale_factor=2.,
                 scale_window=1000):
        self.cur_scale = init_scale
        self.cur_iter = 0
        self.last_overflow_iter = -1
        self.scale_factor = scale_factor
        self.scale_window = scale_window

    # `params` is a list / generator of torch.Variable
    def has_overflow(self, params):
#        return False
        for p in params:
            if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
                return True

        return False

    # `x` is a torch.Tensor
    def _has_inf_or_nan(x):
        cpu_sum = float(x.float().sum())
        if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
            return True
        return False

    # `overflow` is boolean indicating whether we overflowed in gradient
    def update_scale(self, overflow):
        if overflow:
            #self.cur_scale /= self.scale_factor
            self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
            self.last_overflow_iter = self.cur_iter
        else:
            if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
                self.cur_scale *= self.scale_factor
#        self.cur_scale = 1
        self.cur_iter += 1

    @property
    def loss_scale(self):
        return self.cur_scale

    def scale_gradient(self, module, grad_in, grad_out):
        return tuple(self.loss_scale * g for g in grad_in)

    def backward(self, loss):
        scaled_loss = loss*self.loss_scale
        scaled_loss.backward()

In [None]:
hparams = create_hparams()

torch.manual_seed(hparams.seed)
torch.cuda.manual_seed(hparams.seed)

In [None]:
input_path = '/content/drive/MyDrive/Dubbing Project/speech2speechData /never_have/wav_files/hindi'
output_path = '/content/drive/MyDrive/Dubbing Project/speech2speechData /never_have/wav_files/english'

t_input_files = [os.path.join(input_path, file_name) for file_name in os.listdir(input_path)[:10]]
t_output_files = [os.path.join(output_path, file_name) for file_name in os.listdir(output_path)[:10]]

v_input_files = [os.path.join(input_path, file_name) for file_name in os.listdir(input_path)[15:20]]
v_output_files = [os.path.join(output_path, file_name) for file_name in os.listdir(output_path)[15:20]]

In [None]:
trainset = TextMelLoader((t_input_files, t_output_files), hparams)
valset = TextMelLoader((v_input_files, v_output_files), hparams)

collate_fn = TextMelCollate(hparams.n_frames_per_step)

train_loader = DataLoader(trainset, num_workers=2, shuffle=True,
                          batch_size=hparams.batch_size, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)


In [None]:
trainset[0]

(tensor([[ 2.8684e+00,  2.5318e+00,  1.7466e+00,  ...,  2.8774e+00,
           2.8729e+00,  3.4255e+00],
         [ 3.1703e+00,  3.2526e+00,  3.6137e+00,  ...,  3.7573e+00,
           4.1325e+00,  3.9588e+00],
         [ 2.8741e+00,  3.3548e+00,  3.5645e+00,  ...,  3.4109e+00,
           4.5083e+00,  4.8201e+00],
         ...,
         [-1.0485e+00, -8.5859e-01, -4.9844e-01,  ..., -7.8928e-02,
          -2.1793e-01, -5.3728e-01],
         [-1.1455e+00, -8.2102e-01, -6.4068e-01,  ...,  4.9775e-02,
          -1.6804e-04, -1.9076e-01],
         [-1.1992e+00, -6.7772e-01, -4.6821e-01,  ..., -2.6285e-01,
          -3.8602e-01, -5.0477e-01]]),
 tensor([[-3.0442, -3.7131, -6.1225,  ..., -7.6227, -4.9524, -4.2499],
         [-2.9478, -3.7208, -4.4675,  ..., -6.4821, -4.6230, -3.9773],
         [-2.9164, -2.9182, -2.7841,  ..., -5.6637, -4.5555, -3.9699],
         ...,
         [-8.6789, -8.5417, -8.0932,  ..., -9.0226, -8.8471, -8.7186],
         [-8.9122, -8.7363, -8.4004,  ..., -9.1009, -8.9

In [None]:
model = Tacotron2(hparams).cuda()
spec_to_emb_linear = model.linear
encoder = model.encoder

In [None]:
spec_to_emb_linear = nn.Linear(hparams.n_mel_channels,hparams.encoder_embedding_dim)
encoder_lstm = nn.LSTM(hparams.encoder_embedding_dim,
                            int(hparams.encoder_embedding_dim/2), num_layers= 8,
                            batch_first=True, bidirectional=True)

In [None]:
hparams.prenet_dim

32

In [None]:
prenet = Prenet(
            hparams.n_mel_channels * hparams.n_frames_per_step,
            [hparams.prenet_dim, hparams.prenet_dim]).cuda()

attention_layer = Attention(
            hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
            hparams.attention_dim, hparams.attention_location_n_filters,
            hparams.attention_location_kernel_size).cuda()

postnet = Postnet(hparams).cuda()

In [None]:
def get_go_frame(memory):
    """ Gets all zeros frames to use as first decoder input
    PARAMS
    ------
    memory: decoder outputs

    RETURNS
    -------
    decoder_input: all zeros frames
    """
    B = memory.size(0)
    decoder_input = Variable(memory.data.new(
        B, hparams.n_mel_channels * hparams.n_frames_per_step).zero_())
    return decoder_input

def parse_decoder_inputs(decoder_inputs):
    """ Prepares decoder inputs, i.e. mel outputs
    PARAMS
    ------
    decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs

    RETURNS
    -------
    inputs: processed decoder inputs

    """


    # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
    decoder_inputs = decoder_inputs.transpose(1, 2)
    decoder_inputs = decoder_inputs.view(
        decoder_inputs.size(0),
        int(decoder_inputs.size(1)/hparams.n_frames_per_step), -1)
    # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
    decoder_inputs = decoder_inputs.transpose(0, 1)
    return decoder_inputs

In [None]:
attention_rnn = nn.LSTMCell(
    hparams.prenet_dim + hparams.encoder_embedding_dim,
    hparams.attention_rnn_dim).cuda()

attention_layer = Attention(
    hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
    hparams.attention_dim, hparams.attention_location_n_filters,
    hparams.attention_location_kernel_size).cuda()

decoder_rnn = nn.LSTMCell(
    hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
    hparams.decoder_rnn_dim, 1).cuda()

linear_projection = LinearNorm(
    hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
    hparams.n_mel_channels * hparams.n_frames_per_step).cuda()

gate_layer = LinearNorm(
    hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
    bias=True, w_init_gain='sigmoid').cuda()


In [None]:
def parse_decoder_outputs(mel_outputs, gate_outputs, alignments):
    """ Prepares decoder outputs for output
    PARAMS
    ------
    mel_outputs:
    gate_outputs: gate output energies
    alignments:

    RETURNS
    -------
    mel_outputs:
    gate_outpust: gate output energies
    alignments:
    """
    # (T_out, B) -> (B, T_out)
    alignments = torch.stack(alignments).transpose(0, 1)
    # (T_out, B) -> (B, T_out)
    gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
    gate_outputs = gate_outputs.contiguous()
    # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
    mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
    # decouple frames per step
    mel_outputs = mel_outputs.view(
        mel_outputs.size(0), -1, hparams.n_mel_channels)
    # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
    mel_outputs = mel_outputs.transpose(1, 2)

    return mel_outputs, gate_outputs, alignments

In [None]:
def parse_output(outputs, output_lengths=None):
    if hparams.mask_padding and output_lengths is not None:
        mask = ~get_mask_from_lengths(output_lengths)
        mask = mask.expand(hparams.n_mel_channels, mask.size(0), mask.size(1))
        mask = mask.permute(1, 0, 2)

        outputs[0].data.masked_fill_(mask, 0.0)
        outputs[1].data.masked_fill_(mask, 0.0)
        outputs[2].data.masked_fill_(mask[:, 0, :], 1e3)  # gate energies

    return outputs


In [None]:
for i, batch in enumerate(train_loader):
    inputs, input_lengths, mels, max_len, output_lengths = batch

    input_lengths, output_lengths = input_lengths.data, output_lengths.data

    # print("Shape of Input: ", inputs.shape, input_lengths, mels.shape, max_len.shape, output_lengths)

    # print("Encoder Dimension: ", hparams.encoder_embedding_dim)

    # print(inputs.transpose(1,2).shape)
    # embedded_inputs = spec_to_emb_linear(inputs.transpose(1,2)).transpose(1, 2)

    # print("Inputs after Linear Layer: ", embedded_inputs.shape)

    # embedded_inputs = embedded_inputs.transpose(1, 2)

    # # pytorch tensor are not reversible, hence the conversion
    # input_lengths = input_lengths.cpu().numpy()
    # embedded_inputs = nn.utils.rnn.pack_padded_sequence(
    #     embedded_inputs, input_lengths, batch_first=True)

    # print("Padded Sequence: ", embedded_inputs.data.shape)

    # encoder_lstm.flatten_parameters()

    # outputs, _ = encoder_lstm(embedded_inputs)

    # ## Pad the sequences to max lenght
    # outputs, _ = nn.utils.rnn.pad_packed_sequence(
    #     outputs, batch_first=True)

    # print("Padded Sequence after Outputs: ", outputs.data.shape)

    input_lengths, output_lengths = input_lengths.data, output_lengths.data

    embedded_inputs = spec_to_emb_linear(inputs.transpose(1,2)).transpose(1, 2)

    encoder_outputs = encoder(embedded_inputs.cuda(), input_lengths)

    print("Encoder Outputs Shape: ", encoder_outputs.shape)
    ## get_go_frame() -- get the zeros values of encoder in
    decoder_input = get_go_frame(encoder_outputs).unsqueeze(0)

    print("Decoder Input Array: ", decoder_input, decoder_input.shape)

    print("Output Shape before parsing: ", mels.shape)
    decoder_inputs = parse_decoder_inputs(mels)
    print("Output shape after parsing: ", decoder_inputs.shape)

    decoder_inputs = torch.cat((decoder_input.cuda(), decoder_inputs.cuda()), dim=0)

    print("Decoder Shape after Concating: ", decoder_inputs.shape)

    ## Prenet
    decoder_inputs = prenet(decoder_inputs)

    print("Decoder Shape after PreNet: ", decoder_inputs.shape)

    memory = encoder_outputs.cuda()
    mask=~get_mask_from_lengths(input_lengths.cuda())

    B = memory.size(0)
    MAX_TIME = memory.size(1)

    attention_hidden = Variable(memory.data.new(
        B, hparams.attention_rnn_dim).zero_())
    print("Shape of attention_hidden: ", attention_hidden.shape)

    attention_cell = Variable(memory.data.new(
        B, hparams.attention_rnn_dim).zero_())
    print("Shape of attention_cell: ", attention_cell.shape)

    decoder_hidden = Variable(memory.data.new(
        B, hparams.decoder_rnn_dim).zero_())
    print("Shape of decoder_hidden: ", decoder_hidden.shape)

    decoder_cell = Variable(memory.data.new(
        B, hparams.decoder_rnn_dim).zero_())
    print("Shape of decoder_cell: ", decoder_cell.shape)

    attention_weights = Variable(memory.data.new(
        B, MAX_TIME).zero_())
    print("Shape of attention_weights: ", attention_weights.shape)

    attention_weights_cum = Variable(memory.data.new(
        B, MAX_TIME).zero_())
    print("Shape of attention_weights_cum: ", attention_weights_cum.shape)

    attention_context = Variable(memory.data.new(
        B, hparams.encoder_embedding_dim).zero_())
    print("Shape of attention_context: ", attention_context.shape)

    memory = memory
    processed_memory = attention_layer.memory_layer(memory)

    mel_outputs, gate_outputs, alignments = [], [], []

    while len(mel_outputs) < decoder_inputs.size(0) - 1:
        decoder_input = decoder_inputs[len(mel_outputs)]

        cell_input = torch.cat((decoder_input, attention_context), -1)

        attention_hidden, attention_cell = attention_rnn(
            cell_input, (attention_hidden, attention_cell))
        attention_hidden = F.dropout(
            attention_hidden, hparams.p_attention_dropout, True)

        attention_weights_cat = torch.cat(
            (attention_weights.unsqueeze(1),
              attention_weights_cum.unsqueeze(1)), dim=1)
        attention_context, attention_weights = attention_layer(
            attention_hidden, memory, processed_memory,
            attention_weights_cat, mask)

        # print(self.attention_context.shape,self.attention_weights.shape)
        attention_weights_cum += attention_weights
        decoder_input = torch.cat(
            (attention_hidden, attention_context), -1)
        decoder_hidden, decoder_cell = decoder_rnn(
            decoder_input, (decoder_hidden, decoder_cell))
        decoder_hidden = F.dropout(
            decoder_hidden, hparams.p_decoder_dropout, True)

        decoder_hidden_attention_context = torch.cat(
            (decoder_hidden, attention_context), dim=1)
        decoder_output = linear_projection(
            decoder_hidden_attention_context)

        gate_prediction = gate_layer(decoder_hidden_attention_context)
        mel_outputs += [decoder_output.squeeze(1)]
        gate_outputs += [gate_prediction.squeeze(1)]
        alignments += [attention_weights]

        decoder_input = decoder_output

    print("Outputs before parsing: ")
    print("Mel Outputs: ", len(mel_outputs), mel_outputs[0].shape)
    print("Gate Outputs: ", len(gate_outputs), gate_outputs[0].shape)
    print("Alignments: ", len(alignments), alignments[0].shape)

    mel_outputs, gate_outputs, alignments = parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

    print("Outputs after parsing: ")
    print("Mel Outputs: ", len(mel_outputs), mel_outputs[0].shape)
    print("Gate Outputs: ", len(gate_outputs), gate_outputs[0].shape)
    print("Alignments: ", len(alignments), alignments[0].shape)

    mel_outputs_postnet = postnet(mel_outputs)
    mel_outputs_postnet = mel_outputs + mel_outputs_postnet

    [mel_outputs, mel_outputs_postnet, gate_outputs, alignments] = parse_output(
                [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
                output_lengths.cuda())
    break

Encoder Outputs Shape:  torch.Size([1, 1876, 128])
Decoder Input Array:  tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]], device='cuda:0') torch.Size([1, 1, 80])
Output Shape before parsing:  torch.Size([1, 80, 1876])
Output shape after parsing:  torch.Size([1876, 1, 80])
Decoder Shape after Concating:  torch.Size([1877, 1, 80])
Decoder Shape after PreNet:  torch.Size([1877, 1, 32])


  ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))


Shape of attention_hidden:  torch.Size([1, 256])
Shape of attention_cell:  torch.Size([1, 256])
Shape of decoder_hidden:  torch.Size([1, 256])
Shape of decoder_cell:  torch.Size([1, 256])
Shape of attention_weights:  torch.Size([1, 1876])
Shape of attention_weights_cum:  torch.Size([1, 1876])
Shape of attention_context:  torch.Size([1, 128])
Outputs before parsing: 
Mel Outputs:  1876 torch.Size([1, 80])
Gate Outputs:  1876 torch.Size([1])
Alignments:  1876 torch.Size([1, 1876])
Outputs after parsing: 
Mel Outputs:  1 torch.Size([80, 1876])
Gate Outputs:  1 torch.Size([1876])
Alignments:  1 torch.Size([1876, 1876])
