In [1]:
import nltk
from data import TIMITDataset, LibriSpeechDataset
from model import MyWav2Vec2ConformerForPreTraining
import torch
from torch.utils.data import DataLoader
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
from utils import indices2indices

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
libri_dataset = LibriSpeechDataset(split="dev")
train_dataloader = DataLoader(
    libri_dataset,
    batch_size=2,
    shuffle=True,
    num_workers=8,
    # 不完全なバッチの無視
    drop_last=True,
    # 高速化?
    pin_memory=True,
    collate_fn=libri_dataset.collate_fn
)

In [4]:
model = MyWav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large").to(DEVICE)

In [5]:
@torch.no_grad()
def quantize(bx: torch.Tensor, model: MyWav2Vec2ConformerForPreTraining, num_groups: int, num_codevectors_per_group: int):
    # output length from conv layer
    sequence_length = model._get_feat_extract_output_lengths(bx.shape[1]).item()
    mask_time_indices = _compute_mask_indices(
        shape=(2, sequence_length), mask_prob=0.2, mask_length=2
    )
    sampled_negative_indices = _sample_negative_indices(
        features_shape=(2, sequence_length),
        num_negatives=model.config.num_negatives,
        mask_time_indices=mask_time_indices,
    )
    mask_time_indices = torch.tensor(mask_time_indices, device=bx.device, dtype=torch.long)
    sampled_negative_indices = torch.tensor(
        data=sampled_negative_indices, device=bx.device, dtype=torch.long
    )
    model.eval()
    outputs = model(
        bx, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
    )
    quantized_indices = indices2indices(outputs.quntized_indices, num_codevectors_per_group)
    quantized_indices = quantized_indices.to("cpu").numpy().reshape(quantized_indices.shape[0], quantized_indices.shape[1] * quantized_indices.shape[2]).astype(str)
    return quantized_indices

In [6]:
from nltk.lm import KneserNeyInterpolated
from nltk.lm.preprocessing import pad_both_ends
from nltk.util import trigrams, everygrams
from nltk.lm import Vocabulary
from nltk.lm.counter import NgramCounter
ngram_order = 3

In [7]:
from collections import Counter
counts = Counter()
ngram_text = []
G = model.config.num_codevector_groups
V = model.config.num_codevectors_per_group
for i, (bidx, bx, bx_len, by, by_len) in enumerate(train_dataloader):
    print(f"{i} th batch")
    bx = bx.to(DEVICE)
    bx_len = bx_len.to(DEVICE)
    by = by.to(DEVICE)
    by_len = by_len.to(DEVICE)
    quantized_indices = quantize(bx, model, G, V)

    for quantized_index in quantized_indices:
        words = list(pad_both_ends(quantized_index, n=ngram_order))
        counts.update(words)
        ngrams = everygrams(words, max_len=ngram_order)
        ngram_text.append(ngrams)

vocab = Vocabulary(counts=counts, unk_cutoff=1)
counter = NgramCounter(ngram_text)
lm = KneserNeyInterpolated(ngram_order, vocabulary=vocab, counter=counter)

0 th batch
1 th batch
2 th batch
3 th batch
4 th batch
5 th batch
6 th batch
7 th batch
8 th batch
9 th batch
10 th batch
11 th batch
12 th batch
13 th batch
14 th batch
15 th batch
16 th batch
17 th batch
18 th batch
19 th batch
20 th batch
21 th batch
22 th batch
23 th batch
24 th batch
25 th batch
26 th batch
27 th batch
28 th batch
29 th batch
30 th batch
31 th batch
32 th batch
33 th batch
34 th batch
35 th batch
36 th batch
37 th batch
38 th batch
39 th batch
40 th batch
41 th batch
42 th batch
43 th batch
44 th batch
45 th batch
46 th batch
47 th batch
48 th batch
49 th batch
50 th batch
51 th batch
52 th batch
53 th batch
54 th batch
55 th batch
56 th batch
57 th batch
58 th batch
59 th batch
60 th batch
61 th batch
62 th batch
63 th batch
64 th batch
65 th batch
66 th batch
67 th batch
68 th batch
69 th batch
70 th batch
71 th batch
72 th batch
73 th batch
74 th batch
75 th batch
76 th batch
77 th batch
78 th batch
79 th batch
80 th batch
81 th batch
82 th batch
83 th batch
84

In [None]:
import pickle
lm = pickle.load(open("lm.pkl", "rb"))

In [11]:
libri_test_dataset = LibriSpeechDataset(split="test")
test_dataloader = DataLoader(
    libri_test_dataset,
    batch_size=2,
    shuffle=True,
    num_workers=8,
    # 不完全なバッチの無視
    drop_last=True,
    # 高速化?
    pin_memory=True,
    collate_fn=libri_test_dataset.collate_fn
)

In [49]:
G = model.config.num_codevector_groups
V = model.config.num_codevectors_per_group
for i, (bidx, bx, bx_len, by, by_len) in enumerate(test_dataloader):
    print(f"{i} th batch")
    bx = bx.to(DEVICE)
    bx_len = bx_len.to(DEVICE)
    by = by.to(DEVICE)
    by_len = by_len.to(DEVICE)
    quantized_indices = quantize(bx, model, G, V)

    for quantized_index in quantized_indices:
        words = list(pad_both_ends(quantized_index, n=ngram_order))
        ngrams = list(trigrams(words))
        logscore = 0.
        for ngram in ngrams:
            logscore += lm.logscore(ngram[2], ngram[0:2])

        print(logscore)

0 th batch
-2268.60451160972
-1266.4918311164806
1 th batch
-1718.9433447405788
-1799.855163499567
2 th batch
-1440.6129299593842
-721.5751773708581
3 th batch
-2388.9709854632183
-1509.0191162733208
4 th batch
-1089.304381836828
-1084.0131292605913
5 th batch
-967.9734608834524
-1367.468428369365
6 th batch
-2411.995239089369
-1568.079036872662
7 th batch
-744.9439193708306
-2343.2683792952043
8 th batch
-1249.5702999244527
-1413.6641272107245
9 th batch
-2038.7033158759757
-597.062010032942
10 th batch
-1202.9860515568073
-701.698351734117
11 th batch
-3082.872023269759
-2389.516890470288
12 th batch
-2032.018659618461
-1192.803007587298
13 th batch


KeyboardInterrupt: 