In [28]:
from data import LibriLightDataset
import json
import numpy as np
import torch
import copy
train_dataset = LibriLightDataset(
    subset="9h",
    identifier_to_phones_file_path="phones/librispeech_normalized_phones_no_bcl.json",
    vocab_file_path="vocabs/libri-light_9h.json"
)
test_dataset = LibriLightDataset(
    subset="1h",
    identifier_to_phones_file_path="phones/librispeech_normalized_phones_no_bcl.json",
    vocab_file_path="vocabs/libri-light_9h.json"
)


In [30]:
target_phonemes = set(
    ["aa",
    "ae",
    "ah",
    "aw",
    "ay",
    "b",
    "ch",
    "d",
    "dh",
    "dx",
    "eh",
    "axr",
    "ey",
    "f",
    "g",
    "bcl",
    "hh",
    "ih",
    "iy",
    "jh",
    "k",
    "el",
    "em",
    "en",
    "eng",
    "ow",
    "oy",
    "p",
    "r",
    "s",
    "sh",
    "t",
    "th",
    "uh",
    "uw",
    "v",
    "w",
    "y",
    "z",]
)

In [31]:
phone_to_idx = {phone: idx for idx, phone in enumerate(target_phonemes)}

In [35]:
def calculate_tf_idf_over_ds(dataset: torch.utils.data.Dataset):
    target_phones = set(
        ["aa",
        "ae",
        "ah",
        "aw",
        "ay",
        "b",
        "ch",
        "d",
        "dh",
        "dx",
        "eh",
        "axr",
        "ey",
        "f",
        "g",
        "bcl",
        "hh",
        "ih",
        "iy",
        "jh",
        "k",
        "el",
        "em",
        "en",
        "eng",
        "ow",
        "oy",
        "p",
        "r",
        "s",
        "sh",
        "t",
        "th",
        "uh",
        "uw",
        "v",
        "w",
        "y",
    "z",]
    )
    phone_to_idx = {phone: idx for idx, phone in enumerate(target_phones)}
    df = np.zeros(39, dtype=np.float32)
    tf = np.zeros(39, dtype=np.float32)

    for idx in range(len(dataset)):
        phones = dataset[idx][-1]
        unique_phones = set(phones)
        for target_phone in list(target_phones):
            if target_phone in unique_phones:
                df[phone_to_idx[target_phone]] += 1
        for phone in phones:
            if phone is not None:
                tf[phone_to_idx[phone]] += 1
                
    tf = tf / tf.sum()

    df = df / len(dataset)
    df += 1e-8
    idf = np.log(1 / df)

    return tf * idf

In [52]:
def cos_sim(a, b):
    a += 1e-8
    b += 1e-8
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

In [58]:
test_tf_idf_over_ds = calculate_tf_idf_over_ds(test_dataset)

In [59]:
token_counter = np.zeros(39, dtype=np.float32)
document_counter = np.zeros(39, dtype=np.float32)
not_sampled_indices = set(range(len(train_dataset)))
sampled_indices = set()
limit_duration = 600
sampled_duration = 0
while sampled_duration < limit_duration:
    similarities = {}
    count = 0
    for idx in not_sampled_indices:
        count += 1
        print(f"{sampled_duration / limit_duration * 100:.2f}% Inner: {count / len(not_sampled_indices) * 100:.2f}%", end="\r")
        token_counter_copy = copy.deepcopy(token_counter)
        document_counter_copy = copy.deepcopy(document_counter)
        phones = train_dataset[idx][-1]
        unique_phones = set(phones)
        for target_phone in list(target_phonemes):
            if target_phone in unique_phones:
                document_counter_copy[phone_to_idx[target_phone]] += 1
        for phone in phones:
            if phone is not None:
                token_counter_copy[phone_to_idx[phone]] += 1
        tf = token_counter_copy / token_counter_copy.sum()
        df = document_counter_copy / (len(sampled_indices) + 1)
        df += 1e-8
        idf = np.log(1 / df)
        sampled_tf_idf_over_ds = tf * idf
        similarity = cos_sim(test_tf_idf_over_ds, sampled_tf_idf_over_ds)
        similarities[idx] = similarity

    max_similaritiy_idx = max(similarities, key=similarities.get)
    sampled_indices.add(max_similaritiy_idx)
    not_sampled_indices.remove(max_similaritiy_idx)
    sampled_duration += train_dataset[max_similaritiy_idx][2] / 16000
    phones = train_dataset[max_similaritiy_idx][-1]
    unique_phones = set(phones)
    for target_phone in list(target_phonemes):
        if target_phone in unique_phones:
            document_counter[phone_to_idx[target_phone]] += 1
    for phone in phones:
        if phone is not None:
            token_counter[phone_to_idx[phone]] += 1

97.43% Inner: 100.00%

In [60]:
import pickle
sampled_dataset = torch.utils.data.Subset(train_dataset, list(sampled_indices))
with open("max_sim_sampled_dataset.pkl", "wb") as f:
    pickle.dump(sampled_dataset, f)