Skip to content

Commit

Permalink
addressing reviews + fixing lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
varisd committed Jan 2, 2019
1 parent 81bbdbb commit 0409618
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 38 deletions.
30 changes: 15 additions & 15 deletions neuralmonkey/evaluators/chrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def score_instance(self,
reference: List[str]) -> float:
hyp_joined = " ".join(hypothesis)
hyp_chars = [x for x in list(hyp_joined) if x not in self.ignored]
hyp_ngrams = self._get_ngrams(hyp_chars, self.n)
hyp_ngrams = _get_ngrams(hyp_chars, self.n)

ref_joined = " ".join(reference)
ref_chars = [x for x in list(ref_joined) if x not in self.ignored]
ref_ngrams = self._get_ngrams(ref_chars, self.n)
ref_ngrams = _get_ngrams(ref_chars, self.n)

if not hyp_chars or not ref_chars:
if "".join(hyp_chars) == "".join(ref_chars):
Expand Down Expand Up @@ -69,7 +69,7 @@ def chr_r(self, hyp_ngrams: NGramDicts, ref_ngrams: NGramDicts) -> float:
ref_count, hyp_ngrams[m - 1][ngr])
return np.mean(np.divide(
count_matched, count_all, out=np.ones_like(count_all),
where=(count_all!=0)))
where=(count_all != 0)))

def chr_p(self, hyp_ngrams: NGramDicts, ref_ngrams: NGramDicts) -> float:
count_all = np.zeros(self.n)
Expand All @@ -83,18 +83,18 @@ def chr_p(self, hyp_ngrams: NGramDicts, ref_ngrams: NGramDicts) -> float:
hyp_count, ref_ngrams[m - 1][ngr])
return np.mean(np.divide(
count_matched, count_all, out=np.ones_like(count_all),
where=(count_all!=0)))

def _get_ngrams(self, tokens: List[str], n: int) -> NGramDicts:
ngr_dicts = []
for m in range(1, n + 1):
ngr_dict = {} # type: Dict[str, int]
# if m > len(tokens), return an empty dict
for i in range(m, len(tokens) + 1):
ngr = "".join(tokens[i - m:i])
ngr_dict[ngr] = ngr_dict.setdefault(ngr, 0) + 1
ngr_dicts.append(ngr_dict)
return ngr_dicts
where=(count_all != 0)))


def _get_ngrams(tokens: List[str], n: int) -> NGramDicts:
ngr_dicts = []
for m in range(1, n + 1):
ngr_dict = {} # type: Dict[str, int]
for i in range(m, len(tokens) + 1):
ngr = "".join(tokens[i - m:i])
ngr_dict[ngr] = ngr_dict.setdefault(ngr, 0) + 1
ngr_dicts.append(ngr_dict)
return ngr_dicts


# pylint: disable=invalid-name
Expand Down
29 changes: 6 additions & 23 deletions neuralmonkey/tests/test_chrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,22 @@
import unittest

from neuralmonkey.evaluators.chrf import ChrFEvaluator
from neuralmonkey.tests.test_bleu import DECODED, REFERENCE


CORPUS_DECODED = [
"colorful thoughts furiously sleep",
"little piglet slept all night",
"working working working working working be be be be be be be",
"ich bin walrus",
"walrus for präsident"
]

CORPUS_REFERENCE = [
"the colorless ideas slept furiously",
"pooh slept all night",
"working class hero is something to be",
"I am the working class walrus",
"walrus for president"
]

TOKENS = ["a", "b", "a"]
NGRAMS = [
{"a": 2, "b" : 1},
{"ab": 1, "ba" : 1},
{"aba" : 1},
{"a": 2, "b": 1},
{"ab": 1, "ba": 1},
{"aba": 1},
{}]


DECODED = [d.split() for d in CORPUS_DECODED]
REFERENCE = [r.split() for r in CORPUS_REFERENCE]

FUNC = ChrFEvaluator()
FUNC_P = FUNC.chr_p
FUNC_R = FUNC.chr_r
FUNC_NGRAMS = FUNC._get_ngrams


class TestChrF(unittest.TestCase):

def test_empty_decoded(self):
Expand Down Expand Up @@ -68,5 +50,6 @@ def test_get_ngrams(self):
for i, _ in enumerate(NGRAMS):
self.assertDictEqual(ngrams_out[i], NGRAMS[i])


if __name__ == "__main__":
unittest.main()

0 comments on commit 0409618

Please sign in to comment.