## Note
- size: 特徴量
- length: 時系列

In [78]:
from torch.utils.data import Dataset

import torch
from torch import nn, optim
import pandas as pd
import torchaudio
import librosa
import numpy as np
import math

from torchmetrics.functional import char_error_rate, word_error_rate

import glob
import os
import re
import copy


In [79]:
tkwargs_int = {
    "dtype": torch.int32,
    "device": "cuda",
}
tkwargs_float = {
    "dtype": torch.float32,
    "device": "cuda",
}

In [80]:
from torch.utils.data import Dataset
from torchaudio.functional import resample
from transformers import Wav2Vec2CTCTokenizer
from datasets import load_dataset
from datasets import Audio
from typing import List
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import torch

import os
import pickle
import json

class TIMITDatasetWav(Dataset):
    def __init__(self, dataset_pkl_path: str, vocab_file_path: str, resample_rate: int = 16000):
        # Paremeter

        self.resample_rate = resample_rate

        if not os.path.isfile(dataset_pkl_path):
            dataset = load_dataset("../datasets/loading_scripts/timit.py", data_dir="../datasets/TIMIT/")
            dataset = dataset.remove_columns(["id"])

            self.extract_vocab(train_all_texts=dataset["train"]["text"], vocab_file_path=vocab_file_path)
            
            self.tokenizer = Wav2Vec2CTCTokenizer(
                vocab_file_path, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|"
            )
            def prepare_dataset(batch):
                audio = batch["audio"]
                batch["input_values"] = resample(audio, orig_freq=audio["sampling_rate"], new_freq=self.resample_rate)
                batch["input_length"] = len(batch["input_values"])
                batch["labels"] = self.tokenizer(batch["text"]).input_ids
                return batch

            dataset = dataset.map(
                prepare_dataset, remove_columns=dataset.column_names["train"], num_proc=4
            )
            with open(dataset_pkl_path, "wb") as f:
                pickle.dump(dataset, f)

        with open(dataset_pkl_path, "rb") as f:
            self.dataset = pickle.load(f)

        with open(vocab_file_path, "r") as vocab_file:
            vocab = json.load(vocab_file)
        self.vocab = vocab
        self.pad_token_id = vocab["[PAD]"]
        self.unk_token_id = vocab["[UNK]"]
        self.ctc_token_id = vocab["_"]
        

    def __len__(self):
        return len(self.dataset["train"])
    
    def __getitem__(self, idx):

        return  idx, torch.tensor(self.dataset["train"][idx]["input_values"]["array"]), torch.tensor(self.dataset["train"][idx]["labels"]) 
    
    def collate_fn(self, batch):
        idxs, wavs, text_idxs = zip(*batch)
        original_wav_lens = torch.tensor(np.array([len(wav) for wav in wavs]))
        original_text_idx_lens = torch.tensor(np.array([len(text_idx) for text_idx in text_idxs]))
        # padding for spectrogram_db
        padded_wavs = []
        for wav in wavs:
            padded_wav = np.pad(wav, ((0, max(original_wav_lens)-wav.shape[0])), "constant", constant_values=0)
            padded_wavs.append(padded_wav)
        
        padded_wavs = torch.tensor(np.array(padded_wavs))

        # padding and packing for text_idx
        padded_text_idxs = pad_sequence(text_idxs, batch_first=True, padding_value=self.pad_token_id)

        return idxs, padded_wavs, padded_text_idxs, original_wav_lens, original_text_idx_lens

    def extract_vocab(
        self,
        train_all_texts: List = None, 
        test_all_texts : List = None,
        vocab_file_path: str = "./vocab.json",
        ) -> None:
        if train_all_texts is None:
            train_all_texts = []
        if test_all_texts is None:
            test_all_texts = []

        all_text = " ".join(train_all_texts + test_all_texts)
        vocab_list = list(set(all_text))

        vocab = {v: k for k, v in enumerate(vocab_list)}
        # use | as delimeter in stead of " "
        vocab["|"] = vocab[" "]
        # dekete unused char
        del vocab[" "]
        # add unk and pad token
        vocab["[UNK]"] = len(vocab)
        vocab["[PAD]"] = len(vocab)
        vocab["_"] = len(vocab)

        with open(vocab_file_path, "w") as vocab_file:
            json.dump(vocab, vocab_file)



            

In [81]:
from torchaudio import pipelines
bundle = pipelines.WAV2VEC2_BASE

model_sample_rate = bundle.sample_rate
wav_dir_path = "../datasets/waves_yesno/"
dataset = TIMITDatasetWav(dataset_pkl_path="./timit_dataset.pkl", vocab_file_path="./timit_vocab.json", resample_rate=16000)

In [82]:
from torch.utils.data import random_split, DataLoader
# 学習データとテストデータに分割
## 合計サイズが元のサイズと同一になるように注意

