In [1]:
import torch
from torch.nn.functional import pad
import torch.nn
from data import YesNoDataset
from model import CausalConformerModel
from torchaudio.functional import rnnt_loss
from torch.nn.functional import log_softmax
from torchmetrics.functional import char_error_rate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

In [3]:
dataset = YesNoDataset(
    wav_dir_path="datasets/waves_yesno/",
    model_sample_rate=16000,
)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=dataset.collate_fn,
    drop_last=False,
)

In [4]:
idx_to_token = dataset.idx_to_token

In [5]:

model = CausalConformerModel(
    vocab_size=8,
    encoder_input_size=40,
    encoder_subsampled_input_size=32,
    encoder_num_conformer_blocks=1,
    encoder_ff_hidden_size=32,
    encoder_conv_hidden_size=32,
    encoder_conv_kernel_size=16,
    encoder_mha_num_heads=4,
    encoder_dropout=0,
    encoder_subsampling_kernel_size1=3,
    encoder_subsampling_stride1=2,
    encoder_subsampling_kernel_size2=3,
    encoder_subsampling_stride2=2,
    encoder_num_previous_frames="all",
    embedding_size=64,
    predictor_hidden_size=64,
    predictor_num_layers=1,
    jointnet_hidden_size=64,
    blank_idx=dataset.blank_idx,
).to(DEVICE)
checkpoint = torch.load("model_YesNo.pth")
model.load_state_dict(checkpoint["model_state_dict"])


<All keys matched successfully>

In [6]:
benc_input, bpred_input, benc_input_length, bpred_input_length = next(iter(dataloader))
benc_input = benc_input.to(DEVICE)
bhyp_token_indices = model.greedy_inference(benc_input, benc_input_length)

In [7]:
enc_inputs = benc_input
enc_input_lengths = benc_input_length

In [8]:
batch_hyp_token_indices = []
BUFFER_SIZE=4
NUM_PREVIOUS_FRAMES="all"
for enc_input, enc_input_length in zip(enc_inputs, enc_input_lengths):
    if enc_input.size(0) > enc_input_length:
        enc_input = enc_input[:enc_input_length, :]
    hyp_token_indices = []
    buffer = []
    pred_input = torch.tensor([[dataset.blank_idx]], dtype=torch.int32).to(DEVICE)
    pred_output, hidden = model.predictor.forward_wo_prepend(pred_input, torch.tensor([1]), hidden=None)
    for i in range(5, enc_input.shape[0]):
        if NUM_PREVIOUS_FRAMES == "all":
            buffer.append(enc_input[:i+1])
        else:
            buffer.append(enc_input[max(i+1-NUM_PREVIOUS_FRAMES, 0):i+1])
        if len(buffer) == BUFFER_SIZE:
            batch = torch.nn.utils.rnn.pad_sequence(buffer, batch_first=True, padding_value=0)
            batch_lengths = torch.tensor([len(x) for x in buffer])
            buffer = []
            batch_enc_output, batch_subsampled_length = model.encoder(batch, batch_lengths)
            for j in range(len(batch_enc_output)):
                subsampled_length = batch_subsampled_length[j]
                # NOTE: JointNetは線形層を通しているだけであり、時刻に関して独立->現在のenc_outだけで十分
                enc_output = batch_enc_output[j][subsampled_length-1].view(1,1,-1)
                num_token_indices = 0
                while True:
                    logits = model.jointnet(enc_output, pred_output)[0]
                    pred_token_idx = torch.argmax(logits, dim=-1)
                    if pred_token_idx == dataset.blank_idx:
                        break
                    else:
                        num_token_indices += 1
                        hyp_token_indices.append(pred_token_idx.item())
                        pred_input = torch.tensor([[pred_token_idx]], dtype=torch.int32).to(DEVICE)
                        pred_output, hidden = model.predictor.forward_wo_prepend(pred_input, torch.tensor([1]), hidden=hidden)
                    
                    if num_token_indices > 10:
                        break
    batch_hyp_token_indices.append(hyp_token_indices)


In [11]:
bans_token_indices = [
                        bpred_input[i, : bpred_input_length[i]].tolist() for i in range(bpred_input.shape[0])
                    ]
bhyp_text = [
    "".join([idx_to_token[idx] for idx in hyp_token_indices])
    for hyp_token_indices in batch_hyp_token_indices
]
bans_text = [
    "".join([idx_to_token[idx] for idx in ans_token_indices])
    for ans_token_indices in bans_token_indices
]

In [12]:
for i in range(len(bhyp_text)):
    print(f"hyp: {bhyp_text[i]}")
    print(f"ans: {bans_text[i]}")
    print()

hyp: no no no no no no yes yes
ans: yes no no no no no yes yes

hyp: yes yes no no yes no yes no
ans: yes yes no no yes no yes no

hyp: no no yes yes yes yes no yes
ans: yes no yes yes yes yes no yes

hyp: yes yes yes yes no yes no no
ans: yes yes yes yes no yes no no

hyp: no no yes yes yes no no no
ans: no no yes yes yes no no no

hyp: no yes yes yes yes yes yes yes
ans: no yes yes yes yes yes yes yes

hyp: no yes no yes yes yes no no
ans: no yes no yes yes yes no no

hyp: no no yes yes yes no yes no
ans: yes no yes yes yes no yes no



In [13]:
x = torch.ones(3, 10, 5)
input_length=x.shape[1]
future_mask = torch.triu(torch.ones(input_length, input_length), diagonal=1).bool()
print(future_mask)
NUM_PREVIOUS_FRAMES = "all"
# mask before NUM_PREVIOUS_FRAMES
input_length=x.shape[1]
if NUM_PREVIOUS_FRAMES == "all":
    previous_mask = torch.zeros(input_length, input_length).bool()
else:
    previous_mask=torch.tril(torch.ones(input_length, input_length), diagonal=-(NUM_PREVIOUS_FRAMES+1)).bool()
print(previous_mask)
future_and_previous_mask = torch.logical_or(future_mask, previous_mask)

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])
tensor([[False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, Fa

In [15]:
future_and_previous_mask

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])