In [1]:
import torch
from torch.nn.functional import pad
import torch.nn
from data import YesNoDataset, LibriSpeechDataset, get_dataloader, LibriSpeechTextDataset, get_text_dataloader
from model import CausalConformerModel, TorchAudioConformerModel
from torchaudio.functional import rnnt_loss
from torch.nn.functional import log_softmax
from torchmetrics.functional import char_error_rate, word_error_rate
from tqdm import tqdm
import torchaudio
import glob
import json
from tqdm import tqdm
import os
from tokenizer import SentencePieceTokenizer
from create_json import create_librispeech_json
from modules import torchaudio_conformer
from modules.conformer.normalization import CausalLayerNormalization

  from .autonotebook import tqdm as notebook_tqdm


In [22]:
""""
json_dict = {}
with open("./datasets/librispeech/librispeech-lm-norm.txt", "r") as f:
    texts = []
    # get num of raws
    counter = 0
    for line in tqdm(f.readlines()):
        # remove \n
        line = line[:-1]
        if line == "":
            continue
        json_dict[counter] = {
            "raw_transcript": line,
        }
        counter += 1
with open("./json/librispeech_text.json", "w") as f:
    json.dump(json_dict, f, indent=2)
"""

100%|██████████| 40418261/40418261 [00:31<00:00, 1303168.51it/s]


In [2]:
"""
SentencePieceTokenizer.create_model(
    transcription_file_path="vocabs/librispeech_train_960h_transcripts.txt",
    model_prefix="librispeech_char",
    num_tokens=1024,
    model_type="char",
    character_coverage=1.0,
)
"""

'\nSentencePieceTokenizer.create_model(\n    transcription_file_path="vocabs/librispeech_train_960h_transcripts.txt",\n    model_prefix="librispeech_char",\n    num_tokens=1024,\n    model_type="char",\n    character_coverage=1.0,\n)\n'

In [2]:
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

In [26]:
tokenizer = SentencePieceTokenizer(
    model_file_path="./vocabs/librispeech_1024_bpe.model"
)
dataset = LibriSpeechTextDataset(
    json_file_path="./json/librispeech_text.json",
    tokenizer=tokenizer,
)
dataloader = get_text_dataloader(
    dataset,
    batch_text_len=1000,
    num_workers=8,
    pad_idx=tokenizer.pad_token_id,
    pin_memory=True,
)



In [3]:
tokenizer = SentencePieceTokenizer(
    model_file_path="./vocabs/librispeech_1024_bpe.model"
)
dataset = LibriSpeechDataset(
    resampling_rate=16000,
    tokenizer=tokenizer,
    json_file_path="./json/librispeech_train-clean-100.json",
)
dataloader = get_dataloader(
    dataset,
    batch_sec=60,
    num_workers=8,
    pad_idx=tokenizer.pad_token_id,
    pin_memory=True,
)

iterator = iter(dataloader)
for _ in range(30):
    next(iterator)
_, benc_input, bpred_input, benc_input_length, bpred_input_length, baudio_sec = next(iterator)
benc_input = benc_input.to(DEVICE)
bpred_input = bpred_input.to(DEVICE)

Batch Prepare: 100%|██████████| 28539/28539 [00:00<00:00, 380216.38it/s]


In [5]:
bpred_input_length

tensor([60, 74, 14, 67], dtype=torch.int32)

In [4]:
bpred_input

