## Evaluating the pretrained speaker: MS COCO Captions evaluation

The following notebook is to set up / re-use public code for image caption evaluation following the pipeline suggested in the original paper on MS COCO Captions. Additionally, given the specifics of the downstream task (namely, sampling from the model and minimizing the CCE loss against that), a baseline for the validation PPL on sampled captions is computed as well.  

Due to installation difficulties, and because it was not a part of the original paper, the SPICE score computation is commented out in the cloned code and not performed in the present evaluation. 

#### Utils
In order to compute the standard image caption evaluation metrics, the code provided in [this](https://github.com/daqingliu/coco-caption) repo is used. Since it requires the results to be formatted in a specific syntax, script below performs some utility mapping of validation annotation IDs to  validation image IDs. 

In [33]:
# reproducing the desired results format
import json
import torch
from pycocotools.coco import COCO
import math
import pandas as pd
from torchvision import transforms
import os
import sys
import torch.nn as nn
import numpy as np

In [10]:
# --> i need to produce {'image_id': XXX, 'caption': 'lower cased string'} items when validating. 
# when i iterate over items with my data loader, i get annotation ids. 
# so i need to map ann IDs to img IDs. I can do that via th COCO .loadAnns(annIds) and then retrieve 'image_id', 

# actually i'll just create a file for my entire val split
val_ids = torch.load("val_split_IDs_from_COCO_train.pt")
print(len(val_ids))
coco = COCO("../../../data/train/annotations/captions_train2014.json")
val_imgIDs = [coco.loadAnns(i)[0]['image_id'] for i in val_ids]
print(len(val_imgIDs))
# torch.save(val_imgIDs, "val_split_imgIDs_from_COCO_train.pt")

264048
loading annotations into memory...
Done (t=0.42s)
creating index...
index created!
264048


In [44]:
# further, it is necessary to create an annotation subset matching the val set in length
with open("../../../data/speaker_eval_results/decoder_scheduled_sampling_wGreedyDecoding_k150-1_results.json", "r") as f_in:
    coco_all = json.load(f_in)

In [46]:
# len(coco_all["images"])
# len(coco_all["annotations"])
coco_all[0]

{'image_id': 549810,
 'caption': 'Safari safari car print print rusted print rusted print though print though print'}

#### Evaluation function
The wrapper below takes in a trained model and performs the evaluation on a set number of validation images (set including the images used in the reference game) and on the test set (images used neither in pretraining nor experiments). 

In [2]:
# import agent modules from the actual repo
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
from agents.speaker import DecoderRNN
from utils.build_dataset import get_loader
from reference_game_utils.update_policy import clean_sentence
from utils.vocabulary import Vocabulary

In [4]:
from coco_caption.pycocotools.coco import COCO # not sure if this is needed bc of the loadRes method
from coco_caption.pycocoevalcap.eval import COCOEvalCap # TODO in readme add point about renaming!
# import matplotlib.pyplot as plt
import skimage.io as io
# import pylab
# pylab.rcParams['figure.figsize'] = (10.0, 8.0)

import json
from json import encoder
encoder.FLOAT_REPR = lambda o: format(o, '.3f')

In [47]:
def evaluate_speaker(
    model_path: str,
    num_val_imgs: int,
    res_path: str,
    val_ppl_path: str,
    vocab_file: str,
    download_dir: str,
    val_file: str, 
    vocab_threshold: int = 25,
    batch_size: int = 64,
    embed_size: int = 512,
    visual_embed_size: int = 512,
    hidden_size: int = 512,
    decoding_strategy: str = "greedy",
) -> None:
    """
    Evaluate a pretrained model.
    """
    
    # data loader
    transform_test = transforms.Compose([transforms.Resize((224, 224)), 
                                         transforms.ToTensor(), \
                                         transforms.Normalize((0.485, 0.456, 0.406), \
                                                          (0.229, 0.224, 0.225))])
    data_loader_test = get_loader(transform=transform_test,
                         mode='train',
                         batch_size=batch_size,
                         vocab_threshold=vocab_threshold,
                         vocab_from_file=True,
                         download_dir=download_dir, 
                         vocab_file=vocab_file,
                         dataset_path=val_file, 
                         num_imgs=num_val_imgs,
                         embedded_imgs=torch.load("../train_logs/COCO_train_ResNet_features_reshaped_dict.pt"),
                        )
    # add img IDs
    data_loader_test.dataset._img_ids_flat = val_imgIDs[:num_val_imgs]
    
#     print("ids ", data_loader_test.dataset.ids)
#     print("Features ", data_loader_test.dataset.embedded_imgs.shape)
    
    vocab_size = len(data_loader_test.dataset.vocab)
    # load model
    decoder = DecoderRNN(
        embed_size,
        hidden_size,
        vocab_size,
        visual_embed_size,
    )
    
    criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()
    
    # instantiate results 
    results = []
    val_running_loss = 0.0
    val_running_ppl = 0.0
    losses_list = []
    ppl_list = []
    counter = 0
    total = 0
    
    num_steps = math.ceil(len(data_loader_test.dataset)/batch_size)
    
    # configs for the caption evaluations
#     dataDir='.'
#     dataType='val2014'
#     algName = 'fakecap'
#     annFile='%s/annotations/captions_%s.json'%(dataDir,dataType)
#     subtypes=['results', 'evalImgs', 'eval']
#     [resFile, evalImgsFile, evalFile]= \
#     ['%s/results/captions_%s_%s_%s.json'%(dataDir,dataType,algName,subtype) for subtype in subtypes]
    
    
    for i in range(num_steps):
        counter += 1
        
        indices = data_loader_test.dataset.get_func_train_indices()
        
        new_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices=indices)
        data_loader_test.batch_sampler.sampler = new_sampler

        # Obtain the batch.
        targets, distractors, target_features, distractor_features, target_captions, distractor_captions = next(iter(data_loader_test)) 
        
        both_images = torch.cat((target_features.unsqueeze(1), distractor_features.unsqueeze(1)), dim=1)
        # retrieve image IDs
        batch_img_ids = [val_imgIDs[i[0]] for i in indices]
        
        max_seq_len = target_captions.shape[1]-1
        
        with torch.no_grad():
            # get prediction
            captions_pred, log_probs, outputs, entropies = decoder.sample(
                both_images, 
                max_sequence_length=max_seq_len, 
                decoding_strategy=decoding_strategy
            )
            # transform to natural language
            nl_captions_pred = clean_sentence(captions_pred, data_loader_test)
            
            # append to results list together with img ID
            for i, c in list(zip(batch_img_ids, nl_captions_pred)):
