In [17]:
import nltk
from model import MyWav2Vec2ConformerForPreTraining
import torch
import numpy as np
import pickle
from typing import List
from gruut import sentences
import gruut

In [16]:
text = "brighter than early dawn's most brilliant dye are blown clear bands of color through the sky that swirl and sweep and meet to break and foam like rainbow veils upon a bubble's dome"
phoneme_list1 = []
for sent in sentences(text, lang="en-us"):
    for word in sent:
        if word.phonemes:
            for phoneme in word.phonemes:
                phoneme_list1.append(phoneme.lstrip("ˈ"))
text = "in a sunset glowing of crimson and gold she lies the glory of the world a beached king's galley whose sails are furled who is hung with tapestries rich and old"
phoneme_list2 = []
for sent in sentences(text, lang="en-us"):
    for word in sent:
        if word.phonemes:
            for phoneme in word.phonemes:
                phoneme_list2.append(phoneme.lstrip("ˈ"))

In [None]:
gruut.

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
model = MyWav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large").to(DEVICE)

In [4]:
G = model.config.num_codevector_groups
V = model.config.num_codevectors_per_group
max_index = G * V
eps = 1e-8

In [9]:
def calculate_average_kl_divergence(indices: List[np.ndarray], max_index: int) -> float:
    """
    indices: (all_data_size, num_codebooks * seq_len)
            all_data_size: 同一の発話（マイク）のデータ数
            同一の発話内容（マイク）におけるすべてのインデックス系列
    同一の発話内容(マイク)におけるすべてのインデックス系列の平均KLダイバージェンスを計算する
    """
    total_number = len(indices) * len(indices)
    counter = 0
    eps = 1e-8
    bins = np.linspace(0, max_index, max_index + 1)
    average_kl_divergence = 0
    for i in range(len(indices)):
        for j in range(len(indices)):
            counter += 1
            if counter % 10000 == 0:
                print(f"progress: {counter}, {counter / total_number * 100:.2f}%")
            hist1, bin_edges1 = np.histogram(indices[i], bins=bins, density=True)
            hist1 += eps
            hist1 = hist1 / (np.diff(bin_edges1) * hist1.sum())
            hist2, bin_edges2 = np.histogram(indices[j], bins=bins, density=True)
            hist2 += eps
            hist2 = hist2 / (np.diff(bin_edges2) * hist2.sum())

            average_kl_divergence += np.sum(hist1 * np.log(hist1 / hist2))
            if i == j:
                assert np.abs(np.sum(hist1 * np.log(hist1 / hist2)) - 0) < eps

    average_kl_divergence /= total_number
    return average_kl_divergence

In [10]:
# clean-100の1/50のキーをランダムサンプリング
# (nexus6がいくつかのquantizeに失敗しているため、nexus6のキーからサンプリング)
# これらのキーに対してのみKLダイバージェンスを計算する
f_name = f"pickles/nexus6_quantized_indices.pkl"
with open(f_name, "rb") as f:
    matrix_quantized_indices = pickle.load(f)
