# 05-7: Evaluate Text Generation

## BLEU

In [None]:
import math

In [None]:
def calculate_brevity_penalty(reference_len: int, candidate_len: int) -> float:
    # Raise an error if any number is negative
    if reference_len < 0 or candidate_len < 0:
        raise ValueError("Length cannot be negative")
    # If the candidate length is greater than the reference length, r/c < 1, exp(positive number) > 1,  brevity penalty = 1
    if candidate_len > reference_len:
        print(f"Candidate length \t ({candidate_len}) \t is greater than the reference length \t ({reference_len}), \t so the Brevity Penalty is equal to \t 1.000")
        return 1.0
    # If the lengths are equal, then r/c = 1, and exp(0) = 1
    if candidate_len == reference_len:
        print(f"Candidate length \t ({candidate_len}) \t is equal to the reference length \t ({reference_len}), \t so the Brevity Penalty is equal to \t 1.000")
        return 1.0
    # If candidate is empty, brevity penalty = 0, because r/0 -> inf and exp(-inf) -> 0
    if candidate_len == 0:
        print(f"Candidate length \t ({candidate_len}) \t is equal to 0.0, \t\t\t\t so the Brevity Penalty is equal to \t 0.000")
        return 0.0

    # If the candidate length is less than the reference length, brevity penalty = exp(1-r/c)
    print(f"Candidate length \t ({candidate_len}) \t is less than the reference length \t ({reference_len}),\t so the Brevity Penalty is equal to \t {math.exp(1 - reference_len / candidate_len):.3f}")
    return math.exp(1 - reference_len / candidate_len)

In [None]:
def calculate_brevity_penalty_2(reference_len: int, candidate_len: int) -> float:
    # Raise an error if any number is negative
    if reference_len < 0 or candidate_len < 0:
        raise ValueError("Length cannot be negative")
    # Avoid a division by 0
    if candidate_len == 0:
        if reference_len == 0:
            return 1.0
        else:
            return 0.0 
    return min(1.0, math.exp(1 - reference_len / (candidate_len)))

In [None]:
candidates = ["It is a guide to action which ensures that the military always obeys the commands of the party.",
              "It is to insure the troops forever hearing the activity guidebook that party direct.",
              ""]

In [None]:
references = ["It is a guide to action that ensures that the military will forever heed Party commands.",
              "It is the guiding principle which guarantees the military forces always being under the command of the Party.",
              "It is the practical guide for the army always to heed the directions of the party."]

In [None]:
from itertools import product

In [None]:
bp1 = [calculate_brevity_penalty(len(reference), len(candidate)) for reference, candidate in product(references, candidates)]

In [None]:
bp_2 = [calculate_brevity_penalty_2(len(reference), len(candidate)) for reference, candidate in product(references, candidates)]

In [None]:
bp1 == bp_2

### Precision

$\text{modified precision}(n) = \cfrac{\sum \text{Count Clip}(n)}{\sum \text{Count n-gram}_{candidate}}$

$\text{Count Clip}(n) = min(\text{Count n-gram}_{candidate}, max(\text{Count n-gram}_{reference}))$

In [None]:
from collections import Counter
from fractions import Fraction
from itertools import tee


def ngrams(sequence, n):
    # Creates the sliding window, of n no. of items.
    # `iterables` is a tuple of iterables where each iterable is a window of n items.
    iterables = tee(iter(sequence), n)

    for i, sub_iterable in enumerate(iterables):  # For each window,
        for _ in range(i):  # iterate through every order of ngrams
            next(sub_iterable, None)  # generate the ngrams within the window.
    return zip(*iterables)  # Unpack and flattens the iterables.


def count_clip(counts: Counter, max_counts: dict) -> dict:
    clipped_counts = {}
    for ngram, count in counts.items():
        clipped_count = min(count, max_counts[ngram])
        clipped_counts[ngram] = clipped_count

    return clipped_counts
        

