Skip to content

Commit

Permalink
Refactor BLEU Metric.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 601823244
  • Loading branch information
tf-model-analysis-team authored and tfx-copybara committed Jan 26, 2024
1 parent 80b2a02 commit d7adb3f
Show file tree
Hide file tree
Showing 2 changed files with 284 additions and 85 deletions.
206 changes: 133 additions & 73 deletions tensorflow_model_analysis/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,51 @@
_BLEU_NAME_DEFAULT = 'BLEU'


# TODO: b/287700355) - Add __slots__ to this dataclass.
# TODO: b/287700355 - Add __slots__ to _Accumulator
@dataclasses.dataclass
class _Accumulator:
"""Accumulator for _BleuCombiner.
Attributes:
matching_ngrams: A list containing the number of matching n-grams between
the hypothesis and the reference for each n. This should be initialized as
np.zeros(max_ngram_order).
total_ngrams: A list containing the total number of n-grams for each n. Like
'matching_ngrams', this should be initialized as an
np.zeros(max_ngram_order).
hyp_len: The number of unigrams (words) in the hypothesis.
ref_len: The number of unigrams (words) in the reference.
matching_ngrams[n - 1] = number of matching n-grams for n > 0
matching_ngrams[0] = number of matching unigrams
matching_ngrams[1] = number of matching bigrams
...
total_ngrams[n - 1] = (
max(number of n-grams in hyp, number of n-grams in ref) for n > 0
)
total_ngrams[] follows same pattern as matching_ngrams[]
For hypotheses and references, ending punctuation (periods, exclamation
points, etc.) count as their own unigram.
For example, 'Google.' has 2 unigrams: 'Google' and '.'.
"""

matching_ngrams: np.ndarray
total_ngrams: np.ndarray
hyp_len: int = 0
ref_len: int = 0

def __eq__(self, other):
return (
np.array_equal(self.matching_ngrams, other.matching_ngrams)
and np.array_equal(self.total_ngrams, other.total_ngrams)
and self.hyp_len == other.hyp_len
and self.ref_len == other.ref_len
)


# TODO: b/287700355 - Add __slots__ to this dataclass.
@dataclasses.dataclass
class _RefInfo:
ngrams: collections.Counter[dict[tuple[str], int]] # n-grams and counts
Expand Down Expand Up @@ -92,6 +136,36 @@ def __init__(
self.key = key
self.bleu_metric = sacrebleu.BLEU(**bleu_kwargs)

def _extract_statistics_for_empty_reference(
self, hypotheses: Sequence[str]
) -> list[_Accumulator]:
"""Returns sentence-level statistics when there are no references.
Args:
hypotheses: A sequence of hypothesis strings.
Returns:
A list of _Accumulators of segment statistics.
"""
sum_hyp_len = 0
for hypothesis in hypotheses:
_, hyp_len = sacrebleu.helpers.extract_all_word_ngrams(
hypothesis, 1, self.bleu_metric.max_ngram_order
)
sum_hyp_len += hyp_len

# No n-grams.
matching_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int)
total_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int)

return [
_Accumulator(
matching_ngrams=matching_ngrams,
total_ngrams=total_ngrams,
hyp_len=sum_hyp_len,
)
]

def _preprocess_segment(self, sentence: str) -> str:
"""Given a sentence, lowercases (optionally) and tokenizes it."""
if self.bleu_metric.lowercase:
Expand Down Expand Up @@ -153,7 +227,7 @@ def _compute_segment_statistics(
self,
hypothesis: str,
ref_info: _RefInfo,
) -> list[int]:
) -> _Accumulator:
"""Given a (pre-processed) hypothesis sentence and already computed reference n-grams & lengths, returns the best match statistics across the references.
Args:
Expand All @@ -162,7 +236,7 @@ def _compute_segment_statistics(
the list of reference lengths.
Returns:
A list of integers with match statistics.
An _Accumulator with match statistics.
"""
# Extract n-grams for the hypothesis.
hyp_ngrams, hyp_len = sacrebleu.helpers.extract_all_word_ngrams(
Expand All @@ -173,8 +247,8 @@ def _compute_segment_statistics(

# Count the stats.
# Although counter has its internal & and | operators, this is faster.
matching_ngrams = [0] * self.bleu_metric.max_ngram_order
total_ngrams = matching_ngrams[:]
matching_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int)
total_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int)

for hyp_ngram, hyp_count in hyp_ngrams.items():
# n-gram order.
Expand All @@ -188,14 +262,18 @@ def _compute_segment_statistics(
if hyp_ngram in ref_ngrams:
matching_ngrams[n] += min(hyp_count, ref_ngrams[hyp_ngram])

# Return a flattened list as per 'stats' semantics.
return [hyp_len, ref_len] + matching_ngrams + total_ngrams
return _Accumulator(
matching_ngrams=matching_ngrams,
total_ngrams=total_ngrams,
hyp_len=hyp_len,
ref_len=ref_len,
)

def _extract_corpus_statistics(
self,
hypotheses: Sequence[str],
references: Optional[Sequence[Sequence[str]]],
) -> list[list[int]]:
references: Sequence[Sequence[str]],
) -> list[_Accumulator]:
"""Reads the corpus and returns sentence-level match statistics for faster re-computations esp during statistical tests.
Args:
Expand All @@ -205,8 +283,12 @@ def _extract_corpus_statistics(
batch_size_of_hypotheses).
Returns:
A list where each sublist corresponds to segment statistics.
A list of _Accumulators of segment statistics.
"""
if np.all((np.array(references) == [''])):
# Empty Reference.
return self._extract_statistics_for_empty_reference(hypotheses)

stats = []
tok_count = 0