sampled_keys = np.random.choice(list(matrix_quantized_indices.keys()), size=len(matrix_quantized_indices) // 100, replace=False)

In [11]:
# 同一のマイク内での平均KL距離
mic_names = ["matrix", "nexus6", "pseye", "respeaker", "shure", "usb"]
mic_kl_divergences = {}
for mic_name in mic_names:
    print(f"mic: {mic_name}")
    f_name = f"pickles/{mic_name}_quantized_indices.pkl"
    with open(f_name, "rb") as f:
        quantized_indices = pickle.load(f)

    selected_quantized_indices = []
    for key in sampled_keys:
        selected_quantized_indices.append(quantized_indices[key])
    kl_divergence = calculate_average_kl_divergence(selected_quantized_indices, max_index)
    mic_kl_divergences[mic_name] = kl_divergence
    print(f"kl_divergence: {kl_divergence}")

# 全体の平均KL距離
print(f"average kl_divergence: {np.mean(list(mic_kl_divergences.values()))}")


mic: matrix
progress: 10000, 18.58%
progress: 20000, 37.16%
progress: 30000, 55.74%
progress: 40000, 74.32%
progress: 50000, 92.90%
kl_divergence: 1.4124794345111977
mic: nexus6
progress: 10000, 18.58%
progress: 20000, 37.16%
progress: 30000, 55.74%
progress: 40000, 74.32%
progress: 50000, 92.90%
kl_divergence: 0.9872670687122076
mic: pseye
progress: 10000, 18.58%
progress: 20000, 37.16%
progress: 30000, 55.74%
progress: 40000, 74.32%
progress: 50000, 92.90%
kl_divergence: 1.3830109286163514
mic: respeaker
progress: 10000, 18.58%
progress: 20000, 37.16%
progress: 30000, 55.74%
progress: 40000, 74.32%
progress: 50000, 92.90%
kl_divergence: 1.177017115959084
mic: shure
progress: 10000, 18.58%
progress: 20000, 37.16%
progress: 30000, 55.74%
progress: 40000, 74.32%
progress: 50000, 92.90%
kl_divergence: 1.1112444849123906
mic: usb
progress: 10000, 18.58%
progress: 20000, 37.16%
progress: 30000, 55.74%
progress: 40000, 74.32%
progress: 50000, 92.90%
kl_divergence: 1.2632468073121972
average

In [12]:
# 同一の発話内での平均KL距離
utterance_kl_divergences = {}
for idx, utterance_key in enumerate(sampled_keys):
    print(f"utterance_key: {utterance_key}")
    print(f"progress: {idx}, {idx / len(sampled_keys) * 100:.2f}%")
    selected_quantized_indices = []
    for mic_name in mic_names:
        f_name = f"pickles/{mic_name}_quantized_indices.pkl"
        with open(f_name, "rb") as f:
            quantized_indices = pickle.load(f)
        selected_quantized_indices.append(quantized_indices[utterance_key])
    kl_divergence = calculate_average_kl_divergence(selected_quantized_indices, max_index)
    utterance_kl_divergences[utterance_key] = kl_divergence
    print(f"kl_divergence: {kl_divergence}")

# 全体の平均KL距離
print(f"average kl_divergence: {np.mean(list(utterance_kl_divergences.values()))}")

utterance_key: 4014-186179-0009.wav
progress: 0, 0.00%
kl_divergence: 0.24982588351576154
utterance_key: 3982-182255-0039.wav
progress: 1, 0.43%
kl_divergence: 0.5920018075510902
utterance_key: 4051-11218-0007.wav
progress: 2, 0.86%
kl_divergence: 0.3423940646548176
utterance_key: 3879-174923-0022.wav
progress: 3, 1.29%
kl_divergence: 0.3104055583548819
utterance_key: 6836-76549-0014.wav
progress: 4, 1.72%
kl_divergence: 0.4934494701185529
utterance_key: 5688-41232-0034.wav
progress: 5, 2.16%
kl_divergence: 0.4209460848989326
utterance_key: 4441-76250-0022.wav
progress: 6, 2.59%
kl_divergence: 0.46592952371861124
utterance_key: 7447-91187-0020.wav
progress: 7, 3.02%
kl_divergence: 0.6523703903022311
utterance_key: 89-219-0041.wav
progress: 8, 3.45%
kl_divergence: 0.3554165077674751
utterance_key: 2289-152254-0002.wav
progress: 9, 3.88%
kl_divergence: 0.3558718505124323
utterance_key: 3214-167606-0010.wav
progress: 10, 4.31%
kl_divergence: 0.24218741239991365
utterance_key: 441-130108-0

In [31]:
from transformers import BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained(
    "bert-base-uncased",
    cache_dir="/home/shibutani/fs/.cache/huggingface/transformers")
bert_model = BertModel.from_pretrained(
    "bert-base-uncased",
    cache_dir="/home/shibutani/fs/.cache/huggingface/transformers"
    ).to(DEVICE)

t2 = "You're a graduate student at the University of Tokyo."
#t3 = "I'm interested in speech recongnition using neural network and machine learning and natural language processing."
t3 = "You're a graduate student at the University of Tokyo."
sentences = [t2, t3]
encoded_input = tokenizer(sentences, padding=True, return_tensors="pt")
input_ids = encoded_input["input_ids"].to(DEVICE)
attention_mask = encoded_input["attention_mask"].to(DEVICE)
with torch.no_grad():
    outputs = bert_model(input_ids, attention_mask=attention_mask)
last_hidden_states = outputs[0]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [33]:
# 各文章内で分散表現の和を計算 (文章の長さで正規化)
sentence_embed_vecs = (last_hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.unsqueeze(-1).sum(dim=1)
cos_similarity = torch.nn.functional.cosine_similarity(
                sentence_embed_vecs[0], sentence_embed_vecs[1], dim=0)
cos_similarity.item()

1.0