In [None]:
import hydra
import torch
from audiocraft.models import MultiBandDiffusion
from audiotools import AudioSignal
from df.enhance import enhance, init_df
from huggingface_hub import hf_hub_download
from vocos import Vocos

from pflow_encodec.data.tokenizer import EncodecTokenizer, TextTokenizer

In [None]:
def load_model(ckpt_path, device="cpu"):
    ckpt = torch.load(ckpt_path, map_location="cpu")

    model = hydra.utils.instantiate(ckpt["model_config"])
    model.load_state_dict(ckpt["state_dict"])
    model = model.eval().to(device)

    return model, ckpt["data_config"]

In [None]:
ckpt_path = hf_hub_download(repo_id="seastar105/pflow-encodec-libritts", filename="libritts_base.ckpt")

In [None]:
model, data_config = load_model(ckpt_path, "cuda")

In [None]:
prompt_path = hf_hub_download(repo_id="seastar105/pflow-encodec-libritts", filename="prompt_samples/prompt1.wav")

In [None]:
text_tokenizer = TextTokenizer()

In [None]:
encodec_tokenizer = EncodecTokenizer()

In [None]:
df_model, df_states, _ = init_df()

In [None]:
vocos_model = Vocos.from_pretrained("charactr/vocos-encodec-24khz").eval().cuda()

In [None]:
mbd_model = MultiBandDiffusion.get_mbd_24khz(bw=6)

In [None]:
@torch.inference_mode()
def pflow_inference(
    model, text, prompt_path, data_config, cfg_scale=1.0, nfe=16, ode_method="midpoint", return_latent=False
):
    device = next(model.parameters()).device
    prompt = encodec_tokenizer.encode_file(prompt_path).to(device)
    mean = data_config["mean"]
    std = data_config["std"]
    upscale_ratio = data_config["text2latent_ratio"]

    text_token = text_tokenizer.encode_text(text).to(device).unsqueeze(0)
    prompt = (prompt - mean) / std
    result = model.generate(
        text_token, prompt, cfg_scale=cfg_scale, nfe=nfe, ode_method=ode_method, upscale_ratio=upscale_ratio
    )
    result = result * std + mean
    if return_latent:
        return result.cpu()
    recon = encodec_tokenizer.decode_latents(result.to(device=encodec_tokenizer.device, dtype=encodec_tokenizer.dtype))
    return recon.cpu()

In [None]:
@torch.inference_mode()
def mbd_decode(mbd_model, latent):
    codes = encodec_tokenizer.quantize_latents(latent.to(device=encodec_tokenizer.device))
    recon = mbd_model.tokens_to_wav(codes)
    return recon.cpu()

In [None]:
@torch.inference_mode()
def vocos_decode(vocos_model, latent):
    codes = encodec_tokenizer.quantize_latents(latent.to(device=encodec_tokenizer.device)).squeeze()[:16, :]
    features = vocos_model.codes_to_features(codes)
    bandwidth_id = torch.tensor([3]).to(features.device)
    audio = vocos_model.decode(features, bandwidth_id=bandwidth_id)
    return audio.cpu()

In [None]:
@torch.inference_mode()
def df_enhance(df_model, df_state, audio):
    if audio.ndim == 3:
        audio = audio.squeeze(0)
    enhanced = enhance(df_model, df_state, audio)
    return enhanced

In [None]:
AudioSignal(prompt_path).embed(display=False)

In [None]:
text = "P-Flow encodec is Text-to-Speech model trained on Encodec latent using Flow Matching."

In [None]:
pflow_result = pflow_inference(model, text, prompt_path, data_config, cfg_scale=1.0, nfe=16, ode_method="midpoint")
pflow_signal = AudioSignal(pflow_result, 24000).ensure_max_of_audio()
pflow_signal.embed(display=False)

In [None]:
latents = encodec_tokenizer.encode_audio(pflow_result.to(encodec_tokenizer.device))
mbd_recon = mbd_decode(mbd_model, latents)
mbd_signal = AudioSignal(mbd_recon, 24000).ensure_max_of_audio()
mbd_signal.embed(display=False)

In [None]:
mbd_df_result = df_enhance(df_model, df_states, mbd_recon)
mbd_df_signal = AudioSignal(mbd_df_result, 24000).ensure_max_of_audio()
mbd_df_signal.embed(display=False)

In [None]:
latents = encodec_tokenizer.encode_audio(pflow_result.to(encodec_tokenizer.device))
vocos_recon = vocos_decode(vocos_model, latents)
vocos_signal = AudioSignal(vocos_recon, 24000).ensure_max_of_audio()
vocos_signal.embed(display=False)

In [None]:
vocos_df_result = df_enhance(df_model, df_states, vocos_recon)
vocos_df_signal = AudioSignal(vocos_df_result, 24000).ensure_max_of_audio()
vocos_df_signal.embed(display=False)