## Note
- size: 特徴量
- length: 時系列

In [1]:
from torch.utils.data import Dataset

import torch
from torch import nn, optim
import pandas as pd
import torchaudio
import librosa
import numpy as np
import math

from torchmetrics.functional import char_error_rate

import glob
import os
import re
import copy


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tkwargs_int = {
    "dtype": torch.int32,
    "device": "cuda",
}
tkwargs_float = {
    "dtype": torch.float32,
    "device": "cuda",
}

In [3]:
class YesNoDataset(Dataset):
    def __init__(self, wav_dir_path, model_sample_rate):
        super().__init__()

        dataset = []
        columns = ["path", "text_idx"]
        self.labels = ["y", "e", "s", "n", "o", "<space>", "_"]
        self.label_to_idx = {label: i for i, label in enumerate(self.labels)}
        for wav_file_path in glob.glob(wav_dir_path + "*.wav"):
            file_name = os.path.splitext(os.path.basename(wav_file_path))[0]
            text_idx = []
            for c in file_name:
                if c == "1":
                    text_idx += [self.label_to_idx[ic] for ic in "yes"] 
                elif c == "0":
                    text_idx += [self.label_to_idx[ic] for ic in "no"] 
                elif c == "_":
                    text_idx.append(self.label_to_idx["<space>"])
                else:
                    raise ValueError("Invalid Dir Path")
            dataset.append([wav_file_path, text_idx])
        
        self.dataset = pd.DataFrame(dataset, columns=columns)
        self.model_sample_rate = model_sample_rate
        self.spectrogram_transformer = torchaudio.transforms.MelSpectrogram(
            # スペクトル設定
            sample_rate=self.model_sample_rate,
            n_fft=1024,
            # スペクトログラム設定
            win_length= None,
            hop_length= 512,
            window_fn= torch.hann_window,
            # メルスペクトログラム設定
            n_mels=80,
            power=2.0,
        )
    
    
    def __len__(self):
        return self.dataset.shape[0]
    
    def __getitem__(self, idx):
        wav_file_path = self.dataset.iloc[idx, 0]
        text_idx = self.dataset.iloc[idx, 1]
        wav_data, sample_rate = torchaudio.load(wav_file_path)
        if sample_rate != self.model_sample_rate:
            wav_data = torchaudio.functional.resample(wav_data, sample_rate, self.model_sample_rate)
            sample_rate = self.model_sample_rate
        spectrogram = self.spectrogram_transformer(wav_data)
        spectrogram_db = librosa.amplitude_to_db(spectrogram)

        return spectrogram_db[0].transpose(1,0), torch.tensor(text_idx)

        

In [4]:
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

def collate_fn(batch):
    # spectrogram_db: tensor[Time, Melbins]
    # text_idx: tensor[text_len]
    spectrogram_dbs, text_idxs = zip(*batch)
   
    original_spectrogram_db_lens = torch.tensor(np.array([len(spectrogram_db) for spectrogram_db in spectrogram_dbs]))
    original_text_idx_lens = torch.tensor(np.array([len(text_idx) for text_idx in text_idxs]))

    # padding for spectrogram_db
    padded_spectrogram_dbs = []
    for spectrogram_db in spectrogram_dbs:
        padded_spectrogram_db = np.pad(spectrogram_db, ((0,max(original_spectrogram_db_lens)-spectrogram_db.shape[0]),(0,0)), "constant", constant_values=0)
        padded_spectrogram_dbs.append(padded_spectrogram_db)
    
    padded_spectrogram_dbs = torch.tensor(np.array(padded_spectrogram_dbs))

    # padding and packing for text_idx
    padded_text_idxs = pad_sequence(text_idxs, batch_first=True, padding_value=-1)

    return padded_spectrogram_dbs, padded_text_idxs, original_spectrogram_db_lens, original_text_idx_lens

In [5]:
model_sample_rate = 8000
wav_dir_path = "../../datasets/waves_yesno/"
dataset = YesNoDataset(wav_dir_path, model_sample_rate)

