Skip to content

Commit

Permalink
pass groups to the Editex
Browse files Browse the repository at this point in the history
  • Loading branch information
orsinium committed Mar 18, 2019
1 parent a7e60da commit 25d10e5
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions textdistance/algorithms/phonetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,12 @@ def __call__(self, *sequences):
class Editex(_Base):
"""
https://anhaidgroup.github.io/py_stringmatching/v0.3.x/Editex.html
http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.3856&rep=rep1&type=pdf
http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.18.2138&rep=rep1&type=pdf
https://github.com/chrislit/blob/master/abydos/distance/_editex.py
https://habr.com/ru/post/331174/ (RUS)
"""
letter_groups = (
groups = (
frozenset('AEIOUY'),
frozenset('BP'),
frozenset('CKQ'),
Expand All @@ -89,29 +91,41 @@ class Editex(_Base):
frozenset('SXZ'),
frozenset('CSZ'),
)
all_letters = frozenset('AEIOUYBPCKQDTLRMNGJFVSXZ')
ungrouped = frozenset('HW') # all letters in alphabet that not presented in `grouped`

def __init__(self, local=False, match_cost=0, group_cost=1, mismatch_cost=2):
def __init__(self, local=False, match_cost=0, group_cost=1, mismatch_cost=2,
groups=None, ungrouped=None):
self.match_cost = match_cost
self.group_cost = group_cost
self.mismatch_cost = mismatch_cost
self.local = local

if groups is not None:
if ungrouped is None:
raise ValueError('`ungrouped` argument required with `groups`')
self.groups = groups
self.ungrouped = ungrouped
self.grouped = frozenset.union(*self.groups)

# backward compat
if hasattr(self, 'letter_groups'):
self.groups = self.letter_groups

def maximum(self, *sequences):
return max(map(len, sequences)) * self.mismatch_cost

def r_cost(self, *elements):
if self._ident(*elements):
return self.match_cost
if any(map(lambda x: x not in self.all_letters, elements)):
if any(map(lambda x: x not in self.grouped, elements)):
return self.mismatch_cost
for group in self.letter_groups:
for group in self.groups:
if all(map(lambda x: x in group, elements)):
return self.group_cost
return self.mismatch_cost

def d_cost(self, *elements):
if not self._ident(*elements) and elements[0] in 'HW':
if not self._ident(*elements) and elements[0] in self.ungrouped:
return self.group_cost
return self.r_cost(*elements)

Expand Down

0 comments on commit 25d10e5

Please sign in to comment.