In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import datetime as dt
from pathlib import Path

import IPython.display as ipd
import numpy as np
import soundfile as sf
import torch
from tqdm.auto import tqdm

# Hifigan imports
from matcha.hifigan.config import v1
from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.env import AttrDict
from matcha.hifigan.models import Generator as HiFiGAN
# Matcha imports
from matcha.models.matcha_tts import MatchaTTS
from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.model import denormalize
from matcha.utils.utils import get_user_data_dir, intersperse

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
# This allows for real time code changes being reflected in the notebook, no need to restart the kernel

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Filepaths

In [None]:
MATCHA_CHECKPOINT = "/workspace/matcha_ljspeech.ckpt"
OUTPUT_FOLDER = "synth_output"

In [None]:
import torch
from omegaconf import DictConfig, ListConfig, OmegaConf

def load_hparams_from_ckpt(ckpt_path: str):
    # 1) 안전 목록에 OmegaConf 타입 허용
    torch.serialization.add_safe_globals([DictConfig, ListConfig])
    try:
        # 2) 우선 안전하게 weights_only=True로 시도
        ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
    except Exception as e1:
        # 3) 신뢰 가능한 파일일 때만 전체 언피클 허용
        #    (임의 코드 실행 위험 있으니 출처가 본인/신뢰 가능일 때만!)
        ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)

    # 4) 다양한 키 호환
    hparams = (
        ckpt.get("hyper_parameters")
        or ckpt.get("hparams")
        or ckpt.get("hparams_initial")
        or {}
    )

    # 5) DictConfig -> dict
    if isinstance(hparams, (DictConfig, ListConfig)):
        hparams = OmegaConf.to_container(hparams, resolve=True)

    return hparams

# === 사용 예시 ===
hparams = load_hparams_from_ckpt(MATCHA_CHECKPOINT)

# hparams 내부를 한 번 확인(필요하면)
# print(hparams)

# 보통 이렇게 바로 인자로 들어갑니다.
model_empty = MatchaTTS(**hparams)   # 가중치 로드 안 함
model_empty.eval()
print(f"Empty model! Parameter count: {count_params(model_empty)}")

In [None]:
def load_model(checkpoint_path):
    model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)
    model.eval()
    return model
count_params = lambda x: f"{sum(p.numel() for p in x.parameters()):,}"

model = load_model(MATCHA_CHECKPOINT)
print(f"Model loaded! Parameter count: {count_params(model)}") # 18.2M

## Load HiFi-GAN (Vocoder)

In [None]:
from vocos import Vocos
from huggingface_hub import snapshot_download, hf_hub_download
import torch

# load vocoder
def load_vocoder(is_local=False, local_path="", device="cuda", hf_cache_dir=None):
    # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
    if is_local:
        print(f"Load vocos from local path {local_path}")
        config_path = f"{local_path}/config.yaml"
        model_path = f"{local_path}/pytorch_model.bin"
    else:
        print("Download Vocos from huggingface charactr/vocos-mel-24khz")
        repo_id = "charactr/vocos-mel-24khz"
        config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
        model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
    vocoder = Vocos.from_hparams(config_path)
    state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
    from vocos.feature_extractors import EncodecFeatures

    if isinstance(vocoder.feature_extractor, EncodecFeatures):
        encodec_parameters = {
            "feature_extractor.encodec." + key: value
            for key, value in vocoder.feature_extractor.encodec.state_dict().items()
        }
        state_dict.update(encodec_parameters)
    vocoder.load_state_dict(state_dict)
    vocoder = vocoder.eval().to(device)

    return vocoder

vocoder = load_vocoder(True, local_path='/workspace/matchatrain/vocoder')
# denoiser = Denoiser(vocoder, mode='zeros')

def load_vocoder(checkpoint_path):
    h = AttrDict(v1)
    hifigan = HiFiGAN(h).to(device)
    hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)['generator'])
    _ = hifigan.eval()
    hifigan.remove_weight_norm()
    return hifigan

