# Speech synthesis with text predicted GST
In this notebook we show how to use the pretrained models. We use the best performing model

This notebook has been partially taken from a notebook in the original repository on Tacotron 2 (with DDC) synthesis
with GST, Speaker embeddings and vocoder ([link](https://colab.research.google.com/drive/1t0TFC3vqU1nFow5p5FTPjtkT6rFJOSsB?usp=sharing))
and a notebook for speaker controlled speech synthesis
([link](https://github.com/vincenzo-scotti/ITAcotron_2/blob/ITAcotron2/notebooks/ITAcotron-2_synthesis.ipynb)).

Additional resources
[link 1](https://colab.research.google.com/drive/1Gtt9EV1fFzuKbOdqUrLuAMuxBaot5v4F?usp=sharing),
[link 2](https://colab.research.google.com/drive/1-xI9HiG5B-APnwe7KQwtOBPp1gggg-jZ?usp=sharing).


## Environment

### Imports

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
import IPython
from IPython.display import Audio

import random

import numpy as np
import torch

import json

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM

from TTS.utils.audio import AudioProcessor
from TTS.tts.utils.text.symbols import symbols, phonemes, make_symbols
from TTS.utils.io import load_config  # Config class loader
from TTS.tts.utils.generic_utils import setup_model  # TTS model setup
from TTS.tts.utils.io import load_checkpoint  # Model checkpoint loader
from TTS.vocoder.utils.generic_utils import setup_generator  # Vocoder model setup
from TTS.tts.utils.synthesis import synthesis  # Main wrapper for speech synthesis
from pathlib import Path
from encoder import inference as speaker_encoder_model
from encoder.params_model import model_embedding_size as speaker_embedding_size

# from model import ...
# from data import ...

from typing import Optional, Union, ByteString, List, Dict, Callable
from TTS.utils.io import AttrDict
from TTS.tts.models.tacotron2 import Tacotron2
from TTS.vocoder.models.fullband_melgan_generator import FullbandMelganGenerator

### Paths

In [None]:
lm_model_checkpoint_path = tokeniser_checkpoint_path = 'microsoft/DialoGPT-large'

tts_model_configs_path = './resources/models/tts/config.json'
tts_model_checkpoint_path = './resources/models/tts/best_model.pth.tar'
tts_model_speaker_file = './resources/models/tts/speakers.json'
tts_model_scale_stats_path = None

speaker_encoder_model_checkpoint_path = './resources/models/speaker_encoder/pretrained.pt'

vocoder_model_config_path = './resources/models/vocoder/config.json'
vocoder_model_checkpoint_path = './resources/models/vocoder/best_model.pth.tar'
vocoder_model_scale_stats_path = './resources/models/vocoder/scale_stats.npy'

gstt_model_checkpoint_path = './resources/models/gstt/'


### Constants

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Set random seed for reproducibility.

In [None]:
random_seed = 2307

random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)

### Functions

Generate response given context.

In [None]:
def generate_response(
        context: List[str],
        dialogue_lm: transformers.GPT2LMHeadModel,
        tokeniser: transformers.GPT2Tokenizer,
        generate_kwargs: Optional[Dict] = None
) -> str:
    ...

Speaker embedding computation.

In [None]:
def extract_speaker_embedding(
        reference_audio_path: Union[str, List[str]],
        audio_loading_fn: Callable,
        speaker_embedding_fn: Callable
) -> List[float]:
    if isinstance(reference_audio_path, str):
        return extract_speaker_embedding([reference_audio_path], audio_loading_fn, speaker_embedding_fn)

    speaker_embedding: List[float] = np.vstack([
        speaker_embedding_fn(audio_loading_fn(audio_file_path)) for audio_file_path in reference_audio_path
    ]).mean(axis=0).tolist()

    return speaker_embedding


GST extraction from reference audio.

In [None]:
def extract_gst(
        reference_audio_path:str,
        gst_embedding_model: torch.nn.Module
) -> Dict[str, float]:
    ...
    # Run GST extraction
    gst = ...

    gst = {str(i): value for i, value in enumerate(gst)}

    # return gst

GST prediction from text:

In [None]:
def predict_gst(
        context: Optional[List[str]] = None,
        response: Optional[str] = None,
        contextual_embeddings: Optional[torch.tensor] = None,
) -> Dict[str, float]:
    # Input consistency check

    # Extract input contextual embeddings if not provided as input
    if contextual_embeddings is None:
        contextual_embeddings = ...

    # Run GST prediction
    gst = ...

    # Convert to desired data format
    gst = {str(i): value for i, value in enumerate(gst.cpu().squeeze().tolist())}

    return gst


Text-to-Speech function, given the text and, possibly, the reference speaker embeddings and GST and it generates the audio waveform.

In [None]:
def tts(
        text: str,
        tts_model: Tacotron2,
        tts_configs: AttrDict,
        tts_ap: Optional[AudioProcessor] = None,
        audio_loading_fn: Optional[Callable] = None,
        speaker_embedding_fn: Optional[Callable] = None,
        vocoder_model: Optional[FullbandMelganGenerator] = None,
        vocoder_configs: Optional[AttrDict] = None,
        vocoder_ap: Optional[AudioProcessor] = None,
        speaker_reference_clip_path: Optional[Union[List[str], str]] = None,
        speaker_embeddings: Optional[List[float]] = None,
        gst_reference_clip_path: Optional[str] = None,
        gst_style: Optional[Dict[str, float]] = None
) -> np.ndarray:
    # Input consistency check
    assert speaker_reference_clip_path is None or speaker_embeddings is None
    assert gst_reference_clip_path is None or gst_style is None
    # Helper function for the Vocoder issue
    def interpolate_vocoder_input(scale_factor, spec):
        spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0)
        spec = torch.nn.functional.interpolate(spec, scale_factor=scale_factor, mode='bilinear').squeeze(0)
        return spec

    # Prepare inputs for actual synthesis
    use_cuda = device.type == 'cuda'
    use_gl = vocoder_model is None
    if speaker_reference_clip_path is not None:
        speaker_embeddings = extract_speaker_embedding(
            speaker_reference_clip_path, audio_loading_fn, speaker_embedding_fn
        )
    if gst_reference_clip_path is not None:
        gst_style = gst_reference_clip_path

    # Audio synthesis step
    waveform, _, _, mel_postnet_spec, _, _ = synthesis(
        tts_model,
        text,
        tts_configs,
        use_cuda,
        tts_ap,
        style_wav=gst_style,
        use_griffin_lim=use_gl,
        speaker_embedding=speaker_embeddings
    )

    # Postprocessing
    # if tts_configs.model == "Tacotron" and not use_gl:
    #     mel_postnet_spec = tts_model.ap.out_linear_to_mel(mel_postnet_spec.T).T
    if not use_gl:
        mel_postnet_spec = tts_ap._denormalize(mel_postnet_spec.T).T
        vocoder_input = vocoder_ap._normalize(mel_postnet_spec.T)

        output_scale_factor = vocoder_configs.audio.sample_rate / tts_configs.audio.sample_rate
        if output_scale_factor != 1.:
            scale_factor = [1., output_scale_factor]
            vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
        else:
            vocoder_input = torch.FloatTensor(vocoder_input).unsqueeze(0)
        waveform = vocoder_model.inference(vocoder_input)
    if use_cuda and not use_gl:
        waveform = waveform.cpu()
    if not use_gl:
        waveform = waveform.numpy()
    waveform = waveform.squeeze()

    return waveform

Dialogue printing.

In [None]:
def print_dialogue(context: List[str], response: str):
    text = 'Context:\n' + '\t'.join(f'- {line}\n' for line in context) + '\n' + f'Response:\n\t- {response}'
    print(text)

Generate playable audio from waveform inside a Jupyter notebook.

In [None]:
def play_audio(waveform: Union[np.ndarray, List[float], str, ByteString], sr: Optional[int] = None):
    IPython.display.display(IPython.display.Audio(data=waveform, rate=sr))

## Models

### Text generation
We load the model for text generation to generate responses and the corresponding tokeniser to encode the input and decode the output.

In [None]:
tokenizer: transformers.GPT2Tokenizer = AutoTokenizer.from_pretrained(lm_model_checkpoint_path)
text_model: transformers.GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(tokeniser_checkpoint_path)
text_model.to(device)
text_model.eval()

### Speech synthesis

#### Spectrogram predictor

In [None]:
tts_configs: AttrDict = load_config(tts_model_configs_path)
tts_configs.forward_attn_mask = True
if 'characters' in tts_configs.keys():
    symbols, phonemes = make_symbols(**tts_configs.characters)
n_chars = len(phonemes) if tts_configs.use_phonemes else len(symbols)
tts_configs.audio['stats_path'] = tts_model_scale_stats_path

tts_ap = AudioProcessor(**tts_configs.audio)

if tts_configs.use_external_speaker_embedding_file:
    speaker_mapping = json.load(open(tts_model_speaker_file, 'r'))
    n_speakers = len(speaker_mapping)
    speaker_file_id = list(speaker_mapping.keys())[random.choice(range(n_speakers))]  # FIXME Select random speaker
    speaker_embedding = speaker_mapping[speaker_file_id]['embedding']
else:
    n_speakers = 0  # FIXME this will cause a loading issue

tts_model: Tacotron2 = setup_model(n_chars, n_speakers, tts_configs, speaker_embedding_dim=speaker_embedding_size)
tts_model, _ = load_checkpoint(tts_model, tts_model_checkpoint_path, use_cuda=torch.cuda.is_available())
# tts_model.load_state_dict(torch.load(tts_model_checkpoint_path, map_location=torch.device('cpu'))['model'])
# tts_model.decoder.set_r(torch.load(tts_model_checkpoint_path, map_location=torch.device('cpu'))['r'])  #TODO see if correct/useful
tts_model.to(device)
tts_model.eval()

#### Global Style Token
We load the GST estimator to extract the style from a reference audio.
Actually the GST estimator is already loaded as part of the Tacotron 2 spectrogram predictor

In [None]:
gst_model:  =

#### Speaker encoder

In [None]:
speaker_encoder_model.load_model(Path(speaker_encoder_model_checkpoint_path), device=device.type)
# speaker_encoder_model.preprocess_wav
# speaker_encoder_model.embed_utterance

#### Vocoder

In [None]:
vocoder_configs: AttrDict = load_config(vocoder_model_config_path)
vocoder_configs.audio['stats_path'] = vocoder_model_scale_stats_path

vocoder_ap = AudioProcessor(**vocoder_configs.audio)

vocoder_model: FullbandMelganGenerator = setup_generator(vocoder_configs)
vocoder_model.load_state_dict(torch.load(vocoder_model_checkpoint_path, map_location=torch.device('cpu'))['model'])
vocoder_model.remove_weight_norm()
vocoder_model.inference_padding = 0
vocoder_model.to(device)
vocoder_model.eval()

### GST prediction

In [None]:
gstt_model = ...
gstt_model.to(device)
gstt_model.eval()

## Data

### IEMOCAP
We load the dialogues and audio file paths of IEMOCAP to provide some examples and randomly select one sample.

In [None]:
data_set = IEMOCAP()

### Custom example

## Examples

### IEMOCAP

Get a random sample from IEMOCAP

In [None]:
context, response, audio_file_path = data_set[random.randint(0, len(data_set) - 1)]

#### Original

Show original dialogue and play original response

In [None]:
print_dialogue(context, response)
play_audio()

Generate audio using original GST and original response

In [None]:
original_response_audio = tts()

Show original dialogue and play synthesised response with predicted GST

#### Predicted GST

Encode dialogue context and response into contextual embeddings

Predict the GST from the contextual embeddings

Generate audio using predicted GST and original response

Show original dialogue and play synthesised response with predicted GST

In [None]:
print_dialogue(context, response)
play_audio()


#### Predicted response

Generate a response using the LM (return also sequence of contextual embeddings)

Generate audio using original GST and predicted response

Show dialogue with generated response and play synthesised response with original GST

In [None]:
print_dialogue(context, )
play_audio()

#### Predicted GST and response

Generate a response using the LM (return also sequence of contextual embeddings).
(Generation is done only if previous step was skipped)

Predict the GST from the contextual embeddings

Generate audio using predicted GST and predicted response

Show dialogue with generated response and play synthesised response with predicted GST

In [None]:
print_dialogue(context, )
play_audio()


#### Predicted VS. original GST visualisation
Visualisation of the two GSTs on the same (original) response

In [None]:
figure, axes = plt.subplots(nrows=1, ncols=2, figsize=(w, h), sharex=True, sharey=True)
axes = axes.flatten()

ax = axes[0]
...

ax = axes[1]
...

plt.show()

### Context
Gather input dialogue context inserting one line at the time.
Enter an empty line to stop

In [None]:
context = []
while turn := input('>>> '):
    context.append(turn)

### Generate response

Generate a response using the LM (return also sequence of contextual embeddings).

### Audio synthesis

Predict the GST from the contextual embeddings

Generate audio using predicted GST and predicted response

In [None]:
waveform = tts()

Show dialogue with generated response and play synthesised response with predicted GST

In [None]:
print_dialogue(context, )
play_audio()