<a href="https://colab.research.google.com/github/wasimmadha/dubbing-project-research/blob/main/auxiliary_task_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [85]:
!tar -xzf '/content/drive/MyDrive/Dubbing Project/libspeech/dev-clean.tar.gz' -C /content/
!tar -xzf '/content/drive/MyDrive/Dubbing Project/libspeech/train-clean-100.tar.gz' -C /content/

In [86]:
!pip install pydub wandb



In [87]:
## Imports

import torch
from pydub import AudioSegment
import numpy as np
from scipy.signal import get_window
import librosa.util as librosa_util
import random
import torch.utils.data
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
from math import sqrt
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F

import numpy as np
from scipy.io.wavfile import read

import matplotlib
matplotlib.use("Agg")
import matplotlib.pylab as plt
import numpy as np
import json

import random
from torch.utils.tensorboard import SummaryWriter


import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable


import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable
from scipy.signal import get_window
from librosa.util import pad_center, tiny

from librosa.filters import mel as librosa_mel_fn

from huggingface_hub import PyTorchModelHubMixin
import gc
import wandb
from tqdm import tqdm

%matplotlib inline

#### hparams

In [88]:
### hparams.py

class mapDict(dict):
  __getattr__ = dict.get
  __setattr__ = dict.__setitem__
  __delattr__ = dict.__delitem__



def create_hparams(hparams_string=None,verbose=False):
  hparams = {
    ################################
    # Experiment Parameters        #
    ################################
    "epochs":7,
    "iters_per_checkpoint":10,
    "seed":1234,
    "dynamic_loss_scaling":True,
    "fp16_run":False,
    "distributed_run":False,
    "dist_backend":"nccl",
    "dist_url":"tcp://localhost:54321",
    "cudnn_enabled":True,
    "cudnn_benchmark":False,
    "ignore_layers":['embedding.weight'],

    ################################
    # Data Parameters             #
    ################################
    "load_mel_from_disk":False,
    "training_files":'data/train',
    "validation_files":'data/val',
    "text_cleaners":['english_cleaners'],

    ################################
    # Audio Parameters             #
    ################################
    "max_wav_value":32768.0,
    "sampling_rate":22050,
    "filter_length":1024,
    "hop_length":256,
    "win_length":1024,
    "n_mel_channels":80,
    "mel_fmin":0.0,
    "mel_fmax":8000.0,

    #Data parameters
    "input_data_root": r'C:\Users\Wasim\DubbingProject\Speech2Speech\google_research\translatotron\data\prepared_data\source\train',
    "output_data_root": r'C:\Users\Wasim\DubbingProject\Speech2Speech\google_research\translatotron\data\prepared_data\target\train',
    "train_size": 0.75,
    #Output Audio Parameters
    "out_channels":1025,
    ################################
    # Model Parameters             #
    ################################
    "symbols_embedding_dim":512,

    # Encoder parameters
    "encoder_kernel_size":5,
    "encoder_n_convolutions":3,
    "encoder_embedding_dim":128,

    # Decoder parameters
    "n_frames_per_step":1,  # currently only 1 is supported
    "decoder_rnn_dim":256,
    "prenet_dim":32,
    "max_decoder_steps":1000,
    "gate_threshold":0.5,
    "p_attention_dropout":0.1,
    "p_decoder_dropout":0.1,

    # Attention parameters
    "attention_rnn_dim":256,
    "attention_dim":128,
    "attention_heads": 4,

    # Location Layer parameters
    "attention_location_n_filters":32,
    "attention_location_kernel_size":31,

    # Mel-post processing network parameters
    "postnet_embedding_dim":128,
    "postnet_kernel_size":5,
    "postnet_n_convolutions":2,

    ################################
    # Optimization Hyperparameters #
    ################################
    "use_saved_learning_rate":False,
    "learning_rate":1e-3,
    "weight_decay":1e-6,
    "grad_clip_thresh":1.0,
    "batch_size":32,
    "mask_padding":True
    # set model's padded outputs to padded values
  }

  hparams = mapDict(hparams)

  return hparams

### Dataset

#### Reading Files

In [89]:
def read_json(file_path):
  with open(file_path, 'r') as file:
    file_content = json.load(file)
  return file_content

In [90]:
train_data = read_json('/content/drive/MyDrive/Dubbing Project/libspeech/train_output_dict.json')
valid_data = read_json('/content/drive/MyDrive/Dubbing Project/libspeech/valid_output_dict.json')
reverse_phoneme = read_json('/content/drive/MyDrive/Dubbing Project/libspeech/reverse_phoneme_dict.json')
phoneme_list = read_json('/content/drive/MyDrive/Dubbing Project/libspeech/phoneme_dict.json')

In [91]:
len(train_data.keys()), len(valid_data.keys())

(28539, 2703)

In [92]:
list(train_data.keys())[0], train_data[list(train_data.keys())[0]].keys()

('5678-43302-0000', dict_keys(['file_path', 'text', 'phonemes']))

In [93]:
list(valid_data.keys())[0], valid_data[list(valid_data.keys())[0]].keys()

('6313-66125-0000', dict_keys(['file_path', 'text', 'phonemes']))

In [94]:
print("Total Phonemes in reverse: ", len(reverse_phoneme))

Total Phonemes in reverse:  72


In [95]:
print("Total Phonemes: ", len(phoneme_list))

Total Phonemes:  72


In [96]:
phoneme_list

{'AA2': 0,
 'CH': 1,
 'ZH': 2,
 'DH': 3,
 'UW0': 4,
 'W': 5,
 'IY0': 6,
 'OY0': 7,
 'UH2': 8,
 'D': 9,
 'Y': 10,
 'Z': 11,
 'OY2': 12,
 'HH': 13,
 'AY2': 14,
 'IH0': 15,
 'EH1': 16,
 'UW2': 17,
 'UH0': 18,
 'JH': 19,
 'ER2': 20,
 'V': 21,
 'IY2': 22,
 'G': 23,
 'M': 24,
 'AH1': 25,
 'irish': 26,
 'AA0': 27,
 'AY0': 28,
 'OW1': 29,
 'ER1': 30,
 'AO0': 31,
 'EY1': 32,
 'L': 33,
 'AA1': 34,
 'AY1': 35,
 'SH': 36,
 'ER0': 37,
 'AH0': 38,
 'UH1': 39,
 'K': 40,
 'S': 41,
 'AH2': 42,
 'AO2': 43,
 'IY1': 44,
 'OW0': 45,
 'T': 46,
 'IH1': 47,
 'AW2': 48,
 'EH0': 49,
 'org,': 50,
 'AO1': 51,
 'NG': 52,
 'EH2': 53,
 'R': 54,
 '#': 55,
 'EY0': 56,
 'AW0': 57,
 'AE2': 58,
 'AE1': 59,
 'OW2': 60,
 'N': 61,
 'B': 62,
 'AE0': 63,
 'F': 64,
 'AW1': 65,
 'P': 66,
 'TH': 67,
 'EY2': 68,
 'IH2': 69,
 'UW1': 70,
 'OY1': 71}

In [97]:
## Adding Stop Token
phoneme_list['<end/>'] = 72
reverse_phoneme[72] = '<end/>'

phoneme_list['<start/>'] = 73
reverse_phoneme[73] = '<start/>'

In [98]:
print("Total Phonemes after adding reverse: ", len(reverse_phoneme))
print("Total Phonemes after adding: ", len(phoneme_list))

Total Phonemes after adding reverse:  74
Total Phonemes after adding:  74


#### Dataset Class

In [99]:
def load_wav_to_torch(full_path):
    # print(full_path)
    sampling_rate, data = read(full_path)
    return torch.FloatTensor(data.astype(np.float32)), sampling_rate

# def load_flac_to_torch(full_path):
#     print("Path: ", full_path)
#     # Read FLAC file using pydub
#     audio = AudioSegment.from_file(full_path, format="flac")