In [6]:
from torch.utils.data import random_split, DataLoader
# 学習データとテストデータに分割
## 合計サイズが元のサイズと同一になるように注意
train_size = int(len(dataset) * 0.7)
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(
    dataset, [train_size, test_size]
)
BATCH_SIZE = 2
train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    # 不完全なバッチの無視
    drop_last=True,
    # 高速化?
    pin_memory=True,
    collate_fn=collate_fn
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    # 不完全なバッチの無視
    drop_last=True,
    # 高速化?
    pin_memory=True,
    collate_fn=collate_fn
)

In [7]:
import sys
from modules.preprocessing.subsampling import Conv2DSubSampling
from modules.transformers.encoder import TransformerEncoder

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"This learning will be running on {device}.")
input_size = 80
subsampled_input_size = 80
num_labels = len(dataset.labels)
num_epochs = 40

This learning will be running on cuda.


これ以降、各モデルごとに実験用のコードを記述していきます。

In [9]:
def ctc_simple_decode(hypotheses_idxs, labels, padding_idx):
    # hypothesis_idxs: tensor(batch, time)
    # labels: np.array(num_labels)

    hypotheses_idxs = hypotheses_idxs.cpu().numpy()
    hypotheses = []
    blank_idx = labels.index("_")
    for hypothesis_idxs in hypotheses_idxs:
        hypothesis = []
        prev_idx = -1
        for idx in hypothesis_idxs:
            if idx == blank_idx:
                continue
            elif idx == prev_idx:
                continue
            elif idx == padding_idx:
                continue
            else:
                hypothesis.append(labels[idx])
                prev_idx = idx
        hypotheses.append("".join(hypothesis))
    return hypotheses

In [None]:
import pickle
with open("./padded_text_idxs.pkl", "wb") as f:
    pickle.dump(padded_text_idxs, f)
with open("./original_text_idx_lens.pkl", "wb") as f:
    pickle.dump(original_text_idx_lens, f)

In [13]:
from torch.utils.tensorboard import SummaryWriter
import time
from model import Model

model = Model(input_size, subsampled_input_size, num_labels).to(device)

ctc_loss = nn.CTCLoss(reduction="sum", blank=dataset.label_to_idx["_"])
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
# Adam

writer = SummaryWriter()

for i in range(num_epochs):
    t0 = time.time()
    model.train()
    epoch_loss = 0
    cnt = 0
    for _, (padded_spectrogram_dbs, padded_text_idxs, original_spectrofram_db_lens, original_text_idx_lens) in enumerate(train_dataloader):
        cnt += 1
        padded_spectrogram_dbs = padded_spectrogram_dbs.to(device)
        original_spectrofram_db_lens = original_spectrofram_db_lens.to(device)
        padded_text_idxs = padded_text_idxs.to(device)
        original_text_idx_lens = original_text_idx_lens.to(device)
        
        optimizer.zero_grad()
        log_probs, y_lengths  = model(x=padded_spectrogram_dbs, x_lengths=original_spectrofram_db_lens)
        memo_log_probs = log_probs
        memo_y_lengths = y_lengths
        memo_original_t_lengths = original_text_idx_lens
        loss = ctc_loss(log_probs.transpose(1, 0), padded_text_idxs, y_lengths, original_text_idx_lens)
        loss.backward()
        optimizer.step()
        # lossはバッチ内平均ロス
        epoch_loss += (loss.item() / BATCH_SIZE)
    # バッチ内平均ロスの和をイテレーション数で割ることで、一つのデータあたりの平均ロスを求める
    writer.add_scalar("Loss/Training", epoch_loss / cnt, i)

    model.eval()
    with torch.no_grad():
        epoch_test_loss = 0
        cnt = 0
        total_cer = 0
        for _, (padded_spectrogram_dbs, padded_text_idxs, original_spectrofram_db_lens, original_text_idx_lens) in enumerate(test_dataloader):
            cnt += 1
            padded_spectrogram_dbs = padded_spectrogram_dbs.to(device)
            original_spectrofram_db_lens = original_spectrofram_db_lens.to(device)
            padded_text_idxs = padded_text_idxs.to(device)
            original_text_idx_lens = original_text_idx_lens.to(device)
            
            log_probs, y_lengths  = model(x=padded_spectrogram_dbs, x_lengths=original_spectrofram_db_lens)
            loss = ctc_loss(log_probs.transpose(1, 0), padded_text_idxs, y_lengths, original_text_idx_lens)
            epoch_test_loss += loss.item()
            # for CER calculation
            hypotheses_idxs = log_probs.argmax(dim=2) 
            hypotheses = ctc_simple_decode(hypotheses_idxs, dataset.labels, -1)
            teachers = ctc_simple_decode(padded_text_idxs, dataset.labels, -1)
            total_cer += char_error_rate(hypotheses, teachers)

    writer.add_scalar("Loss/Test", epoch_test_loss / cnt, i)
    writer.add_scalar("CER/Test", total_cer / cnt, i)
    t1 = time.time()
    print(f"{i} epoch: {epoch_loss / cnt} loss, {epoch_test_loss / cnt} test loss, CER: {total_cer / cnt}, {t1 - t0} sec")

