In [None]:
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from datasets import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu
import torch.nn.functional as F
from tqdm import tqdm
import argparse

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

def get_args():
    """
    Define parameters for evaluation.

    :return: arguments
    """

    parser = argparse.ArgumentParser(description='Evaluation')

    parser.add_argument('--checkpoint_folder', '-cf', default='/Users/rohitharavindramyla/Desktop/CSCI2470_Project/checkpoints/',
                    help='path to checkpoint')
    parser.add_argument('--dataset', '-d', default='flickr8k', help='dataset')
    parser.add_argument('--beam_size', '-b', default=5, type=int, help='beam size for beam search')

    args = parser.parse_args()
    return args

def evaluate(args):
    """
    Evaluation

    :param beam_size: beam size at which to generate captions for evaluation
    :return: BLEU scores
    """
    # Load model
    checkpoint = torch.load(os.path.join(args.checkpoint_folder,
                                         'BEST_checkpoint_{:s}_5_cap_per_img_5_min_word_freq.pth.tar'.format(
                                             args.dataset)))
    decoder = checkpoint['decoder']
    decoder = decoder.to(device)
    decoder.eval()
    encoder = checkpoint['encoder']
    encoder = encoder.to(device)
    encoder.eval()

    # Load word map (word2ix)
    word_map_file = os.path.join('/Users/rohitharavindramyla/Desktop/CSCI2470_Project/wordmaps',
                             'WORDMAP_{:s}_5_cap_per_img_5_min_word_freq.json'.format(args.dataset))
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)
    rev_word_map = {v: k for k, v in word_map.items()}
    vocab_size = len(word_map)

    # Normalization transform
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # DataLoader
    data_folder = '/Users/rohitharavindramyla/Desktop/CSCI2470_Project/data'  # folder with data files saved by create_input_files.py
    data_name = '{:s}_5_cap_per_img_5_min_word_freq'.format(args.dataset)  # base name shared by data files
    loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=1, shuffle=True, pin_memory=False)

    # TODO: Batched Beam Search

    # Lists to store references (true captions), and hypothesis (prediction) for each image
    # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
    # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
    references = list()
    hypotheses = list()

    # For each image
    for i, (image, caps, caplens, allcaps) in enumerate(
            tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(args.beam_size))):

        k = args.beam_size

        # Move to GPU device, if available
        image = image.to(device)  # (1, 3, 256, 256)

        # Encode
        try:
            if decoder.adaptive_att:
                encoder_out, v_g = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim)

            else:
                encoder_out = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim)

        except AttributeError:
            encoder_out = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim)

        enc_image_size = encoder_out.size(1)
        encoder_dim = encoder_out.size(3)

        # Flatten encoding
        encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # We'll treat the problem as having a batch size of k
        encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)

        # Tensor to store top k previous words at each step; now they're just <start>
        k_prev_words = torch.LongTensor([[word_map['<start>']]] * k).to(device)  # (k, 1)

        # Tensor to store top k sequences; now they're just <start>
        seqs = k_prev_words  # (k, 1)

        # Tensor to store top k sequences' scores; now they're just 0
        top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

        # Lists to store completed sequences and scores
        complete_seqs = list()
        complete_seqs_scores = list()

        # Start decoding
        step = 1
        h, c = decoder.init_hidden_state(encoder_out)

        # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
        while True:

            embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)

            try:

                if decoder.adaptive_att:

                    g_t = decoder.sigmoid(decoder.affine_embed(embeddings) + decoder.affine_decoder(h))
                    s_t = g_t * torch.tanh(c)

                    h, c = decoder.decode_step_adaptive(torch.cat([embeddings, v_g.expand_as(embeddings)], dim=1), (h, c))  # (batch_size_t, decoder_dim)

                    attention_weighted_encoding, alpha = decoder.adaptive_attention(encoder_out, h, s_t)

                    scores = decoder.fc(h) + decoder.fc_encoder(attention_weighted_encoding)

                else:

                    awe, _ = decoder.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)

                    gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
                    awe = gate * awe

                    h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)

                    scores = decoder.fc(h)  # (s, vocab_size)

            except AttributeError:

                awe, _ = decoder.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)

                gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
                awe = gate * awe

                h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)

                scores = decoder.fc(h)  # (s, vocab_size)


            scores = F.log_softmax(scores, dim=1)
            scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)


            # For the first step, all k points will have the same scores (since same k previous words, h, c)
            if step == 1:
                top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
            else:
                # Unroll and find top scores, and their unrolled indices
                top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

            # Convert unrolled indices to actual indices of scores
            prev_word_inds = top_k_words / vocab_size  # (s)
            next_word_inds = top_k_words % vocab_size  # (s)

            # Add new words to sequences
            seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)

            # Which sequences are incomplete (didn't reach <end>)?
            incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                               next_word != word_map['<end>']]
            complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

            # Set aside complete sequences
            if len(complete_inds) > 0:
                complete_seqs.extend(seqs[complete_inds].tolist())
                complete_seqs_scores.extend(top_k_scores[complete_inds])
            k -= len(complete_inds)  # reduce beam length accordingly

            # Proceed with incomplete sequences
            if k == 0:
                break
            seqs = seqs[incomplete_inds]
            h = h[prev_word_inds[incomplete_inds]]
            c = c[prev_word_inds[incomplete_inds]]
            encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
            top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
            k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

            # Break if things have been going on too long
            if step > 50:
                break
            step += 1

        i = complete_seqs_scores.index(max(complete_seqs_scores))
        seq = complete_seqs[i]

        # References
        img_caps = allcaps[0].tolist()
        img_captions = list(
            map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<end>'], word_map['<pad>']}],
                img_caps))  # remove <start> and pads
        references.append(img_captions)

        # Hypotheses
        hypotheses.append([w for w in seq if w not in {word_map['<start>'], word_map['<end>'], word_map['<pad>']}])

        assert len(references) == len(hypotheses)

    # Calculate BLEU scores
    bleu1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0))
    bleu2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0) )
    bleu3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0))
    bleu4 = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25))

    return bleu1 ,bleu2, bleu3, bleu4


if __name__ == '__main__':

    args = get_args()

    bleu1, bleu2, bleu3, bleu4 = evaluate(args)
    print("\nBLEU-1 score @ beam size of %d is %.4f." % (args.beam_size, bleu1))
    print("\nBLEU-2 score @ beam size of %d is %.4f." % (args.beam_size, bleu2))
    print("\nBLEU-3 score @ beam size of %d is %.4f." % (args.beam_size, bleu3))
    print("\nBLEU-4 score @ beam size of %d is %.4f." % (args.beam_size, bleu4))
