In [16]:
%matplotlib inline
import IPython.display
import os
import Tacotron_model.taco_model as M
import torch
from torch.utils.data import DataLoader
from Tacotron_model.util import collate_fn
from Tacotron_model.util import text_to_sequence, wav_to_spectrogram, sequence_to_text
from dataset import TextToSpeechDataset
from torch.autograd import Variable
from tqdm import tqdm
import torch.nn.functional as F
from Tacotron_model.visualize import show_spectrogram, show_attention
from Tacotron_model.griffinlim import TacotronSTFT
from Tacotron_model.griffinlim import griffin_lim

### Text embedding 


In [17]:
eos = '~'
pad = '_'
chars = pad + 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'\"(),-.:;? ' + eos
char_to_id = {char: i for i, char in enumerate(chars)}
def text_to_sequence(text, eos=eos):
    text += eos
    return [char_to_id[char] for char in text]

In [18]:
num_chars = len(chars)
teacher_forcing_ratio = 0.5


In [19]:
def inference_text(text):
    text = text
    embedded_text = text
    text_for_embeddings = text
        #text = text.values  # convert pandas dataframe to np array
    embedded_text = text_to_sequence(embedded_text)
            #print(embedded_text)

    sample = {'text': text,
            'embedded_text': embedded_text,
            }
    
        
    return sample

In [20]:
text = 'Hey there, this is inference of our Text-to-Speech model'

In [21]:
sample = inference_text(text)

In [22]:
text, embedded_text = sample['text'], sample['embedded_text']
print(embedded_text)

[8, 31, 51, 64, 46, 34, 31, 44, 31, 58, 64, 46, 34, 35, 45, 64, 35, 45, 64, 35, 40, 32, 31, 44, 31, 40, 29, 31, 64, 41, 32, 64, 41, 47, 44, 64, 20, 31, 50, 46, 59, 46, 41, 59, 19, 42, 31, 31, 29, 34, 64, 39, 41, 30, 31, 38, 65]


### Loading from Checkpoint

In [23]:
checkpoint = '/home/rajaniep/code/UntitledFolder/project/src/log/tacotron/Training-naive/last.pth.tar'

In [24]:
def load_checkpoint(checkpoint):
    if not os.path.exists(checkpoint):
        raise("File doesn't exist {}".format(checkpoint))
    state = torch.load(checkpoint)

    return state

In [25]:
model = M.MelSpectrogramNet(num_chars, teacher_forcing_ratio)


  "num_layers={}".format(dropout, num_layers))


In [26]:
print("Loading model from checkpoint at ", checkpoint)
model_state_dict = load_checkpoint(checkpoint)
print(model_state_dict.keys())
state_dict = model_state_dict['state_dict']
        
model.load_state_dict(state_dict)


Loading model from checkpoint at  /home/rajaniep/code/UntitledFolder/project/src/log/tacotron/Training-naive/last.pth.tar
dict_keys(['step', 'state_dict', 'optim_dict'])


In [27]:
sequence = torch.autograd.Variable(torch.tensor(embedded_text)).long()
print(sequence)

tensor([  8,  31,  51,  64,  46,  34,  31,  44,  31,  58,  64,  46,
         34,  35,  45,  64,  35,  45,  64,  35,  40,  32,  31,  44,
         31,  40,  29,  31,  64,  41,  32,  64,  41,  47,  44,  64,
         20,  31,  50,  46,  59,  46,  41,  59,  19,  42,  31,  31,
         29,  34,  64,  39,  41,  30,  31,  38,  65])


### Getting the Dataset of One book

In [28]:
PATH = '/home/rajaniep/code/UntitledFolder/project/en_US/by_book/female/judy_bieber/dorothy_and_wizard_oz'
dataset = TextToSpeechDataset(path = PATH,
                                  text_embeddings=text_to_sequence,
                                  mel_transforms=wav_to_spectrogram)

data at path /home/rajaniep/code/UntitledFolder/project/en_US/by_book/female/judy_bieber/dorothy_and_wizard_oz


In [29]:
dataset_train_parts = {}
for i in range(100):
        dataset_train_parts[i] = (dataset[i]['embedded_text'], dataset[i]['mel_spectograms'])
dl_train = DataLoader(dataset_train_parts,batch_size =2,collate_fn= collate_fn)
#print(torch.LongTensor(dl_train.dataset[0]))



### Evaluating model in inference mode

In [30]:
import scipy
import shutil

import numpy as np
import librosa
from librosa import display
from optparse import OptionParser
from matplotlib import pyplot as plt

In [31]:
def spec2wav(mag, n_fft = 512, win_length = 158, hop_length = 128, num_iters=30, phase=None):
   
    assert (num_iters > 0)
    if phase is None:
        phase = np.pi * np.random.rand(*mag.shape)
    stft = mag * np.exp(1.j * phase)
    wav = None
    for i in range(num_iters):
        wav = librosa.istft(stft, win_length=win_length, hop_length=hop_length)
        print(stft.shape)
        if i != num_iters - 1:
            stft = librosa.stft(wav, n_fft=n_fft, win_length=win_length, hop_length=hop_length)
            _, phase = librosa.magphase(stft)
            phase = np.angle(phase)
            stft = mag * np.exp(1.j * phase)
    return wav

In [32]:
import torch

from Tacotron_model.taco_model import MelSpectrogramNet
from graphviz import Digraph
from torch.autograd import Variable


# make_dot was moved to https://github.com/szagoruyko/pytorchviz
from torchviz import make_dot, make_dot_from_trace

In [34]:
val_data = dl_train
model.eval()
step = 0
for epoch in range(2):
    total_loss = 0
    pbar = tqdm(val_data, total=len(val_data), unit=' batches')
    for b, (text_batch, audio_batch, text_lengths, audio_lengths) in enumerate(pbar):
        text = Variable(text_batch)
        #print(text)
        targets = Variable(audio_batch, requires_grad=False)
        #print(targets)

            #  create stop targets
        stop_targets = torch.zeros(targets.size(1), targets.size(0))
        for i in range(len(stop_targets)):
            stop_targets[i, audio_lengths[i] - 1] = 1
        stop_targets = Variable(stop_targets)
        outputs, stop_tokens, attention = model(text,targets,teacher_forcing_ratio=1.0)
        g =make_dot( outputs, params=dict(list(model.named_parameters()) + [('text', text)]))
        g.view()
        
        spec_loss = F.mse_loss(outputs, targets)
        stop_loss = F.binary_cross_entropy_with_logits(stop_tokens, stop_targets)
        loss = spec_loss + stop_loss
        print(loss)
        total_loss += loss.data[0]
        break     

        # plot the first sample in the batch
                
        #output_plot = show_spectrogram(outputs.data.permute(1, 2, 0)[0],
        #                                       sequence_to_text(text.data[0]),
        #                                       return_array=True)
        #target_plot = show_spectrogram(targets.data.permute(1, 2, 0)[0],
        #                                       sequence_to_text(text.data[0]),
        #                                       return_array=True)
        
        #print(outputs.data.permute(1, 2, 0)[0].shape)
        #spec2wav(outputs.data.permute(1, 2, 0)[0].numpy(), n_fft = 512, win_length = 158, hop_length = 128, num_iters=30)
       
               
        step += 1
        print("Step:", step ," Loss: ", total_loss)





  0%|          | 0/50 [00:00<?, ? batches/s][A[A

torch.Size([2, 169])
torch.Size([169, 2, 256])


KeyboardInterrupt: 

In [None]:
from graphviz import Source
Source.from_file('Digraph.gv')