#     # Convert to a numpy array
#     data = audio.get_array_of_samples()

#     # Save the data as a temporary WAV file

#     temp_wav_path = "temp.wav"
#     audio.export(temp_wav_path, format="wav")

#     # Read the temporary WAV file using scipy.io.wavfile.read
#     sample_rate, data = read(temp_wav_path)
#     return torch.FloatTensor(data.astype(np.float32)), sample_rate

def load_flac_to_torch(full_path):
    # Read FLAC file using pydub
    audio = AudioSegment.from_file(full_path, format="flac")

    # Convert to a numpy array
    data = np.array(audio.get_array_of_samples(), dtype=np.float32)

    # Normalize the data to the range [-1, 1]
    data /= np.abs(data).max()

    # Convert numpy array to PyTorch tensor
    tensor_data = torch.from_numpy(data)

    # Print the sample rate
    sample_rate = audio.frame_rate

    return tensor_data, sample_rate

def dynamic_range_compression(x, C=1, clip_val=1e-5):
    """
    PARAMS
    ------
    C: compression factor
    """
    return torch.log(torch.clamp(x, min=clip_val) * C)

def dynamic_range_decompression(x, C=1):
    """
    PARAMS
    ------
    C: compression factor used to compress
    """
    return torch.exp(x) / C

def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
                     n_fft=800, dtype=np.float32, norm=None):
    """
    # from librosa 0.6
    Compute the sum-square envelope of a window function at a given hop length.

    This is used to estimate modulation effects induced by windowing
    observations in short-time fourier transforms.

    Parameters
    ----------
    window : string, tuple, number, callable, or list-like
        Window specification, as in `get_window`

    n_frames : int > 0
        The number of analysis frames

    hop_length : int > 0
        The number of samples to advance between frames

    win_length : [optional]
        The length of the window function.  By default, this matches `n_fft`.

    n_fft : int > 0
        The length of each analysis frame.

    dtype : np.dtype
        The data type of the output

    Returns
    -------
    wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
        The sum-squared envelope of the window function
    """
    if win_length is None:
        win_length = n_fft

    n = n_fft + hop_length * (n_frames - 1)
    x = np.zeros(n, dtype=dtype)

    # Compute the squared window at the desired length
    win_sq = get_window(window, win_length, fftbins=True)
    win_sq = librosa_util.normalize(win_sq, norm=norm)**2
    win_sq = librosa_util.pad_center(win_sq, n_fft)

    # Fill the envelope
    for i in range(n_frames):
        sample = i * hop_length
        x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
    return x


def get_mask_from_lengths(lengths):
    max_len = torch.max(lengths).item()
    ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
    mask = (ids < lengths.unsqueeze(1)).bool()
    return mask


def load_filepaths_and_text(filename, split="|"):
    with open(filename, encoding='utf-8') as f:
        filepaths_and_text = [line.strip().split(split) for line in f]
    return filepaths_and_text


def to_gpu(x):
    x = x.contiguous()

    if torch.cuda.is_available():
        x = x.cuda(non_blocking=True)
    return torch.autograd.Variable(x)

In [100]:
## stft.py

"""
BSD 3-Clause License

Copyright (c) 2017, Prem Seetharaman
All rights reserved.

* Redistribution and use in source and binary forms, with or without
  modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice,
  this list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice, this
  list of conditions and the following disclaimer in the
  documentation and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from this
  software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""

class STFT(torch.nn.Module):
    """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
    def __init__(self, filter_length=800, hop_length=200, win_length=800,
                 window='hann'):
        super(STFT, self).__init__()
        self.filter_length = filter_length
        self.hop_length = hop_length
        self.win_length = win_length
        self.window = window
        self.forward_transform = None
        scale = self.filter_length / self.hop_length
        fourier_basis = np.fft.fft(np.eye(self.filter_length))

        cutoff = int((self.filter_length / 2 + 1))
        fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
                                   np.imag(fourier_basis[:cutoff, :])])

        forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
        inverse_basis = torch.FloatTensor(
            np.linalg.pinv(scale * fourier_basis).T[:, None, :])

        if window is not None:
            assert(filter_length >= win_length)
            # get window and zero center pad it to filter_length
            fft_window = get_window(window, win_length, fftbins=True)
            fft_window = pad_center(data=fft_window, size=filter_length)
            fft_window = torch.from_numpy(fft_window).float()

            # window the bases
            forward_basis *= fft_window
            inverse_basis *= fft_window

        self.register_buffer('forward_basis', forward_basis.float())
        self.register_buffer('inverse_basis', inverse_basis.float())

    def transform(self, input_data):
        num_batches = input_data.size(0)
        num_samples = input_data.size(1)

        self.num_samples = num_samples

        # similar to librosa, reflect-pad the input
        input_data = input_data.view(num_batches, 1, num_samples)
        input_data = F.pad(
            input_data.unsqueeze(1),
            (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
            mode='reflect')
        input_data = input_data.squeeze(1)

        forward_transform = F.conv1d(
            input_data,
            Variable(self.forward_basis, requires_grad=False),
            stride=self.hop_length,
            padding=0)

        cutoff = int((self.filter_length / 2) + 1)
        real_part = forward_transform[:, :cutoff, :]
        imag_part = forward_transform[:, cutoff:, :]

        magnitude = torch.sqrt(real_part**2 + imag_part**2)
        phase = torch.autograd.Variable(
            torch.atan2(imag_part.data, real_part.data))

        return magnitude, phase

    def inverse(self, magnitude, phase):
        recombine_magnitude_phase = torch.cat(
            [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)

        inverse_transform = F.conv_transpose1d(
            recombine_magnitude_phase,
            Variable(self.inverse_basis, requires_grad=False),
            stride=self.hop_length,
            padding=0)

        if self.window is not None:
            window_sum = window_sumsquare(
                self.window, magnitude.size(-1), hop_length=self.hop_length,
                win_length=self.win_length, n_fft=self.filter_length,
                dtype=np.float32)
            # remove modulation effects
            approx_nonzero_indices = torch.from_numpy(
                np.where(window_sum > tiny(window_sum))[0])
            window_sum = torch.autograd.Variable(
                torch.from_numpy(window_sum), requires_grad=False)
            window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
            inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]

            # scale by hop ratio
            inverse_transform *= float(self.filter_length) / self.hop_length

        inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
        inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]

        return inverse_transform

    def forward(self, input_data):
        self.magnitude, self.phase = self.transform(input_data)
        reconstruction = self.inverse(self.magnitude, self.phase)
        return reconstruction

In [101]:
class TacotronSTFT(torch.nn.Module):
    def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
                 n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
                 mel_fmax=8000.0):
        super(TacotronSTFT, self).__init__()
        self.n_mel_channels = n_mel_channels
        self.sampling_rate = sampling_rate
        self.stft_fn = STFT(filter_length, hop_length, win_length)
        mel_basis = librosa_mel_fn(
            sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax)
        mel_basis = torch.from_numpy(mel_basis).float()
        self.register_buffer('mel_basis', mel_basis)

    def spectral_normalize(self, magnitudes):
        output = dynamic_range_compression(magnitudes)
        return output

    def spectral_de_normalize(self, magnitudes):
        output = dynamic_range_decompression(magnitudes)
        return output

    def spectrogram(self,y):
        assert(torch.min(y.data) >= -1)
        assert(torch.max(y.data) <= 1)

        magnitudes, phases = self.stft_fn.transform(y)
        magnitudes = magnitudes.data

        return magnitudes

    def mel_spectrogram(self, y):
        """Computes mel-spectrograms from a batch of waves
        PARAMS
        ------
        y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]

        RETURNS
        -------
        mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
        """
        # assert(torch.min(y.data) >= -1)
        # assert(torch.max(y.data) <= 1)
        magnitudes, phases = self.stft_fn.transform(y)
        magnitudes = magnitudes.data
        mel_output = torch.matmul(self.mel_basis, magnitudes)
        mel_output = self.spectral_normalize(mel_output)
        return mel_output

