In [44]:
from data import TIMITDatasetMelSpecdb, get_dataloader
import torch
import torchaudio
from model import Model
from torchmetrics.functional import char_error_rate, word_error_rate
from modules.decoders.ctc import greedy_decoder
from quantizer import Quantizer
from typing import Dict
import pickle

In [45]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_dataset = TIMITDatasetMelSpecdb("test")
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False, drop_last=False, collate_fn=test_dataset.collate_fn)

Using custom data configuration default-7dc5a6ddcdc99305
Found cached dataset timit (/home/shibutani/fs/.cache/huggingface/datasets/timit/default-7dc5a6ddcdc99305/0.0.0/e393649805e8c068eb5c3311baf236f53ffa81289ecc57e285c6e06a31f00ba8)
100%|██████████| 2/2 [00:00<00:00, 455.01it/s]


In [46]:
with open("quantized_indices_memory.pkl", "rb") as f:
    quantized_indices_memory = pickle.load(f)
len(quantized_indices_memory)

4620

In [47]:
#quantizer = Quantizer(DEVICE)
#quantized_indices_memory = {}
#for idx in range(len(train_dataset)):
#    # show progress
#    if idx % 100 == 0:
#        print(f"{idx / len(train_dataset) * 100:.2f}%")
#    audio = train_dataset[idx][-1].unsqueeze(0).to(DEVICE)
#    quantized_indices = quantizer.quantize(audio)
#    quantized_indices_memory[idx] = quantized_indices[0].tolist()
#with open("quantized_indices_memory.pkl", "wb") as f:
#    pickle.dump(quantized_indices_memory, f)

In [48]:
def calculate_kl_divergence_between_uniform_and_empirical_distribution(sampled_quantized_idx_count: torch.Tensor):
    assert sampled_quantized_idx_count.ndim == 1
    # update sampled_quantized_idx_count
    # calculate KL divergence
    empirical_distribution = sampled_quantized_idx_count / sampled_quantized_idx_count.sum()
    uniform_distribution = torch.ones_like(empirical_distribution) / len(empirical_distribution)
    kl_divergence = torch.sum(empirical_distribution * torch.log(empirical_distribution / uniform_distribution))
    if kl_divergence < 0:
        raise ValueError("KL divergence must be positive.")
    return kl_divergence

In [50]:
sampled_quantized_idx_count = torch.randint(0, 200, (320*320,), dtype=torch.float32)
sampled_quantized_idx_count += 1e-8
calculate_kl_divergence_between_uniform_and_empirical_distribution(sampled_quantized_idx_count)

tensor(0.1952)

In [51]:
class KLBasedSampler:
    def __init__(self, quantized_indices_memory: Dict, dataset: torch.utils.data.Dataset, ratio: float, device: torch.device):
        self.device = device
        self.dataset = dataset
        self.ratio = ratio
        self.target_num_samples = int(len(dataset) * ratio)
        self.initial_num_samples =int(0.1 * self.target_num_samples)
        self.sampled_indices = set()
        self.not_sampled_indices = set(range(len(dataset)))
        self.sampled_quantized_idx_count = torch.zeros(320 * 320, dtype=torch.float32)
        self.quantized_indices_memory = quantized_indices_memory
        # sample initial samples
        self.sample_initial_samples()
    
    def calculate_kl_divergence_between_uniform_and_empirical_distribution(self, sample):
        """Calculate KL divergence between uniform and empirical distribution.
        Args:
            sample (1D torch.Tensor): quantized indices tensor
        Returns:
            kl_divergence (float): KL divergence between uniform and empirical distribution.
        """
        assert sample.ndim == 1
        # update sampled_quantized_idx_count
        sampled_quantized_idx_count = self.sampled_quantized_idx_count.clone()
        with torch.no_grad():
            for quantized_idx in sample:
                sampled_quantized_idx_count[quantized_idx] += 1
        # calculate KL divergence
        # avoid zero division
        sampled_quantized_idx_count += 1e-8
        empirical_distribution = sampled_quantized_idx_count / sampled_quantized_idx_count.sum()
        uniform_distribution = torch.ones_like(empirical_distribution) / len(empirical_distribution)
        kl_divergence = torch.sum(empirical_distribution * torch.log(empirical_distribution / uniform_distribution))
        if kl_divergence < 0:
            raise ValueError("KL divergence must be positive.")
        return kl_divergence

    def sample_initial_samples(self):
        # sample initial samples based on random sampling
        initial_indices = torch.randperm(len(self.dataset))[:self.initial_num_samples].tolist()
        self.sampled_indices.update(initial_indices)
        self.not_sampled_indices.difference_update(initial_indices)
        # update sampled_quantized_idx_count
        with torch.no_grad():
            for idx in initial_indices:
                quantized_indices = torch.tensor(self.quantized_indices_memory[idx])
                for quantized_idx in quantized_indices:
                    self.sampled_quantized_idx_count[quantized_idx] += 1

    def sample(self):
        # sample new samples based on KL divergence until the number of samples reaches the target number of samples
        # return subset of dataset
        while len(self.sampled_indices) < self.target_num_samples:
            # show progress
            print(f"sampled {len(self.sampled_indices) / self.target_num_samples * 100:.2f} %")
            # calculate KL divergence for each not sampled indices
            kl_divergences = {}
            not_sampled_indices = list(self.not_sampled_indices)
            for idx in not_sampled_indices:
                quantized_indices = torch.tensor(self.quantized_indices_memory[idx])
                kl_divergence = self.calculate_kl_divergence_between_uniform_and_empirical_distribution(quantized_indices)
                kl_divergences[idx] = kl_divergence
            # select the index with the minimum KL divergence
            min_kl_divergence_idx = min(kl_divergences.keys(), key=kl_divergences.get)
            print(f"min_kl_divergence: {kl_divergences[min_kl_divergence_idx]}")
            # update sampled_indices and not_sampled_indices
            self.sampled_indices.add(min_kl_divergence_idx)
            self.not_sampled_indices.remove(min_kl_divergence_idx)
            # update sampled_quantized_idx_count
            quantized_indices = self.quantized_indices_memory[min_kl_divergence_idx]
            for quantized_idx in quantized_indices:
                self.sampled_quantized_idx_count[quantized_idx] += 1

        # return subset of self.dataset
        return torch.utils.data.Subset(self.dataset, list(self.sampled_indices))


