In [1]:
import torch
from torch.nn.functional import pad
import torch.nn
from data import YesNoDataset
from model import CausalConformerModel
from torchaudio.functional import rnnt_loss
from torch.nn.functional import log_softmax
from torchmetrics.functional import char_error_rate

In [2]:
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

In [3]:
dataset = YesNoDataset(
    wav_dir_path="datasets/waves_yesno/",
    model_sample_rate=16000,
)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=dataset.collate_fn,
    drop_last=False,
)

In [4]:
idx_to_token = dataset.idx_to_token

In [22]:

model = CausalConformerModel(
    vocab_size=8,
    encoder_input_size=40,
    encoder_subsampled_input_size=64,
    encoder_num_conformer_blocks=1,
    encoder_ff_hidden_size=64,
    encoder_conv_hidden_size=64,
    encoder_conv_kernel_size=8,
    encoder_mha_num_head=4,
    encoder_dropout=0,
    embedding_size=64,
    predictor_hidden_size=64,
    predictor_num_layers=1,
    jointnet_hidden_size=64,
    blank_idx=dataset.blank_idx,
).to(DEVICE)
checkpoint = torch.load("model_YesNo.pth")
model.load_state_dict(checkpoint["model_state_dict"])


<All keys matched successfully>

In [23]:
benc_input, bpred_input, benc_input_length, bpred_input_length = next(iter(dataloader))
benc_input = benc_input.to(DEVICE)
bhyp_token_indices = model.greedy_inference(benc_input, benc_input_length)

In [24]:
bans_token_indices = [
                        bpred_input[i, : bpred_input_length[i]].tolist() for i in range(bpred_input.shape[0])
                    ]
bhyp_text = [
    "".join([idx_to_token[idx] for idx in hyp_token_indices])
    for hyp_token_indices in bhyp_token_indices
]
bans_text = [
    "".join([idx_to_token[idx] for idx in ans_token_indices])
    for ans_token_indices in bans_token_indices
]

In [25]:
for i in range(len(bhyp_text)):
    print(f"hyp: {bhyp_text[i]}")
    print(f"ans: {bans_text[i]}")
    print()

hyp: yes no no no no no yes yes
ans: yes no no no no no yes yes

hyp: yes yes no no yes yes no
ans: yes yes no no yes no yes no

hyp: yes no yes yes yes yes no yes
ans: yes no yes yes yes yes no yes

hyp: yes yes yes yes no yes no no
ans: yes yes yes yes no yes no no

hyp: no no yes yes yes no no no
ans: no no yes yes yes no no no

hyp: no yes yes yes yes yes yes yes
ans: no yes yes yes yes yes yes yes

hyp: no yes no yes yes yes no no
ans: no yes no yes yes yes no no

hyp: yes no yes yes yes no yes no
ans: yes no yes yes yes no yes no



In [1]:
import torch

In [3]:
x = torch.ones(2,3,4)

In [4]:
mask = torch.triu(torch.ones(x.size(1), x.size(1)), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))

In [7]:
torch.triu(torch.ones(x.size(1), x.size(1)), diagonal=1).bool()

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [8]:
input_lengths = torch.tensor([2, 3, 4, 5, 6, 7, 8, 9])
max_length=9

In [9]:
torch.arange(max_length)[None, :] >= input_lengths[:, None]

tensor([[False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False]])