In [5]:
import editdistance


def calc_cer(target_text, predicted_text) -> float:
    if len(target_text) == 0:
        return 0
    
    return editdistance.distance(predicted_text, target_text) / len(target_text)


def calc_wer(target_text, predicted_text) -> float:
    target_text, predicted_text = target_text.split(' '), predicted_text.split(' ')
    if len(target_text) == 0:
        return 0
    
    return editdistance.distance(predicted_text, target_text) / len(target_text)


In [6]:
import numpy as np

for target, pred, expected_wer, expected_cer in [
    ("if you can not measure it you can not improve it", 
     "if you can nt measure t yo can not i", 
     0.454, 0.25),
    ("if you cant describe what you are doing as a process you dont know what youre doing", 
     "if you cant describe what you are doing as a process you dont know what youre doing", 
     0.0, 0.0),
    ("one measurement is worth a thousand expert opinions", 
     "one  is worth thousand opinions", 
     0.375, 0.392)
]:
    wer = calc_wer(target, pred)
    cer = calc_cer(target, pred)
    assert np.isclose(wer, expected_wer, atol=1e-3), f"true: {target}, pred: {pred}, expected wer {expected_wer} != your wer {wer}"
    assert np.isclose(cer, expected_cer, atol=1e-3), f"true: {target}, pred: {pred}, expected cer {expected_cer} != your cer {cer}"

In [7]:
from hw_asr.text_encoder.ctc_char_text_encoder import CTCCharTextEncoder

In [9]:
encoder = CTCCharTextEncoder.get_simple_alphabet()

In [55]:
import torch

In [142]:
log_probs = torch.load("log_probs.pth")
input_lengths = torch.load("input_lengths.pth")
targets = torch.load("targets.pth")
target_lengths = torch.load("target_lengths.pth")
spectrogram1 = torch.load('spectrogram1.pth')
spectrogram2 = torch.load('spectrogram2.pth')

result = torch.load('result.pth')

In [143]:
from torch.nn import CTCLoss


In [144]:
log_probs.shape, targets.shape, input_lengths.shape, target_lengths.shape, spectrogram1.shape, spectrogram2.shape, result.shape

(torch.Size([658, 4, 28]),
 torch.Size([4, 127]),
 torch.Size([4]),
 torch.Size([4]),
 torch.Size([4, 658, 128]),
 torch.Size([4, 658, 128]),
 torch.Size([4, 658, 128]))

In [145]:
loss = CTCLoss(reduction='none')

In [146]:
loss(log_probs=log_probs, targets=targets,
                               input_lengths=input_lengths, target_lengths=target_lengths)

tensor([1587.6580, 1313.0886, 1682.4307, 1681.2706], grad_fn=<CtcLossBackward>)

In [147]:
target_lengths

tensor([112,  82, 127, 124])

In [148]:
input_lengths

tensor([615, 500, 656, 658])

In [149]:
spectrogram1[:, 0, :20]

