# Setup

## Imports

In [2]:
import json
import os.path

In [3]:
from vaiutils import plot_images, path_consts

## Define Useful Variables and Functions

In [4]:
for k, v in path_consts('COCO'):
    exec(k + '=v')

In [5]:
def list_to_dict(dict_list, key):
    list_dict = {}
    for item in dict_list:
        list_dict[item[key]] = {k: v for k, v in item.items() if k != key}
    return list_dict

In [6]:
def search_caption(search_string, match_all=False, case_sensitive=False):
    if not case_sensitive:
        search_string = search_string.lower()
    search_string = search_string.split(',')
    
    results = []
    
    for file_id in data.keys():
        for beam_size in data[file_id].keys():
            sample_strings = data[file_id][beam_size]['captions']
            if type(sample_strings) is not list:
                sample_strings = [sample_strings]
            for sample_string in sample_strings:
                if case_sensitive:
                    match_list = [s in sample_string for s in search_string]
                else:
                    match_list = [s in sample_string.lower() for s in search_string]

                if match_all:
                    matched = all(match_list)
                else:
                    matched = any(match_list)
                    
                if matched:
                    results.append((file_id, beam_size))
    return sorted(list(set(results)))

In [7]:
def show_captions(beam_size=1, file_idx=None, search_string=None):
    if search_string is not None:
        case_sensitive = search_string != search_string.lower()
        match_all = search_string[0] != '~'
        if not match_all:
            search_string = search_string[1:]
        search_results = search_caption(search_string, match_all, case_sensitive)
        if len(search_results) == 0:
            print('Picture not found.')
            return
        file_idx = unique([s[0] for s in search_results])
        file_idx = file_idx[randint(len(file_idx))]
        beam_sizes = [s[1] for s in search_results if s[0] == file_idx]
        beam_size = beam_sizes[randint(len(beam_sizes))]
        show_captions(beam_size, file_idx)
        return
    
    if file_idx is None:
        file_idx = randint(len(data))
        file_idx = list(data.keys())[file_idx]
        
    try:
        captions = data[file_idx][beam_size]
    except:
        print('Did not find image with beam size', beam_size)
    
    if 'probabilities' in captions.keys():
        titles = ''.join(["{}    ({:.2f})\n".format(caption, probability)
                              for caption, probability in zip(captions['captions'], captions['probabilities'])])
    else:
        titles = captions['captions']
        
    plot_images([imread(os.path.join(DIR_DATA, 'val2017', file_idx))], titles)

## Checks

In [8]:
assert os.path.exists(os.path.join(DIR_OUTPUT, 'Train.json')), "Captions not found."

## Load Data

In [9]:
with open(os.path.join(DIR_OUTPUT, 'Train.json')) as f:
    json_data = json.load(f)['data']

In [10]:
data = list_to_dict(json_data, 'file_id')
data = {k: list_to_dict(v['results'], 'beam_size') for k, v in data.items()}
del json_data

# Show Captions

In [None]:
show_captions(search_string='cake')