In [102]:
train_audio_list = [i['file_path'] for i in train_data.values()]
train_phonemes_list = [i['phonemes'] for i in train_data.values()]

In [103]:
### data_utils.py
class AudioPhonemeDataset(torch.utils.data.Dataset):
    """
        1) loads audio,text pairs
        2) normalizes text and converts them to sequences of one-hot vectors
        3) computes mel-spectrograms from audio files.
    """
    def __init__(self, audio_paths, phoneme_list, hparams):
        self.inputs = audio_paths
        self.outputs = phoneme_list
        self.text_cleaners = hparams.text_cleaners
        self.max_wav_value = hparams.max_wav_value
        self.sampling_rate = hparams.sampling_rate
        self.load_mel_from_disk = hparams.load_mel_from_disk
        self.stft = TacotronSTFT(
            hparams.filter_length, hparams.hop_length, hparams.win_length,
            hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
            hparams.mel_fmax)

    def get_mel_spec_pair(self, index):
        # separate filename and text
        # lin = self.get_spec(self.outputs[index])
        # mel = self.get_mel(self.inputs[index])
        inputs = self.get_mel(self.inputs[index])
        outputs = self.get_phonemes_list(self.outputs[index])

        return (inputs,outputs)

    def get_mel(self, filename):
        if not self.load_mel_from_disk:
            audio, sampling_rate = load_flac_to_torch(filename)
            # audio = audio[:, 0]
            # if sampling_rate != self.stft.sampling_rate:
            #     raise ValueError("{} {} SR doesn't match target {} SR".format(
            #         sampling_rate, self.stft.sampling_rate))
            audio_norm = audio / self.max_wav_value
            audio_norm = audio_norm.unsqueeze(0)
            audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
            melspec = self.stft.mel_spectrogram(audio_norm)
            melspec = torch.squeeze(melspec, 0)
            melspec_arr = melspec[0].numpy()
        else:
            melspec = torch.from_numpy(np.load(filename))
            assert melspec.size(0) == self.stft.n_mel_channels, (
                'Mel dimension mismatch: given {}, expected {}'.format(
                    melspec.size(0), self.stft.n_mel_channels))

        return melspec

    def get_phonemes_list(self, phonemes):
        phonemes.insert(0, '<start/>')
        phonemes.append('<end/>')

        tensor_phonemes = torch.zeros(hparams.vocab_size, len(phonemes))

        for i, phoneme in enumerate(phonemes):
            tensor_phonemes[phoneme_list[phoneme], i] = 1

        return tensor_phonemes


    def get_spec(self, filename):
        if not self.load_mel_from_disk:
            audio, sampling_rate = load_wav_to_torch(filename)
            if sampling_rate != self.stft.sampling_rate:
                raise ValueError("{} {} SR doesn't match target {} SR".format(
                    sampling_rate, self.stft.sampling_rate))
            audio_norm = audio / self.max_wav_value
            audio_norm = audio_norm.unsqueeze(0)
            audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
            spec = self.stft.spectrogram(audio_norm)
            spec = torch.squeeze(spec, 0)
        else:
            spec = torch.from_numpy(np.load(filename))

        return spec

    def __getitem__(self, index):
        return self.get_mel_spec_pair(index)

    def __len__(self):
        return len(self.inputs)


class AudioPhonemesCollate():
    """ Zero-pads model inputs and targets based on number of frames per setep
    """
    def __init__(self, n_frames_per_step):
        self.n_frames_per_step = n_frames_per_step

    def __call__(self, batch):
        """Collate's training batch from normalized text and mel-spectrogram
        PARAMS
        ------
        batch: [mel_normalized, text_normalized]
        """

        num_mels = batch[0][0].size(0)


        # Sort sequences by input length in descending order
        input_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([x[0].size(1) for x in batch]),
            dim=0, descending=True
        )

        max_input_len = input_lengths[0]

        # Pad input sequences
        input_padded = torch.FloatTensor(len(batch), num_mels, max_input_len)
        input_padded.zero_()
        for i in ids_sorted_decreasing:
            mel = batch[i][0]
            input_padded[i, :, :mel.size(1)] = mel

        max_target_len = max([x[1].size(1) for x in batch])

        # include mel padded and gate padded
        target_padded = torch.FloatTensor(len(batch), hparams.vocab_size, max_target_len)
        target_padded.zero_()
        gate_padded = torch.FloatTensor(len(batch), max_target_len)
        gate_padded.zero_()
        output_lengths = torch.LongTensor(len(batch))

        for i in ids_sorted_decreasing:
            target = batch[i][1]
            target_padded[i, :, :target.size(1)] = target
            gate_padded[i, :len(target)] = 1
            output_lengths[i] = len(target)

        return input_padded, input_lengths, target_padded, gate_padded, \
            output_lengths

In [104]:
hparams = create_hparams()

torch.manual_seed(hparams.seed)
torch.cuda.manual_seed(hparams.seed)

In [105]:
hparams.vocab_size = len(phoneme_list)

In [106]:
for audio in train_audio_list:
  if not audio:
    print(audio)

In [107]:
train_audio_list = [i['file_path'] for i in train_data.values()]
train_phonemes_list = [i['phonemes'] for i in train_data.values()]

train_dataset = AudioPhonemeDataset(train_audio_list, train_phonemes_list, hparams)

In [108]:
valid_audio_list = [i['file_path'] for i in valid_data.values()]
valid_phonemes_list = [i['phonemes'] for i in valid_data.values()]

val_dataset = AudioPhonemeDataset(valid_audio_list, valid_phonemes_list, hparams)

In [109]:
audio, phonemes = train_dataset[0]

In [110]:
audio.shape, phonemes.shape

(torch.Size([80, 937]), torch.Size([74, 157]))

In [111]:
collate_fn = AudioPhonemesCollate(hparams.n_frames_per_step)

train_loader = DataLoader(train_dataset, num_workers=2, shuffle=True,
                          batch_size=3, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)

valid_dataloader = DataLoader(val_dataset, num_workers=2, shuffle=True,
                          batch_size=3, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)


In [112]:
for i, batch in enumerate(train_loader):
  input_padded, input_lengths, target_padded, gate_padded, \
            output_lengths = batch

  print("Inputs Shape: ", input_padded.shape, input_lengths.shape)

  print("Outputs: ",)
  print(target_padded)
  print(gate_padded)

  break

Inputs Shape:  torch.Size([3, 80, 926]) torch.Size([3])
Outputs: 
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [1., 0., 0.,  ..., 0., 0., 0.]]])
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.

### Model

In [29]:
## layers.py

class LinearNorm(torch.nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
        super(LinearNorm, self).__init__()
        self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)

        torch.nn.init.xavier_uniform_(
            self.linear_layer.weight,
            gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, x):
        return self.linear_layer(x)


class ConvNorm(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, dilation=1, bias=True, w_init_gain='linear'):
        super(ConvNorm, self).__init__()
        if padding is None:
            assert(kernel_size % 2 == 1)
            padding = int(dilation * (kernel_size - 1) / 2)

        self.conv = torch.nn.Conv1d(in_channels, out_channels,
                                    kernel_size=kernel_size, stride=stride,
                                    padding=padding, dilation=dilation,
                                    bias=bias)

        torch.nn.init.xavier_uniform_(
            self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, signal):
        conv_signal = self.conv(signal)
        return conv_signal