tensor([[0.0000e+00, 6.2123e+02, 3.3449e+03, 0.0000e+00, 2.2509e+02, 2.6115e+02,
         0.0000e+00, 1.3664e+01, 1.4962e+00, 9.6077e+01, 1.1065e+02, 4.1968e+01,
         2.6610e+02, 0.0000e+00, 4.1583e+00, 4.1194e-01, 2.9761e+01, 8.5165e+00,
         2.3186e+00, 8.5824e-01],
        [0.0000e+00, 6.8417e-01, 3.6837e+00, 0.0000e+00, 1.1331e+02, 1.3147e+02,
         0.0000e+00, 3.6135e+02, 3.9570e+01, 1.4662e+02, 1.6886e+02, 3.3999e+01,
         2.1557e+02, 0.0000e+00, 1.6303e+02, 1.6151e+01, 1.1135e+01, 3.1863e+00,
         1.0416e+01, 3.8554e+00],
        [0.0000e+00, 2.3120e-02, 1.2448e-01, 0.0000e+00, 1.4084e+00, 1.6341e+00,
         0.0000e+00, 6.4227e+00, 7.0331e-01, 1.4390e+00, 1.6572e+00, 9.6978e-02,
         6.1489e-01, 0.0000e+00, 2.4434e-02, 2.4206e-03, 2.3275e-01, 6.6603e-02,
         3.4645e-01, 1.2824e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
       

In [150]:
spectrogram2[:, 0, :20]

tensor([[ 0.0000, -0.0737, -0.0737,  0.0000, -0.4424, -0.4424,  0.0000, -0.4050,
         -0.4050, -0.4306, -0.4306, -0.4426, -0.4426,  0.0000, -0.4436, -0.4436,
         -0.3608, -0.3608, -0.3634, -0.3634],
        [ 0.0000, -0.0749, -0.0749,  0.0000, -0.4426, -0.4426,  0.0000, -0.4048,
         -0.4048, -0.4305, -0.4305, -0.4427, -0.4427,  0.0000, -0.4435, -0.4435,
         -0.3608, -0.3608, -0.3634, -0.3634],
        [ 0.0000, -0.0749, -0.0749,  0.0000, -0.4429, -0.4429,  0.0000, -0.4050,
         -0.4050, -0.4307, -0.4307, -0.4428, -0.4428,  0.0000, -0.4436, -0.4436,
         -0.3608, -0.3608, -0.3634, -0.3634],
        [ 0.0000, -0.0749, -0.0749,  0.0000, -0.4429, -0.4429,  0.0000, -0.4050,
         -0.4050, -0.4307, -0.4307, -0.4428, -0.4428,  0.0000, -0.4436, -0.4436,
         -0.3608, -0.3608, -0.3634, -0.3634]], grad_fn=<SliceBackward>)

In [151]:
result[:, 0, :5]

tensor([[ 0.0191,  0.0049,  0.0198, -0.0161, -0.0310],
        [ 0.0191,  0.0049,  0.0198, -0.0161, -0.0310],
        [ 0.0191,  0.0049,  0.0198, -0.0161, -0.0310],
        [ 0.0191,  0.0049,  0.0198, -0.0161, -0.0310]],
       grad_fn=<SliceBackward>)

In [139]:
log_probs[0, :, :5]

tensor([[-3.2745, -3.3577, -3.3136, -3.2968, -3.3361],
        [-3.2745, -3.3577, -3.3136, -3.2968, -3.3361],
        [-3.2745, -3.3577, -3.3136, -3.2968, -3.3361],
        [-3.2745, -3.3577, -3.3136, -3.2968, -3.3361]],
       grad_fn=<SliceBackward>)

In [140]:
torch.nn.functional.log_softmax(log_probs, dim=-1)[:]

tensor([[[-3.2745, -3.3577, -3.3136,  ..., -3.3595, -3.3186, -3.3953],
         [-3.2745, -3.3577, -3.3136,  ..., -3.3595, -3.3186, -3.3953],
         [-3.2745, -3.3577, -3.3136,  ..., -3.3596, -3.3186, -3.3953],
         [-3.2745, -3.3577, -3.3136,  ..., -3.3595, -3.3186, -3.3953]],

        [[-3.2743, -3.3684, -3.3266,  ..., -3.3557, -3.3265, -3.3885],
         [-3.2743, -3.3683, -3.3266,  ..., -3.3557, -3.3265, -3.3885],
         [-3.2743, -3.3684, -3.3266,  ..., -3.3557, -3.3265, -3.3885],
         [-3.2743, -3.3684, -3.3266,  ..., -3.3557, -3.3265, -3.3885]],

        [[-3.2737, -3.3746, -3.3328,  ..., -3.3527, -3.3312, -3.3853],
         [-3.2737, -3.3746, -3.3327,  ..., -3.3527, -3.3312, -3.3853],
         [-3.2737, -3.3746, -3.3328,  ..., -3.3527, -3.3312, -3.3853],
         [-3.2737, -3.3746, -3.3328,  ..., -3.3527, -3.3312, -3.3852]],

        ...,

        [[-3.2734, -3.3837, -3.3410,  ..., -3.3464, -3.3374, -3.3824],
         [-3.2734, -3.3837, -3.3410,  ..., -3.3464, -3.33

In [141]:
targets

tensor([[ 6., 15., 21., 18., 27., 13., 15., 14., 20.,  8., 19., 27.,  8.,  1.,
          4., 27., 13.,  1.,  4.,  5., 27.,  7., 18.,  5.,  1., 20., 27.,  3.,
          8.,  1., 14.,  7.,  5., 19., 27.,  8.,  5., 27.,  2., 15., 18.,  5.,
         27.,  8.,  9., 13., 19.,  5., 12.,  6., 27., 13., 15., 18.,  5., 27.,
         12.,  9., 11.,  5., 27.,  1., 27., 13.,  1., 14., 27.,  8.,  9., 19.,
         27., 13.,  1., 14., 14.,  5., 18., 27., 23.,  1., 19., 27., 13., 21.,
          3.,  8., 27., 13., 15., 18.,  5., 27.,  3., 15., 14., 19.,  9.,  4.,
          5., 18.,  5.,  4., 27.,  1., 14.,  4., 27.,  7., 18.,  1., 22.,  5.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.],
        [23.,  9., 20.,  8., 27., 19., 15., 21., 16., 27.,  1., 14.,  4., 27.,
          6.,  9., 19.,  8., 27., 19.,  5., 18., 22.,  5., 27., 23.,  8.,  9.,
         20.,  5., 27., 23.,  9., 14.,  5., 19., 27., 19., 21.,  3.,  8., 27.,
          1., 19., 27., 18.,  8.,  5.

In [32]:
torch.isnan(log_probs).sum(), torch.min(log_probs), torch.max(log_probs)

(tensor(0),
 tensor(-194.7010, grad_fn=<MinBackward1>),
 tensor(0., grad_fn=<MaxBackward1>))

In [33]:
torch.isnan(targets).sum(), torch.min(targets), torch.max(targets)

(tensor(0), tensor(0.), tensor(27.))

In [34]:
torch.isnan(input_lengths).sum(), torch.min(input_lengths), torch.max(input_lengths)

(tensor(0), tensor(128), tensor(128))

In [35]:
torch.isnan(target_lengths).sum(), torch.min(target_lengths), torch.max(target_lengths)

(tensor(0), tensor(36), tensor(178))

In [20]:
log_probs

tensor([[[-3.3639, -3.3025, -3.3200,  ..., -3.2768, -3.3583, -3.3349],
         [-3.3635, -3.3029, -3.3197,  ..., -3.2787, -3.3583, -3.3350],
         [-3.3616, -3.3039, -3.3194,  ..., -3.2774, -3.3583, -3.3341],
         ...,
         [-3.3628, -3.3032, -3.3197,  ..., -3.2777, -3.3584, -3.3348],
         [-3.3579, -3.3041, -3.3199,  ..., -3.2759, -3.3585, -3.3329],
         [-3.3601, -3.3042, -3.3195,  ..., -3.2770, -3.3584, -3.3336]],

        [[-3.3606, -3.3038, -3.3199,  ..., -3.2769, -3.3586, -3.3336],
         [-3.3636, -3.3029, -3.3197,  ..., -3.2787, -3.3583, -3.3350],
         [-3.3549, -3.3037, -3.3224,  ..., -3.2755, -3.3582, -3.3314],
         ...,
         [-3.3635, -3.3027, -3.3198,  ..., -3.2785, -3.3582, -3.3352],
         [-3.3472, -3.3010, -3.3271,  ..., -3.2754, -3.3644, -3.3276],
         [-3.3640, -3.3030, -3.3198,  ..., -3.2760, -3.3582, -3.3350]],

        [[-3.3629, -3.3033, -3.3199,  ..., -3.2773, -3.3585, -3.3344],
         [-3.3635, -3.3029, -3.3196,  ..., -3