In [52]:
train_dataset = TIMITDatasetMelSpecdb("train")

Using custom data configuration default-7dc5a6ddcdc99305
Found cached dataset timit (/home/shibutani/fs/.cache/huggingface/datasets/timit/default-7dc5a6ddcdc99305/0.0.0/e393649805e8c068eb5c3311baf236f53ffa81289ecc57e285c6e06a31f00ba8)
100%|██████████| 2/2 [00:00<00:00, 495.46it/s]


In [53]:
kl_based_sampler = KLBasedSampler(quantized_indices_memory=quantized_indices_memory, dataset=train_dataset, ratio=0.1, device=DEVICE)
subset = kl_based_sampler.sample()


sampled 9.96 %
min_kl_divergence: 6.10467004776001
sampled 10.17 %
min_kl_divergence: 6.0709309577941895
sampled 10.39 %
min_kl_divergence: 6.040905475616455
sampled 10.61 %
min_kl_divergence: 6.013134956359863
sampled 10.82 %


KeyboardInterrupt: 

In [None]:
import pickle
# save subset
with open("subset.pkl", "wb") as f:
    pickle.dump(subset, f)

In [3]:
def random_sampler(dataset: torch.utils.data.Dataset, ratio: float):
    """Randomly sample a subset of the dataset"""
    num_samples = int(len(dataset) * ratio)
    indices = torch.randperm(len(dataset))[:num_samples].tolist()
    return torch.utils.data.Subset(dataset, indices)

In [4]:
random_train_subset = random_sampler(train_dataset, 0.8)
random_train_dataloader = get_dataloader(random_train_subset, batch_size=32, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn)

In [5]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"This Learning is running on {DEVICE}")

This Learning is running on cuda


In [6]:
input_feature_size = 80
self_attn_dim = 256
feed_forward_dim = 1024
token_size = len(train_dataset.vocab.keys())
num_epochs = 40
total_size = len(train_dataset)

In [7]:
from modules.preprocessing.subsampling import Conv2DSubSampling
from modules.transformers.encoder import TransformerEncoder
from modules.transformers.scheduler import TransformerLR
from torch import nn


class Model(nn.Module):
    def __init__(self, nlabel):
        super(Model, self).__init__()
        self.nlabel = nlabel
        # out_sizeがSelf-Attentionの入力次元
        self.conv2d_sub_sampling = Conv2DSubSampling(
            in_size=80, out_size=256, kernel_size1=3, kernel_size2=3, stride1=2, stride2=1)
        # in_hidden_sizeがFFの次元
        self.transformer_encoder = TransformerEncoder(
            in_size=256, nlayer=12, nhead=4, in_hidden_size=1024, dropout=0.1, norm_first=True)
        self.fc = nn.Linear(256, nlabel, bias=True)
        self.log_softmax = nn.functional.log_softmax

    def forward(self, x, x_lengths):
        # args:
        #   x: [B, T, in_size]
        #   x_lengths: [B]
        #       padding前のシーケンス長
        # return:
        #   log_prob: [B, T, nlabel]
        #   y_lengths: [B]
        #       非パディング部分のシーケンス長
        subsampled_x, subsampled_x_length = self.conv2d_sub_sampling(x, x_lengths)
        encoded, encoded_inner = self.transformer_encoder(
            subsampled_x, subsampled_x_length)  # [B, T', subsampled_in_size]
        y = self.fc(encoded)  # [B, T', nlabel]
        y_lengths = subsampled_x_length
        log_probs = self.log_softmax(y, dim=-1)  # [B, T', nlabel]
        return log_probs, y_lengths

