This notebook is to test the vocoder and to find a nice voice for the assistant.

In [1]:
import torch
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset
from audio_helper import play_audio

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
_tts_model.to(device)

_tts_vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
_tts_vocoder.to(device)

_tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")

message = "Ricardo, how are you doing today?"

sampling_rate = 16000

# TODO replace with simpler approach
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")


In [2]:
len(embeddings_dataset)

7931

In [3]:
_speaker_embeddings = torch.tensor(embeddings_dataset[1]["xvector"]).unsqueeze(0).to(device)

inputs = _tts_processor(text=message, return_tensors="pt").to(device)
speech = _tts_model.generate_speech(inputs["input_ids"], _speaker_embeddings, vocoder=_tts_vocoder)
play_audio(speech.to('cpu'), sampling_rate)


I could make a cycle through all embeddings, but let's instead create a grid with play buttons

In [18]:
from IPython.display import display
import ipywidgets as widgets

PAGE_SIZE = 12  # Number of buttons to load at once
GRID_COLUMNS = 4  # Number of columns in the grid

# Example generator for audio data
def audio_data_generator():
    for i in range(len(embeddings_dataset)):  # Example range, adjust as needed
        _speaker_embeddings = torch.tensor(embeddings_dataset[i]["xvector"]).unsqueeze(0).to(device)
        inputs = _tts_processor(text=message, return_tensors="pt").to(device)
        speech = _tts_model.generate_speech(inputs["input_ids"], _speaker_embeddings, vocoder=_tts_vocoder)
        yield f'button_{i}', speech.to('cpu')

audio_gen = audio_data_generator()
audio_data_dict = {}

# Function to handle button click
def on_button_click(b):
    audio_data = audio_data_dict[b.description]
    play_audio(audio_data, sampling_rate)

# Function to load more buttons in a grid layout
def load_more_buttons(container, n=PAGE_SIZE):
    new_buttons = []
    for _ in range(n):
        try:
            label, audio_data = next(audio_gen)
            audio_data_dict[label] = audio_data
            button = widgets.Button(description=label)
            button.on_click(on_button_click)
            new_buttons.append(button)
        except StopIteration:
            break
    
    # Arrange buttons in a grid
    rows = []
    for i in range(0, len(new_buttons), GRID_COLUMNS):
        row = new_buttons[i:i+GRID_COLUMNS]
        rows.append(widgets.HBox(row))
    
    container.children = tuple(list(container.children) + rows)

# Create the initial set of buttons
button_container = widgets.VBox()
load_more_buttons(button_container)

# Create a scrollable container
scrollable_box = widgets.Box([button_container], layout=widgets.Layout(
    overflow_y='auto',
    border='1px solid black',
    width='500px',  # Adjust width as needed
    height='300px',
    display='block'
))

# Load more button
load_more_button = widgets.Button(description="Load More")
def load_more_on_click(b):
    load_more_buttons(button_container)

load_more_button.on_click(load_more_on_click)

# Display the UI
display(scrollable_box)
display(load_more_button)

Box(children=(VBox(children=(HBox(children=(Button(description='button_0', style=ButtonStyle()), Button(descri…

Button(description='Load More', style=ButtonStyle())

I thought this would be easier, let's find it in the dataset. From the description:
    bdl (US male)
    slt (US female)
    jmk (Canadian male)
    awb (Scottish male)
    rms (US male)
    clb (US female)
    ksp (Indian male)

Let's find the Scottish male.

In [23]:
embeddings_dataset['filename'][1500:2000]

['cmu_us_bdl_arctic-wav-arctic_a0363',
 'cmu_us_bdl_arctic-wav-arctic_a0364',
 'cmu_us_bdl_arctic-wav-arctic_a0365',
 'cmu_us_bdl_arctic-wav-arctic_a0366',
 'cmu_us_bdl_arctic-wav-arctic_a0367',
 'cmu_us_bdl_arctic-wav-arctic_a0368',
 'cmu_us_bdl_arctic-wav-arctic_a0369',
 'cmu_us_bdl_arctic-wav-arctic_a0370',
 'cmu_us_bdl_arctic-wav-arctic_a0371',
 'cmu_us_bdl_arctic-wav-arctic_a0372',
 'cmu_us_bdl_arctic-wav-arctic_a0373',
 'cmu_us_bdl_arctic-wav-arctic_a0374',
 'cmu_us_bdl_arctic-wav-arctic_a0375',
 'cmu_us_bdl_arctic-wav-arctic_a0376',
 'cmu_us_bdl_arctic-wav-arctic_a0377',
 'cmu_us_bdl_arctic-wav-arctic_a0378',
 'cmu_us_bdl_arctic-wav-arctic_a0379',
 'cmu_us_bdl_arctic-wav-arctic_a0380',
 'cmu_us_bdl_arctic-wav-arctic_a0381',
 'cmu_us_bdl_arctic-wav-arctic_a0382',
 'cmu_us_bdl_arctic-wav-arctic_a0383',
 'cmu_us_bdl_arctic-wav-arctic_a0384',
 'cmu_us_bdl_arctic-wav-arctic_a0385',
 'cmu_us_bdl_arctic-wav-arctic_a0386',
 'cmu_us_bdl_arctic-wav-arctic_a0387',
 'cmu_us_bdl_arctic-wav-a

In [28]:
indices = []
indices = [i for i, s in enumerate(embeddings_dataset['filename']) if "awb" in s]
print(f"awb indices min: {indices[0]}, max: {indices[-1]}")

awb indices min: 0, max: 1137


Ah, I was already using the Scottish voice, let's check the others

In [29]:
substrings = ['bdl', 'slt', 'jmk', 'awb', 'rms', 'clb', 'ksp']
for subs in substrings:  # not optimized but no matter
    indices = [i for i, s in enumerate(embeddings_dataset['filename']) if subs in s]
    print(f"{subs} indices min: {indices[0]}, max: {indices[-1]}")

bdl indices min: 1138, max: 2270
slt indices min: 6799, max: 7930
jmk indices min: 3403, max: 4534
awb indices min: 0, max: 1137
rms indices min: 5667, max: 6798
clb indices min: 2271, max: 3402
ksp indices min: 4535, max: 5666


In [None]:
for i in range(2271, 3402):
    message = f"How you doing {i}?"
    _speaker_embeddings = torch.tensor(embeddings_dataset[i]["xvector"]).unsqueeze(0).to(device)
    # 3407 is nice

    inputs = _tts_processor(text=message, return_tensors="pt").to(device)
    speech = _tts_model.generate_speech(inputs["input_ids"], _speaker_embeddings, vocoder=_tts_vocoder)
    play_audio(speech.to('cpu'), sampling_rate)


All voices are somewhat consistent. I will keep on using the US Female.