In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import datetime as dt
from pathlib import Path

import IPython.display as ipd
import numpy as np
import soundfile as sf
import torch
from tqdm.auto import tqdm
import pandas as pd

# Hifigan imports
from matcha.hifigan.config import v1
from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.env import AttrDict
from matcha.hifigan.models import Generator as HiFiGAN
# Matcha imports
from matcha.models.matcha_tts import MatchaTTS
from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.model import denormalize
from matcha.utils.utils import get_user_data_dir, intersperse

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
# This allows for real time code changes being reflected in the notebook, no need to restart the kernel

## Filepaths

In [None]:
import torch
from omegaconf import DictConfig, ListConfig, OmegaConf

def load_hparams_from_ckpt(ckpt_path: str):
    # 1) 안전 목록에 OmegaConf 타입 허용
    torch.serialization.add_safe_globals([DictConfig, ListConfig])
    try:
        # 2) 우선 안전하게 weights_only=True로 시도
        ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
    except Exception as e1:
        # 3) 신뢰 가능한 파일일 때만 전체 언피클 허용
        ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)

    # 4) 다양한 키 호환
    hparams = (
        ckpt.get("hyper_parameters")
        or ckpt.get("hparams")
        or ckpt.get("hparams_initial")
        or {}
    )

    # 5) DictConfig -> dict
    if isinstance(hparams, (DictConfig, ListConfig)):
        hparams = OmegaConf.to_container(hparams, resolve=True)

    return hparams

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MATCHA_CHECKPOINT = "/workspace/last.ckpt"
OUTPUT_FOLDER = "synth_output"

# === 사용 예시 ===
hparams = load_hparams_from_ckpt(MATCHA_CHECKPOINT)

# 보통 이렇게 바로 인자로 들어갑니다.
# model_empty = MatchaTTS(**hparams)   # 가중치 로드 안 함
# model_empty.eval()
# print(f"Empty model! Parameter count: {count_params(model_empty)}")

In [None]:
def load_model(checkpoint_path):
    model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)
    model.eval()
    return model

count_params = lambda x: f"{sum(p.numel() for p in x.parameters()):,}"

model = load_model(MATCHA_CHECKPOINT)
print(f"Model loaded! Parameter count: {count_params(model)}") # 18.2M

## Load HiFi-GAN (Vocoder)

In [None]:
from vocos import Vocos
from huggingface_hub import snapshot_download, hf_hub_download
import torch

def load_vocoder(checkpoint_path):
    h = AttrDict(v1)
    hifigan = HiFiGAN(h).to(device)
    hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)['generator'])
    _ = hifigan.eval()
    hifigan.remove_weight_norm()
    return hifigan

vocoder = load_vocoder('/workspace/generator_v1')
denoiser = Denoiser(vocoder, mode='zeros')

In [None]:
@torch.inference_mode()
def process_text(text: str):
    x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners2'])[0], 0),dtype=torch.long, device=device)[None]
    x_lengths = torch.tensor([x.shape[-1]],dtype=torch.long, device=device)
    x_phones = sequence_to_text(x.squeeze(0).tolist())
    return {
        'x_orig': text, # Hi how are you today?
        'x': x,         # ids of phoneme embedding
        'x_lengths': x_lengths,
        'x_phones': x_phones # _h_ˈ_a_ɪ_ _h_ˌ_a_ʊ_ _ɑ_ː_ɹ_ _j_u_ː_ _t_ə_d_ˈ_e_ɪ_?_
    }

@torch.inference_mode()
def synthesise(text, spks=None):
    text_processed = process_text(text)
    
    start_t = dt.datetime.now()
    output = model.synthesise(
        text_processed['x'], 
        text_processed['x_lengths'],
        n_timesteps=n_timesteps,
        temperature=temperature,
        spks=spks,
        length_scale=length_scale
    )
    # merge everything to one dict    
    output.update({'start_t': start_t, **text_processed})
    return output

@torch.inference_mode()
def to_waveform(mel, vocoder):
    audio = vocoder(mel).clamp(-1, 1)
    audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
    return audio.cpu().squeeze()
    
def save_to_folder(filename: str, output: dict, folder: str):
    folder = Path(folder)
    folder.mkdir(exist_ok=True, parents=True)
    np.save(folder / f'{filename}', output['mel'].cpu().numpy())
    sf.write(folder / f'{filename}.wav', output['waveform'], 22050, 'PCM_24')

print(process_text("Hi how are you today?"))

## Setup text to synthesise

In [None]:
import pandas as pd

eval_df = pd.read_csv("/workspace/matchatrain/LJSpeech-1.1/test.csv", sep="|", header=None)
print(len(eval_df))
print(eval_df.iloc[0])

texts =  [eval_df.iloc[i][1] for i in range(len(eval_df))]
print(texts[0])

In [None]:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v3-turbo"

wmodel = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
wmodel.to(device)
processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=wmodel,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
)

## Number of ODE Solver steps
n_timesteps = 10

## Changes to the speaking rate
length_scale=1.0

## Sampling temperature
temperature = 0.667

## Synthesis

In [None]:
import re, unicodedata

# 영어용 간단 정규화 (소문자, 기호 제거, 공백 정리)
_punct = re.compile(r"[^\w\s]")
def _normalize(s: str) -> str:
    s = unicodedata.normalize("NFKC", s).lower()
    s = _punct.sub(" ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

def _lev(a, b):  # Levenshtein distance
    n, m = len(a), len(b)
    dp = list(range(m+1))
    for i in range(1, n+1):
        prev, dp[0] = dp[0], i
        for j in range(1, m+1):
            cur = dp[j]
            cost = 0 if a[i-1] == b[j-1] else 1
            dp[j] = min(dp[j] + 1, dp[j-1] + 1, prev + cost)
            prev = cur
    return dp[m]

def wer_en(ref: str, hyp: str) -> float:
    r = _normalize(ref).split()
    h = _normalize(hyp).split()
    if not r:
        return 0.0 if not h else 1.0
    return _lev(r, h) / len(r)

def cer_en(ref: str, hyp: str) -> float:
    r = list(_normalize(ref).replace(" ", ""))
    h = list(_normalize(hyp).replace(" ", ""))
    if not r:
        return 0.0 if not h else 1.0
    return _lev(r, h) / len(r)

In [None]:
import torch
import torch.nn.functional as F
import time
import torchaudio
from IPython.display import Audio

resampler = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)

wer_list, cer_list = [], []
outputs, rtfs = [], []
rtfs_w = []
for i, text in enumerate(tqdm(texts)):
    st = time.time()
    output = synthesise(text) #, torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0))
    waveform = to_waveform(output['mel'], vocoder)
    mono_16k = resampler(waveform)
    
    script = pipe(mono_16k)
    # print(text, '\n')
    # print(script['text'], '\n\n')

    _w = wer_en(text, script['text'])
    _c = cer_en(text, script['text'])
    wer_list.append(_w)
    cer_list.append(_c)

    # display(Audio(waveform.cpu().detach().numpy(), rate=22050))

    if i%50==49:
        print(f"\nMean WER: {sum(wer_list)/len(wer_list):.4f}")
        print(f"Mean CER: {sum(cer_list)/len(cer_list):.4f}")