In [1]:
import numpy as np
import torch


def grouper(seq, n):
    '''Extract all n-grams from a sequence
    '''
    ngrams = []
    for i in range(len(seq) - n + 1):
        ngrams.append(seq[i : i + n])
    
    return ngrams


def n_gram_precision(reference, candidate, n):
    '''Calculate the precision for a given order of n-gram
    '''
    total_matches = 0
    ngrams_r = grouper(reference, n)
    ngrams_c = grouper(candidate, n)
    total_num = len(ngrams_c)
    assert total_num > 0
    for ngram_c in ngrams_c:
        if ngram_c in ngrams_r:
            total_matches += 1
    return total_matches / total_num


def brevity_penalty(reference, candidate):
    '''Calculate the brevity penalty between a reference and candidate
    '''
    if len(candidate) == 0:
        return 0
    if len(reference) <= len(candidate):
        return 1
    return np.exp(1 - (len(reference) / len(candidate)))


def BLEU_score_batch(reference_lst, hypothesis_lst, n):
    '''Calculate the BLEU score
    '''
    prec_list = []
    for reference, hypothesis in zip(reference_lst, hypothesis_lst):
        bp = brevity_penalty(reference, hypothesis)
        prec = 1
        cand_len = min(n, len(hypothesis))
        if(cand_len == 0):
            return 0
        for i in range(1, cand_len + 1):
            prec = prec * n_gram_precision(reference, hypothesis, i)
        prec = prec ** (1/n)
        prec_list.append(prec * bp)
    bleu4_score = sum(prec_list) / len(prec_list)
    return bleu4_score
    

In [2]:
BLEU_score_batch(
    torch.tensor([[0, 2, 3, 3, 4, 5, 0], [0, 5, 7, 4, 3, 2, 0]]).tolist(),
    torch.tensor([[0, 2, 3, 3, 4, 8, 0], [0, 5, 7, 4, 3, 1, 0]]).tolist(),
    4
)

0.6434588841607617