In [1]:
import torch
from torch.nn.functional import pad
import torch.nn
from data import YesNoDataset
from model import Model
from torchaudio.functional import rnnt_loss
from torch.nn.functional import log_softmax
from torchmetrics.functional import char_error_rate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

In [3]:
dataset = YesNoDataset(
    wav_dir_path="datasets/waves_yesno/",
    model_sample_rate=16000,
)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=dataset.collate_fn,
    drop_last=False,
)

In [4]:
idx_to_label = {v: k for k, v in dataset.label_to_idx.items()}

In [5]:
dataset.label_to_idx

{'_': 0, 'y': 1, 'e': 2, 's': 3, 'n': 4, 'o': 5, '<space>': 6}

In [6]:

model = Model(
    vocab_size=7,
    encoder_input_size=40,
    encoder_hidden_size=64,
    encoder_num_layers=1,
    embedding_size=64,
    predictor_hidden_size=64,
    predictor_num_layers=1,
    jointnet_hidden_size=64,
    blank_idx=0,
).to(DEVICE)
checkpoint = torch.load("model_YesNo.pth")
model.load_state_dict(checkpoint["model_state_dict"])


<All keys matched successfully>

In [7]:
encoder = model.encoder
predictor = model.predictor
jointnet = model.jointnet

In [8]:
mel_specs, tokens, mel_spec_lengths, token_lengths = next(iter(dataloader))
mel_specs = mel_specs.to(DEVICE)
tokens = tokens.to(DEVICE)

enc_inputs = mel_specs
enc_input_lengths = mel_spec_lengths

In [9]:
with torch.no_grad():
    batch_hyp_tokens = []
    for i, (enc_input, enc_input_length) in enumerate(zip(enc_inputs, enc_input_lengths)):
        if enc_input.size(0) > enc_input_length:
            enc_input = enc_input[:enc_input_length, :]

        enc_output, _ = encoder(
            enc_input.unsqueeze(0), torch.tensor([enc_input.size(0)])
        )  # [1, subsampled_enc_input_length, output_size]
        pred_input = torch.tensor([[0]], dtype=torch.int32).to(enc_output.device)
        pred_output, hidden = predictor.forward_wo_prepend(pred_input, torch.tensor([1]), hidden=None)
        # [1, 1, output_size]
        timestamp = 0
        hyp_tokens = []
        while timestamp < enc_output.size(1):
            enc_output_at_timestamp = enc_output[0, timestamp]
            logits = jointnet(enc_output_at_timestamp.view(1, 1, -1), pred_output)
            pred_token = logits.argmax(dim=-1)
            if pred_token != 0:
                hyp_tokens.append(pred_token.item())
                pred_input = torch.tensor([[pred_token]], dtype=torch.int32).to(enc_output.device)
                pred_output, hidden = predictor.forward_wo_prepend(pred_input, torch.tensor([1]), hidden=hidden)
            else:
                timestamp += 1
        batch_hyp_tokens.append(hyp_tokens)
        print("hyp_tokens:", hyp_tokens)
        print("tokens    :", tokens[i, :token_lengths[i]].tolist())

hyp_tokens: [1, 2, 3, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6, 1, 2, 3, 6, 1, 2, 3]
tokens    : [1, 2, 3, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6, 1, 2, 3, 6, 1, 2, 3]
hyp_tokens: [1, 2, 3, 6, 1, 2, 3, 6, 4, 5, 6, 4, 5, 6, 1, 2, 3, 6, 4, 5, 6, 1, 2, 3, 6, 4, 5]
tokens    : [1, 2, 3, 6, 1, 2, 3, 6, 4, 5, 6, 4, 5, 6, 1, 2, 3, 6, 4, 5, 6, 1, 2, 3, 6, 4, 5]
hyp_tokens: [1, 2, 3, 6, 4, 5, 6, 1, 2, 3, 6, 1, 2, 3, 6, 1, 2, 3, 6, 1, 2, 3, 6, 4, 5, 6, 1, 2, 3]
tokens    : [1, 2, 3, 6, 4, 5, 6, 1, 2, 3, 6, 1, 2, 3, 6, 1, 2, 3, 6, 1, 2, 3, 6, 4, 5, 6, 1, 2, 3]
hyp_tokens: [1, 2, 3, 6, 1, 2, 3, 6, 1, 2, 3, 6, 1, 2, 3, 6, 4, 5, 6, 1, 2, 3, 6, 4, 5, 6, 4, 5]
tokens    : [1, 2, 3, 6, 1, 2, 3, 6, 1, 2, 3, 6, 1, 2, 3, 6, 4, 5, 6, 1, 2, 3, 6, 4, 5, 6, 4, 5]
hyp_tokens: [4, 5, 6, 4, 5, 6, 1, 2, 3, 6, 1, 2, 3, 6, 1, 2, 3, 6, 4, 5, 6, 4, 5, 6, 4, 5]
tokens    : [4, 5, 6, 4, 5, 6, 1, 2, 3, 6, 1, 2, 3, 6, 1, 2, 3, 6, 4, 5, 6, 4, 5, 6, 4, 5]
hyp_tokens: [4, 5, 6, 1, 2, 3, 6, 1, 2, 3, 6, 1, 2, 3,

In [32]:
batch_hyp_tokens
hyp_texts = ["".join([idx_to_label[idx] for idx in hyp_token]) for hyp_token in batch_hyp_tokens]

In [33]:
ans_tokens = [tokens[i, : token_lengths[i]].tolist() for i in range(tokens.shape[0])]
ans_texts = ["".join([idx_to_label[idx] for idx in ans_token]) for ans_token in ans_tokens]

In [35]:
char_error_rate(hyp_texts, ans_texts)

tensor(0.1993)

In [23]:
label_to_idx = dataset.label_to_idx
idx_to_label = {v: k for k, v in label_to_idx.items()}

In [20]:
"""
# pred_input: [1, B, 1]
model.eval()
pred_input = torch.tensor([0], dtype=torch.int32).to(DEVICE)
pred_output, hidden = predictor.inference(pred_input)
timestamp = 0
results = []
while timestamp < encoder_output.shape[0]:
    encoder_output_t = encoder_output[timestamp].view(1,1,-1)
    # 各encoder_output[i]は内部的に過去の情報を持っているため、jointで渡すのは現時刻の潜在表現で良い
    joint_output = jointnet(encoder_output_t, pred_output)
    pred_token_idx = torch.argmax(joint_output, dim=-1)
    if pred_token_idx == 0:
        timestamp += 1
    else:
        results.append(pred_token_idx.view(1).item())
        pred_input = torch.tensor([pred_token_idx], dtype=torch.int32).to(DEVICE)
        pred_output, hidden = predictor.inference(pred_input, hidden)

print(results)
"""

AttributeError: 'Predictor' object has no attribute 'inference'