In [1]:
import torch
from torch.nn.functional import pad
import torch.nn
from data import YesNoDataset
from torchaudio.functional import rnnt_loss
from torch.nn.functional import log_softmax
from torchmetrics.functional import char_error_rate
from modules.conformer.convolution import CausalConvolutionLayer
from modules.subsampling import Conv2DSubSampling

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 8
dataset = YesNoDataset(
    wav_dir_path="datasets/waves_yesno/",
    model_sample_rate=16000,
)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=dataset.collate_fn,
    drop_last=False,
)
num_labels = len(dataset.idx_to_token)
token_to_idx = dataset.token_to_idx
blank_idx = dataset.blank_idx
pad_idx = dataset.pad_idx
idx_to_token = dataset.idx_to_token

In [3]:
from torch.nn.functional import log_softmax
class Model(torch.nn.Module):
    def __init__(self, input_length, subsampled_input_length, num_labels):
        super(Model, self).__init__()
        self.input_length = input_length
        self.subsampled_input_length = subsampled_input_length
        self.num_labels = num_labels
        self.subsampling = Conv2DSubSampling(input_length, subsampled_input_length, 3, 2, 3, 2)
        self.causal_conv = CausalConvolutionLayer(
            input_channels=subsampled_input_length,
            hidden_channels=32,
            depthwise_kernel_size=12,
            dropout=0
        ) # [B, T, D]
        self.fc = torch.nn.Linear(subsampled_input_length, num_labels)
    
    def forward(self, padded_input, input_lengths):
        subsampled_padded_input, subsampled_input_lengths = self.subsampling(padded_input, input_lengths)
        padded_output = self.causal_conv(subsampled_padded_input)
        padded_output = self.fc(padded_output)
        padded_log_prob = log_softmax(padded_output, dim=2)
        return padded_log_prob, subsampled_input_lengths

In [4]:
def ctc_simple_decode(hypotheses_idxs, idx_to_token, padding_idx, blank_idx):

    hypotheses_idxs = hypotheses_idxs.cpu().numpy()
    hypotheses = []
    for hypothesis_idxs in hypotheses_idxs:
        hypothesis = []
        prev_idx = -1
        for idx in hypothesis_idxs:
            if idx == blank_idx:
                continue
            elif idx == prev_idx:
                continue
            elif idx == padding_idx:
                continue
            else:
                hypothesis.append(idx_to_token[idx])
                prev_idx = idx
        hypotheses.append("".join(hypothesis))
    return hypotheses

In [5]:
from torch import nn, optim

import time

model = Model(input_length=40, subsampled_input_length=32, num_labels=num_labels).to(DEVICE)

ctc_loss = nn.CTCLoss(reduction="sum", blank=dataset.blank_idx)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
# Adam

for i in range(50):
    t0 = time.time()
    model.train()
    epoch_loss = 0
    cnt = 0
    total_cer = 0
    for _, (padded_spectrogram_dbs, padded_text_idxs, original_spectrofram_db_lens, original_text_idx_lens) in enumerate(dataloader):
        cnt += 1
        optimizer.zero_grad()
        padded_spectrogram_dbs = padded_spectrogram_dbs.to(DEVICE)
        padded_text_idxs = padded_text_idxs.to(DEVICE)
      
        padded_log_probs, sub_sampled_padded_spectrogram_db_lens = model(padded_spectrogram_dbs, original_spectrofram_db_lens)
        loss = ctc_loss(padded_log_probs.transpose(1,0), padded_text_idxs, sub_sampled_padded_spectrogram_db_lens, original_text_idx_lens)
        loss.backward()
        optimizer.step()
        # lossはバッチ内平均ロス
        epoch_loss += (loss.item() / BATCH_SIZE)

        hypotheses_idxs = padded_log_probs.argmax(dim=2) 
        hypotheses = ctc_simple_decode(hypotheses_idxs, idx_to_token=idx_to_token, padding_idx=pad_idx, blank_idx=blank_idx)
        teachers = ctc_simple_decode(padded_text_idxs, idx_to_token=idx_to_token, padding_idx=pad_idx, blank_idx=blank_idx)
        total_cer += char_error_rate(hypotheses, teachers)

    t1 = time.time()
    print(f"{i} epoch: {epoch_loss / cnt} loss, CER: {total_cer / cnt}, {t1 - t0} sec")

0 epoch: 175.83381366729736 loss, CER: 0.9370034337043762, 1.3335866928100586 sec
1 epoch: 141.56728553771973 loss, CER: 0.7443107962608337, 0.48404765129089355 sec
2 epoch: 105.41268348693848 loss, CER: 0.9437316656112671, 0.465836763381958 sec
3 epoch: 70.4853253364563 loss, CER: 0.9689525961875916, 0.48598313331604004 sec
4 epoch: 53.91745448112488 loss, CER: 0.9706601500511169, 0.39849090576171875 sec
5 epoch: 51.39614009857178 loss, CER: 0.9695340394973755, 1.102468729019165 sec
6 epoch: 49.97663640975952 loss, CER: 0.9580181837081909, 0.7339046001434326 sec
7 epoch: 49.06383752822876 loss, CER: 0.968400239944458, 0.4747483730316162 sec
8 epoch: 48.33909487724304 loss, CER: 0.9650269150733948, 0.4532613754272461 sec
9 epoch: 47.74324631690979 loss, CER: 0.9638643860816956, 0.4011263847351074 sec
10 epoch: 47.006303787231445 loss, CER: 0.9712439179420471, 0.5733516216278076 sec
11 epoch: 45.912123680114746 loss, CER: 0.9666744470596313, 0.6571249961853027 sec
12 epoch: 44.139127969