# vocoder = load_vocoder('/workspace/generator_v1')
# denoiser = Denoiser(vocoder, mode='zeros')

### Helper functions to synthesise

In [None]:
@torch.inference_mode()
def process_text(text: str):
    x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners2'])[0], 0),dtype=torch.long, device=device)[None]
    x_lengths = torch.tensor([x.shape[-1]],dtype=torch.long, device=device)
    x_phones = sequence_to_text(x.squeeze(0).tolist())
    return {
        'x_orig': text, # Hi how are you today?
        'x': x,         # ids of phoneme embedding
        'x_lengths': x_lengths,
        'x_phones': x_phones # _h_ˈ_a_ɪ_ _h_ˌ_a_ʊ_ _ɑ_ː_ɹ_ _j_u_ː_ _t_ə_d_ˈ_e_ɪ_?_
    }


@torch.inference_mode()
def synthesise(text, spks=None):
    text_processed = process_text(text)
    
    start_t = dt.datetime.now()
    output = model.synthesise(
        text_processed['x'], 
        text_processed['x_lengths'],
        n_timesteps=n_timesteps,
        temperature=temperature,
        spks=spks,
        length_scale=length_scale
    )
    # merge everything to one dict    
    output.update({'start_t': start_t, **text_processed})
    return output

@torch.inference_mode()
def to_waveform(mel, vocoder):
    audio = vocoder.decode(mel).clamp(-1, 1)
    audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
    return audio.cpu().squeeze()
    
def save_to_folder(filename: str, output: dict, folder: str):
    folder = Path(folder)
    folder.mkdir(exist_ok=True, parents=True)
    np.save(folder / f'{filename}', output['mel'].cpu().numpy())
    sf.write(folder / f'{filename}.wav', output['waveform'], 22050, 'PCM_24')

print(process_text("Hi how are you today?"))

## Setup text to synthesise

In [None]:
texts = [
    "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent."
]

### Hyperparameters

In [None]:
## Number of ODE Solver steps
n_timesteps = 10

## Changes to the speaking rate
length_scale=1.0

## Sampling temperature
temperature = 0.667

## Synthesis

In [None]:
import torch
import torch.nn.functional as F
import time

outputs, rtfs = [], []
rtfs_w = []
for i, text in enumerate(tqdm(texts)):
    st = time.time()
    output = synthesise(text) #, torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0))
    mel_100 = F.interpolate(output['mel'].unsqueeze(dim=1), size=(100, output['mel'].size(-1)), mode="bilinear", align_corners=False).squeeze(dim=1)
    print(f"Only model [{time.time() - st}]")
    output['waveform'] = to_waveform(mel_100, vocoder)
    print(f"With vocoder [{time.time() - st}]")

    # Compute Real Time Factor (RTF) with HiFi-GAN
    t = (dt.datetime.now() - output['start_t']).total_seconds()
    rtf_w = t * 22050 / (output['waveform'].shape[-1])

    ## Pretty print
    print(f"{'*' * 53}")
    print(f"Input text - {i}")
    print(f"{'-' * 53}")
    print(output['x_orig'])
    print(f"{'*' * 53}")
    print(f"Phonetised text - {i}")
    print(f"{'-' * 53}")
    print(output['x_phones'])
    print(f"{'*' * 53}")
    print(f"RTF:\t\t{output['rtf']:.6f}")
    print(f"RTF Waveform:\t{rtf_w:.6f}")
    rtfs.append(output['rtf'])
    rtfs_w.append(rtf_w)

    ## Display the synthesised waveform
    ipd.display(ipd.Audio(output['waveform'], rate=22050))

    ## Save the generated waveform
    save_to_folder(i, output, OUTPUT_FOLDER)

print(f"Number of ODE steps: {n_timesteps}")
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")
print(f"Mean RTF Waveform (incl. vocoder):\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}")