In [6]:
import torch
from torch.nn.functional import pad
import torch.nn
from data import YesNoDataset
from torchaudio.functional import rnnt_loss
from torch.nn.functional import log_softmax
from torchmetrics.functional import char_error_rate
from modules.conformer.conformer import CausalConformerBlock
from modules.subsampling import Conv2DSubSampling
from modules.encoder import CausalConformerEncoder

In [2]:
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 8
dataset = YesNoDataset(
    wav_dir_path="datasets/waves_yesno/",
    model_sample_rate=16000,
)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=dataset.collate_fn,
    drop_last=False,
)
num_labels = len(dataset.idx_to_token)
token_to_idx = dataset.token_to_idx
blank_idx = dataset.blank_idx
pad_idx = dataset.pad_idx
idx_to_token = dataset.idx_to_token

In [10]:
from torch.nn.functional import log_softmax
class Model(torch.nn.Module):
    def __init__(self, input_size, subsampled_input_size, num_labels):
        super(Model, self).__init__()
        self.input_size = input_size
        self.subsampled_input_size = subsampled_input_size
        self.num_labels = num_labels
        self.encoder = CausalConformerEncoder(
            input_size=input_size,
            subsampled_input_size=subsampled_input_size,
            ff_hidden_size=32,
            conv_hidden_size=64,
            conv_kernel_size=8,
            mha_num_heads=4,
            dropout=0.1,
            num_conformer_blocks=1,
        )
        self.fc = torch.nn.Linear(subsampled_input_size, num_labels)
    
    def forward(self, padded_input, input_lengths):
        padded_output, subsampled_input_lengths = self.encoder(padded_input, input_lengths)
        padded_output = self.fc(padded_output)
        padded_log_prob = log_softmax(padded_output, dim=2)
        return padded_log_prob, subsampled_input_lengths

In [11]:
def ctc_simple_decode(hypotheses_idxs, idx_to_token, padding_idx, blank_idx):

    hypotheses_idxs = hypotheses_idxs.cpu().numpy()
    hypotheses = []
    for hypothesis_idxs in hypotheses_idxs:
        hypothesis = []
        prev_idx = -1
        for idx in hypothesis_idxs:
            if idx == blank_idx:
                continue
            elif idx == prev_idx:
                continue
            elif idx == padding_idx:
                continue
            else:
                hypothesis.append(idx_to_token[idx])
                prev_idx = idx
        hypotheses.append("".join(hypothesis))
    return hypotheses

In [12]:
from torch import nn, optim

import time

model = Model(input_size=40, subsampled_input_size=32, num_labels=num_labels).to(DEVICE)

ctc_loss = nn.CTCLoss(reduction="sum", blank=dataset.blank_idx)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
# Adam

for i in range(50):
    t0 = time.time()
    model.train()
    epoch_loss = 0
    cnt = 0
    total_cer = 0
    print(f"epoch: {i}")
    for _, (padded_spectrogram_dbs, padded_text_idxs, original_spectrofram_db_lens, original_text_idx_lens) in enumerate(dataloader):
        cnt += 1
        optimizer.zero_grad()
        padded_spectrogram_dbs = padded_spectrogram_dbs.to(DEVICE)
        padded_text_idxs = padded_text_idxs.to(DEVICE)
      
        padded_log_probs, sub_sampled_padded_spectrogram_db_lens = model(padded_spectrogram_dbs, original_spectrofram_db_lens)
        loss = ctc_loss(padded_log_probs.transpose(1,0), padded_text_idxs, sub_sampled_padded_spectrogram_db_lens, original_text_idx_lens)
        loss.backward()
        optimizer.step()
        # lossはバッチ内平均ロス
        epoch_loss += (loss.item() / BATCH_SIZE)

        hypotheses_idxs = padded_log_probs.argmax(dim=2) 
        hypotheses = ctc_simple_decode(hypotheses_idxs, idx_to_token=idx_to_token, padding_idx=pad_idx, blank_idx=blank_idx)
        teachers = ctc_simple_decode(padded_text_idxs, idx_to_token=idx_to_token, padding_idx=pad_idx, blank_idx=blank_idx)
        total_cer += char_error_rate(hypotheses, teachers)

    t1 = time.time()
    print(f"{i} epoch: {epoch_loss / cnt} loss, CER: {total_cer / cnt}, {t1 - t0} sec")

epoch: 0
0 epoch: 73.52474880218506 loss, CER: 0.8217387795448303, 6.232017993927002 sec
epoch: 1
1 epoch: 48.50467109680176 loss, CER: 0.9225143790245056, 7.184962511062622 sec
epoch: 2
2 epoch: 47.51719260215759 loss, CER: 0.9158633351325989, 7.084156036376953 sec
epoch: 3
3 epoch: 47.17652463912964 loss, CER: 0.8755994439125061, 7.802304029464722 sec
epoch: 4
4 epoch: 46.967416286468506 loss, CER: 0.885417103767395, 6.042619943618774 sec
epoch: 5
5 epoch: 46.80734133720398 loss, CER: 0.8968003392219543, 5.90822958946228 sec
epoch: 6
6 epoch: 46.67368817329407 loss, CER: 0.9079841375350952, 5.460905313491821 sec
epoch: 7
7 epoch: 46.54531240463257 loss, CER: 0.9292165040969849, 6.78022313117981 sec
epoch: 8
8 epoch: 46.336510181427 loss, CER: 0.9083645343780518, 6.57177734375 sec
epoch: 9
9 epoch: 45.97253894805908 loss, CER: 0.9038147926330566, 6.470540523529053 sec
epoch: 10
10 epoch: 45.085493087768555 loss, CER: 0.925201952457428, 6.819506406784058 sec
epoch: 11
11 epoch: 42.8085

KeyboardInterrupt: 