In [None]:
from collections import Counter

import numpy as np
from nltk.translate.bleu_score import sentence_bleu

In [None]:
reference_translations = [
    "resources have to be sufficient and they have to be predictable",
    "adequate and predictable resources are required",
]
translations = [
    "there is a need for adequate and predictable resources",
    "resources be sufficient and predictable to",
]


def get_ngram(sentence, n):
    sentence = sentence.split()
    return [" ".join(sentence[i : i + n]) for i in range(len(sentence) - n + 1)]


def precision(references, translation, n):
    translation_ngram = get_ngram(translation, n)
    references_ngram = [get_ngram(ref, n) for ref in references]
    translation_ngram_counter = Counter(translation_ngram)
    references_ngram_counter = [Counter(ref) for ref in references_ngram]
    numerator = sum(
        min(count, max(ref.get(ngram, 0) for ref in references_ngram_counter))
        for ngram, count in translation_ngram_counter.items()
    )
    denominator = sum(translation_ngram_counter.values())
    print(f"p{n}: {numerator}/{denominator}= {numerator/denominator:.3f}")
    return numerator / denominator


def brevity_penalty(references, translation):
    c = len(translation.split())
    rs = [len(ref.split()) for ref in references]
    rd = [abs(r - c) for r in rs]
    min_rd = min(rd)
    r = min([r for r, d in zip(rs, rd) if d == min_rd])
    print(f"len(c): {c}")
    print(f"len(r): {r}")
    return 1 if c > r else np.exp(1 - r / c)


def bleu(references, translation, lambdas):
    bp = brevity_penalty(references, translation)
    print(f"BP: {bp:.3f}")
    p = np.exp(sum(lam * np.log(precision(references, translation, i + 1)) for i, lam in enumerate(lambdas)))
    return bp * p

In [None]:
lambdas = [0.5, 0.5]

for c in translations:
    blue_score = bleu(reference_translations, c, lambdas=lambdas)
    nltk_blue_score = sentence_bleu(reference_translations, c, weights=lambdas)
    print(c)
    print(f"BLEU: {blue_score:.3f}")
    print(f"NLTK BLEU: {nltk_blue_score:.3f}")
    print()