BATCH_SIZE = 8
train_dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    # 不完全なバッチの無視
    drop_last=True,
    # 高速化?
    pin_memory=True,
    collate_fn=dataset.collate_fn
)


In [83]:
import sys
sys.path.append("..")
from modules.preprocessing.subsampling import Conv2DSubSampling
from modules.transformers.encoder import TransformerEncoder
from torch import nn

In [84]:
class Model(nn.Module):
    def __init__(self, nlabel):
        super(Model, self).__init__()
        self.in_size = bundle._params["encoder_embed_dim"]
        self.nlabel = nlabel
        self.wav2vec_encoder = bundle.get_model()
        self.fc = nn.Linear(self.in_size, self.nlabel, bias=True)
        self.log_softmax = nn.functional.log_softmax
    
    def forward(self, x, x_lengths):
        # args:
        #   x: [B, T]
        #   x_lengths: [B]
        #       padding前のシーケンス長
        # return:
        #   log_prob: [B, T, nlabel]
        #   y_lengths: [B]
        #       非パディング部分のシーケンス長
        encoded, y_lengths = self.wav2vec_encoder.extract_features(x, x_lengths) # encoded: [L, B, T, in_size]

        y = self.fc(encoded[-1]) # [B, T', nlabel]
        
        log_probs = self.log_softmax(y, dim=2) # [B, T', nlabel]
        return log_probs, y_lengths
        

In [85]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"This learning will be running on {device}.")

num_labels = len(dataset.vocab)
num_epochs = 10

This learning will be running on cuda.


これ以降、各モデルごとに実験用のコードを記述していきます。

In [86]:
def ctc_simple_decode(hypotheses_idxs, vocab):
    # hypothesis_idxs: tensor(batch, time)
    # labels: np.array(num_labels)
    index_to_vocab = {v: k for k, v in vocab.items()}
    hypotheses_idxs = hypotheses_idxs.cpu().numpy()
    hypotheses = []
    padding_idx = vocab["[PAD]"]
    blank_idx = vocab["_"]
    separator_idx = vocab["|"]
    
    for hypothesis_idxs in hypotheses_idxs:
        hypothesis = []
        prev_idx = -1
        for idx in hypothesis_idxs:
            if idx == blank_idx:
                continue
            elif idx == prev_idx:
                continue
            elif idx == padding_idx:
                continue
            elif idx == separator_idx:
                hypothesis.append(" ")
                prev_idx = idx
            else:
                hypothesis.append(index_to_vocab[idx])
                prev_idx = idx
        hypotheses.append("".join(hypothesis))
    return hypotheses

In [87]:
from torch.optim.lr_scheduler import _LRScheduler
class TransformerLR(_LRScheduler):
    """TransformerLR class for adjustment of learning rate.

    The scheduling is based on the method proposed in 'Attention is All You Need'.
    """

    def __init__(self, optimizer, warmup_epochs=1000, last_epoch=-1, verbose=False):
        """Initialize class."""
        self.warmup_epochs = warmup_epochs
        self.normalize = self.warmup_epochs**0.5
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        """Return adjusted learning rate."""
        step = self.last_epoch + 1
        scale = self.normalize * min(step**-0.5, step * self.warmup_epochs**-1.5)
        return [base_lr * scale for base_lr in self.base_lrs]

