In [None]:
from vocos import Vocos
from huggingface_hub import snapshot_download, hf_hub_download
import torch

# load vocoder
def load_vocoder(is_local=False, local_path="", device="cuda", hf_cache_dir=None):
    # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
    if is_local:
        print(f"Load vocos from local path {local_path}")
        config_path = f"{local_path}/config.yaml"
        model_path = f"{local_path}/pytorch_model.bin"
    else:
        print("Download Vocos from huggingface charactr/vocos-mel-24khz")
        repo_id = "charactr/vocos-mel-24khz"
        config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
        model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
    vocoder = Vocos.from_hparams(config_path)
    state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
    from vocos.feature_extractors import EncodecFeatures

    if isinstance(vocoder.feature_extractor, EncodecFeatures):
        encodec_parameters = {
            "feature_extractor.encodec." + key: value
            for key, value in vocoder.feature_extractor.encodec.state_dict().items()
        }
        state_dict.update(encodec_parameters)
    vocoder.load_state_dict(state_dict)
    vocoder = vocoder.eval().to(device)

    return vocoder

In [None]:
voco = load_vocoder(True, local_path='/workspace/matchatrain/vocoder')

In [None]:
import torch, torchaudio, soundfile as sf
from vocos import Vocos
import matplotlib.pyplot as plt
from audiotools import AudioSignal
import time

wav_path = "/workspace/matchatrain/gpt-4o-mini.mp3"
device = "cuda" if torch.cuda.is_available() else "cpu"

y, sr = torchaudio.load(wav_path)
if y.size(0) > 1:
    y = y.mean(dim=0, keepdim=True)
if sr != 24000:
    y = torchaudio.functional.resample(y, sr, 24000)
sr = 24000

# 2) 멜 스펙트로그램 (24kHz, n_mels=100, hop=256)
mel_fn = torchaudio.transforms.MelSpectrogram(
    sample_rate=sr, n_fft=1024, hop_length=256, win_length=1024,
    f_min=0.0, f_max=sr/2, n_mels=100, power=1.0, center=True, norm=None
)
mel = mel_fn(y)                                 # [1, 100, T]
mel = torch.log(mel.clamp(min=1e-5))            # log-mel

# 시각화
plt.figure(figsize=(8, 4))
plt.imshow(mel[0].cpu().numpy(), aspect='auto', origin='lower',
           interpolation='none')
plt.colorbar(label="Log-Mel energy")
plt.xlabel("Frames")
plt.ylabel("Mel bins (100)")
plt.title("Log-Mel Spectrogram")
plt.tight_layout()
plt.show()

# vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device).eval()
st = time.time()
with torch.inference_mode():
    wav_hat = voco.decode(mel.to(device))      # [1, samples]
print(time.time() - st)

sf.write("reconstructed.wav", wav_hat[0].cpu().numpy(), sr)
AudioSignal(wav_hat[0].cpu().numpy(), sample_rate=sr).widget()