In [1]:
%env CUDA_VISIBLE_DEVICES=3

env: CUDA_VISIBLE_DEVICES=3


In [2]:
from pathlib import Path
from utils.load_yaml import HpsYaml
import torch
import numpy as np
from src import build_model
import os
from tqdm import tqdm
from data_objects.kaldi_interface import KaldiInterface

In [3]:
def build_vq_model(model_config, model_file, device):
    model_class = build_model(model_config["model_name"])
    bnf2code_model = model_class(
        model_config["model"]
    ).to(device)
    ckpt = torch.load(model_file, map_location=device)
    bnf2code_model.load_state_dict(ckpt["model"])
    bnf2code_model.eval()
    return bnf2code_model

def get_bnfs(spk_id, utterance_id, kaldi_dir):
    ki = KaldiInterface(wav_scp=str(os.path.join(kaldi_dir, 'wav.scp')),
                        bnf_scp=str(os.path.join(kaldi_dir, 'bnf/feats.scp')))
    bnf = ki.get_feature('_'.join([spk_id, utterance_id]), 'bnf')
    return bnf

In [4]:
# Build models
print("Load VQ-model...")
device = 'cuda'

vq_train_config = Path('/path/to/conf/vq_128.yaml')
bnf2code_config = HpsYaml(vq_train_config) 
bnf2code_model_file = Path('/path/to/ckpt/vq128/loss_step_100000.pth')

bnf2code_model = build_vq_model(bnf2code_config, bnf2code_model_file, device)

Load VQ-model...


In [5]:
@torch.no_grad()
def translate2code(bnf_fpath):
    bnf = np.load(bnf_fpath)
    bnf = torch.from_numpy(bnf).unsqueeze(0).to(device)

    bnf_qn, indices = bnf2code_model.inference(torch.squeeze(bnf))
    
    return bnf_qn.cpu().numpy(), indices.cpu().numpy()


In [6]:
base_bnf_fpath = '/path/to/ppgs'
output_dir = '/path/to/output/vq128'

root, _ , files = next(os.walk(base_bnf_fpath))
for file in tqdm(files):
    bnf_qn, indices = translate2code(os.path.join(root, file))

    os.makedirs(f"{output_dir}/ppgs", exist_ok=True)
    bnf_fname = f"{output_dir}/ppgs/{file}"
    np.save(bnf_fname, bnf_qn, allow_pickle=False)

    os.makedirs(f"{output_dir}/indices", exist_ok=True)
    ind_fname = f"{output_dir}/indices/{file}"
    np.save(ind_fname, indices, allow_pickle=False)

100%|██████████| 25745/25745 [00:43<00:00, 589.17it/s]