#                 print("I and C ", i, c)
                
                if "end" in c.split(" "):
                    len_c = sum([1 for x in c.split(" ")[:c.split(" ").index("end")] if x != "end" ])
                else:
                    len_c = len(c.split(" "))            
                
                results.append({"image_id": i, "caption": " ".join(c.split()[:len_c])})
                
            # compute val PPL
            loss = criterion(outputs.transpose(1,2), target_captions[:, 1:]) 
            losses_list.append(loss.item())
            ppl = np.exp(loss.item())
            ppl_list.append(ppl)

            val_running_loss += loss.item()
            val_running_ppl += ppl
    
    print("Final average loss: ", val_running_loss / counter)
    print("Final average PPL: ", val_running_ppl / counter)
    
    # check if results dir exists
    os.makedirs("../../../data/speaker_eval_results/", exist_ok=True)
    
    # write out results file
    with open(res_path, "w") as f:
        json.dump(results, f)
    # write out validation PPLs
    df_out = pd.DataFrame({
        "loss": losses_list,
        "PPL": ppl_list,
    })
    df_out.to_csv(val_ppl_path)
    
    # now compute the evaluations, as proposed in the notebook from the repo referenced above 
    cocoRes = coco.loadRes(res_path)
    # create cocoEval object by taking coco and cocoRes
    coco_truth = COCO("../../../data/train/annotations/captions_train2014.json")
    cocoEval = COCOEvalCap(coco_truth, cocoRes) # data_loader_test.dataset.coco

    # evaluate on a subset of images by setting
    # cocoEval.params['image_id'] = cocoRes.getImgIds()
    # please remove this line when evaluating the full validation set
    cocoEval.params['image_id'] = cocoRes.getImgIds()

    # evaluate results
    cocoEval.evaluate()
    coco_metrics = cocoEval.eval.items()
    # TODO maybe write these smh too
    print("Final coco metrics: ", coco_metrics)

In [48]:
evaluate_speaker(
    model_path="../models/decoder-coco-512dim-scheduled_sampling_wGreedyDecoding_k150-1.pkl",
    num_val_imgs=500,
    res_path="../../../data/speaker_eval_results/decoder_scheduled_sampling_wGreedyDecoding_k150-1_results.json",
    val_ppl_path="../../../data/speaker_eval_results/decoder_scheduled_sampling_wGreedyDecoding_k150-1_val.json",
    vocab_file="vocab4000.pkl",
    download_dir="../../../data/train",
    val_file="val_split_IDs_from_COCO_train_tensor.pt", 
#     vocab_threshold: int = 25,
#     batch_size: int = 64,
#     embed_size: int = 512,
#     visual_embed_size: int = 512,
#     hidden_size: int = 512,
#     decoding_strategy: str = "greedy",
)

Vocabulary successfully loaded from vocab.pkl file!
loading annotations into memory...
Done (t=0.46s)
creating index...
index created!
Obtaining caption lengths...


100%|██████████| 500/500 [00:00<00:00, 114730.13it/s]
100%|██████████| 3700/3700 [00:00<00:00, 88550.54it/s]


Final average loss:  8.310003161430359
Final average PPL:  4064.401759388544
Loading and preparing results...
DONE (t=0.03s)
creating index...
index created!
loading annotations into memory...
Done (t=0.50s)
creating index...
index created!
tokenization...
setting up scorers...
computing Bleu score...


AssertionError: 