This notebook is to test attention performance of a TTS model on a list of hard sentences.

### Features of this notebook
- You can see visually how your model performs on each sentence and try to dicern common problems.
- At the end, final attention score would be printed showing the ultimate performace of your model. You can use this value to perform model selection.
- You can change the list of sentences, e.g. using the full list from https://openreview.net/forum?id=BJeFQ0NtPS.

In [None]:
%load_ext autoreload
%autoreload 2
import os
from pathlib import Path

import IPython
import numpy as np
import torch
from matplotlib import pyplot as plt

from TTS.api import TTS
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.manage import ModelManager

%matplotlib inline
plt.rcParams["figure.figsize"] = (16, 5)

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
SENTENCES = [
    "Hurry.",
    "Warehouse.",
    "Allergic trouser.",
    "I want more detailed information.",
    "Abstraction is often one floor above you.",
    "Any climbing dish listens to a cumbersome formula.",
    "Nineteen twenty is when we are unique together until we realise we are all the same.",
    "If the Easter bunny and the tooth fairy had babies would they take your teeth and leave chocolate for you?",
]

In [None]:
manager = ModelManager()
MODEL_PATH, CONFIG_PATH, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DDC")
device = "cuda" if torch.cuda.is_available() else "cpu"
api = TTS(model_path=MODEL_PATH, config_path=CONFIG_PATH).to(device)
OUT_PATH = Path(os.environ.get("NB_OUTPUT_DIR", ".")) / "test_attention"
OUT_PATH.mkdir(parents=True, exist_ok=True)

In [None]:
attn_scores = []
for i, sentence in enumerate(SENTENCES):
    print(f" > {sentence}")
    outputs = api.synthesizer.tts_model.synthesize(sentence, use_griffin_lim=True)
    attn_scores.append(alignment_diagonal_score(outputs["alignments"]))
    IPython.display.display(IPython.display.Audio(outputs["wav"], rate=api.synthesizer.output_sample_rate))
    fig = plot_alignment(outputs["alignments"], fig_size=(8, 5))
    IPython.display.display(fig)
    file_name = OUT_PATH / (f"{i:>02}_" + sentence[:200].replace(" ", "_").replace(".", "").replace("?", "") + ".wav")
    api.synthesizer.save_wav(outputs["wav"], file_name)

In [None]:
np.mean(attn_scores)