In [1]:
import torch
from torch.nn.functional import pad
import torch.nn
from data import YesNoDataset, LibriSpeechDataset
from model import CausalConformerModel
from torchaudio.functional import rnnt_loss
from torch.nn.functional import log_softmax
from torchmetrics.functional import char_error_rate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

In [11]:
dataset = YesNoDataset(
    wav_dir_path="datasets/waves_yesno/",
    model_sample_rate=16000,
)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=dataset.collate_fn,
    drop_last=False,
)
idx_to_token = dataset.idx_to_token

In [3]:
libri_dataset = LibriSpeechDataset(
    root="./datasets",
    split="train",
    resampling_rate=16000,
    vocab_file_path="vocabs/librispeech_train_960h.json"
)
libri_dataloader = torch.utils.data.DataLoader(
    libri_dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=libri_dataset.collate_fn,
    drop_last=False,
)
import IPython.display as ipd
ipd.Audio(libri_dataset[0][-2], rate=16000)

In [46]:
cpt_path = "./cpts/yesno_causal_conformer/836b576aac7744c2857c00ae170f7b91/model_YesNo_80.pth"
with open(cpt_path, "rb") as f:
    cpt = torch.load(f, map_location=DEVICE)

model = CausalConformerModel(
    vocab_size=8,
    encoder_input_size=40,
    encoder_subsampled_input_size=128,
    encoder_num_conformer_blocks=1,
    encoder_ff_hidden_size=128,
    encoder_conv_hidden_size=128,
    encoder_conv_kernel_size=16,
    encoder_mha_num_heads=4,
    encoder_dropout=0,
    encoder_subsampling_kernel_size1=3,
    encoder_subsampling_stride1=2,
    encoder_subsampling_kernel_size2=3,
    encoder_subsampling_stride2=2,
    encoder_num_previous_frames="all",
    embedding_size=64,
    predictor_hidden_size=64,
    predictor_num_layers=1,
    jointnet_hidden_size=64,
    blank_idx=dataset.blank_idx,
    decoder_buffer_size=4,
).to(DEVICE)
model.load_state_dict(cpt["model"])


<All keys matched successfully>

In [48]:
model.eval()
benc_input, bpred_input, benc_input_length, bpred_input_length, baudio_sec = next(iter(dataloader))
benc_input = benc_input.to(DEVICE)
bhyp_token_indices = model.streaming_greedy_inference(benc_input, benc_input_length)
bhyp_token_indices

[[0,
  1,
  2,
  5,
  3,
  4,
  5,
  3,
  4,
  5,
  3,
  4,
  5,
  3,
  4,
  5,
  3,
  4,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2],
 [0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  3,
  4,
  5,
  3,
  4,
  5,
  0,
  1,
  2,
  5,
  3,
  4,
  5,
  0,
  1,
  2,
  5,
  3,
  4],
 [0,
  1,
  2,
  5,
  3,
  4,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  3,
  4,
  5,
  0,
  1,
  2],
 [0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  3,
  4,
  5,
  0,
  1,
  2,
  5,
  3,
  4,
  5,
  3,
  4],
 [3,
  4,
  5,
  3,
  4,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  3,
  4,
  5,
  3,
  4,
  5,
  3,
  4],
 [3,
  4,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2],
 [3,
  4,
  5,
  0,
  1,
  2,
  5,
  3,
  4,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  0,
  1,
  2,
  5,
  3,
  4,
  5,
  3,
  4],
 [0,
  1,
  2,
  5,
  3,
  4

In [49]:
bans_token_indices = [
    bpred_input[i, : bpred_input_length[i]].tolist() for i in range(bpred_input.shape[0])
]
bhyp_text = [
    "".join([idx_to_token[idx] for idx in hyp_token_indices])
    for hyp_token_indices in bhyp_token_indices
]
bans_text = [
    "".join([idx_to_token[idx] for idx in ans_token_indices])
    for ans_token_indices in bans_token_indices
]

In [50]:
for i in range(len(bhyp_text)):
    print(f"hyp: {bhyp_text[i]}")
    print(f"ans: {bans_text[i]}")
    print()

hyp: yes no no no no no yes yes
ans: yes no no no no no yes yes

hyp: yes yes no no yes no yes no
ans: yes yes no no yes no yes no

hyp: yes no yes yes yes yes no yes
ans: yes no yes yes yes yes no yes

hyp: yes yes yes yes no yes no no
ans: yes yes yes yes no yes no no

hyp: no no yes yes yes no no no
ans: no no yes yes yes no no no

hyp: no yes yes yes yes yes yes yes
ans: no yes yes yes yes yes yes yes

hyp: no yes no yes yes yes no no
ans: no yes no yes yes yes no no

hyp: yes no yes yes yes no yes no
ans: yes no yes yes yes no yes no



In [13]:
x = torch.ones(3, 10, 5)
input_length=x.shape[1]
future_mask = torch.triu(torch.ones(input_length, input_length), diagonal=1).bool()
print(future_mask)
NUM_PREVIOUS_FRAMES = "all"
# mask before NUM_PREVIOUS_FRAMES
input_length=x.shape[1]
if NUM_PREVIOUS_FRAMES == "all":
    previous_mask = torch.zeros(input_length, input_length).bool()
else:
    previous_mask=torch.tril(torch.ones(input_length, input_length), diagonal=-(NUM_PREVIOUS_FRAMES+1)).bool()
print(previous_mask)
future_and_previous_mask = torch.logical_or(future_mask, previous_mask)

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])
tensor([[False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, Fa

In [15]:
future_and_previous_mask

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])