In [21]:
import torch
from typing import Optional

In [22]:
def levenshtein_distance_list(r, h):

    # initialisation
    d = [[0] * (len(h)+1)] * (len(r)+1)

    # computation
    for i in range(1, len(r)+1):
        for j in range(1, len(h)+1):

            if r[i-1] == h[j-1]:
                d[i].append(d[i-1][j-1])
            else:
                substitution = d[i-1][j-1] + 1
                insertion = d[i][j-1] + 1
                deletion = d[i-1][j] + 1
                d[i].append(min(substitution, insertion, deletion))

    return d[len(r)][len(h)]

In [23]:
# https://martin-thoma.com/word-error-rate-calculation/


def levenshtein_distance(r: str, h: str, device: Optional[str] = None):

    # initialisation
    d = torch.zeros((2, len(h)+1), dtype=torch.long)  # , device=device)
    dold = 0
    dnew = 1

    # computation
    for i in range(1, len(r)+1):
        d[dnew, 0] = 0
        for j in range(1, len(h)+1):

            if r[i-1] == h[j-1]:
                d[dnew, j] = d[dnew-1, j-1]
            else:
                substitution = d[dnew-1, j-1] + 1
                insertion = d[dnew, j-1] + 1
                deletion = d[dnew-1, j] + 1
                d[dnew, j] = min(substitution, insertion, deletion)

        dnew, dold = dold, dnew

    dist = d[dnew, -1].item()

    return dist

In [24]:
# https://martin-thoma.com/word-error-rate-calculation/


def levenshtein_distance_torch(r: str, h: str, device: Optional[str] = None):

    # initialisation
    d = torch.zeros((2, len(h)+1), dtype=torch.long)  # , device=device)
    dold = 0
    dnew = 1

    # computation
    for i in range(1, len(r)+1):
        d[dnew, 0] = 0
        for j in range(1, len(h)+1):
            # print(r[i-1], h[j-1])
            if r[i-1] == h[j-1]:
                d[dnew, j] = d[dnew-1, j-1]
            else:
                substitution = d[dnew-1, j-1] + 1
                insertion = d[dnew, j-1] + 1
                deletion = d[dnew-1, j] + 1
                d[dnew, j] = min(substitution, insertion, deletion)

        dnew, dold = dold, dnew

    dist = d[dnew, -1].item()

    return dist

In [25]:
# https://martin-thoma.com/word-error-rate-calculation/
def levenshtein_distance_list_2(r: str, h: str):

    # initialisation
    dold = list(range(len(h)+1))
    dnew = list(0 for _ in range(len(h)+1))

    # computation
    for i in range(1, len(r)+1):
        dnew[0] = i
        for j in range(1, len(h)+1):
            if r[i-1] == h[j-1]:
                dnew[j] = dold[j-1]
            else:
                substitution = dold[j-1] + 1
                insertion = dnew[j-1] + 1
                deletion = dold[j] + 1
                dnew[j] = min(substitution, insertion, deletion)

        dnew, dold = dold, dnew

    return dold[-1]

In [26]:
def levenshtein_distance_list_3(r, h):

    # initialisation
    d = [[0] * (len(h)+1)] * 2

    # computation
    for i in range(1, len(r)+1):
        for j in range(1, len(h)+1):

            if r[i-1 % 2] == h[j-1]:
                d[i].append(d[i-1%2][j-1])
            else:
                substitution = d[i-1%2][j-1] + 1
                insertion = d[i%2][j-1] + 1
                deletion = d[i-1%2][j] + 1
                d[i].append(min(substitution, insertion, deletion))

    # print(d)
    return d[len(r)][len(h)]

In [27]:
def levenshtein_distance_numpy(r, h):
    # initialisation
    import numpy

    d = numpy.zeros((len(r) + 1) * (len(h) + 1), dtype=numpy.uint8)
    d = d.reshape((len(r) + 1, len(h) + 1))
    for i in range(len(r) + 1):
        for j in range(len(h) + 1):
            if i == 0:
                d[0][j] = j
            elif j == 0:
                d[i][0] = i

    # computation
    for i in range(1, len(r) + 1):
        for j in range(1, len(h) + 1):
            if r[i - 1] == h[j - 1]:
                d[i][j] = d[i - 1][j - 1]
            else:
                substitution = d[i - 1][j - 1] + 1
                insertion = d[i][j - 1] + 1
                deletion = d[i - 1][j] + 1
                d[i][j] = min(substitution, insertion, deletion)

    return d[len(r)][len(h)]

In [28]:
r = "ab"
h = "cc"
levenshtein_distance(r, h) == 2

False

In [30]:
    r = "abcdddee"
    h = "abcddde"

    %timeit levenshtein_distance(r, h)

    jitted = torch.jit.script(levenshtein_distance)
    %timeit jitted(r, h)

    %timeit levenshtein_distance_list(r, h)
    %timeit levenshtein_distance_list_2(r, h)
    # %timeit levenshtein_distance_list_3(r, h)
    
    %timeit levenshtein_distance_torch(r, h)

    # jitted = torch.jit.script(levenshtein_distance_list)
    # %timeit jitted(r, h)

    # %timeit levenshtein_distance_array(r, h)
    # jitted = torch.jit.script(levenshtein_distance_array)
    # %timeit jitted(r, h)

2.73 ms ± 3.35 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.46 ms ± 2.34 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
41.9 µs ± 16.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
35.1 µs ± 58.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
2.7 ms ± 15.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


NameError: name 'levenshtein_distance_array' is not defined