-
Notifications
You must be signed in to change notification settings - Fork 590
/
edist.py
84 lines (78 loc) · 2.77 KB
/
edist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import re
import numpy as np
from scipy.ndimage import filters
def levenshtein(a,b):
"""Calculates the Levenshtein distance between a and b.
(Clever compact Pythonic implementation from hetland.org)"""
n, m = len(a), len(b)
if n > m: a,b = b,a; n,m = m,n
current = range(n+1)
for i in range(1,m+1):
previous,current = current,[i]+[0]*n
for j in range(1,n+1):
add,delete = previous[j]+1,current[j-1]+1
change = previous[j-1]
if a[j-1]!=b[i-1]: change = change+1
current[j] = min(add, delete, change)
return current[n]
def xlevenshtein(a,b,context=1):
"""Calculates the Levensthein distance between a and b
and generates a list of differences by context."""
n, m = len(a), len(b)
assert m>0 # xlevenshtein should only be called with non-empty b string (ground truth)
if a == b: return 0,[] # speed up for the easy case
sources = np.empty((m+1,n+1),object)
sources[:,:] = None
dists = np.full((m+1,n+1),99999)
dists[0,:] = np.arange(n+1)
for i in range(1,m+1):
previous = dists[i-1,:]
current = dists[i,:]
current[0] = i
for j in range(1,n+1):
if previous[j]+1<current[j]:
sources[i,j] = (i-1,j)
dists[i,j] = previous[j]+1
if current[j-1]+1<current[j]:
sources[i,j] = (i,j-1)
dists[i,j] = current[j-1]+1
delta = 1*(a[j-1]!=b[i-1])
if previous[j-1]+delta<current[j]:
sources[i,j] = (i-1,j-1)
dists[i,j] = previous[j-1]+delta
cost = current[n]
# reconstruct the paths and produce two aligned strings
l = sources[i,n]
path = []
while l is not None:
path.append(l)
i,j = l
l = sources[i,j]
al,bl = [],[]
path = [(n+2,m+2)]+path
for k in range(len(path)-1):
i,j = path[k]
i0,j0 = path[k+1]
u = "_"
v = "_"
if j!=j0 and j0<n: u = a[j0]
if i!=i0 and i0<m: v = b[i0]
al.append(u)
bl.append(v)
al = "".join(al[::-1])
bl = "".join(bl[::-1])
# now compute a splittable string with the differences
assert len(al)==len(bl)
al = " "*context+al+" "*context
bl = " "*context+bl+" "*context
assert "~" not in al and "~" not in bl
same = np.array([al[i]==bl[i] for i in range(len(al))],'i')
same = filters.minimum_filter(same,1+2*context)
als = "".join([al[i] if not same[i] else "~" for i in range(len(al))])
bls = "".join([bl[i] if not same[i] else "~" for i in range(len(bl))])
# print(als)
# print(bls)
ags = re.split(r'~+',als)
bgs = re.split(r'~+',bls)
confusions = [(a,b) for a,b in zip(ags,bgs) if a!="" or b!=""]
return cost,confusions