class TacotronSTFT(torch.nn.Module):
    def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
                 n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
                 mel_fmax=8000.0):
        super(TacotronSTFT, self).__init__()
        self.n_mel_channels = n_mel_channels
        self.sampling_rate = sampling_rate
        self.stft_fn = STFT(filter_length, hop_length, win_length)
        mel_basis = librosa_mel_fn(
            sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax)
        mel_basis = torch.from_numpy(mel_basis).float()
        self.register_buffer('mel_basis', mel_basis)

    def spectral_normalize(self, magnitudes):
        output = dynamic_range_compression(magnitudes)
        return output

    def spectral_de_normalize(self, magnitudes):
        output = dynamic_range_decompression(magnitudes)
        return output

    def spectrogram(self,y):
        assert(torch.min(y.data) >= -1)
        assert(torch.max(y.data) <= 1)

        magnitudes, phases = self.stft_fn.transform(y)
        magnitudes = magnitudes.data

        return magnitudes

    def mel_spectrogram(self, y):
        """Computes mel-spectrograms from a batch of waves
        PARAMS
        ------
        y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]

        RETURNS
        -------
        mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
        """
        # assert(torch.min(y.data) >= -1)
        # assert(torch.max(y.data) <= 1)
        magnitudes, phases = self.stft_fn.transform(y)
        magnitudes = magnitudes.data
        mel_output = torch.matmul(self.mel_basis, magnitudes)
        mel_output = self.spectral_normalize(mel_output)
        return mel_output


In [30]:
## model.py

class LocationLayer(nn.Module):
    def __init__(self, attention_n_filters, attention_kernel_size,
                 attention_dim):
        super(LocationLayer, self).__init__()
        padding = int((attention_kernel_size - 1) / 2)
        self.location_conv = ConvNorm(2, attention_n_filters,
                                      kernel_size=attention_kernel_size,
                                      padding=padding, bias=False, stride=1,
                                      dilation=1)
        self.location_dense = LinearNorm(attention_n_filters, attention_dim,
                                         bias=False, w_init_gain='tanh')

    def forward(self, attention_weights_cat):
        processed_attention = self.location_conv(attention_weights_cat)
        processed_attention = processed_attention.transpose(1, 2)
        processed_attention = self.location_dense(processed_attention)
        return processed_attention


class Attention(nn.Module):
    def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
                 attention_location_n_filters, attention_location_kernel_size,num_heads=4):
        super(Attention, self).__init__()
        self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
                                      bias=False, w_init_gain='tanh')
        self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
                                       w_init_gain='tanh')
        # self.v = LinearNorm(attention_dim, 1, bias=False)
        self.attention = nn.MultiheadAttention(attention_dim,num_heads,batch_first=True)
        self.location_layer = LocationLayer(attention_location_n_filters,
                                            attention_location_kernel_size,
                                            attention_dim)
        self.score_mask_value = -float("inf")

    def get_alignment_energies(self, query, processed_memory,
                               attention_weights_cat):
        """
        PARAMS
        ------
        query: decoder output (batch, n_mel_channels * n_frames_per_step)
        processed_memory: processed encoder outputs (B, T_in, attention_dim)
        attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)

        RETURNS
        -------
        alignment (batch, max_time)
        """

        processed_query = self.query_layer(query.unsqueeze(1))
        processed_attention_weights = self.location_layer(attention_weights_cat)
        # energies = self.v(torch.tanh(
        #     processed_query + processed_attention_weights + processed_memory))
        # energies = energies.squeeze(-1)
        return processed_query,processed_attention_weights

    def forward(self, attention_hidden_state, memory, processed_memory,
                attention_weights_cat, mask):
        """
        PARAMS
        ------
        attention_hidden_state: attention rnn last output
        memory: encoder outputs
        processed_memory: processed encoder outputs
        attention_weights_cat: previous and cummulative attention weights
        mask: binary mask for padded data
        """
        # alignment = self.get_alignment_energies(
            # attention_hidden_state, processed_memory, attention_weights_cat)

        processed_query,processed_attention_weights = self.get_alignment_energies(
        attention_hidden_state, processed_memory, attention_weights_cat)

        # if mask is not None:
        #     alignment.data.masked_fill_(mask, self.score_mask_value)

        # attention_weights = F.softmax(alignment, dim=1)
        # attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
        # attention_context = attention_context.squeeze(1)

        attention_context,attention_weights = self.attention(processed_query,processed_attention_weights,processed_memory)
        return attention_context.squeeze(1), attention_weights.squeeze(1)


# class Attention(nn.Module):
#     def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
#                  attention_location_n_filters, attention_location_kernel_size):
#         super(Attention, self).__init__()
#         self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
#                                       bias=False, w_init_gain='tanh')
#         self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
#                                        w_init_gain='tanh')
#         self.v = LinearNorm(attention_dim, 1, bias=False)
#         self.location_layer = LocationLayer(attention_location_n_filters,
#                                             attention_location_kernel_size,
#                                             attention_dim)
#         self.score_mask_value = -float("inf")

#     def get_alignment_energies(self, query, processed_memory,
#                                attention_weights_cat):
#         """
#         PARAMS
#         ------
#         query: decoder output (batch, n_mel_channels * n_frames_per_step)
#         processed_memory: processed encoder outputs (B, T_in, attention_dim)
#         attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)

#         RETURNS
#         -------
#         alignment (batch, max_time)
#         """

#         processed_query = self.query_layer(query.unsqueeze(1))
#         processed_attention_weights = self.location_layer(attention_weights_cat)
#         energies = self.v(torch.tanh(
#             processed_query + processed_attention_weights + processed_memory))

#         energies = energies.squeeze(-1)
#         return energies

#     def forward(self, attention_hidden_state, memory, processed_memory,
#                 attention_weights_cat, mask):
#         """
#         PARAMS
#         ------
#         attention_hidden_state: attention rnn last output
#         memory: encoder outputs
#         processed_memory: processed encoder outputs
#         attention_weights_cat: previous and cummulative attention weights
#         mask: binary mask for padded data
#         """
#         alignment = self.get_alignment_energies(
#             attention_hidden_state, processed_memory, attention_weights_cat)

#         if mask is not None:
#             alignment.data.masked_fill_(mask, self.score_mask_value)

#         attention_weights = F.softmax(alignment, dim=1)
#         attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
#         attention_context = attention_context.squeeze(1)

#         return attention_context, attention_weights


class Prenet(nn.Module):
    def __init__(self, in_dim, sizes):
        super(Prenet, self).__init__()
        in_sizes = [in_dim] + sizes[:-1]
        self.layers = nn.ModuleList(
            [LinearNorm(in_size, out_size, bias=False)
             for (in_size, out_size) in zip(in_sizes, sizes)])

    def forward(self, x):
        for linear in self.layers:
            x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
        return x


class Postnet(nn.Module):
    """Postnet
        - Five 1-d convolution with 512 channels and kernel size 5
    """

    def __init__(self, hparams):
        super(Postnet, self).__init__()
        self.convolutions = nn.ModuleList()

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(hparams.vocab_size, hparams.postnet_embedding_dim,
                         kernel_size=hparams.postnet_kernel_size, stride=1,
                         padding=int((hparams.postnet_kernel_size - 1) / 2),
                         dilation=1, w_init_gain='tanh'),
                nn.BatchNorm1d(hparams.postnet_embedding_dim))
        )

        for i in range(1, hparams.postnet_n_convolutions - 1):
            self.convolutions.append(
                nn.Sequential(
                    ConvNorm(hparams.postnet_embedding_dim,
                             hparams.postnet_embedding_dim,
                             kernel_size=hparams.postnet_kernel_size, stride=1,
                             padding=int((hparams.postnet_kernel_size - 1) / 2),
                             dilation=1, w_init_gain='tanh'),
                    nn.BatchNorm1d(hparams.postnet_embedding_dim))
            )

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(hparams.postnet_embedding_dim, hparams.vocab_size,
                         kernel_size=hparams.postnet_kernel_size, stride=1,
                         padding=int((hparams.postnet_kernel_size - 1) / 2),
                         dilation=1, w_init_gain='linear'),
                nn.BatchNorm1d(hparams.vocab_size))
            )

    def forward(self, x):
        for i in range(len(self.convolutions) - 1):
            x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
        x = F.dropout(self.convolutions[-1](x), 0.5, self.training)

        return x


