In [None]:
import torch
from wavlm import WavLM, WavLMConfig
from tqdm import tqdm
import librosa
import pandas as pd
import json

In [None]:
from transformers import AutoProcessor, HubertModel
from datasets import load_dataset
import soundfile as sf

processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft").cuda()

In [None]:
# pretrained_path = "/data/codes/prep_ps_pykaldi/pretrained/wavlm-base+.pt"
# checkpoint = torch.load(pretrained_path)

# config = WavLMConfig(checkpoint['cfg'])
# model = WavLM(config).cuda()
# model.load_state_dict(checkpoint['model'])
# _ = model.eval()

In [None]:
align_path = "/data/codes/prep_ps_pykaldi/exp/sm/test/merged_align.out"
align_df = pd.read_csv(align_path, names=["id", "alignment"], sep="\t")
align_df.head(2)

In [None]:
def extract_feature(alignment, features):
    index = 0
    phonemes = []
    indices = -1 * torch.ones(alignment[-1][1] + alignment[-1][2])
    for phoneme, start_frame, duration in alignment:
        if phoneme == "SIL":
            continue
        end_frame = start_frame + duration
        indices[start_frame:end_frame] = index
        phonemes.append(phoneme)
        index += 1

    indices[indices==-1] = indices.max() + 1

    indices = torch.nn.functional.one_hot(indices.long(), num_classes=int(indices.max().item())+1).cuda()
    indices = indices / indices.sum(0, keepdim=True)
    
    if features.shape[0] != indices.shape[0]:
        features = features[0:indices.shape[0]]
    features = torch.matmul(indices.transpose(0, 1), features)

    return features, phonemes

In [None]:
wav_dir = "/data/codes/prep_ps_pykaldi/prep_data/wav"

wavlm_features = []
for index in tqdm(align_df.index):
    wav_id = align_df["id"][index]
    alignment = align_df["alignment"][index]

    alignment = json.loads(alignment)
    wav, sr = librosa.load(f'{wav_dir}/{wav_id}.wav', sr=16000)

    # input_values = torch.from_numpy(wav).unsqueeze(0).cuda()
    with torch.no_grad():
        features = processor(wav, return_tensors="pt", sampling_rate=16000)
        features = model(features["input_values"].cuda()).last_hidden_state

    index = torch.arange(features.shape[1]).unsqueeze(-1).cuda()
    expanded_index = index.expand((-1, 2)).flatten()
    features = features[0][expanded_index]

    features, phonemes = extract_feature(alignment, features)

    sample = {
        "id": wav_id,
        "phoneme": phonemes,
        "features": features[0:len(phonemes)].cpu().numpy()
    }
    wavlm_features.append(sample)