In [72]:
import time
for run in range(10):
    print(f"{run} th run")
    model = Model(num_labels).to(device)

    ctc_loss = nn.CTCLoss(reduction="sum", blank=dataset.ctc_token_id)
    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.005)
    scheduler = TransformerLR(optimizer, warmup_epochs=1000)

    print("model initialized")

    for i in range(num_epochs):
        print(f"{i} th epoch")
        t0 = time.time()
        model.train()
        epoch_loss = 0
        cnt = 0
        for _, (idxs, padded_wavs, padded_text_idxs, original_wav_lens, original_text_idx_lens) in enumerate(train_dataloader):
            cnt += 1
            padded_wavs = padded_wavs.to(device)
            original_wav_lens = original_wav_lens.to(device)
            padded_text_idxs = padded_text_idxs.to(device)
            original_text_idx_lens = original_text_idx_lens.to(device)
            
            optimizer.zero_grad()
            
            log_probs, y_lengths  = model(x=padded_wavs, x_lengths=original_wav_lens)

            loss = ctc_loss(log_probs.transpose(1, 0), padded_text_idxs, y_lengths, original_text_idx_lens)
            loss.backward()
            optimizer.step()
            # lossはバッチ内平均ロス
            epoch_loss += (loss.item() / BATCH_SIZE)
        scheduler.step()
        # バッチ内平均ロスの和をイテレーション数で割ることで、一つのデータあたりの平均ロスを求める

        model.eval()
        input_to_wer = {}
        # idxが固定されるか確認する用
        input_to_teacher = {}
        with torch.no_grad():
            cnt = 0
            total_cer = 0
            total_wer = 0
            for _, (idxs, padded_wavs, padded_text_idxs, original_wav_lens, original_text_idx_lens) in enumerate(train_dataloader):
                cnt += 1
                padded_wavs = padded_wavs.to(device)
                original_wav_lens = original_wav_lens.to(device)
                padded_text_idxs = padded_text_idxs.to(device)
                original_text_idx_lens = original_text_idx_lens.to(device)
                
                log_probs, y_lengths  = model(x=padded_wavs, x_lengths=original_wav_lens)
                # for CER calculation
                hypotheses_idxs = log_probs.argmax(dim=2) 
                hypotheses = ctc_simple_decode(hypotheses_idxs, dataset.vocab)
                teachers = ctc_simple_decode(padded_text_idxs, dataset.vocab)
                total_cer += char_error_rate(hypotheses, teachers)
                batch_wer = 0
                for idx, hypothesis, teacher in zip(idxs, hypotheses, teachers):
                    input_to_wer[idx] = word_error_rate(hypothesis, teacher)
                    input_to_teacher[idx] = teacher
                    batch_wer += input_to_wer[idx]
                total_wer += batch_wer / len(idxs)


        t1 = time.time()
        print(f"{i} epoch: {epoch_loss / cnt} loss,  CER: {total_cer / cnt}, WER: {total_wer / cnt}, {t1 - t0} sec")

        checkpoint = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "random": random.getstate(),
            "np_random": np.random.get_state(), # numpy.randomを使用する場合は必要
            "torch": torch.get_rng_state(),
            "torch_random": torch.random.get_rng_state(),
            "cuda_random": torch.cuda.get_rng_state(), # gpuを使用する場合は必要
            "cuda_random_all": torch.cuda.get_rng_state_all(), # 複数gpuを使用する場合は必要
            "input_to_wer": input_to_wer,
            "input_to_teacher": input_to_teacher,
        }
        if (i + 1) % 10 == 0:
            torch.save(checkpoint, f"cpts/timit_finetune_checkpoint_{run}_{i}.pt")

model initialized
0 th epoch
0 epoch: 491.0145320793165 loss,  CER: 0.9462673664093018, WER: 1.0, 154.84885001182556 sec
1 th epoch
1 epoch: 347.0334681307629 loss,  CER: 1.0, WER: 1.0, 155.31140637397766 sec
2 th epoch
2 epoch: 205.50808662930228 loss,  CER: 1.0, WER: 1.0, 154.17100501060486 sec


In [None]:
with torch.no_grad():
    total_cer = 0
    cnt = 0
    for _, (idxs, padded_spectrogram_dbs,padded_text_idxs, original_spectrofram_db_lens, original_text_idx_lens) in enumerate(train_dataloader):
        padded_wavs = padded_wavs.to(device)
        original_wav_lens = original_wav_lens.to(device)
        padded_text_idxs = padded_text_idxs.to(device)
        original_text_idx_lens = original_text_idx_lens.to(device)
        
        log_probs, y_lengths  = model(x=padded_wavs, x_lengths=original_wav_lens)

        hypotheses_idxs = log_probs.argmax(dim=2)
        hypotheses = ctc_simple_decode(hypotheses_idxs, dataset.vocab)
        teachers = ctc_simple_decode(padded_text_idxs, dataset.vocab)
        for hypothesis, teacher in zip(hypotheses, teachers):
            print(f"hyp: {hypothesis}")
            print(f"tea: {teacher}")
        total_cer += char_error_rate(hypotheses, teachers)
        cnt += 1
    print(f"CER: {total_cer / cnt}")

hyp: cl-ohFeo-ozV,-Q-ecljtc
tea: Death reminds man of his sin, but it reminds him also of his transience.
hyp: cImRi-jVjcYmVcN-t-c-ctc
tea: So somebody else knew what would hapen to her father's money if she died.
hyp: -bNF-c-G-GKeF-lFicF-iRbRc"c
tea: When you're les fatigued, things just naturaly lok brighter.
hyp: c-xo-G-bG-p-c-z-c-
tea: Within a month, cals were up seventy per cent.
hyp: i -o
tea: Move the garbage nearer to the large window.
hyp: bcl-m-ioc-c-om-RkoV-c
tea: She had your dark suit in greasy wash water al year.
hyp: lczGF-G-z-R-cR-VR-RlhzhE"
tea: If you do, go to it.
hyp: -o-oto I
tea: Fil that canten with fresh spring water.
hyp: cl-ohFeo-ozV,-Q-ecljtc
tea: She had your dark suit in greasy wash water al year.
hyp: cImRi-jVjcYmVcN-t-c-ctc
tea: Coperation along with understanding aleviate dispute.
hyp: -bNF-c-G-GKeF-lFicF-iRbRc"c
tea: Every cab neds repainting often.
hyp: c-xo-G-bG-p-c-z-c-
tea: The dimensions of these waves dwarf al our usual standards of measurement.


KeyboardInterrupt: 