In [None]:
from nemo.collections.tts.models import T5TTS_Model
from nemo.collections.tts.data.text_to_speech_dataset import T5TTSDataset, DatasetSample
from omegaconf.omegaconf import OmegaConf, open_dict
import torch
import os
import soundfile as sf
from IPython.display import display, Audio
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
hparams_file = "/data/t5_new_cp/configs/unnormalizedLalign005_singleencoder_kernel3_hparams.yaml"
checkpoint_file = "/data/t5_new_cp/checkpoints/unnormalizedLalign005_singleencoder_kernel3_epoch_20.ckpt" #"/datap/misc/continuouscheckpoints/edresson_epoch21.ckpt"
out_dir = "/datap/misc/t5tts_inference_notebook_samples"
#codecmodel_path = "/datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo"
codecmodel_path = "/data/codec_checkpoints/codecs-no-eliz/AudioCodec_21Hz_no_eliz_without_wavlm_disc.nemo"

In [None]:
model_cfg = OmegaConf.load(hparams_file).cfg

with open_dict(model_cfg):
    model_cfg.codecmodel_path = codecmodel_path
    if hasattr(model_cfg, 'text_tokenizer'):
        # Backward compatibility for models trained with absolute paths in text_tokenizer
        model_cfg.text_tokenizer.g2p.phoneme_dict = "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt"
        model_cfg.text_tokenizer.g2p.heteronyms = "scripts/tts_dataset_files/heteronyms-052722"
        model_cfg.text_tokenizer.g2p.phoneme_probability = 1.0
    model_cfg.train_ds = None
    model_cfg.validation_ds = None


model = T5TTS_Model(cfg=model_cfg)
# Load weights from checkpoint file
print("Loading weights from checkpoint")
ckpt = torch.load(checkpoint_file)
model.load_state_dict(ckpt['state_dict'])
print("Loaded weights.")

if model_cfg.t5_decoder.pos_emb == "learnable":
    if (model_cfg.t5_decoder.use_flash_self_attention) is False and (model_cfg.t5_decoder.use_flash_self_attention is False):
        print("Using kv cache for inference.")
        model.use_kv_cache_for_inference = True

model.cuda()
model.eval()

In [None]:
test_dataset = T5TTSDataset(
    dataset_meta={},
    sample_rate=model_cfg.sample_rate,
    min_duration=0.5,
    max_duration=20,
    codec_model_downsample_factor=model_cfg.codec_model_downsample_factor,
    bos_id=model.bos_id,
    eos_id=model.eos_id,
    context_audio_bos_id=model.context_audio_bos_id,
    context_audio_eos_id=model.context_audio_eos_id,
    audio_bos_id=model.audio_bos_id,
    audio_eos_id=model.audio_eos_id,
    num_audio_codebooks=model_cfg.num_audio_codebooks,
    prior_scaling_factor=None,
    load_cached_codes_if_available=True,
    dataset_type='test',
    tokenizer_config=None,
    load_16khz_audio=model.model_type == 'single_encoder_sv_tts',
    use_text_conditioning_tokenizer=model.use_text_conditioning_encoder,
    pad_context_text_to_max_duration=model.pad_context_text_to_max_duration,
    context_duration_min=model.cfg.get('context_duration_min', 5.0),
    context_duration_max=model.cfg.get('context_duration_max', 5.0),
)
test_dataset.text_tokenizer, test_dataset.text_conditioning_tokenizer = model._setup_tokenizers(model.cfg, mode='test')

In [None]:
do_jhsd = True
if not do_jhsd:
    audio_dir =  "/datap/misc/RodneyLindy44" #"/datap/misc/Datasets/riva"
    entry = {
        #"audio_filepath": "Lindy/22khz/CMU_HAPPY/LINDY_CMU_HAPPY_000567.wav",
        "audio_filepath": "Lindy/44khz/CMU_HAPPY/LINDY_CMU_HAPPY_000567.wav",
        "duration": 6.275604,
        "text": "This is a very long sentence to test the speed of my text to speech synthesis model, let's see how long it takes to generate the audio. Adding more words to make it even longer. Haha.",
        "speaker": "dummy",
    #     "context_text": "Speaker and Emotion: | Language:en Dataset:Riva Speaker:Lindy_CMU_FEARFUL |",
        "context_audio_filepath": "Rodney/44khz/DROP/RODNEY_DROP_000060.wav", #"Rodney/22khz/DROP/RODNEY_DROP_000060.wav",
        "context_audio_duration": 8.0,
    }
else:
    # Jensen data
    audio_dir =  "/data/NV-RESTRICTED/JHSD/22khz"
    entry = {
        #"audio_filepath": "Lindy/22khz/CMU_HAPPY/LINDY_CMU_HAPPY_000567.wav",
        "audio_filepath": "AMP20_KEYNOTE-VOOnly-44khz-16bit-mono_17.wav",
        "duration": 4.84,
        "text": "This is a very long sentence to test the speed of my text to speech synthesis model, let's see how long it takes to generate the audio. Adding more words to make it even longer. Haha.",
        "speaker": "dummy",
    #     "context_text": "Speaker and Emotion: | Language:en Dataset:Riva Speaker:Lindy_CMU_FEARFUL |",
        #"context_audio_filepath": "GTC20_FALL_KEYNOTE-VOOnly-44khz-16bit-mono_221.wav",
        "context_audio_filepath": "AMP20_KEYNOTE-VOOnly-44khz-16bit-mono_6.wav", #"Rodney/22khz/DROP/RODNEY_DROP_000060.wav",
        #"context_audio_duration": 5.35,# 8.02,
        "context_audio_duration": 8.02,
    }
data_sample = DatasetSample(
    dataset_name="sample",
    manifest_entry=entry,
    audio_dir=audio_dir,
    feature_dir=audio_dir,
    text=entry['text'],
    speaker=None,
    speaker_index=0
)
test_dataset.data_samples = [data_sample]

test_data_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,
    collate_fn=test_dataset.collate_fn,
    num_workers=0,
    shuffle=False
)


In [None]:
item_idx = 0
for bidx, batch in enumerate(test_data_loader):
    print("Processing batch {} out of {}".format(bidx, len(test_data_loader)))
    model.t5_decoder.reset_cache(use_cache=True)
    batch_cuda ={}
    for key in batch:
        if isinstance(batch[key], torch.Tensor):
            batch_cuda[key] = batch[key].cuda()
        else:
            batch_cuda[key] = batch[key]
    import time
    st = time.time()
    predicted_audio, predicted_audio_lens, _, _ = model.infer_batch(batch_cuda, max_decoder_steps=500, temperature=0.5, topk=80)
    print("generation time", time.time() - st)
    for idx in range(predicted_audio.size(0)):
        predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy()
        predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]]
        audio_path = os.path.join(out_dir, f"predicted_audio_{item_idx}.wav")
        sf.write(audio_path, predicted_audio_np, model.cfg.sample_rate)
        display(Audio(audio_path))
        item_idx += 1

In [None]:
context_filepath = os.path.join(audio_dir, entry['context_audio_filepath'])
display(Audio(context_filepath))
