In [None]:
%cd /data/codes/prep_ps_pykaldi/
import pandas as pd
import numpy as np
import torch
import pickle
import json
import re

In [None]:
MAX_LENGTH = 32

metadata_path = "/data/codes/prep_ps_pykaldi/prep_data/jsonl_v1/info_qt_10_trainset.jsonl"
align_path = "/data/codes/prep_ps_pykaldi/exp/sm/train_new/merged_align.out"
gop_path = '/data/codes/prep_ps_pykaldi/exp/sm/train_new/merged_gop.pkl'
out_dir = "/data/codes/prep_ps_pykaldi/exp/sm/train_new"

In [None]:
def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        lines = f.readlines()
        lines = [json.loads(line.strip()) for line in lines]
    
    lines = pd.DataFrame(lines)
    return lines

metadata = load_jsonl(metadata_path)
metadata.head(2)

In [None]:
with open(gop_path, 'rb') as f:
    gop_features = pickle.load(f)
    
metadata = metadata[metadata.id.isin(gop_features)]
metadata.head(2)

In [None]:
def extract_gop_feature(id):
    sample = gop_features[str(id)]
    features = [
        np.array(feature) for feature, phoneme in zip(sample["gopt"], sample["phones"][0])
        if phoneme != "SIL"
    ]
    return np.stack(features)

def extract_phonemes(id):
    sample = gop_features[str(id)]
    phonemes = [
        re.sub("\d", "",phoneme.split("_")[0]) for phoneme in sample["phones"][0]
        if phoneme != "SIL"
    ]
    return phonemes

metadata["features"] = metadata.id.apply(lambda x: extract_gop_feature(x))
metadata["kaldi_phoneme"] = metadata.id.apply(lambda x: extract_phonemes(x))
metadata.head(2)

In [None]:
align_df = pd.read_csv(align_path, names=["id", "alignment"], sep="\t")

def extract_duration(alignment):
    alignment = json.loads(alignment)
    durations = []
    
    for phoneme, start, duration in alignment:
        if phoneme == "SIL":
            continue
        durations.append(round(duration * 0.02, 4))

    return durations

def extract_phonemes(alignment):
    alignment = json.loads(alignment)
    phonemes = []
    
    for phoneme, start, duration in alignment:
        if phoneme == "SIL":
            continue
        phonemes.append(phoneme.split("_")[0])

    return phonemes

align_df["durations"] = align_df["alignment"].apply(lambda x: extract_duration(x))
align_df["phonemes"] = align_df["alignment"].apply(lambda x: extract_phonemes(x))
align_df["id"] = align_df["id"].apply(str)
align_df.head(2)

In [None]:
metadata = pd.merge(metadata, align_df[["id", "durations", "alignment"]], how="left", on="id")
metadata.head(2)

In [None]:
metadata["length"] = metadata["arpas"].apply(len)
metadata["length"].hist(bins=100)

### Extract sentence scores

In [None]:
sentence_scores = []

for index in metadata.index:
    sentence_score = metadata["utterance_scores"][index].copy()

    sentence_scores.append(sentence_score)

sentence_scores = torch.tensor(sentence_scores)
sentence_scores = sentence_scores.numpy()
print(sentence_scores.shape)
np.save(f'{out_dir}/sentence_scores.npy', sentence_scores)
sentence_scores = None

### Extract word scores

In [None]:
word_scores = []

for index in metadata.index:
    word_score = metadata["word_scores"][index].copy()
    word_id = metadata["word_ids"][index].copy()

    word_score_in_phone_levels = []
    for wid in word_id:
        word_score_in_phone_levels.append(word_score[wid])

    padding = [-1,]*(MAX_LENGTH-len(word_score_in_phone_levels))
    word_score_in_phone_levels = word_score_in_phone_levels + padding
    word_score_in_phone_levels = torch.tensor(word_score_in_phone_levels)
    word_scores.append(word_score_in_phone_levels)

word_scores = torch.stack(word_scores, dim=0)
word_scores = word_scores.numpy()
print(word_scores.shape)
np.save(f'{out_dir}/word_scores.npy', word_scores)
word_scores = None

### Extract word ids

In [None]:
word_ids = []

for index in metadata.index:
    word_id = metadata["word_ids"][index].copy()

    padding = [-1,]*(MAX_LENGTH-len(word_id))
    word_id = word_id + padding
    word_id = torch.tensor(word_id)
    word_ids.append(word_id)

word_ids = torch.stack(word_ids, dim=0)
word_ids = word_ids.numpy()
print(word_ids.shape)
np.save(f'{out_dir}/word_ids.npy', word_ids)
word_ids = None

### Extract gop

