In [1]:
import torch
from torch.nn.functional import pad
import torch.nn
from data import YesNoDataset, LibriSpeechDataset, get_dataloader
from model import CausalConformerModel, TorchAudioConformerModel
from torchaudio.functional import rnnt_loss
from torch.nn.functional import log_softmax
from torchmetrics.functional import char_error_rate
from tqdm import tqdm
import torchaudio
import glob
import json
import os
from tokenizer import SentencePieceTokenizer
from create_json import create_librispeech_json
from modules import torchaudio_conformer
from modules.conformer.normalization import TimewiseBatchNormalization

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
"""
SentencePieceTokenizer.create_model(
    transcription_file_path="vocabs/librispeech_train_960h_transcripts.txt",
    model_prefix="librispeech_char",
    num_tokens=1024,
    model_type="char",
    character_coverage=1.0,
)
"""

'\nSentencePieceTokenizer.create_model(\n    transcription_file_path="vocabs/librispeech_train_960h_transcripts.txt",\n    model_prefix="librispeech_char",\n    num_tokens=1024,\n    model_type="char",\n    character_coverage=1.0,\n)\n'

In [26]:
batch_norm = TimewiseBatchNormalization(input_size=3)

In [27]:
batch_norm.train()
x = torch.randn(1, 3, 1)
print(x)
print(batch_norm(x))

tensor([[[ 1.5375],
         [ 2.1648],
         [-1.5052]]])
tensor([[[0.],
         [0.],
         [0.]]], grad_fn=<TransposeBackward0>)


In [29]:
time_and_chan_wise_mean = batch_norm.moving_mean
time_and_chan_wise_var = batch_norm.moving_var
time_and_chan_wise_std = torch.sqrt(time_and_chan_wise_var + batch_norm.eps)
time_and_chan_wise_std = time_and_chan_wise_std + batch_norm.eps
normalized_x = (x - time_and_chan_wise_mean) / time_and_chan_wise_std


tensor([[[0.],
         [0.],
         [0.]]])

In [30]:
batch_norm.eval()
x = torch.randn(1, 3, 1)
print(x)
print(batch_norm(x))

tensor([[[ 1.1630],
         [-1.8460],
         [ 0.0052]]])
tensor([[[ -3744.7776],
         [-40104.4258],
         [ 15103.0430]]], grad_fn=<TransposeBackward0>)


In [7]:
batch_norm.moving_mean

tensor([[[ 0.7907, -1.0085,  1.5201,  0.0829],
         [ 0.7827, -1.2146,  0.3603,  0.6646],
         [-2.3024, -1.9154, -0.2944,  0.4802]]])

In [3]:
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

In [5]:
model_path = "./artifacts/librispeech_small_test/a3ed4353137f48a39a789451847375ac/artifacts/model_70.pth"
with open(model_path, "rb") as f:
    cpt = torch.load(f)
model_state = cpt["model"]
model_args = cpt["model_args"]
model = TorchAudioConformerModel(**model_args).to(DEVICE)
model.load_state_dict(model_state)

<All keys matched successfully>

In [4]:
model_path = "./librispeech_small/artifacts/f4291c590aba4faf8aff7947af9af524/artifacts/model_20.pth"
with open(model_path, "rb") as f:
    cpt = torch.load(f)
model_state = cpt["model"]
model_args = cpt["model_args"]
model = CausalConformerModel(**model_args).to(DEVICE)
model.load_state_dict(model_state)

<All keys matched successfully>

In [7]:
import random
random.seed(0)
tokenizer = SentencePieceTokenizer(
    model_file_path="./vocabs/librispeech_1024_bpe.model"
)
dataset = LibriSpeechDataset(
    resampling_rate=16000,
    tokenizer=tokenizer,
    json_file_path="./json/librispeech_train-clean-100.json",
)
dataloader = get_dataloader(
    dataset,
    batch_sec=60,
    num_workers=1,
    pad_idx=tokenizer.pad_token_id,
    pin_memory=True,
)

iterator = iter(dataloader)
for _ in range(30):
    next(iterator)