class Encoder(nn.Module):
    """Encoder module:
        - Three 1-d convolution banks
        - Bidirectional LSTM
    """
    def __init__(self, hparams):
        super(Encoder, self).__init__()

        # convolutions = []
        # for _ in range(hparams.encoder_n_convolutions):
        #     conv_layer = nn.Sequential(
        #         ConvNorm(hparams.encoder_embedding_dim,
        #                  hparams.encoder_embedding_dim,
        #                  kernel_size=hparams.encoder_kernel_size, stride=1,
        #                  padding=int((hparams.encoder_kernel_size - 1) / 2),
        #                  dilation=1, w_init_gain='relu'),
        #         nn.BatchNorm1d(hparams.encoder_embedding_dim))
        #     convolutions.append(conv_layer)
        # self.convolutions = nn.ModuleList(convolutions)

        self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
                            int(hparams.encoder_embedding_dim/2), num_layers= 1, # According to paper it should be 8
                            batch_first=True, bidirectional=True)

    def forward(self, x, input_lengths):
        # for conv in self.convolutions:
        #     x = F.dropout(F.relu(conv(x)), 0.5, self.training)

        x = x.transpose(1, 2)

        # pytorch tensor are not reversible, hence the conversion
        input_lengths = input_lengths.cpu().numpy()
        x = nn.utils.rnn.pack_padded_sequence(
            x, input_lengths, batch_first=True)

        self.lstm.flatten_parameters()
        outputs, _ = self.lstm(x)

        outputs, _ = nn.utils.rnn.pad_packed_sequence(
            outputs, batch_first=True)

        return outputs

    def inference(self, x):
        # for conv in self.convolutions:
        #     x = F.dropout(F.relu(conv(x)), 0.5, self.training)

        x = x.transpose(1, 2)

        self.lstm.flatten_parameters()
        outputs, _ = self.lstm(x)

        return outputs