In [None]:
gops = []

for index in metadata.index:
    gop = metadata["features"][index].copy()

    padding = [[0,]*len(gop[0]),]*(MAX_LENGTH-len(gop))
    gop = gop.tolist() + padding
    gop = torch.tensor(gop)
    gops.append(gop)

gops = torch.stack(gops, dim=0)
gops = gops.numpy()
print(gops.shape)
np.save(f'{out_dir}/gop.npy', gops)
gops = None

### Extract duration

In [None]:
durations = []

for index in metadata.index:
    duration = metadata["durations"][index].copy()

    padding = [0, ]*(MAX_LENGTH-len(duration))

    duration += padding
    duration = torch.tensor(duration)
    durations.append(duration)

durations = torch.stack(durations, dim=0)
durations = durations.numpy()
np.save(f'{out_dir}/duration.npy', durations)
durations = None

### Extract phone scores

In [None]:
phone_scores = []

for index in metadata.index:
    phone_score = metadata["phone_scores"][index].copy()

    padding = [-1, ]*(MAX_LENGTH-len(phone_score))

    phone_score += padding
    phone_score = torch.tensor(phone_score)
    phone_scores.append(phone_score)

phone_scores = torch.stack(phone_scores, dim=0)
phone_scores = phone_scores.numpy()
np.save(f'{out_dir}/phone_scores.npy', phone_scores)
phone_scores = None

### Extract phone ids

In [None]:
phone_dict_path =  "/data/codes/prep_ps_pykaldi/resources/phone_dict.json"
with open(phone_dict_path, "r", encoding="utf-8") as f:
    phone_dict = json.load(f)

In [None]:
phone_ids = []

pad_token_id = phone_dict["PAD"]
for index in metadata.index:
    phoneme = metadata["arpas"][index].copy()

    phoneme = [re.sub("\d", "", phn) for phn in phoneme]
    phoneme = [phone_dict[phn] for phn in phoneme]
    padding = [pad_token_id, ]*(MAX_LENGTH-len(phoneme))

    phoneme += padding
    phone_ids.append(torch.tensor(phoneme))

phone_ids = torch.stack(phone_ids, dim=0)
phone_ids = phone_ids.numpy()
np.save(f'{out_dir}/phone_ids.npy', phone_ids)
phone_ids = None

### Extract WavLM Feature

In [None]:
%cd /data/codes/prep_ps_pykaldi/wavlm
import torch
from wavlm import WavLM, WavLMConfig
from tqdm import tqdm
import librosa
import pandas as pd
import json

In [None]:
pretrained_path = "/data/codes/prep_ps_pykaldi/pretrained/wavlm-base+.pt"
checkpoint = torch.load(pretrained_path)

config = WavLMConfig(checkpoint['cfg'])
model = WavLM(config).eval().cuda()
model.load_state_dict(checkpoint['model'])

In [None]:
def extract_feature(alignment, features):
    index = 0
    phonemes = []
    indices = -1 * torch.ones(alignment[-1][1] + alignment[-1][2])
    for phoneme, start_frame, duration in alignment:
        if phoneme == "SIL":
            continue
        end_frame = start_frame + duration
        indices[start_frame:end_frame] = index
        phonemes.append(phoneme)
        index += 1

    indices[indices==-1] = indices.max() + 1

    indices = torch.nn.functional.one_hot(indices.long(), num_classes=int(indices.max().item())+1).cuda()
    indices = indices / indices.sum(0, keepdim=True)
    
    if features.shape[0] != indices.shape[0]:
        features = features[0:indices.shape[0]]
    features = torch.matmul(indices.transpose(0, 1), features)

    return features[:-1].cpu(), phonemes

In [None]:
wav_dir = "/data/codes/prep_ps_pykaldi/prep_data/wav"

wavlm_features = []
for index in tqdm(metadata.index):
    wav_id = metadata["id"][index]
    alignment = metadata["alignment"][index]

    alignment = json.loads(alignment)
    wav, sr = librosa.load(f'{wav_dir}/{wav_id}.wav', sr=16000)

    input_values = torch.from_numpy(wav).unsqueeze(0).cuda()
    with torch.no_grad():
        features = model.extract_features(input_values)[0]

    index = torch.arange(features.shape[1]).unsqueeze(-1).cuda()
    expanded_index = index.expand((-1, 2)).flatten()
    features = features[0][expanded_index]

    features, phonemes = extract_feature(alignment, features)

    features = torch.concat([features, torch.zeros(MAX_LENGTH-len(phonemes), 768)], axis=0)
    wavlm_features.append(features)

