Skip to content

Commit

Permalink
add test_self_bleu_metric (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
JianGuanTHU authored and hzhwcmhf committed Jul 1, 2019
1 parent 9f87ea4 commit dcb751f
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 17 deletions.
28 changes: 12 additions & 16 deletions cotk/metric/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def run_f(self, ele):
'''Auxiliary function which returns:
* **sentence-self-bleu**: sentence-self-bleu value.
'''
return sentence_bleu(ele[0], ele[1], ele[2], smoothing_function=SmoothingFunction().method1)
return sentence_bleu(ele[0], ele[1], smoothing_function=SmoothingFunction().method1)

def close(self):
'''Return a dict which contains:
Expand All @@ -476,21 +476,17 @@ def close(self):
ref = self.hyps[:self.sample]

try:
result = {}
for ngram in range(2, 5):
weight = tuple((1. / ngram for _ in range(ngram)))
if self.sample >= 1000:
pool = Pool(multiprocessing.cpu_count())
bleu_irl = pool.map(self.run_f, [(ref[:i]+ref[i+1:self.sample], ref[i], weight) \
for i in range(self.sample)])
pool.close()
pool.join()
else:
bleu_irl = []
for i in range(self.sample):
bleu_irl.append(self.run_f((ref[:i]+ref[i+1:], ref[i], weight)))
result["self-bleu-%d"%ngram] = 1.0 * sum(bleu_irl) / len(bleu_irl)
return result
bleu_irl = []
if self.sample >= 1000:
pool = Pool(multiprocessing.cpu_count())
bleu_irl = pool.map(self.run_f, [(ref[:i]+ref[i+1:self.sample], ref[i]) \
for i in range(self.sample)])
pool.close()
pool.join()
elif self.sample > 1:
for i in range(self.sample):
bleu_irl.append(self.run_f((ref[:i]+ref[i+1:], ref[i])))
return {"self-bleu" : 1.0 * sum(bleu_irl) / len(bleu_irl)}
except ZeroDivisionError as _:
raise ZeroDivisionError("Bleu smoothing divided by zero. This is a known bug of corpus_bleu, \
usually caused when there is only one sample and the sample length is 1.")
Expand Down
62 changes: 61 additions & 1 deletion tests/metric/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from cotk.metric import MetricBase, \
BleuPrecisionRecallMetric, EmbSimilarityPrecisionRecallMetric, \
PerplexityMetric, MultiTurnPerplexityMetric, BleuCorpusMetric, MultiTurnBleuCorpusMetric, \
PerplexityMetric, MultiTurnPerplexityMetric, BleuCorpusMetric, SelfBleuCorpusMetric, MultiTurnBleuCorpusMetric, \
SingleTurnDialogRecorder, MultiTurnDialogRecorder, LanguageGenerationRecorder, HashValueRecorder, \
MetricChain
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction
Expand Down Expand Up @@ -671,6 +671,66 @@ def test_bleu_bug(self):
bcm.forward(data)
bcm.close()

self_bleu_test_parameter = generate_testcase(\
(zip(test_argument), "add"),
(zip(test_shape, test_type), "multi"),
(zip(test_gen_len), "multi"),
)


class TestSelfBleuCorpusMetric:
def run_f(self, ele):
'''Auxiliary function which returns:
* **sentence-self-bleu**: sentence-self-bleu value.
'''
return sentence_bleu(ele[0], ele[1], smoothing_function=SmoothingFunction().method1)

def get_self_bleu(self, dataloader, input, gen_key):
gens = []
for gen_sen in input[gen_key]:
gen_sen_processed = dataloader.trim_index(gen_sen)
gens.append(gen_sen_processed)
refs = copy.deepcopy(gens)
bleu_irl = []
for i in range(len(gens)):
bleu_irl.append(self.run_f((refs[:i]+refs[i+1:], refs[i])))
return 1.0 * sum(bleu_irl) / len(bleu_irl)

@pytest.mark.parametrize('argument, shape, type, gen_len', self_bleu_test_parameter)
def test_close(self, argument, shape, type, gen_len):
# 'default' or 'custom'
# 'pad' or 'jag'
# 'list' or 'array'
# 'equal' or 'unequal'
# 'random', 'non-empty', 'empty'
# 'random', 'non-empty', 'empty'
dataloader = FakeDataLoader()
gen_key = 'gen' \
if argument == 'default' else 'gk'
data = dataloader.get_data(gen_key=gen_key, \
to_list=(type == 'list'), \
pad=(shape == 'pad'), \
gen_len=gen_len)
_data = copy.deepcopy(data)
if argument == 'default':
bcm = SelfBleuCorpusMetric(dataloader)
else:
bcm = SelfBleuCorpusMetric(dataloader, gen_key)

bcm.forward(data)
assert np.isclose(bcm.close()['self-bleu'], self.get_self_bleu(dataloader, data, gen_key))
assert same_dict(data, _data)

def test_self_bleu_bug(self):
dataloader = FakeDataLoader()
gen = [[1]]
data = {'gen': gen}
bcm = SelfBleuCorpusMetric(dataloader)

with pytest.raises(ZeroDivisionError):
bcm.forward(data)
bcm.close()

multi_bleu_test_parameter = generate_testcase(\
(zip(test_argument), "add"),
(zip(test_shape, test_type), "multi"),
Expand Down

0 comments on commit dcb751f

Please sign in to comment.