_, benc_input, bpred_input, benc_input_length, bpred_input_length, baudio_sec = next(iterator)
benc_input = benc_input.to(DEVICE)
bpred_input = bpred_input.to(DEVICE)

Batch Prepare: 100%|██████████| 28539/28539 [00:00<00:00, 352943.80it/s]


In [8]:
with torch.no_grad():
    model.eval()
    bpadded_output, bpadded_ctc_log_probs, bsubsampled_enc_input_length = model(
        padded_enc_input=benc_input,
        enc_input_lengths=benc_input_length,
        padded_pred_input=bpred_input,
        pred_input_lengths=bpred_input_length,
    )
    loss = rnnt_loss(
        logits=bpadded_output,
        targets=bpred_input,
        logit_lengths=bsubsampled_enc_input_length.to(DEVICE),
        target_lengths=bpred_input_length.to(DEVICE),
        blank=tokenizer.blank_token_id,
        reduction="sum",
    )
    print(loss / bpred_input.shape[0])
    bhyp_token_indices = model.greedy_inference(
        enc_inputs=benc_input[0].unsqueeze(0), enc_input_lengths=[benc_input_length[0]]
    )
    bhyp_text = tokenizer.batch_token_ids_to_text(bhyp_token_indices)
    bans_token_indices = [
        bpred_input[i, : bpred_input_length[i]].tolist() for i in range(bpred_input.shape[0])
    ]
    bhyp_text = tokenizer.batch_token_ids_to_text(bhyp_token_indices)
    bans_text = tokenizer.batch_token_ids_to_text(bans_token_indices)
    for hyp_text, ans_text in zip(bhyp_text, bans_text):
        print(f"hyp: {hyp_text}")
        print(f"ans: {ans_text}")


tensor(12.7601, device='cuda:4')
hyp: the explained also what they wanted by acting as if they had a piece of blubber in their mouths and then pretending to cut instead of territ i have not as yet noticed the four years whom we had on board
ans: they explained also what they wanted by acting as if they had a piece of blubber in their mouth and then pretending to cut instead of tear it i have not as yet noticed the fuegians whom we had on board


In [47]:

# calc softmax
p = torch.exp(bpadded_output) / torch.exp(bpadded_output).sum(dim=-1, keepdim=True)
p[0][:, :, 0]