wavlm_features = torch.stack(wavlm_features, dim=0)
wavlm_features = wavlm_features.numpy()
np.save(f'{out_dir}/wavlm_features.npy', wavlm_features)
wavlm_features = None

### Extract Hubert Feature

In [None]:
from transformers import AutoProcessor, HubertModel
from datasets import load_dataset
import soundfile as sf
import torch
from tqdm import tqdm
import librosa
import pandas as pd
import json

In [None]:
processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft").eval().cuda()

In [None]:
def extract_feature(alignment, features):
    index = 0
    phonemes = []
    indices = -1 * torch.ones(alignment[-1][1] + alignment[-1][2])
    for phoneme, start_frame, duration in alignment:
        if phoneme == "SIL":
            continue
        end_frame = start_frame + duration
        indices[start_frame:end_frame] = index
        phonemes.append(phoneme)
        index += 1

    indices[indices==-1] = indices.max() + 1

    indices = torch.nn.functional.one_hot(indices.long(), num_classes=int(indices.max().item())+1).cuda()
    indices = indices / indices.sum(0, keepdim=True)
    
    if features.shape[0] != indices.shape[0]:
        print(features.shape, indices.shape)
        features = features[0:indices.shape[0]]
    features = torch.matmul(indices.transpose(0, 1), features)

    return features[:-1].cpu(), phonemes

In [None]:
wav_dir = "/data/codes/prep_ps_pykaldi/prep_data/wav"

hubert_features = []
for index in tqdm(metadata.index):
    wav_id = metadata["id"][index]
    alignment = metadata["alignment"][index]

    alignment = json.loads(alignment)
    wav, sr = librosa.load(f'{wav_dir}/{wav_id}.wav', sr=16000)

    with torch.no_grad():
        features = processor(wav, return_tensors="pt", sampling_rate=16000)
        features = model(features["input_values"].cuda()).last_hidden_state

    index = torch.arange(features.shape[1]).unsqueeze(-1).cuda()
    expanded_index = index.expand((-1, 2)).flatten()
    features = features[0][expanded_index]

    features, phonemes = extract_feature(alignment, features)

    features = torch.concat([features, torch.zeros(MAX_LENGTH-len(phonemes), 1024)], axis=0)
    hubert_features.append(features)

hubert_features = torch.stack(hubert_features, dim=0)
hubert_features = hubert_features.numpy()
# np.save(f'{out_dir}/hubert_features.npy', hubert_features)
hubert_features = None

### Extract Wav2vec Feature

In [None]:
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from datasets import load_dataset
import soundfile as sf
import torch
from tqdm import tqdm
import librosa
import pandas as pd
import json

In [None]:
model = Wav2Vec2Model.from_pretrained("/data/codes/prep_ps_pykaldi/pretrained/charsiu").eval().cuda()

In [None]:
def extract_feature(alignment, features):
    index = 0
    phonemes = []
    indices = -1 * torch.ones(alignment[-1][1] + alignment[-1][2])
    for phoneme, start_frame, duration in alignment:
        if phoneme == "SIL":
            continue
        end_frame = start_frame + duration
        indices[start_frame:end_frame] = index
        phonemes.append(phoneme)
        index += 1

    indices[indices==-1] = indices.max() + 1

    indices = torch.nn.functional.one_hot(indices.long(), num_classes=int(indices.max().item())+1).cuda()
    indices = indices / indices.sum(0, keepdim=True)
    
    if features.shape[0] != indices.shape[0]:
        print(features.shape[0], indices.shape[0])
        features = features[0:indices.shape[0]]
        print("Hello")
    features = torch.matmul(indices.transpose(0, 1), features)

    return features[:-1].cpu(), phonemes

In [None]:
wav_dir = "/data/codes/prep_ps_pykaldi/prep_data/wav"

hubert_features = []
for index in tqdm(metadata.index):
    wav_id = metadata["id"][index]
    alignment = metadata["alignment"][index]

    alignment = json.loads(alignment)
    wav, sr = librosa.load(f'{wav_dir}/{wav_id}.wav', sr=16000)
    features = torch.from_numpy(wav).unsqueeze(0).cuda()
    with torch.no_grad():
        # features = processor(wav, return_tensors="pt", sampling_rate=16000)
        features = model(features).last_hidden_state
        if index % 100 == 0:
            torch.cuda.empty_cache()

    features, phonemes = extract_feature(alignment, features[0])

    features = torch.concat([features, torch.zeros(MAX_LENGTH-len(phonemes), 768)], axis=0)
    hubert_features.append(features)

hubert_features = torch.stack(hubert_features, dim=0)
hubert_features = hubert_features.numpy()
# np.save(f'{out_dir}/wav2vec_features.npy', hubert_features)
hubert_features = None