In [1]:
import math
from nltk import word_tokenize
from collections import Counter
from nltk.util import ngrams


class BLEU(object):
    @staticmethod
    def compute(candidate, references, weights):
        candidate = [c.lower() for c in candidate]
        references = [[r.lower() for r in reference] for reference in references]

        p_ns = (BLEU.modified_precision(candidate, references, i) for i, _ in enumerate(weights, start=1))
        s = math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n)

        bp = BLEU.brevity_penalty(candidate, references)
        return bp * math.exp(s)

    @staticmethod
    def modified_precision(candidate, references, n):
        counts = Counter(ngrams(candidate, n))

        if not counts:
            return 0

        max_counts = {}
        for reference in references:
            reference_counts = Counter(ngrams(reference, n))
            for ngram in counts:
                max_counts[ngram] = max(max_counts.get(ngram, 0), reference_counts[ngram])

        clipped_counts = dict((ngram, min(count, max_counts[ngram])) for ngram, count in counts.items())

        return sum(clipped_counts.values()) / sum(counts.values())
    
    @staticmethod
    def brevity_penalty(candidate, references):
        c = len(candidate)
        # r = min(abs(len(r) - c) for r in references)
        r = min(len(r) for r in references)

        if c > r:
            return 1
        else:
            return math.exp(1 - r / c)
        


In [2]:
from EncoderDecoder import EncoderDecoder, EncoderCNN, DecoderRNN
from data_utils import Img2LatexDataset, load_img
from train_model import *
from torch.utils.data import DataLoader

model = load_model("./models/model_old.pt")
model.eval()

dataset = Img2LatexDataset("../data/SyntheticData/images/", "../data/SyntheticData/test.csv")
loader = DataLoader(dataset, batch_size=2, shuffle=True)

batch = next(iter(loader))
print(f"Images shape: {batch[0].shape}, formulas shape: {batch[1].shape}")

Using device: cpu
Using device: cpu
LOADED MODEL to cpu
Images shape: torch.Size([2, 3, 224, 224]), formulas shape: torch.Size([2, 405])


In [3]:
imgs, labels = batch

preds = model(imgs)

print(preds)


Predicting token 1/629..
Predicting token 2/629..
Predicting token 3/629..
Predicting token 4/629..
Predicting token 5/629..
Predicting token 6/629..
Predicting token 7/629..
Predicting token 8/629..
Predicting token 9/629..
Predicting token 10/629..
Predicting token 11/629..
Predicting token 12/629..
Predicting token 13/629..
Predicting token 14/629..
Predicting token 15/629..
Predicting token 16/629..
Predicting token 17/629..
Predicting token 18/629..
Predicting token 19/629..
Predicting token 20/629..
Predicting token 21/629..
Predicting token 22/629..
Predicting token 23/629..
Predicting token 24/629..
Predicting token 25/629..
Predicting token 26/629..
Predicting token 27/629..
Predicting token 28/629..
Predicting token 29/629..
Predicting token 30/629..
Predicting token 31/629..
Predicting token 32/629..
Predicting token 33/629..
Predicting token 34/629..
Predicting token 35/629..
Predicting token 36/629..
Predicting token 37/629..
Predicting token 38/629..
Predicting token 39/6

In [4]:
ground_truths = [" ".join([model.decoder.vocab[tok] for tok in labels[i]]) for i in range(len(labels))]
print(ground_truths)
overall = 0
for gt, pred in zip(ground_truths, preds):
    gt = gt.split()
    pred = pred.split()
    overall += BLEU.compute(pred,[gt], weights=[1/4, 1/4, 1/4, 1/4])

print("Macro Bleu : ", overall/len(preds))

Macro Bleu :  0.04487193606856423
