diff --git a/cotk/metric/metric.py b/cotk/metric/metric.py index 10d7ccfd..f8e929af 100644 --- a/cotk/metric/metric.py +++ b/cotk/metric/metric.py @@ -463,7 +463,7 @@ def run_f(self, ele): '''Auxiliary function which returns: * **sentence-self-bleu**: sentence-self-bleu value. ''' - return sentence_bleu(ele[0], ele[1], ele[2], smoothing_function=SmoothingFunction().method1) + return sentence_bleu(ele[0], ele[1], smoothing_function=SmoothingFunction().method1) def close(self): '''Return a dict which contains: @@ -476,21 +476,17 @@ def close(self): ref = self.hyps[:self.sample] try: - result = {} - for ngram in range(2, 5): - weight = tuple((1. / ngram for _ in range(ngram))) - if self.sample >= 1000: - pool = Pool(multiprocessing.cpu_count()) - bleu_irl = pool.map(self.run_f, [(ref[:i]+ref[i+1:self.sample], ref[i], weight) \ - for i in range(self.sample)]) - pool.close() - pool.join() - else: - bleu_irl = [] - for i in range(self.sample): - bleu_irl.append(self.run_f((ref[:i]+ref[i+1:], ref[i], weight))) - result["self-bleu-%d"%ngram] = 1.0 * sum(bleu_irl) / len(bleu_irl) - return result + bleu_irl = [] + if self.sample >= 1000: + pool = Pool(multiprocessing.cpu_count()) + bleu_irl = pool.map(self.run_f, [(ref[:i]+ref[i+1:self.sample], ref[i]) \ + for i in range(self.sample)]) + pool.close() + pool.join() + elif self.sample > 1: + for i in range(self.sample): + bleu_irl.append(self.run_f((ref[:i]+ref[i+1:], ref[i]))) + return {"self-bleu" : 1.0 * sum(bleu_irl) / len(bleu_irl)} except ZeroDivisionError as _: raise ZeroDivisionError("Bleu smoothing divided by zero. This is a known bug of corpus_bleu, \ usually caused when there is only one sample and the sample length is 1.") diff --git a/tests/metric/test_metric.py b/tests/metric/test_metric.py index 0408bb10..cc00ed3b 100644 --- a/tests/metric/test_metric.py +++ b/tests/metric/test_metric.py @@ -7,7 +7,7 @@ from cotk.metric import MetricBase, \ BleuPrecisionRecallMetric, EmbSimilarityPrecisionRecallMetric, \ - PerplexityMetric, MultiTurnPerplexityMetric, BleuCorpusMetric, MultiTurnBleuCorpusMetric, \ + PerplexityMetric, MultiTurnPerplexityMetric, BleuCorpusMetric, SelfBleuCorpusMetric, MultiTurnBleuCorpusMetric, \ SingleTurnDialogRecorder, MultiTurnDialogRecorder, LanguageGenerationRecorder, HashValueRecorder, \ MetricChain from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction @@ -671,6 +671,66 @@ def test_bleu_bug(self): bcm.forward(data) bcm.close() +self_bleu_test_parameter = generate_testcase(\ + (zip(test_argument), "add"), + (zip(test_shape, test_type), "multi"), + (zip(test_gen_len), "multi"), +) + + +class TestSelfBleuCorpusMetric: + def run_f(self, ele): + '''Auxiliary function which returns: + * **sentence-self-bleu**: sentence-self-bleu value. + ''' + return sentence_bleu(ele[0], ele[1], smoothing_function=SmoothingFunction().method1) + + def get_self_bleu(self, dataloader, input, gen_key): + gens = [] + for gen_sen in input[gen_key]: + gen_sen_processed = dataloader.trim_index(gen_sen) + gens.append(gen_sen_processed) + refs = copy.deepcopy(gens) + bleu_irl = [] + for i in range(len(gens)): + bleu_irl.append(self.run_f((refs[:i]+refs[i+1:], refs[i]))) + return 1.0 * sum(bleu_irl) / len(bleu_irl) + + @pytest.mark.parametrize('argument, shape, type, gen_len', self_bleu_test_parameter) + def test_close(self, argument, shape, type, gen_len): + # 'default' or 'custom' + # 'pad' or 'jag' + # 'list' or 'array' + # 'equal' or 'unequal' + # 'random', 'non-empty', 'empty' + # 'random', 'non-empty', 'empty' + dataloader = FakeDataLoader() + gen_key = 'gen' \ + if argument == 'default' else 'gk' + data = dataloader.get_data(gen_key=gen_key, \ + to_list=(type == 'list'), \ + pad=(shape == 'pad'), \ + gen_len=gen_len) + _data = copy.deepcopy(data) + if argument == 'default': + bcm = SelfBleuCorpusMetric(dataloader) + else: + bcm = SelfBleuCorpusMetric(dataloader, gen_key) + + bcm.forward(data) + assert np.isclose(bcm.close()['self-bleu'], self.get_self_bleu(dataloader, data, gen_key)) + assert same_dict(data, _data) + + def test_self_bleu_bug(self): + dataloader = FakeDataLoader() + gen = [[1]] + data = {'gen': gen} + bcm = SelfBleuCorpusMetric(dataloader) + + with pytest.raises(ZeroDivisionError): + bcm.forward(data) + bcm.close() + multi_bleu_test_parameter = generate_testcase(\ (zip(test_argument), "add"), (zip(test_shape, test_type), "multi"),