In [4]:
import os
import yaml
import json
import torch
import numpy as np
from main import inference, train
from tsnecuda import TSNE
import random
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import IPython

In [2]:
# Data exploring methods

def visualize_speaker_emb(data_path, k_sample):
    values = []
    for speaker in os.listdir(data_path)[:50]:
        fname = f'{data_path}/{speaker}/speaker'
        fnames = random.choices([f'{fname}/{f}' for f in os.listdir(fname)], k=k_sample)
        speaker_emb = [np.load(name) for name in fnames]
        values += speaker_emb
    X_embedded = TSNE(n_components=2, perplexity=15, learning_rate=10).fit_transform(np.array(values).squeeze(1))
    speaker_idx=np.arange(len(os.listdir(data_path)[:50]))
    speaker_idx=np.repeat(speaker_idx, k_sample)
    
    tsne_df = pd.DataFrame({'X':embedded[:,1],
                        'Y':embedded[:,0],
                        'speaker':speaker_idxs})

    sns.scatterplot(x='X', y='Y', data=tsne_df,
                    hue="speaker",
                    palette='Accent')
    
    return X_embedded, speaker_idx
    


def plot_mel(mel, energy, pitch, metadata, titles=['mel','energy','pitch']):
    pitch_min = metadata['pitch']['min'] * metadata['pitch']['std'] + metadata['pitch']['mean']
    pitch_max = metadata['pitch']['max'] * metadata['pitch']['std'] + metadata['pitch']['mean']
    pitch = pitch * metadata['pitch']['std'] + metadata['pitch']['mean']
    energy_min = metadata['energy']['min']
    energy_max = metadata['energy']['max']

    data = (mel, energy, pitch)
    fig, axes = plt.subplots(len(data), 1, squeeze=False)
    
    def _add_axis(fig, old_ax):
        ax = fig.add_axes(old_ax.get_position(), anchor="W")
        ax.set_facecolor("None")
        return ax
    
    axes[0][0].imshow(mel, origin="lower")
    axes[0][0].set_aspect(2.5, adjustable="box")
    axes[0][0].set_ylim(0, mel.shape[0])
    axes[0][0].set_title('Mel', fontsize="medium")
    axes[0][0].tick_params(labelsize="x-small", left=False, labelleft=False)
    axes[0][0].set_anchor("W")  

    ax1 = _add_axis(fig, axes[0][0])
    ax1.plot(pitch, color="tomato")
    ax1.set_xlim(0, mel.shape[1])
    ax1.set_ylim(0, pitch_max)
    ax1.set_ylabel("F0", color="tomato")
    ax1.tick_params(
        labelsize="x-small", colors="tomato", bottom=False, labelbottom=False
    )

    ax2 = _add_axis(fig, axes[0][0])
    ax2.plot(energy, color="darkviolet")
    ax2.set_xlim(0, mel.shape[1])
    ax2.set_ylim(energy_min, energy_max)
    ax2.set_ylabel("Energy", color="darkviolet")
    ax2.yaxis.set_label_position("right")
    ax2.tick_params(
        labelsize="x-small",
        colors="darkviolet",
        bottom=False,
        labelbottom=False,
        left=False,
        labelleft=False,
        right=True,
        labelright=True,
    )
    return fig

In [3]:
control = {
    'duration': 1.0,
    'pitch': 1.0,
    'energy':1.0,
} 



output, wav_paths = inference(
                    text='Hi! How are you doing?', 
                    checkpoint_path='fastspeech2/checkpoints/epoch=119-step=96695.ckpt',
                    vocoder='hifi',
                    control=control
)

mel_postnet, mel_preds, pitch_preds, energy_preds, log_duration_preds = output

durations = torch.clamp(
    (torch.round(torch.exp(log_duration_preds) - 1)), min=0)
pitch_preds = torch.repeat_interleave(pitch_preds, durations.long(), dim=0)
energy_preds = torch.repeat_interleave(energy_preds, durations.long(), dim=0)

plot_mel(mel_postnet, energy_preds, pitch_preds)

Raw Text Sequence: Hi! How are you doing
Phoneme Sequence: {HH AY1 sp HH AW1 AA1 R Y UW1 D UW1 IH0 NG}
Converting Melspectrogram to wav file


  torch.tensor(mel, device=device))
100%|████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2054.02it/s]

TTS generation complete! outputs saved to output_wavs





NameError: name 'duration' is not defined