tensor([[1.0000, 0.9954, 0.7857,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 0.9913, 0.9820,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 0.9791, 0.9395,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9745, 0.9999, 0.9968,  ..., 1.0000, 1.0000, 1.0000],
        [0.9396, 0.9999, 0.9977,  ..., 1.0000, 1.0000, 1.0000],
        [0.9995, 0.9949, 0.9989,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:4')

In [35]:
logp[0][0][0]

tensor([-2.0266e-06, -2.5104e+01, -2.5104e+01,  ..., -3.5290e+01,
        -2.8254e+01, -3.2530e+01], device='cuda:4')

In [18]:
enc_output, _ = model.encoder(
                benc_input[0].unsqueeze(0), torch.tensor([benc_input[0].size(0)])
            )  # [1, subsampled_enc_input_length, output_size]
pred_input = torch.tensor([[tokenizer.blank_token_id]], dtype=torch.int32).to(enc_output.device)
pred_output, hidden = model.predictor.forward_wo_prepend(pred_input, torch.tensor([1]), hidden=None)
timestamp = 0
hyp_tokens = []
while timestamp < enc_output.shape[1]:
    enc_output_at_t = enc_output[0, timestamp, :]
    logits = model.jointnet(enc_output_at_t.view(1, 1, -1), pred_output)
    pred_token = logits.argmax(dim=-1)
    if timestamp == 0:
        break
    if pred_token != tokenizer.blank_token_id:
        hyp_tokens.append(pred_token.item())
        pred_input = torch.tensor([[pred_token]], dtype=torch.int32).to(enc_output.device)
        pred_output, hidden = model.predictor.forward_wo_prepend(pred_input, torch.tensor([1]), hidden=hidden)
    else:
        timestamp += 1
    
    if len(hyp_tokens) > 100:
        print("detect")
        break
print(torch.sort(logits[0][0][0]))

torch.return_types.sort(
values=tensor([-34.1344, -32.8127, -32.6079,  ...,  -5.3108,  -3.6692,   9.8395],
       device='cuda:4', grad_fn=<SortBackward0>),
indices=tensor([631, 278, 879,  ..., 101,  25,   0], device='cuda:4'))


'▁at'

In [None]:
model.eval()
hyp_token_indices = []
for i in tqdm(range(7, enc_input.shape[0])):
    enc_x = enc_input[:i, :].unsqueeze(0)
    enc_x_len = torch.tensor([i])
    enc_y, enc_sub_x_len = model.encoder(enc_x, enc_x_len)
    joint_x = enc_y[0, -1, :].unsqueeze(0).unsqueeze(0)
    num_token_indices = 0
    while True:
        logits = model.jointnet(joint_x, pred_output)
        hyp = torch.argmax(logits, dim=-1)
        if hyp == tokenizer.blank_token_id:
            break
        else:
            hyp_token_indices.append(hyp.item())
            pred_input = torch.tensor([[hyp]], dtype=torch.int32).to(enc_input.device)
            pred_output, hidden = model.predictor.forward_wo_prepend(
                pred_input, torch.tensor([1]), hidden=hidden
            )
            num_token_indices += 1
        
        if num_token_indices > 5:
            break


In [None]:
hyp_token_indices

In [None]:
model.eval()
hyp_token_indices = []
for i in tqdm(range(8, enc_input.shape[0])):
    enc_x = enc_input[:i, :].unsqueeze(0)
    enc_x_len = torch.tensor([i])
    enc_y, enc_sub_x_len = model.encoder(enc_x, enc_x_len)
    break
    batch_enc_output, batch_subsampled_length = model.encoder(batch, batch_lengths)
    subsampled_length = batch_subsampled_length[0]
    # NOTE: JointNetは線形層を通しているだけであり、時刻に関して独立->現在のenc_outだけで十分
    enc_output = batch_enc_output[0][subsampled_length - 1].view(1, 1, -1)
    num_token_indices = 0
    while True:
        logits = model.jointnet(enc_output, pred_output)[0]
        pred_token_idx = torch.argmax(logits, dim=-1)
        if pred_token_idx == model.blank_idx:
            break
        else:
            num_token_indices += 1
            hyp_token_indices.append(pred_token_idx.item())
            pred_input = torch.tensor([[pred_token_idx]], dtype=torch.int32).to(enc_input.device)
            pred_output, hidden = model.predictor.forward_wo_prepend(
                pred_input, torch.tensor([1]), hidden=hidden
            )

        if num_token_indices >= 5:
            break
hyp_text = tokenizer.token_ids_to_text(hyp_token_indices)
print(hyp_text)

In [None]:
with torch.no_grad():
    model.eval()
    hyp = model.streaming_greedy_inference(
        enc_inputs=benc_input[0].unsqueeze(0), enc_input_lengths=[benc_input_length[0]]
    )
    hyp_text = tokenizer.batch_token_ids_to_text(hyp)
    print(hyp_text)

In [None]:
bans_token_indices = [
    bpred_input[i, : bpred_input_length[i]].tolist() for i in range(bpred_input.shape[0])
]

In [None]:
tokenizer.batch_token_ids_to_text(bans_token_indices)

In [None]:
tokenizer.batch_token_ids_to_text([hyp_token_indices])

In [None]:
out, _ = model(
    padded_enc_input=benc_input,
    enc_input_lengths=benc_input_length,
    padded_pred_input=bpred_input,
    pred_input_lengths=bpred_input_length,
)

In [None]:
model.blank_idx

In [None]:
print(out[0].argmax(-1))

In [None]:
sampled_idx = []
for bidx, bx, by, bx_len, by_len, baudio_sec in dataloader:
    sampled_idx.append(bidx)
    print(sum(baudio_sec))
print(sampled_idx)
sampled_indices = [item for sublist in sampled_idx for item in sublist]
print(len(set(sampled_indices)) == len(dataset))

In [None]:
libri_dataset = LibriSpeechDataset(
    json_file_path="./json/librispeech_dev-other.json", resampling_rate=16000, tokenizer=tokenizer
)
libri_dataloader = get_dataloader(
    libri_dataset,
    batch_sec=100,
    num_workers=1,
    pad_idx=tokenizer.pad_token_id,
    pin_memory=True,
)

In [None]:
libri_dataset[10]

In [None]:
sampled_idx = []
for bidx, bx, by, bx_len, by_len, baudio_sec in libri_dataloader:
    sampled_idx.append(bidx)
    print(sum(baudio_sec))
print(sampled_idx)
sampled_indices = [item for sublist in sampled_idx for item in sublist]
print(len(set(sampled_indices)) == len(libri_dataset))

In [3]:
x = torch.ones(3, 10, 5)
input_length=x.shape[1]
future_mask = torch.triu(torch.ones(input_length, input_length), diagonal=1).bool()
print(future_mask)
NUM_PREVIOUS_FRAMES = "all"
# mask before NUM_PREVIOUS_FRAMES
input_length=x.shape[1]
if NUM_PREVIOUS_FRAMES == "all":
    previous_mask = torch.zeros(input_length, input_length).bool()
else:
    previous_mask=torch.tril(torch.ones(input_length, input_length), diagonal=-(NUM_PREVIOUS_FRAMES+1)).bool()
future_and_previous_mask = torch.logical_or(future_mask, previous_mask)
print(future_and_previous_mask)

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])
tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  T

In [None]:
x = torch.ones(3, 10, 5) # [B, D, T]


In [None]:
cum_sum = x.cumsum(dim=-1).sum(dim=0).repeat(x.shape[0], 1, 1)
cum_num_element = (
    (torch.arange(1, x.shape[-1] + 1) * x.shape[0]).repeat(x.shape[0], x.shape[1], 1).to(x.device)
)
cum_mean = cum_sum / cum_num_element
cum_var = ((x - cum_mean) ** 2).cumsum(dim=-1).sum(dim=0).repeat(x.shape[0], 1, 1) / cum_num_element
cum_std = torch.sqrt(cum_var + self.eps)
cum_std = cum_std + self.eps
normalized_x = (x - cum_mean) / cum_std
if self.affine:
    normalized_x = normalized_x * self.gamma + self.beta

In [None]:
eps = 1e-8
time_and_chan_wise_sum = x.sum(dim=0).unsqueeze(0).repeat(x.shape[0], 1, 1)
time_and_chan_wise_mean = time_and_chan_wise_sum / x.shape[0]
time_and_chan_wise_var = ((x - time_and_chan_wise_mean) ** 2).sum(dim=0).unsqueeze(0).repeat(x.shape[0], 1, 1) / x.shape[0]
time_and_chan_wise_std = torch.sqrt(time_and_chan_wise_var + eps)
time_and_chan_wise_std = time_and_chan_wise_std + eps
normalized_x = (x - time_and_chan_wise_mean) / time_and_chan_wise_std
#if self.affine:
#    normalized_x = normalized_x * self.gamma + self.beta



In [None]:
eps = 1e-8
time_and_batch_wise_sum = x.sum(dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
time_and_batch_wise_mean = time_and_batch_wise_sum / x.shape[1]
time_and_batch_wise_var = ((x - time_and_batch_wise_mean) ** 2).sum(dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) / x.shape[1]
time_and_batch_wise_std = torch.sqrt(time_and_batch_wise_var + eps)
time_and_batch_wise_std = time_and_batch_wise_std + eps
normalized_x = (x - time_and_batch_wise_mean) / time_and_batch_wise_std

In [None]:
x = torch.randint(0, 10, (3, 10, 5))
x # [B, T, D]
x = x.transpose(1, 2) # [B, D, T]
gamma = torch.ones(x.shape[1])

In [None]:
n_x = x.transpose(1, 2) # [B, T, D]
n_x = n_x * gamma
n_x = n_x.transpose(1, 2) # [B, D, T]

In [None]:
n_x