In [8]:

asr_model = Model(nlabel=token_size).to(DEVICE)
ctc_loss = torch.nn.CTCLoss(reduction="sum", blank=train_dataset.ctc_token_id)
optimizer = torch.optim.Adam(asr_model.parameters(), lr=0.01, betas=(0.9, 0.98), eps=1e-9)
scheduler = TransformerLR(optimizer, d_model=256, warmup_steps=8000) # Warmup終了時点でおよそ0.0017になっている

for epoch in range(num_epochs):
    asr_model.train()
    train_epoch_loss = 0
    train_epoch_cer = 0
    train_epoch_wer = 0
    train_cnt = 0

    for i, (bidx, bx, bx_len, by, by_len, texts, phonetic_details) in enumerate(random_train_dataloader):
        bx = bx.to(DEVICE)
        bx_len = bx_len.to(DEVICE)
        by = by.to(DEVICE)
        by_len = by_len.to(DEVICE)
        optimizer.zero_grad()
        log_probs, y_lengths = asr_model(bx, bx_len)
        loss = ctc_loss(log_probs.transpose(1, 0), by, y_lengths, by_len)
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_epoch_loss += (loss.item() / bx.size(0))

        # calculate CER
        hypothesis = torch.argmax(log_probs, dim=-1)
        hypotheses = greedy_decoder(hypothesis, train_dataset.vocab, "[PAD]", "|", "_")
        answers = greedy_decoder(by, train_dataset.vocab, "[PAD]", "|", "_")
        train_epoch_cer += char_error_rate(hypotheses, answers)
        train_epoch_wer += word_error_rate(hypotheses, answers)

        train_cnt += 1

    print(f"Epoch {epoch + 1} of {num_epochs} train loss: {train_epoch_loss / train_cnt}, train CER: {train_epoch_cer / train_cnt}, train WER: {train_epoch_wer / train_cnt}")

    asr_model.eval()
    test_epoch_loss = 0
    test_epoch_cer = 0
    test_epoch_wer = 0
    test_cnt = 0
    with torch.no_grad():
        for i, (bidx, bx, bx_len, by, by_len, texts, phonetic_details) in enumerate(test_dataloader):
            bx = bx.to(DEVICE)
            bx_len = bx_len.to(DEVICE)
            by = by.to(DEVICE)
            by_len = by_len.to(DEVICE)
            log_probs, y_lengths = asr_model(bx, bx_len)
            loss = ctc_loss(log_probs.transpose(1, 0), by, y_lengths, by_len)
            test_epoch_loss += (loss.item() / bx.size(0))

            # calculate CER
            hypothesis = torch.argmax(log_probs, dim=-1)
            hypotheses = greedy_decoder(hypothesis, train_dataset.vocab, "[PAD]", "|", "_")
            answers = greedy_decoder(by, train_dataset.vocab, "[PAD]", "|", "_")
            test_epoch_cer += char_error_rate(hypotheses, answers)
            test_epoch_wer += word_error_rate(hypotheses, answers)

            test_cnt += 1
        
    print(f"Epoch {epoch + 1} of {num_epochs} test loss: {test_epoch_loss / test_cnt}, test CER: {test_epoch_cer / test_cnt}, test WER: {test_epoch_wer / test_cnt}")


Epoch 1 of 40 train loss: 207.77212298849355, train CER: 0.9964577555656433, train WER: 1.0
Epoch 1 of 40 test loss: 156.152664472472, test CER: 1.0, test WER: 1.0
Epoch 2 of 40 train loss: 153.84124742590862, train CER: 1.0, train WER: 1.0
Epoch 2 of 40 test loss: 151.7391800790463, test CER: 1.0, test WER: 1.0
Epoch 3 of 40 train loss: 150.87309464164403, train CER: 1.0, train WER: 1.0
Epoch 3 of 40 test loss: 149.18733704764887, test CER: 1.0, test WER: 1.0
Epoch 4 of 40 train loss: 143.1272523962933, train CER: 0.9989821314811707, train WER: 1.0
Epoch 4 of 40 test loss: 133.57229225590544, test CER: 0.980032205581665, test WER: 1.0
Epoch 5 of 40 train loss: 124.56327003810716, train CER: 0.8546525835990906, train WER: 1.0004898309707642
Epoch 5 of 40 test loss: 114.21900853570902, test CER: 0.7276081442832947, test WER: 0.9961405396461487
Epoch 6 of 40 train loss: 107.17956980829653, train CER: 0.676473081111908, train WER: 0.992084264755249
Epoch 6 of 40 test loss: 98.615533648796

KeyboardInterrupt: 