In [None]:
# %load get_object_audio_pairs.py
import os
import json
import pickle
from random import choice, shuffle
from shutil import copyfile
from collections import defaultdict

def process_split(split, img_thresh, wav_thresh):
    json_in_path = 'small_dataset/{}/json'.format(split)
    wav_in_path = 'small_dataset/{}/wav'.format(split)
    img_in_path = 'small_dataset/{}/img'.format(split)
    box_in_path = 'small_dataset/{}/box'.format(split)

    object_to_audio_slice = {}
    object_to_img_boxes = {}

    fl_cnt = 0
    for fl in os.listdir(box_in_path):
        fl_id = fl.split('.')[0]
        img_data = pickle.load(open('{}/{}'.format(box_in_path, fl), 'rb'))
        img_objects = img_data['objects']
        img_boxes = img_data['boxes']
        for obj in img_objects:
            object_to_audio_slice[(obj.lower(), fl_id)] = []
            if (obj.lower(), fl_id) not in object_to_img_boxes:
                object_to_img_boxes[(obj.lower(), fl_id)] = []
        for obj, box in zip(img_objects, img_boxes):
            img_fl = fl_id + '.jpg'
            object_to_img_boxes[(obj.lower(), fl_id)].append((img_fl, box))

    fl_id_format = 'COCO_val2014_{:012d}'

    for fl in os.listdir(json_in_path):
        json_data = json.load(open('{}/{}'.format(json_in_path, fl), 'r'))
        time_codes = json_data['timecode']
        duration = float(json_data['duration'])*1000
        wav_file_name = json_data['wavFilename']
        fl_id = fl_id_format.format(json_data['imgID'])
        words = []
        word_start_times = []
        for time in time_codes:
            if time[1] == 'WORD':
                words.append(time[2].lower())
                word_start_times.append(float(time[0]))
        for word_i in range(len(words)-1):
            if (words[word_i], fl_id) in object_to_audio_slice:
                start = word_start_times[word_i]
                end = word_start_times[word_i+1]
                object_to_audio_slice[(words[word_i], fl_id)].append((wav_file_name, (start, end)))
        if (words[-1], fl_id) in object_to_audio_slice:
            start = word_start_times[-1]
            object_to_audio_slice[(words[-1], fl_id)].append((wav_file_name, (start, duration)))
    
    out_file = open('{}_split.tsv'.format(split), 'w')
    line_format = '{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t\n'
    out_file.write(line_format.format('file_id','img_file', 'box0', 'box1', 'box2', 'box3', 'wav_file', 'start', 'end', 'object_org', 'object_mm', 'score'))
    file_id_format = '{:09d}'
    total_objects = 0
    for obj, fl_id in object_to_audio_slice:
        for wav_file_name, duration in object_to_audio_slice[(obj, fl_id)]:
            for img_file_name, box in object_to_img_boxes[(obj, fl_id)]:
                if min(box[2]-box[0], box[3]-box[1]) < img_thresh or duration[1]-duration[0] < wav_thresh:
                    continue
                total_objects += 1
                print(total_objects)
                file_id = file_id_format.format(total_objects)
                out_file.write(line_format.format(file_id, img_file_name, box[0], box[1], box[2], box[3], wav_file_name, duration[0], duration[1], obj, obj, 1))
    
    obj_to_img_info = defaultdict(set)
    for obj, fl_id in object_to_img_boxes:
        for img_file_name, box in object_to_img_boxes[(obj, fl_id)]:
            if min(box[2]-box[0], box[3]-box[1]) < img_thresh:
                continue
            obj_to_img_info[obj].add((img_file_name, box))
    
    for obj in obj_to_img_info:
        obj_to_img_info[obj] = list(obj_to_img_info[obj])
    
    objects = list(obj_to_img_info.keys())
    total_objects_bak = total_objects
    mm_count = 0
    while mm_count < total_objects_bak * 2:
        for obj, fl_id in object_to_audio_slice:
            for wav_file_name, duration in object_to_audio_slice[(obj, fl_id)]:
                if duration[1]-duration[0] < wav_thresh:
                    continue
                
                obj_mm = obj
                while obj_mm == obj:
                    obj_mm = choice(objects)
                img_file_name, box = choice(obj_to_img_info[obj_mm])
                
                total_objects += 1
                print(total_objects)
                file_id = file_id_format.format(total_objects)
                out_file.write(line_format.format(file_id, img_file_name, box[0], box[1], box[2], box[3], wav_file_name, duration[0], duration[1], obj, obj_mm, 0))
                
                obj_mm = obj
                while obj_mm == obj:
                    obj_mm = choice(objects)
                img_file_name, box = choice(obj_to_img_info[obj_mm])
                
                total_objects += 1
                print(total_objects)
                file_id = file_id_format.format(total_objects)
                out_file.write(line_format.format(file_id, img_file_name, box[0], box[1], box[2], box[3], wav_file_name, duration[0], duration[1], obj, obj_mm, 0))
                
                mm_count += 2
                
    out_file.close()
    return total_objects

img_thresh = 50 # min img box size in px
wav_thresh = 500 # min wav slice duration in ms
for split in ['train', 'test', 'val']:
    total_objects = process_split(split, img_thresh, wav_thresh)
    print(split, total_objects)