Expand Down Expand Up @@ -238,73 +320,48 @@ def _extract_corpus_statistics(

return stats

def _compute_score_from_stats(self, stats: list[int]) -> sacrebleu.BLEUScore:
def _compute_score_from_accumulator(
self, accumulator: _Accumulator
) -> sacrebleu.BLEUScore:
"""Computes the final score from already aggregated statistics.
'stats' semantics are preserved here from the wrapped implementation.
stats = [hyp_len, ref_len, matching_ngrams, total_ngrams] where
hyp_len = number of unigrams (words) in the hypothesis
ref_len = number of unigrams (words) in the reference
Note, ending punctuation (periods, exclamation points, etc.) count as
their own unigram.
For example, 'Google.' has 2 unigrams: 'Google' and '.'
matching_ngrams[n - 1] = number of matching n-grams for n > 0
matching_ngrams[0] = number of matching unigrams
matching_ngrams[1] = number of matching bigrams
...
total_ngrams[n - 1] = number of n-grams in hyp for n > 0
total_ngrams[] follows same pattern as matching_ngrams[]
Args:
stats: A list of segment-level statistics.
accumulator: An accumulator containing segment-level statistics.
Returns:
A 'BLEUScore' object.
"""
bleu_metric = self.bleu_metric

# matching_ngrams[n - 1] = number of matching n-grams for n > 0
matching_ngrams = stats[2 : 2 + bleu_metric.max_ngram_order]

# total_ngrams[n - 1] = number of n-grams in hyp for n > 0
total_ngrams = stats[2 + bleu_metric.max_ngram_order :]

# hyp_len = number of unigrams (words) in the hypothesis
hyp_len = int(stats[0])

# ref_len = number of unigrams (words) in the reference
ref_len = int(stats[1])

# TODO: b/319702245 - Resolve the issue below in compute_bleu().
# We need to convert the np.ndarray's to a lists here.
# If we leave it as a np.ndarray of ints, then sacrebleu will not be able to
# add decimal smooth values to the stats list within compute_bleu().
# If we convert it to an np.ndarray of floats, then sacrebleu will not be
# able to propely set BLEUScore._verbose because there is no format code 'd'
# for floats.
return self.bleu_metric.compute_bleu(
correct=matching_ngrams,
total=total_ngrams,
sys_len=hyp_len,
ref_len=ref_len,
correct=accumulator.matching_ngrams.tolist(),
total=accumulator.total_ngrams.tolist(),
sys_len=accumulator.hyp_len,
ref_len=accumulator.ref_len,
smooth_method=bleu_metric.smooth_method,
smooth_value=bleu_metric.smooth_value,
effective_order=bleu_metric.effective_order,
max_ngram_order=bleu_metric.max_ngram_order,
)

def create_accumulator(self):
"""Accumulator is the running total of 'stats' of type np.ndarray.
Args: None.
Returns:
'stats' list of all zeros.
"""
# TODO: b/321082946 - Replace 'stats' semantics with a dataclass.
# len(stats)
# = len(hyp_len) + len(ref_len) + len(matching_ngrams) + len(total_ngrams)
# = 1 + 1 + max_ngram_order + max_ngram_order = 2 + 2 * max_ngram_order
return np.zeros(2 + 2 * self.bleu_metric.max_ngram_order, dtype=int)
return _Accumulator(
matching_ngrams=np.zeros(self.bleu_metric.max_ngram_order, dtype=int),
total_ngrams=np.zeros(self.bleu_metric.max_ngram_order, dtype=int),
)

def add_input(
self,
accumulator: np.ndarray,
accumulator: _Accumulator,
element: metric_types.StandardMetricInputs,
) -> np.ndarray:
) -> _Accumulator:
# references = labels, hypotheses = predictions
references, hypotheses, _ = next(
metric_util.to_label_prediction_example_weight(
Expand All @@ -318,28 +375,31 @@ def add_input(
)
)

# Sum accumulator and new stats
return accumulator + np.sum(
self._extract_corpus_statistics(hypotheses, references), axis=0
)
corpus_stats = self._extract_corpus_statistics(hypotheses, references)
corpus_stats.append(accumulator)

return self.merge_accumulators(corpus_stats)

def merge_accumulators(
self, list_of_stats: Iterable[np.ndarray]
) -> np.ndarray:
"""Sum of list of stats."""
return np.sum(list_of_stats, axis=0)
self, accumulators: Iterable[_Accumulator]
) -> _Accumulator:
accumulators = iter(accumulators)
result = next(accumulators)
for accumulator in accumulators:
result.hyp_len += accumulator.hyp_len
result.ref_len += accumulator.ref_len
result.matching_ngrams = np.sum(
[result.matching_ngrams, accumulator.matching_ngrams], axis=0
)
result.total_ngrams = np.sum(
[result.total_ngrams, accumulator.total_ngrams], axis=0
)
return result

def extract_output(
self, accumulator: np.ndarray
self, accumulator: _Accumulator
) -> dict[metric_types.MetricKey, sacrebleu.BLEUScore]:
# TODO: b/319702245 - Resolve the issue below in compute_bleu().
# We need to convert the accumulator to a list here.
# If we leave it as a np.ndarray of ints, then sacrebleu will not be able to
# add decimal smooth values to the stats list within compute_bleu().
# If we convert it to an np.ndarray of floats, then sacrebleu will not be
# able to propely set BLEUScore._verbose because there is no format code 'd'
# for floats.
return {self.key: self._compute_score_from_stats(accumulator.tolist())}
return {self.key: self._compute_score_from_accumulator(accumulator)}


def _bleu(
Expand Down
Loading

0 comments on commit d7adb3f

Please sign in to comment.