In [1]:
import numpy as np
import pickle
import os
from tqdm import tqdm

In [2]:
from utils import load_vocab, decode_caption, load_caption, rrv_votes, load_annotations, print_image

In [3]:
vocab = load_vocab(dict_file = "../../outputs/vocab/5000/coco2014_vocab.json")
image_id_to_index, index_to_image_id, annotations_dict = load_annotations(annotations_dir="../../annotations/", 
                                                                          annotations_file='captions_val2014.json',
                                                                         map_file = "../../outputs/val_image_id_to_idx.csv")
print("Processed {} images".format(len(image_id_to_index)))
print("Processed {} images".format(len(annotations_dict.keys())))

word_to_idx
idx_to_word
Loaded dictionary...
Dictionary size: 5004
Error proccessing image_id: image_index
Skipping file person_keypoints_train2014.json
Skipping file instances_train2014.json
Skipping file instances_val2014.json
Skipping file person_keypoints_val2014.json
Processed 40504 images
Processed 40504 images


In [4]:
def reweighted_range_vote(votes, weights):
    """
    :param votes: N x N numpy array where votes[s][t] indicates the vote sentence s gives to sentence t. 
        Each row sums to 1
    :param weights: a numpy array of size N where weight[s] indicates by how much 
        the votes of s need to be weighted
    :return: An iterator giving the winners in order of the RRV
    """
    scores = weights @ votes
    for x in reversed(np.argsort(scores)):
        return x, scores[x]

In [17]:
def load_beam_caption_with_state(beam_size, im_id):
    backed_off = 0
    file_name = '../../outputs/beam_captions_with_hidden_states_{}/{}.pickle'.format(beam_size, im_id)
    if not os.path.isfile(file_name):
        backed_off = 1
        file_name = '../../outputs/beam_captions_with_hidden_states_{}/{}.pickle'.format(10, im_id)
    with open(file_name, 'rb') as file:
        beam_caption = pickle.load(file)
    return beam_caption, backed_off

In [24]:
def save_vote_captions(captions, beam_size):
    file_name = '../../outputs/voted_captions/{}/lstm_states.pickle'.format(beam_size)
    if os.path.isfile(file_name): 
        raise ValueError("File {} already exists".format(file_name))
    with open(file_name, 'wb') as file:
        pickle.dump(captions, file, pickle.HIGHEST_PROTOCOL)


In [25]:
def extract_lstm_states_caption(beam_size, im_id):
    try:
        beam_caption, backed_off = load_beam_caption_with_state(beam_size, im_id)
        sim_matrix = 1 - np.array(beam_caption['similarity_matrix'])
        probabilities = np.array(beam_caption['probabilities'])
        winner, voted_score = reweighted_range_vote(sim_matrix, probabilities)
        voted_probability = probabilities[winner]
        voted_caption = beam_caption['captions'][winner]
        return ([np.array(voted_caption)], voted_probability, voted_score), backed_off
    except Exception as e:
        print(e, im_id)
        return None, 0

In [26]:
beam_size = 100
voted_captions = {}
n_backed_off = 0
for image_id in tqdm(sorted(annotations_dict)):
    caption, backed_off = extract_lstm_states_caption(beam_size, image_id)
    n_backed_off += backed_off
    if caption:
        voted_captions[image_id] = caption

100%|██████████| 40504/40504 [03:23<00:00, 199.43it/s]


In [27]:
n_backed_off

0

In [28]:
len(voted_captions)

40504

In [29]:
len(voted_captions) - n_backed_off

40504

In [31]:
save_vote_captions(voted_captions, 100)

In [129]:
voted_captions

{0: ([array([  1,   4,  13,  47,   4,  92,  33,   4, 363,   2])],
  0.001008578222019707,
  0.024602035006023374),
 1: ([array([  1,   4, 164,   6, 769,  22,  14,   4, 164,   6, 769,   2])],
  0.0023497612972159117,
  0.011829763599684934),
 2: ([array([  1,   4,  65,  27,  11, 233,  33,   7,  25,   2])],
  0.0009554519319120989,
  0.005496788183957918),
 3: ([array([  1,   4, 129,  69,   5,   7,  70,   6,   4,  86,   2])],
  0.0034008859883146694,
  0.008995220673705234),
 4: ([array([  1,   4, 116,  15,   5,  34,   6,   4, 127,   2])],
  0.00022301811032278073,
  0.01654554356643812),
 5: ([array([  1,   4,  13,  11,  19,   8,   7,  64,   9,   4, 118,   2])],
  5.416136769383144e-05,
  0.006273803787368426),
 6: ([array([  1,   4,  38, 157,  11, 110,   4, 237,   6,  68,   2])],
  0.00011081797069672708,
  0.02017009964057937),
 7: ([array([  1,   4,  13,  19,   8,   7, 618,  28,   4, 118,   2])],
  0.00014054159768742266,
  0.006644337827372551),
 8: ([array([ 1,  4, 35,  6, 20, 58, 