In [5]:
import editdistance


def calc_cer(target_text, predicted_text) -> float:
    if len(target_text) == 0:
        return 0
    
    return editdistance.distance(predicted_text, target_text) / len(target_text)


def calc_wer(target_text, predicted_text) -> float:
    target_text, predicted_text = target_text.split(' '), predicted_text.split(' ')
    if len(target_text) == 0:
        return 0
    
    return editdistance.distance(predicted_text, target_text) / len(target_text)


In [6]:
import numpy as np

for target, pred, expected_wer, expected_cer in [
    ("if you can not measure it you can not improve it", 
     "if you can nt measure t yo can not i", 
     0.454, 0.25),
    ("if you cant describe what you are doing as a process you dont know what youre doing", 
     "if you cant describe what you are doing as a process you dont know what youre doing", 
     0.0, 0.0),
    ("one measurement is worth a thousand expert opinions", 
     "one  is worth thousand opinions", 
     0.375, 0.392)
]:
    wer = calc_wer(target, pred)
    cer = calc_cer(target, pred)
    assert np.isclose(wer, expected_wer, atol=1e-3), f"true: {target}, pred: {pred}, expected wer {expected_wer} != your wer {wer}"
    assert np.isclose(cer, expected_cer, atol=1e-3), f"true: {target}, pred: {pred}, expected cer {expected_cer} != your cer {cer}"

In [7]:
from hw_asr.text_encoder.ctc_char_text_encoder import CTCCharTextEncoder

In [9]:
encoder = CTCCharTextEncoder.get_simple_alphabet()

In [17]:
import torch

In [42]:
log_probs = torch.load("log_probs.pth")
input_lengths = torch.load("input_lengths.pth")
targets = torch.load("targets.pth")
target_lengths = torch.load("target_lengths.pth")

In [43]:
from torch.nn import CTCLoss


In [44]:
log_probs.shape, targets.shape, input_lengths.shape, target_lengths.shape

(torch.Size([944, 20, 28]),
 torch.Size([20, 178]),
 torch.Size([20]),
 torch.Size([20]))

In [45]:
loss = CTCLoss(reduction='none')

In [46]:
loss(log_probs=log_probs, targets=targets,
                               input_lengths=input_lengths, target_lengths=target_lengths)

tensor([ 657.8795, 1683.1038, 1934.3074, 1541.3494, 1956.8157, 1615.7985,
         859.8023, 1875.2427, 2455.1677,  683.3677, 1890.3376, 1083.0039,
        1492.3206, 2146.5991, 1580.9023,  796.2955,  629.4090,  813.8438,
        1196.7618,  788.1266], grad_fn=<CtcLossBackward>)

In [47]:
target_lengths

tensor([ 61, 111, 178,  93, 157, 105,  64, 148, 176,  36, 176,  71, 139, 174,
         89,  60,  37,  53, 103,  40])

In [32]:
torch.isnan(log_probs).sum(), torch.min(log_probs), torch.max(log_probs)

(tensor(0),
 tensor(-194.7010, grad_fn=<MinBackward1>),
 tensor(0., grad_fn=<MaxBackward1>))

In [33]:
torch.isnan(targets).sum(), torch.min(targets), torch.max(targets)

(tensor(0), tensor(0.), tensor(27.))

In [34]:
torch.isnan(input_lengths).sum(), torch.min(input_lengths), torch.max(input_lengths)

(tensor(0), tensor(128), tensor(128))

In [35]:
torch.isnan(target_lengths).sum(), torch.min(target_lengths), torch.max(target_lengths)

(tensor(0), tensor(36), tensor(178))

In [20]:
log_probs

tensor([[[-3.3639, -3.3025, -3.3200,  ..., -3.2768, -3.3583, -3.3349],
         [-3.3635, -3.3029, -3.3197,  ..., -3.2787, -3.3583, -3.3350],
         [-3.3616, -3.3039, -3.3194,  ..., -3.2774, -3.3583, -3.3341],
         ...,
         [-3.3628, -3.3032, -3.3197,  ..., -3.2777, -3.3584, -3.3348],
         [-3.3579, -3.3041, -3.3199,  ..., -3.2759, -3.3585, -3.3329],
         [-3.3601, -3.3042, -3.3195,  ..., -3.2770, -3.3584, -3.3336]],

        [[-3.3606, -3.3038, -3.3199,  ..., -3.2769, -3.3586, -3.3336],
         [-3.3636, -3.3029, -3.3197,  ..., -3.2787, -3.3583, -3.3350],
         [-3.3549, -3.3037, -3.3224,  ..., -3.2755, -3.3582, -3.3314],
         ...,
         [-3.3635, -3.3027, -3.3198,  ..., -3.2785, -3.3582, -3.3352],
         [-3.3472, -3.3010, -3.3271,  ..., -3.2754, -3.3644, -3.3276],
         [-3.3640, -3.3030, -3.3198,  ..., -3.2760, -3.3582, -3.3350]],

        [[-3.3629, -3.3033, -3.3199,  ..., -3.2773, -3.3585, -3.3344],
         [-3.3635, -3.3029, -3.3196,  ..., -3