tensor([[ 106,  886,   38,  320,  871,  151,    8,  663,   55,  106,  886,   38,
          320,  871,  422,  151,    8,   88,   13,   78,   18,   17,  267,  145,
         1005,  191, 1009,   43,  327,   87,    7,  344,  209,   25,  225,  156,
           28,    8,  659, 1004,  755,  179,   92, 1001,  999,  580,    8,   15,
          331, 1014,   25,   34,  968,   82,  140,  750,   27,   93, 1004,    3,
            4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
            4,    4],
        [ 179,  925,   76,   24,  361,   34,  910,  125,   31,   26,  368,  758,
          741,  154,  919,   24,   27,  141,   58,  170, 1012,  119,   25,  277,
         1016,   21,  122,  138,   17,  595,   93,   10,  147,  271,   62,   39,
            8,  488,   25,  127,  443,  505,  198,   39, 1001,  965,  308,   39,
          198,  303,  710,  938,   51,  589,   62,   22,  980,   76,  553,  996,
          164,   98,  825,  812,   61,  128,  160,    8,  636,  125,  406,   67,
      

In [13]:
bpred_input.view(-1).shape

torch.Size([272])

In [3]:
a = torch.randn(2, 3)
torch.nn.functional.pad(a, (1, 0), value=0).shape

torch.Size([2, 4])

In [5]:
model_path = "./artifacts/librispeech_clean-100/1a8150c0e8f64726b4f7058ff8fc6c74/artifacts/model_40.pth"
with open(model_path, "rb") as f:
    cpt = torch.load(f)
model_state = cpt["model"]
model_args = cpt["model_args"]
model = CausalConformerModel(**model_args).to(DEVICE)
model.load_state_dict(model_state)

<All keys matched successfully>

In [7]:
with torch.no_grad():
    model.eval()
    bpadded_output, bpadded_ctc_log_probs, bsubsampled_enc_input_length = model(
        padded_enc_input=benc_input,
        enc_input_lengths=benc_input_length,
        padded_pred_input=bpred_input,
        pred_input_lengths=bpred_input_length,
    )
    loss = rnnt_loss(
        logits=bpadded_output,
        targets=bpred_input,
        logit_lengths=bsubsampled_enc_input_length.to(DEVICE),
        target_lengths=bpred_input_length.to(DEVICE),
        blank=tokenizer.blank_token_id,
        reduction="sum",
    )
    bhyp_token_indices = model.streaming_greedy_inference(
        enc_inputs=benc_input, enc_input_lengths=benc_input_length
    )
    bhyp_text = tokenizer.batch_token_ids_to_text(bhyp_token_indices)
    bans_token_indices = [
        bpred_input[i, : bpred_input_length[i]].tolist() for i in range(bpred_input.shape[0])
    ]
    bhyp_text = tokenizer.batch_token_ids_to_text(bhyp_token_indices)
    bans_text = tokenizer.batch_token_ids_to_text(bans_token_indices)
    
    cer = char_error_rate(bhyp_text, bans_text) * benc_input.shape[0]
    wer = word_error_rate(bhyp_text, bans_text) * benc_input.shape[0]




5it [01:20, 16.09s/it]


tensor(0.0658)

In [8]:
bhyp_token_indices

[[605,
  19,
  209,
  646,
  356,
  119,
  8,
  904,
  604,
  25,
  106,
  886,
  24,
  111,
  87,
  7,
  369,
  10,
  356,
  78,
  304,
  59,
  106,
  886,
  269,
  150,
  35,
  87,
  7,
  10,
  559,
  28,
  416,
  69,
  19,
  31,
  169,
  573,
  1014],
 [8,
  16,
  54,
  63,
  5,
  42,
  119,
  107,
  75,
  84,
  260,
  38,
  25,
  361,
  398,
  63,
  753,
  927,
  341,
  218,
  112,
  61,
  58,
  246,
  50,
  43,
  455,
  48,
  592,
  149,
  27,
  22,
  186,
  26,
  58,
  27,
  8,
  273,
  203,
  25,
  376,
  631,
  8,
  92,
  998,
  42,
  120,
  56,
  88,
  670,
  1004,
  151,
  293,
  55,
  32,
  58,
  467,
  19,
  95,
  382,
  199,
  414,
  997,
  1016,
  213,
  998,
  1004,
  384,
  628,
  24,
  266,
  699,
  572,
  28,
  8,
  86,
  102,
  1012,
  86,
  64,
  153],
 [674,
  422,
  108,
  187,
  62,
  25,
  262,
  1021,
  1000,
  1014,
  24,
  350,
  31,
  9,
  700,
  323,
  144,
  81,
  95,
  7,
  369,
  53,
  42,
  1001,
  312,
  31,
  220,
  1013,
  32],
 [145,
  461,
  32,
  

In [25]:
tokenizer.eos_token_id

3

In [24]:
bpadded_output[0][68, :, 3]

tensor([-26.7156, -26.0835, -25.7718, -28.0198, -25.3701, -27.5337, -23.8810,
        -24.2610, -26.7484, -29.9810, -24.0653, -24.8224, -26.5370, -32.6300,
        -32.5643, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628,
        -12.9628, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628,
        -12.9628, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628,
        -12.9628, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628,
        -12.9628, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628,
        -12.9628, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628,
        -12.9628, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628,
        -12.9628, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628, -12.9628,
        -12.9628, -12.9628, -12.9628, -12.9628], device='cuda:4')

In [14]:
bhyp_token_indices

[[103, 19, 407, 19, 209, 402, 208, 465, 94, 27, 141, 58, 327],
 [35,
  209,
  314,
  158,
  145,
  656,
  58,
  96,
  230,
  388,
  129,
  534,
  444,
  281,
  94,
  184,
  206,
  67,
  8,
  672,
  34,
  447,
  128,
  95,
  31,
  790,
  1004,
  79,
  94,
  141,
  32,
  112,
  103,
  128,
  582,
  25,
  155,
  26,
  17,
  212,
  8,
  322,
  814,
  145,
  989],
 [593,
  519,
  27,
  694,
  27,
  129,
  975,
  313,
  338,
  25,
  344,
  103,
  27,
  603,
  275,
  61,
  19,
  253,
  669,
  27,
  583,
  90,
  1016,
  154,
  237,
  6,
  1001,
  1004,
  8,
  92,
  335,
  69,
  303,
  371,
  28,
  189,
  117,
  996,
  17,
  1005,
  168,
  267,
  96,
  131,
  45,
  517,
  788,
  25,
  112,
  129,
  10,
  44,
  118,
  1000,
  469,
  71,
  422,
  27,
  256,
  491,
  25,
  159,
  324,
  115,
  129,
  190,
  69,
  383,
  27,
  256,
  306,
  171,
  999],
 [202,
  71,
  209,
  226,
  32,
  156,
  28,
  256,
  444,
  25,
  397,
  87,
  114,
  256,
  231,
  150,
  46,
  72,
  167,
  998,
  101,
  8,
  

In [11]:
for i in range(benc_input.shape[0]):
    print(bhyp_text[i])
    print(bans_text[i])
    print()

but i think i will remember not to do it again
but i think i will remember not to do it again

he will see no change it is true my young men did not go out on the war path they had dreams for not doing so but they love and venerate the great white chief
he will see no change it is true my young men did not go out on the war path they had dreams for not doing so but they love and venerate the great white chief

please to return to my service well and good but to suppose that i am going to disturb or unhens the ancient usage of knight errantry is all nonsense and so my sancho get you back to your house and explain my intentions to your teresa
please to return to my service well and good but to suppose that i am going to disturb or unhinge the ancient usage of knight errantry is all nonsense and so my sancho get you back to your house and explain my intentions to your teresa

if you will bring one of your men and come with me yourself said gimblet at the conclusion of the interview
if you

In [None]:

# calc softmax
p = torch.exp(bpadded_output) / torch.exp(bpadded_output).sum(dim=-1, keepdim=True)
p[0][:, :, 0]

tensor([[1.0000, 0.9954, 0.7857,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 0.9913, 0.9820,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 0.9791, 0.9395,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9745, 0.9999, 0.9968,  ..., 1.0000, 1.0000, 1.0000],
        [0.9396, 0.9999, 0.9977,  ..., 1.0000, 1.0000, 1.0000],
        [0.9995, 0.9949, 0.9989,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:4')

In [2]:
logp[0][0][0]

NameError: name 'logp' is not defined

In [3]:
enc_output, _ = model.encoder(
                benc_input[0].unsqueeze(0), torch.tensor([benc_input[0].size(0)])
            )  # [1, subsampled_enc_input_length, output_size]
pred_input = torch.tensor([[tokenizer.blank_token_id]], dtype=torch.int32).to(enc_output.device)
pred_output, hidden = model.predictor.forward_wo_prepend(pred_input, torch.tensor([1]), hidden=None)
timestamp = 0
hyp_tokens = []
while timestamp < enc_output.shape[1]:
    enc_output_at_t = enc_output[0, timestamp, :]
    logits = model.jointnet(enc_output_at_t.view(1, 1, -1), pred_output)
    pred_token = logits.argmax(dim=-1)
    if timestamp == 0:
        break
    if pred_token != tokenizer.blank_token_id:
        hyp_tokens.append(pred_token.item())
        pred_input = torch.tensor([[pred_token]], dtype=torch.int32).to(enc_output.device)
        pred_output, hidden = model.predictor.forward_wo_prepend(pred_input, torch.tensor([1]), hidden=hidden)
    else:
        timestamp += 1
    
    if len(hyp_tokens) > 100:
        print("detect")
        break
print(torch.sort(logits[0][0][0]))

NameError: name 'model' is not defined

'▁at'

In [None]:
model.eval()
hyp_token_indices = []
for i in tqdm(range(7, enc_input.shape[0])):
    enc_x = enc_input[:i, :].unsqueeze(0)
    enc_x_len = torch.tensor([i])
    enc_y, enc_sub_x_len = model.encoder(enc_x, enc_x_len)
    joint_x = enc_y[0, -1, :].unsqueeze(0).unsqueeze(0)
    num_token_indices = 0
    while True:
        logits = model.jointnet(joint_x, pred_output)
        hyp = torch.argmax(logits, dim=-1)
        if hyp == tokenizer.blank_token_id:
            break
        else:
            hyp_token_indices.append(hyp.item())
            pred_input = torch.tensor([[hyp]], dtype=torch.int32).to(enc_input.device)
            pred_output, hidden = model.predictor.forward_wo_prepend(
                pred_input, torch.tensor([1]), hidden=hidden
            )
            num_token_indices += 1
        
        if num_token_indices > 5:
            break


In [None]:
hyp_token_indices

In [None]:
model.eval()
hyp_token_indices = []
for i in tqdm(range(8, enc_input.shape[0])):
    enc_x = enc_input[:i, :].unsqueeze(0)
    enc_x_len = torch.tensor([i])
    enc_y, enc_sub_x_len = model.encoder(enc_x, enc_x_len)
    break
    batch_enc_output, batch_subsampled_length = model.encoder(batch, batch_lengths)
    subsampled_length = batch_subsampled_length[0]
    # NOTE: JointNetは線形層を通しているだけであり、時刻に関して独立->現在のenc_outだけで十分
    enc_output = batch_enc_output[0][subsampled_length - 1].view(1, 1, -1)
    num_token_indices = 0
    while True:
        logits = model.jointnet(enc_output, pred_output)[0]
        pred_token_idx = torch.argmax(logits, dim=-1)
        if pred_token_idx == model.blank_idx:
            break
        else:
            num_token_indices += 1
            hyp_token_indices.append(pred_token_idx.item())
            pred_input = torch.tensor([[pred_token_idx]], dtype=torch.int32).to(enc_input.device)
            pred_output, hidden = model.predictor.forward_wo_prepend(
                pred_input, torch.tensor([1]), hidden=hidden
            )

        if num_token_indices >= 5:
            break
hyp_text = tokenizer.token_ids_to_text(hyp_token_indices)
print(hyp_text)

In [None]:
with torch.no_grad():
    model.eval()
    hyp = model.streaming_greedy_inference(
        enc_inputs=benc_input[0].unsqueeze(0), enc_input_lengths=[benc_input_length[0]]
    )
    hyp_text = tokenizer.batch_token_ids_to_text(hyp)
    print(hyp_text)

In [None]:
bans_token_indices = [
    bpred_input[i, : bpred_input_length[i]].tolist() for i in range(bpred_input.shape[0])
]

In [None]:
tokenizer.batch_token_ids_to_text(bans_token_indices)

In [None]:
tokenizer.batch_token_ids_to_text([hyp_token_indices])

In [None]:
out, _ = model(
    padded_enc_input=benc_input,
    enc_input_lengths=benc_input_length,
    padded_pred_input=bpred_input,
    pred_input_lengths=bpred_input_length,
)

In [None]:
model.blank_idx

In [None]:
print(out[0].argmax(-1))

In [None]:
sampled_idx = []
for bidx, bx, by, bx_len, by_len, baudio_sec in dataloader:
    sampled_idx.append(bidx)
    print(sum(baudio_sec))
print(sampled_idx)
sampled_indices = [item for sublist in sampled_idx for item in sublist]
print(len(set(sampled_indices)) == len(dataset))

In [None]:
libri_dataset = LibriSpeechDataset(
    json_file_path="./json/librispeech_dev-other.json", resampling_rate=16000, tokenizer=tokenizer
)
libri_dataloader = get_dataloader(
    libri_dataset,
    batch_sec=100,
    num_workers=1,
    pad_idx=tokenizer.pad_token_id,
    pin_memory=True,
)

In [None]:
libri_dataset[10]

In [None]:
sampled_idx = []
for bidx, bx, by, bx_len, by_len, baudio_sec in libri_dataloader:
    sampled_idx.append(bidx)
    print(sum(baudio_sec))
print(sampled_idx)
sampled_indices = [item for sublist in sampled_idx for item in sublist]
print(len(set(sampled_indices)) == len(libri_dataset))

In [3]:
x = torch.ones(3, 10, 5)
input_length=x.shape[1]
future_mask = torch.triu(torch.ones(input_length, input_length), diagonal=1).bool()
print(future_mask)
NUM_PREVIOUS_FRAMES = "all"
# mask before NUM_PREVIOUS_FRAMES
input_length=x.shape[1]
if NUM_PREVIOUS_FRAMES == "all":
    previous_mask = torch.zeros(input_length, input_length).bool()
else:
    previous_mask=torch.tril(torch.ones(input_length, input_length), diagonal=-(NUM_PREVIOUS_FRAMES+1)).bool()
future_and_previous_mask = torch.logical_or(future_mask, previous_mask)
print(future_and_previous_mask)

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])
tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  T

In [None]:
x = torch.ones(3, 10, 5) # [B, D, T]


In [None]:
cum_sum = x.cumsum(dim=-1).sum(dim=0).repeat(x.shape[0], 1, 1)
cum_num_element = (
    (torch.arange(1, x.shape[-1] + 1) * x.shape[0]).repeat(x.shape[0], x.shape[1], 1).to(x.device)
)
cum_mean = cum_sum / cum_num_element
cum_var = ((x - cum_mean) ** 2).cumsum(dim=-1).sum(dim=0).repeat(x.shape[0], 1, 1) / cum_num_element
cum_std = torch.sqrt(cum_var + self.eps)
cum_std = cum_std + self.eps
normalized_x = (x - cum_mean) / cum_std
if self.affine:
    normalized_x = normalized_x * self.gamma + self.beta

In [None]:
eps = 1e-8
time_and_chan_wise_sum = x.sum(dim=0).unsqueeze(0).repeat(x.shape[0], 1, 1)
time_and_chan_wise_mean = time_and_chan_wise_sum / x.shape[0]
time_and_chan_wise_var = ((x - time_and_chan_wise_mean) ** 2).sum(dim=0).unsqueeze(0).repeat(x.shape[0], 1, 1) / x.shape[0]
time_and_chan_wise_std = torch.sqrt(time_and_chan_wise_var + eps)
time_and_chan_wise_std = time_and_chan_wise_std + eps
normalized_x = (x - time_and_chan_wise_mean) / time_and_chan_wise_std
#if self.affine:
#    normalized_x = normalized_x * self.gamma + self.beta



In [None]:
eps = 1e-8
time_and_batch_wise_sum = x.sum(dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
time_and_batch_wise_mean = time_and_batch_wise_sum / x.shape[1]
time_and_batch_wise_var = ((x - time_and_batch_wise_mean) ** 2).sum(dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) / x.shape[1]
time_and_batch_wise_std = torch.sqrt(time_and_batch_wise_var + eps)
time_and_batch_wise_std = time_and_batch_wise_std + eps
normalized_x = (x - time_and_batch_wise_mean) / time_and_batch_wise_std

In [None]:
x = torch.randint(0, 10, (3, 10, 5))
x # [B, T, D]
x = x.transpose(1, 2) # [B, D, T]
gamma = torch.ones(x.shape[1])

In [None]:
n_x = x.transpose(1, 2) # [B, T, D]
n_x = n_x * gamma
n_x = n_x.transpose(1, 2) # [B, D, T]

In [None]:
n_x