In [15]:
import json
import os
from data import LibriLightDataset, TEDLIUMRelease2Dataset, TEDLIUMRelease2SpecificTalkDataset
from sampler import PhonemeKLSampler
import numpy as np
import pickle
import matplotlib.pyplot as plt
import torch

In [26]:
libri_light_dataset = LibriLightDataset(
    subset="10h",
    identifier_to_phones_file_path="phones/librispeech_normalized_phones_no_bcl.json",
    vocab_file_path="vocabs/libri-light_10h.json",
)

extracting vocab...


In [27]:
tedlium2_dataset = TEDLIUMRelease2Dataset(
    identifier_to_phones_file_path="phones/ted2_normalized_phones_no_bcl.json",
    subset="train",
    vocab_file_path="vocabs/libri-light_10h.json",
)

Number of phones not found: 14698


In [28]:
tedlium2_dataset[0]

(0,
 tensor([-0.2830, -0.3177, -0.3274,  ..., -0.1134, -0.0843, -0.0593]),
 tensor(30720),
 tensor([26, 24,  4, 18, 16, 23, 14, 28,  2, 18,  8, 13, 28, 23, 24, 27,  5]),
 tensor(17),
 'today because of\n',
 ['t', 'ah', 'd', 'ey', 'b', 'ih', 'k', 'aa', 'z', 'ah', 'v'])

In [2]:
#with open("phones/converter.json", "r") as f:
#    converter = json.load(f)
#    unique_phones = set(converter.values())
#phone_idx_map = {phone: idx for idx, phone in enumerate(unique_phones)}
#with open("phones/phone_idx_map.json", "w") as f:
#    json.dump(phone_idx_map, f)

In [4]:
with open("phones/phone_to_idx.json", "r") as f:
    phone_to_idx = json.load(f)

In [7]:
phone_counters = {}
stm_dir_path = "datasets/TEDLIUM_release2/train/stm/"
for file in sorted(os.listdir(stm_dir_path)):
    if file.endswith(".stm"):
        talk_id = file.replace(".stm", "")
        phone_counters[talk_id] = [0] * len(phone_to_idx)
        dataset = TEDLIUMRelease2SpecificTalkDataset(
            identifier_to_phones_file_path="phones/ted2_normalized_phones_no_bcl.json",
            subset="train",
            vocab_file_path="vocabs/libri-light_10h.json",
            talk_id=talk_id,
        )
        for _, _, _, _, _, _, phones in dataset:
            for phone in phones:
                if phone is not None:
                    phone_counters[talk_id][phone_to_idx[phone]] += 1

KeyboardInterrupt: 

In [None]:
phone_dists = {}
for talk_id, phone_counter in phone_counters.items():
    phone_counter = np.array(phone_counter, dtype=np.float32)
    phone_counter += 1e-8
    total = sum(phone_counter)
    phone_dists[talk_id] = [count / total for count in phone_counter]


In [23]:
libri_phone_counts = [0] * len(phone_to_idx)
for i in range(len(libri_light_dataset)):
    phones = libri_light_dataset[i][-1]
    for phone in phones:
        if phone is not None:
            libri_phone_counts[phone_to_idx[phone]] += 1
libri_phone_counts = np.array(libri_phone_counts, dtype=np.float32)
libri_phone_distribute = (libri_phone_counts + 1e-8) / sum(libri_phone_counts)

In [29]:
with open("phones/tedlium2_phone_distributes.pkl", "rb") as f:
    tedlium2_phone_distributes = pickle.load(f)
with open("phones/tedlium2_phone_counters.pkl", "rb") as f:
    tedlium2_phone_counters = pickle.load(f)
with open("phones/libri_phone_distribute.pkl", "rb") as f:
    libri_phone_distribute = pickle.load(f)

In [30]:
def calculate_kl_divergence(p, q):
    p = np.array(p)
    q = np.array(q)
    return np.sum(p * np.log(p / q))

In [31]:
sampled_talk_ids = set()
not_sampled_talk_ids = set(tedlium2_phone_distributes.keys())
sampled_phone_counts = np.zeros(len(phone_to_idx), dtype=np.float32)
sampled_duration = 0
TARGET_DURATION = 600
while sampled_duration < TARGET_DURATION:
    print(f"{sampled_duration/ TARGET_DURATION * 100:.2f}%")
    max_kl = 0
    max_kl_talk_id = None
    for talk_id in list(not_sampled_talk_ids):
        if sum(tedlium2_phone_counters[talk_id]) < 100:
            not_sampled_talk_ids.remove(talk_id)
            continue
        sampled_phone_counts_copy = sampled_phone_counts + np.array(tedlium2_phone_counters[talk_id])
        sampled_phone_distribute_copy = (sampled_phone_counts_copy + 1e-8) / sum(sampled_phone_counts_copy)
        kl_divergence = calculate_kl_divergence(sampled_phone_distribute_copy, libri_phone_distribute)
        if kl_divergence > max_kl:
            max_kl = kl_divergence
            max_kl_talk_id = talk_id
    if max_kl_talk_id is not None:
        sampled_talk_ids.add(max_kl_talk_id)
        not_sampled_talk_ids.remove(max_kl_talk_id)
        sampled_phone_counts += np.array(tedlium2_phone_counters[max_kl_talk_id])
        
        sampled_dataset = TEDLIUMRelease2SpecificTalkDataset(
            identifier_to_phones_file_path="phones/ted2_normalized_phones_no_bcl.json",
            subset="train",
            vocab_file_path="vocabs/libri-light_10h.json",
            talk_id=max_kl_talk_id,
        )
        for i in range(len(sampled_dataset)):
            duration = sampled_dataset[i][2].item() / 16000
            sampled_duration += duration

 



0.00%
74.74%
77.15%
78.62%
80.60%
87.99%
90.19%


In [32]:
datasets = []
for talk_id in sampled_talk_ids:
    dataset = TEDLIUMRelease2SpecificTalkDataset(
        identifier_to_phones_file_path="phones/ted2_normalized_phones_no_bcl.json",
        subset="train",
        vocab_file_path="vocabs/libri-light_10h.json",
        talk_id=talk_id,
    )
    datasets.append(dataset)
dataset = torch.utils.data.ConcatDataset(datasets)
with open("tedlium2_difficult_600.pkl", "wb") as f:
    pickle.dump(dataset, f)

In [34]:
sampled_duration

656.3999999999996