In [1]:
from fastai.text import *

In [2]:
targ = [1, 2, 3, 4, 5, 1, 2]
pred = [1, 2, 3, 7, 5, 1, 1]

In [3]:
cnt_pred, cnt_targ = Counter(pred), Counter(targ)

In [4]:
cnt_pred

Counter({1: 3, 2: 1, 3: 1, 7: 1, 5: 1})

In [39]:
cnt_pred[1], cnt_pred[10]

(3, 0)

In [8]:
list((g,c) for g,c in cnt_pred.items())

[(1, 3), (2, 1), (3, 1), (7, 1), (5, 1)]

In [5]:
corrects = sum([min(c, cnt_targ[g]) for g,c in cnt_pred.items()])
corrects

5

In [15]:
class NGram():
    def __init__(self, ngram, max_n=5000): self.ngram,self.max_n = ngram,max_n
    def __eq__(self, other):
        if len(self.ngram) != len(other.ngram): return False
        return np.all(np.array(self.ngram) == np.array(other.ngram))
    def __hash__(self): return int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)]))

max_n should be set to the vocab size, so we're sure we don't have two different ngrams with the same hash

In [16]:
def get_grams(x, n, max_n=5000):
    return x if n==1 else [NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)]

In [18]:
NGram(targ[0:1]) == NGram(targ[0:1])

True

In [20]:
get_grams(targ, 1)

[1, 2, 3, 4, 5, 1, 2]

In [22]:
[ng.ngram for ng in get_grams(targ, 2)]

[[1, 2], [2, 3], [3, 4], [4, 5], [5, 1], [1, 2]]

In [24]:
[ng.ngram for ng in get_grams(targ,3)]

[[1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 1], [5, 1, 2]]

In [29]:
[ng.ngram for ng in get_grams(targ,8)]

[]

In [30]:
for i in range(7-8+1):
    print("no")

In [31]:
list(range(1))

[0]

In [33]:
list(range(-1))

[]

In [36]:
[(ng.ngram, c) for ng, c in Counter(get_grams(targ, 2)).items()]

[([1, 2], 2), ([2, 3], 1), ([3, 4], 1), ([4, 5], 1), ([5, 1], 1)]

In [40]:
def get_correct_ngrams(pred, targ, n, max_n=5000):
    pred_grams,targ_grams = get_grams(pred, n, max_n=max_n),get_grams(targ, n, max_n=max_n)
    pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams)
    return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams)

In [43]:
get_correct_ngrams(pred, targ, 1)

(5, 7)

In [44]:
get_correct_ngrams(pred, targ, 3)

(1, 5)

In [45]:
def sentence_bleu(pred, targ, max_n=5000):
    corrects = [get_correct_ngrams(pred,targ,n,max_n=max_n) for n in range(1,5)]
    n_precs = [c/t for c,t in corrects]
    len_penalty = exp(1 - len(targ)/len(pred)) if len(pred) < len(targ) else 1
    return len_penalty * ((n_precs[0]*n_precs[1]*n_precs[2]*n_precs[3]) ** 0.25)

In [47]:
sentence_bleu(targ, targ)

1.0

In [48]:
sentence_bleu([1, 2, 3, 4, 5, 1, 9], targ)

0.8091067115702212

In [49]:
def corpus_bleu(preds, targs, max_n=5000):
    pred_len,targ_len,n_precs,counts = 0,0,[0]*4,[0]*4
    for pred,targ in zip(preds,targs):
        pred_len += len(pred)
        targ_len += len(targ)
        for i in range(4):
            c,t = ngram_corrects(pred, targ, i+1, max_n=max_n)
            n_precs[i] += c
            counts[i] += t
    n_precs = [c/t for c,t in zip(n_precs,counts)]
    len_penalty = exp(1 - targ_len/pred_len) if pred_len < targ_len else 1
    return len_penalty * ((n_precs[0]*n_precs[1]*n_precs[2]*n_precs[3]) ** 0.25)

In [50]:
class CorpusBLEU(Callback):
    def __init__(self, vocab_sz):
        self.vocab_sz = vocab_sz
        self.name = 'bleu'
    
    def on_epoch_begin(self, **kwargs):
        self.pred_len,self.targ_len,self.n_precs,self.counts = 0,0,[0]*4,[0]*4
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        last_output = last_output.argmax(dim=-1)
        for pred,targ in zip(last_output.cpu().numpy(),last_target.cpu().numpy()):
            self.pred_len += len(pred)
            self.targ_len += len(targ)
            for i in range(4):
                c,t = get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz)
                self.n_precs[i] += c
                self.counts[i] += t
    
    def on_epoch_end(self, last_metrics, **kwargs):
        n_precs = [c/t for c,t in zip(n_precs,counts)]
        len_penalty = exp(1 - targ_len/pred_len) if pred_len < targ_len else 1
        bleu = len_penalty * ((n_precs[0]*n_precs[1]*n_precs[2]*n_precs[3]) ** 0.25)
        return add_metrics(last_metrics, bleu)