def calculate_modified_precision(references, candidate, n):
    candidate = candidate.split()
    candidate_counts = Counter(ngrams(candidate, n)) if len(candidate) >= n else Counter()
    
    max_counts = {}
    for ref in references:
        reference = ref.split()
        reference_counts = (
            Counter(ngrams(reference, n)) if len(reference) >= n else Counter()
        )
        for ngram in candidate_counts:
            max_counts[ngram] = max(max_counts.get(ngram, 0), reference_counts[ngram])

    clipped_counts = count_clip(candidate_counts, max_counts)
    numerator = sum(clipped_counts.values())
    
    # Ensures that denominator is minimum 1 to avoid ZeroDivisionError.
    denominator = max(1, sum(candidate_counts.values()))

    return Fraction(numerator, denominator, _normalize=False)

In [None]:
print("References\n")
_ = [print(reference) for reference in references]

In [None]:
print("Candidates\n")
_ = [print(f"Candidate {i} is '{candidate}'") for i, candidate in enumerate(candidates)]

In [None]:
[f"The {j+1}-gram modified precision for candidate {i} is {calculate_modified_precision(references, candidate, j+1)}" for i, candidate in enumerate(candidates) for j in range(4)]

### n-gram overlap

$\text{n-gram overlap} = \exp(\sum_{n=1}^{N}w_n\log(\text{modified precision}(n)))$

In [None]:
def calculate_n_gram_overlap(references, candidate, weights=(0.25, 0.25, 0.25, 0.25)):

    # compute modified precision for 1-4 ngrams
    modified_precision_numerators = Counter()  
    modified_precision_denominators = Counter()  
    candidate_lengths, reference_lengths = 0, 0

    for i, _ in enumerate(weights, start=1):
        modified_precision_i = calculate_modified_precision(references, candidate, i)
        modified_precision_numerators[i] += modified_precision_i.numerator
        modified_precision_denominators[i] += modified_precision_i.denominator

    # remove zero precision
    modified_precision_n = [
        Fraction(modified_precision_numerators[i], modified_precision_denominators[i], 
        _normalize=False)
        for i, _ in enumerate(weights, start=1)
        if modified_precision_numerators[i] > 0
    ]
    weighted_precisions = (weight_i * math.log(precision_i) for weight_i, precision_i in zip(weights, modified_precision_n))
    precisions_sum = math.fsum(weighted_precisions)

    return math.exp(precisions_sum)

def bleu(references, candidate, weights=(0.25, 0.25, 0.25, 0.25)):  
    candidate_len = len(candidate.split())
    references_lens = (len(reference.split()) for reference in references)

    # Reference length closest to the candidate length
    closest_reference_len = min(
        references_lens, key=lambda reference_len: (abs(reference_len - candidate_len), reference_len)
    )
    brevity_penalty = calculate_brevity_penalty_2(closest_reference_len, candidate_len)
    n_gram_overlap = calculate_n_gram_overlap(references, candidate, weights)
    
    return brevity_penalty * n_gram_overlap
    

### BLEU

$BLEU = \text{Brevity Penalty}\times\text{n-gram overlap}$

In [None]:
def bleu(references, candidate, weights=(0.25, 0.25, 0.25, 0.25)):  
    candidate_len = len(candidate.split())
    references_lens = (len(reference.split()) for reference in references)

    # Reference length closest to the candidate length
    closest_reference_len = min(
        references_lens, key=lambda reference_len: (abs(reference_len - candidate_len), reference_len)
    )
    brevity_penalty = calculate_brevity_penalty_2(closest_reference_len, candidate_len)
    n_gram_overlap = calculate_n_gram_overlap(references, candidate, weights)
    
    return brevity_penalty * n_gram_overlap

In [None]:
bleu(references, candidates[0])

### NLTK Implementation

In [None]:
!pip install -U nltk

In [None]:
# TODO: Implement BLEU score with NLTK library