0 epoch: 112.33126237657335 loss, 94.17699940999348 test loss, CER: 1.0, 0.9137444496154785 sec
1 epoch: 110.18859608968098 loss, 94.06585523817274 test loss, CER: 0.8983814716339111, 0.9033684730529785 sec
2 epoch: 110.25487560696072 loss, 93.31545766194661 test loss, CER: 1.0, 0.9266366958618164 sec


KeyboardInterrupt: 

In [None]:
with torch.no_grad():
    total_cer = 0
    cnt = 0
    for _, (padded_spectrogram_dbs,padded_text_idxs, original_spectrofram_db_lens, original_text_idx_lens) in enumerate(test_dataloader):
        padded_spectrogram_dbs = padded_spectrogram_dbs.to(device)
        original_spectrofram_db_lens = original_spectrofram_db_lens.to(device)
        padded_text_idxs = padded_text_idxs.to(device)
        original_text_idx_lens = original_text_idx_lens.to(device)
        
        log_probs, y_lengths  = model(x=padded_spectrogram_dbs, x_lengths=original_spectrofram_db_lens)

        hypotheses_idxs = log_probs.argmax(dim=2)
        hypotheses = ctc_simple_decode(hypotheses_idxs, dataset.labels, -1)
        teachers = ctc_simple_decode(padded_text_idxs, dataset.labels, -1)
        for hypothesis, teacher in zip(hypotheses, teachers):
            print(f"hyp: {hypothesis}")
            print(f"tea: {teacher}")
        total_cer += char_error_rate(hypotheses, teachers)
        cnt += 1
    print(f"CER: {total_cer / cnt}")

hyp: <space>no<space>no<space>yes<space>no<space>yes<space>no<space>no<space>yes
tea: no<space>no<space>yes<space>no<space>yes<space>no<space>no<space>yes
hyp: no<space>no<space>yes<space>no<space>yes<space>no<space>no<space>noye
tea: no<space>no<space>yes<space>no<space>yes<space>no<space>no<space>no
hyp: no<space>yes<space>yes<space>yes<space>yes<space>yes<space>yes<space>yesye
tea: no<space>yes<space>yes<space>yes<space>yes<space>yes<space>yes<space>yes
hyp: no<space>no<space>es<space>yes<space>no<space>yes<space>yes<space>no
tea: no<space>no<space>yes<space>yes<space>no<space>yes<space>yes<space>no
hyp: no<space>yes<space>yes<space>no<space>no<space>yes<space>yes<space>no
tea: no<space>yes<space>yes<space>no<space>no<space>yes<space>yes<space>no
hyp: no<space>yes<space>yes<space>yes<space>no<space>no<space>no<space>noy
tea: no<space>yes<space>yes<space>yes<space>no<space>no<space>no<space>no
hyp: no<space>yes<space>yes<space>yes<space>yes<space>no<space>yes<space>no
tea: no<space>y