In [None]:
import torch
import torch.nn.functional as F
import json
import numpy as np
from espnet.utils.io_utils import LoadInputsAndTargets
from espnet.utils.training.batchfy import make_batchset
from espnet.asr.pytorch_backend.asr_init import load_trained_model
from espnet.asr.pytorch_backend.asr import CustomConverter

In [None]:
device = torch.device("cpu")

In [None]:
with open("/root/espnet/egs/librispeech/asr1/dump/dev_clean/deltafalse/data.json", "r") as f:
    data_json = json.load(f)["utts"]

In [None]:
data_batches = make_batchset(data_json, 1)

In [None]:
load_tr = LoadInputsAndTargets(
        mode='asr', load_output=True, preprocess_conf=None,
        preprocess_args={'train': True}  # Switch the mode of preprocessing
    )

In [None]:
model, train_args = load_trained_model("/root/espnet/egs/librispeech/asr1/exp/train_clean_100_pytorch_phone_hybrid/results/model.acc.best")
model = model.to(device=device)

In [None]:
converter = CustomConverter(subsampling_factor=1, dtype=torch.float32)

In [None]:
data_batches[0]

In [None]:
phone_to_int = dict(zip(train_args.char_list, np.arange(len(train_args.char_list))))

In [None]:
keyword = "<space> M AE N AH F EH S T <space>"

In [None]:
keyword_tokens = torch.tensor([[phone_to_int[phn] for phn in keyword.split(" ")]]).to(device)

In [None]:
data_input = [load_tr(data_batches[0])]

In [None]:
data = converter(data_input, device)

In [None]:
hs_pad, hlens, _ = model.enc(data[0], data[1])

In [None]:
print(hs_pad.shape)
hs_pad = hs_pad[:, :400 :,]
print(hs_pad.shape)

best_loss = 9999
ctc_losses = []
indices = []
for start_idx in range(0, hs_pad.size(1) - 50, 2):
    for end_idx in range(start_idx+50, min(start_idx+500, hs_pad.size(1)-1), 4):
        ctc_loss = model.ctc(hs_pad[:, start_idx:end_idx+1 :,], torch.tensor([end_idx - start_idx]), keyword_tokens)
        ctc_losses.append(ctc_loss)
        indices.append([start_idx, end_idx])
        if(ctc_loss < best_loss):
            best_loss = ctc_loss
            best_indices = [start_idx, end_idx]

In [None]:
print("Best loss : {0} \t keyword found between {1} to {2} seconds".format(str(best_loss), str(best_indices[0]*4/100), str(best_indices[1]*4/100)))

In [None]:
# print(sorted(list(zip(ctc_losses, indices))))