In [5]:
import numpy as np
import json


def calculate_iou(pred_span, gold_span):
    gold_start, gold_end, pred_start, pred_end = gold_span[0], gold_span[1], pred_span[0], pred_span[1]
    intersection = max(0, min(gold_end, pred_end) - max(gold_start, pred_start))
    union = max(0, max(gold_end, pred_end) - min(gold_start, pred_start))
    if union <= 0 or intersection <= 0:
        return 0
    return intersection / union


def check_ans(pred_span, gold_spans):
    if not isinstance(gold_spans[0], (list, tuple)):
        gold_spans = [gold_spans]
    return max([calculate_iou(pred_span, gold_span) for gold_span in gold_spans])


def get_iou_at_different_turns(example, max_turns=4):
    pred_answer_list = example['pred_answer_list'] + ['throughout'] * (max_turns - len(example['pred_answer_list']))
    start, end = 0, example['duration']
    
    res_list = list()
    for pred in pred_answer_list:
        interval = (end - start) / 4
        if pred == 'beginning':
            end = end - 2*interval
        elif pred == 'middle':
            start, end = start + interval, end - interval
        elif pred == 'end':
            start = start + 2 * interval
        res_list.append(check_ans((start, end), example['gt_span']))
    return res_list


In [None]:
# change output_fname at your need
output_fname = '../outputs/charades_sta-recursive_grounding-4_turns.jsonl'
gold_fname = '../data/test-anno/charades_sta-recursive_grounding.json'
gold_data = json.load(open(gold_fname))
max_turns = 4

res_list_all_turns = list()

for line, gold in zip(open(output_fname), gold_data):
    example = json.loads(line)
    if example['duration'] is None: continue
    example['gt_span'] = gold['answer']
    res_list = get_iou_at_different_turns(example, max_turns)
    res_list_all_turns.append(res_list)

print(output_fname, 'num examples:', len(res_list_all_turns))
for turns in range(max_turns):
    print('turns: %d' % (turns + 1))
    iou_list = [ious[turns] for ious in res_list_all_turns]
    print('mean iou: %.4f' % np.mean(iou_list))
    print('iou@0.3/0.5/0.7: %.4f/%.4f/%.4f' % tuple([len([i for i in iou_list if i > thres]) / len(iou_list) for thres in [0.3, 0.5, 0.7]]))
