In [None]:
import math
from nltk import word_tokenize
from collections import Counter
from nltk.util import ngrams


class BLEU(object):
    @staticmethod
    def compute(candidate, references, weights):
        candidate = [c.lower() for c in candidate]
        references = [[r.lower() for r in reference] for reference in references]

        p_ns = (BLEU.modified_precision(candidate, references, i) for i, _ in enumerate(weights, start=1))
        s = math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n)

        bp = BLEU.brevity_penalty(candidate, references)
        return bp * math.exp(s)

    @staticmethod
    def modified_precision(candidate, references, n):
        counts = Counter(ngrams(candidate, n))

        if not counts:
            return 0

        max_counts = {}
        for reference in references:
            reference_counts = Counter(ngrams(reference, n))
            for ngram in counts:
                max_counts[ngram] = max(max_counts.get(ngram, 0), reference_counts[ngram])

        clipped_counts = dict((ngram, min(count, max_counts[ngram])) for ngram, count in counts.items())

        return sum(clipped_counts.values()) / sum(counts.values())
    
    @staticmethod
    def brevity_penalty(candidate, references):
        c = len(candidate)
        # r = min(abs(len(r) - c) for r in references)
        r = min(len(r) for r in references)

        if c > r:
            return 1
        else:
            return math.exp(1 - r / c)
        


In [None]:
from EncoderDecoder import EncoderDecoder, EncoderCNN, DecoderRNN
from data_utils import Img2LatexDataset, load_img
from train_model import *
from torch.utils.data import DataLoader

model = load_model("./models/part1a.pt")
model.eval()

dataset = Img2LatexDataset("../data/SyntheticData/images/", "../data/SyntheticData/test.csv")
loader = DataLoader(dataset, batch_size=2, shuffle=True)

batch = next(iter(loader))
print(f"Images shape: {batch[0].shape}, formulas shape: {batch[1].shape}")

In [None]:
overall = 0

counted = 1

for batch in loader:
    imgs, labels = batch
    preds = []

    for i in range(len(labels)):
        preds.append(" ".join(model(imgs[i])))
        counted += 1
        
    ground_truths = [" ".join([model.decoder.vocab[tok] for tok in labels[i]]) for i in range(len(labels))]
    for gt, pred in zip(ground_truths, preds):
        gt = gt.split()
        pred = pred.split()
        overall += BLEU.compute(pred,[gt], weights=[1/4, 1/4, 1/4, 1/4])

    if counted % 10 == 1:
        print(f"Out of {counted}, BLEU score: {overall/counted}")

print("Macro Bleu : ", overall/counted)