class PhonemesDecoder(nn.Module):
    def __init__(self, hparams):
        super(PhonemesDecoder, self).__init__()
        self.n_mel_channels = hparams.n_mel_channels
        self.out_channels = hparams.out_channels
        self.n_frames_per_step = hparams.n_frames_per_step
        self.encoder_embedding_dim = hparams.encoder_embedding_dim
        self.attention_rnn_dim = hparams.attention_rnn_dim
        self.decoder_rnn_dim = hparams.decoder_rnn_dim
        self.prenet_dim = hparams.prenet_dim
        self.max_decoder_steps = hparams.max_decoder_steps
        self.gate_threshold = hparams.gate_threshold
        self.p_attention_dropout = hparams.p_attention_dropout
        self.p_decoder_dropout = hparams.p_decoder_dropout
        self.attention_dim = hparams.attention_dim

        self.prenet = Prenet(
            hparams.vocab_size,
            [hparams.prenet_dim, hparams.prenet_dim])

        self.attention_rnn = nn.LSTMCell(
            hparams.prenet_dim + hparams.encoder_embedding_dim,
            hparams.attention_rnn_dim)

        self.attention_layer = Attention(
            hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
            hparams.attention_dim, hparams.attention_location_n_filters,
            hparams.attention_location_kernel_size)

        self.decoder_rnn = nn.LSTMCell(
            hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
            hparams.decoder_rnn_dim, 1)

        self.linear_projection = LinearNorm(
            hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
            hparams.vocab_size, w_init_gain='sigmoid')

        self.gate_layer = LinearNorm(
            hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
            bias=True, w_init_gain='sigmoid')

    def get_go_frame(self, memory):
        """ Gets all zeros frames to use as first decoder input
        PARAMS
        ------
        memory: decoder outputs

        RETURNS
        -------
        decoder_input: all zeros frames
        """
        B = memory.size(0)
        decoder_input = Variable(memory.data.new(
            B, hparams.vocab_size).zero_())
        return decoder_input

    def initialize_decoder_states(self, memory, mask):
        """ Initializes attention rnn states, decoder rnn states, attention
        weights, attention cumulative weights, attention context, stores memory
        and stores processed memory
        PARAMS
        ------
        memory: Encoder outputs
        mask: Mask for padded data if training, expects None for inference
        """
        B = memory.size(0)
        MAX_TIME = memory.size(1)

        self.attention_hidden = Variable(memory.data.new(
            B, self.attention_rnn_dim).zero_())
        self.attention_cell = Variable(memory.data.new(
            B, self.attention_rnn_dim).zero_())

        self.decoder_hidden = Variable(memory.data.new(
            B, self.decoder_rnn_dim).zero_())
        self.decoder_cell = Variable(memory.data.new(
            B, self.decoder_rnn_dim).zero_())

        self.attention_weights = Variable(memory.data.new(
            B, MAX_TIME).zero_())
        self.attention_weights_cum = Variable(memory.data.new(
            B, MAX_TIME).zero_())
        self.attention_context = Variable(memory.data.new(
            B, self.encoder_embedding_dim).zero_())

        self.memory = memory
        self.processed_memory = self.attention_layer.memory_layer(memory)
        self.mask = mask

    def parse_decoder_inputs(self, decoder_inputs):
        """ Prepares decoder inputs, i.e. mel outputs
        PARAMS
        ------
        decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs

        RETURNS
        -------
        inputs: processed decoder inputs

        """
        # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
        decoder_inputs = decoder_inputs.transpose(1, 2)
        decoder_inputs = decoder_inputs.view(
            decoder_inputs.size(0),
            int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
        # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
        decoder_inputs = decoder_inputs.transpose(0, 1)
        return decoder_inputs

    def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
        """ Prepares decoder outputs for output
        PARAMS
        ------
        mel_outputs:
        gate_outputs: gate output energies
        alignments:

        RETURNS
        -------
        mel_outputs:
        gate_outpust: gate output energies
        alignments:
        """
        # (T_out, B) -> (B, T_out)
        alignments = torch.stack(alignments).transpose(0, 1)
        # (T_out, B) -> (B, T_out)
        gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
        gate_outputs = gate_outputs.contiguous()
        # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
        mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
        # decouple frames per step
        mel_outputs = mel_outputs.view(
            mel_outputs.size(0), -1, hparams.vocab_size)
        # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
        mel_outputs = mel_outputs.transpose(1, 2)

        return mel_outputs, gate_outputs, alignments

    def decode(self, decoder_input):
        """ Decoder step using stored states, attention and memory
        PARAMS
        ------
        decoder_input: previous mel output

        RETURNS
        -------
        mel_output:
        gate_output: gate output energies
        attention_weights:
        """
        cell_input = torch.cat((decoder_input, self.attention_context), -1)
        self.attention_hidden, self.attention_cell = self.attention_rnn(
            cell_input, (self.attention_hidden, self.attention_cell))
        self.attention_hidden = F.dropout(
            self.attention_hidden, self.p_attention_dropout, self.training)

        attention_weights_cat = torch.cat(
            (self.attention_weights.unsqueeze(1),
              self.attention_weights_cum.unsqueeze(1)), dim=1)
        self.attention_context, self.attention_weights = self.attention_layer(
            self.attention_hidden, self.memory, self.processed_memory,
            attention_weights_cat, self.mask)

        # print(self.attention_context.shape,self.attention_weights.shape)
        self.attention_weights_cum += self.attention_weights
        decoder_input = torch.cat(
            (self.attention_hidden, self.attention_context), -1)
        self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
            decoder_input, (self.decoder_hidden, self.decoder_cell))
        self.decoder_hidden = F.dropout(
            self.decoder_hidden, self.p_decoder_dropout, self.training)

        decoder_hidden_attention_context = torch.cat(
            (self.decoder_hidden, self.attention_context), dim=1)
        decoder_output = self.linear_projection(
            decoder_hidden_attention_context)

        gate_prediction = self.gate_layer(decoder_hidden_attention_context)
        return decoder_output, gate_prediction, self.attention_weights

    def forward(self, memory,decoder_inputs,memory_lengths,output_lengths):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """
        # print(memory.shape)
        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(decoder_inputs)
        self.initialize_decoder_states(
            memory, mask=~get_mask_from_lengths(memory_lengths))
        # print(decoder_inputs.shape)
        mel_outputs, gate_outputs, alignments = [], [], []
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            decoder_input = decoder_inputs[len(mel_outputs)]
            mel_output, gate_output, attention_weights = self.decode(
                decoder_input)
            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze(1)]
            alignments += [attention_weights]

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        # print(mel_outputs.shape)
        return mel_outputs, gate_outputs, alignments

    def inference(self, memory):
        """ Decoder inference
        PARAMS
        ------
        memory: Encoder outputs

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """
        decoder_input = self.get_go_frame(memory)

        self.initialize_decoder_states(memory, mask=None)

        mel_outputs, gate_outputs, alignments = [], [], []
        while True:
            decoder_input = self.prenet(decoder_input)
            mel_output, gate_output, alignment = self.decode(decoder_input)

            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output]
            alignments += [alignment]

            if torch.sigmoid(gate_output.data) > self.gate_threshold:
                break
            elif len(mel_outputs) == self.max_decoder_steps:
                print("Warning! Reached max decoder steps")
                break

            decoder_input = mel_output

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        return mel_outputs, gate_outputs, alignments


class AuxiliaryPhonemesModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, hparams):
        super(AuxiliaryPhonemesModel, self).__init__()
        self.mask_padding = hparams.mask_padding
        self.fp16_run = hparams.fp16_run
        self.n_mel_channels = hparams.n_mel_channels
        self.n_frames_per_step = hparams.n_frames_per_step
        # self.embedding = nn.Embedding(
        #     hparams.n_symbols, hparams.symbols_embedding_dim)
        self.linear = nn.Linear(hparams.n_mel_channels,hparams.encoder_embedding_dim)
        # std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
        # val = sqrt(3.0) * std  # uniform bounds for std
        # self.embedding.weight.data.uniform_(-val, val)
        self.encoder = Encoder(hparams)
        self.decoder = PhonemesDecoder(hparams)
        self.postnet = Postnet(hparams)

    def parse_batch(self, batch):
        # # text_padded, input_lengths, mel_padded, gate_padded, \
        # #     output_lengths = batch
        # mels, gate,input_lengths, spec, output_lengths = batch
        # mels = to_gpu(mels).float()
        # gate = to_gpu(gate).float()
        # input_lengths = to_gpu(input_lengths).long()
        # inp_len = torch.max(input_lengths.data).item()
        # spec = to_gpu(spec).float()
        # out_lengths = to_gpu(out_lenghts).long()
        # out_len = torch.max(out_lengths).item()
        # # gate_padded = to_gpu(gate_padded).float()
        # # output_lengths = to_gpu(output_lengths).long()

        # return (
        #     (mels,gate,input_lengths,max_len,spec,output_lengths,out_len),
        #     (spec,output_lengths,out_len))
        input_padded, input_lengths, mel_padded, gate_padded, \
            output_lengths = batch
        input_padded = to_gpu(input_padded).float()
        input_lengths = to_gpu(input_lengths).long()
        max_len = torch.max(input_lengths.data).item()
        mel_padded = to_gpu(mel_padded).float()
        gate_padded = to_gpu(gate_padded).float()
        output_lengths = to_gpu(output_lengths).long()

        return (
            (input_padded, input_lengths, mel_padded, max_len, output_lengths),
            (mel_padded, gate_padded))

    def parse_output(self, outputs, output_lengths=None):
        if self.mask_padding and output_lengths is not None:
            mask = ~get_mask_from_lengths(output_lengths)
            mask = mask.expand(hparams.vocab_size, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            outputs[0].data.masked_fill_(mask, 0.0)
            outputs[1].data.masked_fill_(mask, 0.0)
            outputs[2].data.masked_fill_(mask[:, 0, :], 1e3)  # gate energies

        return outputs

    def forward(self, inputs):
        # mels,gate,input_lengths,max_len,spec,output_lengths,out_len = inputs
        # mels, input_lengths = mels.data, input_lengths.data

        # # embedded_inputs = self.embedding(text_inputs).transpose(1, 2)

        # encoder_outputs = self.encoder(mels, input_lengths)

        # mel_outputs, gate_outputs, alignments = self.decoder(
        #     encoder_outputs, spec,input_lengths,output_lengths)

        # mel_outputs_postnet = self.postnet(mel_outputs)
        # mel_outputs_postnet = mel_outputs + mel_outputs_postnet

        # return self.parse_output(
        #     [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
        #     output_lengths)
        inputs, input_lengths, mels, max_len, output_lengths = inputs
        input_lengths, output_lengths = input_lengths.data, output_lengths.data

        embedded_inputs = self.linear(inputs.transpose(1,2)).transpose(1, 2)

        encoder_outputs = self.encoder(embedded_inputs, input_lengths)
        # print(encoder_outputs.shape)
        mel_outputs, gate_outputs, alignments = self.decoder(
            encoder_outputs, mels, input_lengths,output_lengths)

        mel_outputs_postnet = self.postnet(mel_outputs)
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet

        return mel_outputs, mel_outputs_postnet, gate_outputs, alignments

    def inference(self, inputs):
        # embedded_inputs = self.embedding(inputs).transpose(1, 2)
        encoder_outputs = self.encoder.inference(inputs)
        mel_outputs, gate_outputs, alignments = self.decoder.inference(
            encoder_outputs)

        mel_outputs_postnet = self.postnet(mel_outputs)
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet

        outputs = self.parse_output(
            [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])

        return outputs


In [31]:
## loss_function.py
class Tacotron2Loss(nn.Module):
    def __init__(self):
        super(Tacotron2Loss, self).__init__()

    def forward(self, model_output, targets):
        phoneme_target, gate_target = targets[0], targets[1]
        mel_out, mel_out_postnet, gate_out, _ = model_output

        gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)

        mel_out_indices = torch.argmax(mel_out, dim=2)

        mel_out_postnet_indices = torch.argmax(mel_out_postnet, dim=2)

        phoneme_target_indices = torch.argmax(mel_out, dim=2)

        mel_loss = F.cross_entropy(mel_out, phoneme_target) + F.cross_entropy(mel_out_postnet, phoneme_target)
        return mel_loss + gate_loss

In [32]:
## loss_scaler.py

import torch

class LossScaler:

    def __init__(self, scale=1):
        self.cur_scale = scale

    # `params` is a list / generator of torch.Variable
    def has_overflow(self, params):
        return False

    # `x` is a torch.Tensor
    def _has_inf_or_nan(x):
        return False

    # `overflow` is boolean indicating whether we overflowed in gradient
    def update_scale(self, overflow):
        pass

    @property
    def loss_scale(self):
        return self.cur_scale

    def scale_gradient(self, module, grad_in, grad_out):
        return tuple(self.loss_scale * g for g in grad_in)

    def backward(self, loss):
        scaled_loss = loss*self.loss_scale
        scaled_loss.backward()

class DynamicLossScaler:

    def __init__(self,
                 init_scale=2**32,
                 scale_factor=2.,
                 scale_window=1000):
        self.cur_scale = init_scale
        self.cur_iter = 0
        self.last_overflow_iter = -1
        self.scale_factor = scale_factor
        self.scale_window = scale_window

    # `params` is a list / generator of torch.Variable
    def has_overflow(self, params):
#        return False
        for p in params:
            if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
                return True

        return False

    # `x` is a torch.Tensor
    def _has_inf_or_nan(x):
        cpu_sum = float(x.float().sum())
        if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
            return True
        return False

    # `overflow` is boolean indicating whether we overflowed in gradient
    def update_scale(self, overflow):
        if overflow:
            #self.cur_scale /= self.scale_factor
            self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
            self.last_overflow_iter = self.cur_iter
        else:
            if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
                self.cur_scale *= self.scale_factor
#        self.cur_scale = 1
        self.cur_iter += 1

    @property
    def loss_scale(self):
        return self.cur_scale

    def scale_gradient(self, module, grad_in, grad_out):
        return tuple(self.loss_scale * g for g in grad_in)

    def backward(self, loss):
        scaled_loss = loss*self.loss_scale
        scaled_loss.backward()

##### Testing

In [33]:
model = AuxiliaryPhonemesModel(hparams).cuda()

In [34]:

train_loader = DataLoader(train_dataset, num_workers=2, shuffle=True,
                          batch_size=hparams.batch_size, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)


In [35]:
spec_to_emb_linear = model.linear
encoder = model.encoder
decoder = model.decoder
postnet = model.postnet

In [36]:
def get_go_frame(memory):
    """ Gets all zeros frames to use as first decoder input
    PARAMS
    ------
    memory: decoder outputs

    RETURNS
    -------
    decoder_input: all zeros frames
    """
    B = memory.size(0)
    decoder_input = Variable(memory.data.new(
        B, hparams.vocab_size).zero_())
    return decoder_input

In [37]:
def parse_decoder_inputs(self, decoder_inputs):
    """ Prepares decoder inputs, i.e. mel outputs
    PARAMS
    ------
    decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs

    RETURNS
    -------
    inputs: processed decoder inputs

    """
    # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
    decoder_inputs = decoder_inputs.transpose(1, 2)
    decoder_inputs = decoder_inputs.view(
        decoder_inputs.size(0),
        int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
    # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
    decoder_inputs = decoder_inputs.transpose(0, 1)
    return decoder_inputs


In [38]:
prenet = Prenet(
    hparams.vocab_size,
    [hparams.prenet_dim, hparams.prenet_dim])


In [39]:
def parse_decoder_inputs(decoder_inputs):
    """ Prepares decoder inputs, i.e. mel outputs
    PARAMS
    ------
    decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs

    RETURNS
    -------
    inputs: processed decoder inputs

    """
    # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
    decoder_inputs = decoder_inputs.transpose(1, 2)
    decoder_inputs = decoder_inputs.view(
        decoder_inputs.size(0),
        int(decoder_inputs.size(1)/hparams.n_frames_per_step), -1)
    # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
    decoder_inputs = decoder_inputs.transpose(0, 1)
    return decoder_inputs


In [40]:
criterion = Tacotron2Loss()

In [41]:
for i, batch in enumerate(train_loader):
  inputs, input_lengths, target, gate_padded, output_lengths = batch

  print("Inputs Shape: ", inputs.shape)
  inputs = inputs.cuda()
  target = target.cuda()

  input_lengths, output_lengths = input_lengths.data, output_lengths.data

  embedded_inputs = spec_to_emb_linear(inputs.transpose(1,2)).transpose(1, 2)

  encoder_outputs = encoder(embedded_inputs.cuda(), input_lengths)

  print("Encoder Outputs Shape: ", encoder_outputs.shape)
  ## get_go_frame() -- get the zeros values of encoder in
  decoder_input = get_go_frame(encoder_outputs).unsqueeze(0)

  print("Decoder Input Array: ", decoder_input, decoder_input.shape)

  print("Output Shape before parsing: ", target.shape)

  decoder_inputs = parse_decoder_inputs(target)
  decoder_inputs = torch.cat((decoder_input.cuda(), decoder_inputs.cuda()), dim=0)

  print("Output Shape after parsing: ", decoder_inputs.shape)

  mel_outputs, gate_outputs, alignments = decoder(
      encoder_outputs, target, input_lengths.cuda(),output_lengths.cuda())

  mel_outputs_postnet = postnet(mel_outputs)


  loss = criterion((mel_outputs, mel_outputs_postnet, gate_outputs, alignments), (target, gate_padded.cuda()))

  print(loss.item())
  break

Inputs Shape:  torch.Size([32, 80, 999])
Encoder Outputs Shape:  torch.Size([32, 999, 128])
Decoder Input Array:  tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0') torch.Size([1, 32, 74])
Output Shape before parsing:  torch.Size([32, 74, 180])
Output Shape after parsing:  torch.Size([181, 32, 74])


  ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))


6.731008052825928


In [42]:
target_indices = torch.randint(0, 102, (1, 73))
target_indices.shape

torch.Size([1, 73])

In [43]:
mel_outputs.shape, target.shape

(torch.Size([32, 74, 180]), torch.Size([32, 74, 180]))

In [44]:
loss = nn.CrossEntropyLoss()(mel_outputs, target)
loss.item()

2.679483413696289

In [45]:
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
mel_outputs_postnet.shape

torch.Size([32, 74, 180])

In [46]:
mel_outputs.shape, gate_outputs, alignments.shape, mel_outputs_postnet.shape

(torch.Size([32, 74, 180]),
 tensor([[-0.1785, -0.2063, -0.1970,  ..., -0.2740, -0.2541, -0.2296],
         [-0.2193, -0.2056, -0.2584,  ..., -0.2476, -0.2999, -0.2636],
         [-0.0764, -0.0835, -0.1099,  ..., -0.1066, -0.1000, -0.1153],
         ...,
         [-0.0604, -0.0607, -0.0538,  ..., -0.0446, -0.0672, -0.0741],
         [-0.0148, -0.0260, -0.0252,  ..., -0.0358, -0.0763, -0.0445],
         [-0.0367, -0.0212, -0.0238,  ..., -0.0176, -0.0296, -0.0320]],
        device='cuda:0', grad_fn=<CloneBackward0>),
 torch.Size([32, 180, 999]),
 torch.Size([32, 74, 180]))

#### Training

In [47]:
model = AuxiliaryPhonemesModel(hparams).cuda()

In [48]:
train_loader = DataLoader(train_dataset, num_workers=2, shuffle=True,
                          batch_size=hparams.batch_size, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)

valid_loader = DataLoader(val_dataset, num_workers=2, shuffle=True,
                          batch_size=hparams.batch_size, pin_memory=False,
                          drop_last=True, collate_fn=collate_fn)


In [49]:
criterion = Tacotron2Loss()

In [50]:
learning_rate = hparams.learning_rate
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                              weight_decay=hparams.weight_decay)


In [51]:
model.train()


AuxiliaryPhonemesModel(
  (linear): Linear(in_features=80, out_features=128, bias=True)
  (encoder): Encoder(
    (lstm): LSTM(128, 64, batch_first=True, bidirectional=True)
  )
  (decoder): PhonemesDecoder(
    (prenet): Prenet(
      (layers): ModuleList(
        (0): LinearNorm(
          (linear_layer): Linear(in_features=74, out_features=32, bias=False)
        )
        (1): LinearNorm(
          (linear_layer): Linear(in_features=32, out_features=32, bias=False)
        )
      )
    )
    (attention_rnn): LSTMCell(160, 256)
    (attention_layer): Attention(
      (query_layer): LinearNorm(
        (linear_layer): Linear(in_features=256, out_features=128, bias=False)
      )
      (memory_layer): LinearNorm(
        (linear_layer): Linear(in_features=128, out_features=128, bias=False)
      )
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (location_layer): LocationLayer(
     

In [52]:
#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
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Token: 
Add token as git credential? (Y/n) 
Token is valid (permission: write).
[1m[31mCannot authenticate through git-credential as no helper is defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub.
Run the following command in your terminal in case you want to set the 'store

In [53]:
def train(model, train_dataloader, valid_dataloader, save_path='/content/drive/MyDrive/Dubbing Project/models'):

  criterion = Tacotron2Loss()

  learning_rate = hparams.learning_rate
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                                weight_decay=hparams.weight_decay)

  best_valid_loss = float('inf')

  wandb.init(
      project="S2ST",
      config={
        "learning_rate": hparams.learning_rate,
        "architecture": "Translatotron",
        "dataset": "Librspeech 100 Hours of data",
        "epochs": hparams.epochs,
        'batch_size': hparams.batch_size,
        "description": f"translatoron first run",
      }
  )

  for epoch in range(hparams.epochs):
      # Training
      model.train()
      total_loss = 0.0

      with tqdm(train_dataloader, desc=f'Training Epoch {epoch + 1}/{hparams.epochs}', unit='batch') as t:
        for i, batch in enumerate(t):
            inputs, input_lengths, target, gate_padded, output_lengths = batch

            batch = inputs.cuda(), input_lengths.cuda(), target.cuda(), gate_padded.cuda(), output_lengths.cuda()

            mel_outputs, mel_outputs_postnet, gate_outputs, alignments = model(batch)

            loss = criterion((mel_outputs, mel_outputs_postnet, gate_outputs, alignments), (target.cuda(), gate_padded.cuda()))

            total_loss += loss.item()
            t.set_postfix(loss=loss.item(), total_loss=total_loss / (i + 1))

            optimizer.zero_grad()

            loss.backward()

            optimizer.step()

            wandb.log({"train_loss": total_loss / (i + 1)})

            gc.collect()
            torch.cuda.empty_cache()

      # Validation
      model.eval()
      total_valid_loss = 0.0

      gc.collect()
      torch.cuda.empty_cache()

      with torch.no_grad():
          with tqdm(valid_dataloader, desc=f'Validation Epoch {epoch + 1}/{hparams.epochs}', unit='batch') as v:
            for i, batch in enumerate(v):
                inputs, input_lengths, target, gate_padded, output_lengths = batch

                batch = inputs.cuda(), input_lengths.cuda(), target.cuda(), gate_padded.cuda(), output_lengths.cuda()

                mel_outputs, mel_outputs_postnet, gate_outputs, alignments = model(batch)

                loss = criterion((mel_outputs, mel_outputs_postnet, gate_outputs, alignments), (target.cuda(), gate_padded.cuda()))
                v.set_postfix(loss=loss.item(), total_loss=total_valid_loss / (i + 1))

                total_valid_loss += loss.item()
                wandb.log({"valid_loss": total_valid_loss / (i + 1)})

                gc.collect()
                torch.cuda.empty_cache()

      # Save the model
      save_filename = os.path.join(save_path, f'phoneme_predictor_{epoch + 1}.pt')
      torch.save(model.state_dict(), save_filename)

      gc.collect()
      torch.cuda.empty_cache()



In [54]:
gc.collect()
torch.cuda.empty_cache()


In [None]:
save_path = '/content/drive/MyDrive/Dubbing Project/models'

## f42e9dfb9ecc6347595dd9aa95ce1ce04e08004d
train(model, train_loader, valid_loader, save_path)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112426744445101, max=1.0…

Training Epoch 1/7: 100%|██████████| 891/891 [1:07:59<00:00,  4.58s/batch, loss=3.61, total_loss=4.49]
Validation Epoch 1/7: 100%|██████████| 84/84 [05:58<00:00,  4.27s/batch, loss=1.7, total_loss=1.8]
Training Epoch 2/7: 100%|██████████| 891/891 [1:07:52<00:00,  4.57s/batch, loss=3.41, total_loss=3.45]
Validation Epoch 2/7: 100%|██████████| 84/84 [05:58<00:00,  4.27s/batch, loss=1.53, total_loss=1.44]
Training Epoch 3/7: 100%|██████████| 891/891 [1:08:16<00:00,  4.60s/batch, loss=3.18, total_loss=3.04]
Validation Epoch 3/7: 100%|██████████| 84/84 [06:05<00:00,  4.35s/batch, loss=1.25, total_loss=1.22]
Training Epoch 4/7: 100%|██████████| 891/891 [1:08:20<00:00,  4.60s/batch, loss=3.03, total_loss=2.83]
Validation Epoch 4/7: 100%|██████████| 84/84 [06:05<00:00,  4.35s/batch, loss=1.43, total_loss=1.15]
Training Epoch 5/7: 100%|██████████| 891/891 [1:08:22<00:00,  4.60s/batch, loss=2.93, total_loss=2.7]
Validation Epoch 5/7: 100%|██████████| 84/84 [06:00<00:00,  4.29s/batch, loss=1.32, 

In [None]:
save_filename = os.path.join(save_path, f'final_phoneme_predictor.pt')
torch.save(model.state_dict(), save_filename)

In [None]:
from google.colab import runtime
runtime.unassign()

In [None]:
mel_outputs_postnet.shape

In [None]:
mel_out_postnet_indices = torch.argmax(mel_outputs_postnet, dim=1)

phoneme_target_indices = torch.argmax(target, dim=1)


In [None]:
mel_out_postnet_indices, phoneme_target_indices

In [None]:
# for i, batch in enumerate(train_loader):
#   inputs, input_lengths, target, gate_padded, output_lengths = batch

#   batch = inputs.cuda(), input_lengths.cuda(), target.cuda(), gate_padded.cuda(), output_lengths.cuda()

#   mel_outputs, mel_outputs_postnet, gate_outputs, alignments = model(batch)

#   loss = criterion((mel_outputs, mel_outputs_postnet, gate_outputs, alignments), (target.cuda(), gate_padded.cuda()))

#   if i % 10 == 0:
#     print("Loss is: ", loss.item())

#   loss.backward()

#   optimizer.step()

### Inference

In [56]:
save_path='/content/drive/MyDrive/Dubbing Project/models'

In [57]:
model.load_state_dict(torch.load(os.path.join(save_path, f'final_phoneme_predictor.pt')))

<All keys matched successfully>

In [79]:
for i, batch in enumerate(valid_dataloader):
    inputs, input_lengths, target, gate_padded, output_lengths = batch

    batch = inputs.cuda(), input_lengths.cuda(), target.cuda(), gate_padded.cuda(), output_lengths.cuda()

    mel_outputs, mel_outputs_postnet, gate_outputs, alignments = model(batch)

    break

In [80]:
mel_out_postnet_indices = torch.argmax(mel_outputs_postnet, dim=1)

phoneme_target_indices = torch.argmax(target, dim=1)


In [81]:
binary_tensor = torch.where(gate_outputs > 0, torch.tensor(1), torch.tensor(0))

In [82]:
mel_out_postnet_indices[0]

tensor([ 3, 61,  9, 38, 46, 38, 11, 41, 47,  3, 38, 46,  3, 25, 54,  5, 11, 61,
        29, 46, 38, 54, 15, 52, 59, 62, 65, 46,  3, 25, 62, 16,  9, 37, 38, 62,
        65, 46,  3, 25, 66, 54, 41,  9, 25,  3, 37, 65, 61, 46, 61,  9, 61,  9,
        61, 61, 46, 24, 61,  9, 41,  5, 15, 24, 24, 24, 15, 61, 24, 61, 61, 54,
        61, 33, 66, 24, 54, 61, 33, 61, 11, 61, 62, 61,  3, 54],
       device='cuda:0')

In [83]:
phoneme_target_indices[0]

tensor([59, 61,  9, 35,  5, 38, 33, 41, 44,  3, 38, 46,  3, 16, 54, 15, 11, 61,
        29, 36, 30, 40, 15, 52, 38, 62, 65, 46,  3, 25, 62, 71, 11, 37, 38, 62,
        65, 46,  3, 25, 23, 30, 33, 11, 35,  3, 37,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])

In [84]:
binary_tensor[0]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')