In [1]:
import json

class AnnotationLoader:
    def __init__(self, train_path, val_path, origin_path, EGO4D_JSON_PATH):
        self.train_data = json.load(open(train_path))
        self.val_data = json.load(open(val_path))
        self.data = {**self.train_data, **self.val_data}
        
        self.origin_narration = json.load(open(origin_path))['videos']
        
        meta_data = json.load(open(EGO4D_JSON_PATH))['videos']
        self.meta_data = {}
        for meta_d in meta_data:
            self.meta_data[meta_d['video_uid']] = meta_d
        
    def get_data(self):
        return self.data
    
    def get_origin_narration(self):
        return self.origin_narration
    
    def get_meta_data(self):
        return self.meta_data

class BetaAlphaCalculator:
    def __init__(self, data, alpha=4.9):
        self.data = data
        self.beta_map = {}
        self.alpha = alpha
    
    def compute_beta(self):
        for video_uid, annotation_uid_narrations in self.data.items():
            for annotation_uid, narrations in annotation_uid_narrations.items():
                if len(narrations) == 0:
                    continue
                total_time = 0
                for i in range(len(narrations) - 1):
                    total_time += narrations[i+1]['time'] - narrations[i]['time']
                self.beta_map[annotation_uid] = total_time / len(narrations)
    
    def get_beta_map(self):
        return self.beta_map
    
    def get_alpha(self):
        return self.alpha


if __name__ == "__main__":
    train_path = "/home/zhangyl/videollm-online/datasets/ego4d/v2/annotations/narration_stream_train.json"
    val_path = "/home/zhangyl/videollm-online/datasets/ego4d/v2/annotations/narration_stream_val.json"
    origin_path = '/home/zhangyl/videollm-online/datasets/ego4d/v2/annotations/all_narrations_redacted.json'
    EGO4D_JSON_PATH = "/mnt/extra/dataset/ego4d/ego4d.json"
    loader = AnnotationLoader(train_path, val_path, origin_path, EGO4D_JSON_PATH)
    data = loader.get_data()
    origin_data = loader.get_origin_narration()
    meta_data = loader.get_meta_data()
    alpha = 4.9
    
    beta_alpha_calculator = BetaAlphaCalculator(data, alpha)
    beta_alpha_calculator.compute_beta()
    beta_map = beta_alpha_calculator.get_beta_map()
    alpha = beta_alpha_calculator.get_alpha()

In [21]:
class videoClipFilter:
    def __init__(self):
        self.filter_func_list = []
    
    def add_filter_func(self, func):
        self.filter_func_list.append(func)
    
    def filter(self, data, origin_data, meta_data, beta_map, alpha):
        filtered_data = {}
        for video_uid, annotation_uid_narrations in data.items():
            for annotation_uid, narrations in annotation_uid_narrations.items():
                is_filtered = False
                for func in self.filter_func_list:
                    if func(video_uid, annotation_uid, narrations, origin_data, meta_data, beta_map, alpha):
                        is_filtered = True
                        break
                    
                if not is_filtered:
                    if video_uid not in filtered_data:
                        filtered_data[video_uid] = {}
                    filtered_data[video_uid][annotation_uid] = narrations
        
        return filtered_data


def filter_func_aspect_ratio(video_uid, annotation_uid, narrations, origin_data, meta_data, beta_map, alpha):
    if video_uid not in meta_data:
        return True
    
    is_stereo = meta_data[video_uid]['is_stereo']
    if is_stereo:
        aspect_ratio = meta_data[video_uid]['video_metadata']['display_resolution_width'] / 2 / meta_data[video_uid]['video_metadata']['display_resolution_height']
    else:
        aspect_ratio = meta_data[video_uid]['video_metadata']['display_resolution_width'] / meta_data[video_uid]['video_metadata']['display_resolution_height']
    if (aspect_ratio < 0.5 or aspect_ratio > 2):
        return True
    
    return False

# def filter_func_unsure_tags(video_uid, annotation_uid, narrations, origin_data, meta_data, beta_map, alpha):
#     for narration in narrations:
#         if '#Unsure' in narration['text'] or '#unsure' in narration['text']:
#             return True
#     return False

# def filter_func_less(video_uid, annotation_uid, narrations, origin_data, meta_data, beta_map, alpha):
#     for narration in narrations:
#         if len(narration['text'].split()) <= 3:
#             return True
#     return False

def filter_func_beta(video_uid, annotation_uid, narrations, origin_data, meta_data, beta_map, alpha):
    if beta_map[annotation_uid] > 20 or beta_map[annotation_uid] < 1.5:
        return True
    return False

def filter_func_loose(video_uid, annotation_uid, narrations, origin_data, meta_data, beta_map, alpha):
    summs = origin_data[video_uid]['summaries']
    if len(summs) == 0:
        return True
    
    is_match = False
    for summ in summs:
        if summ['_annotation_uid'] == annotation_uid:
            is_match = True
            break
    if not is_match:
        True
    
    summ_len = summ['end_time'] - summ['start_time']
    narration_len = narrations[-1]['time'] - narrations[0]['time']
    
    if narrations[-1]['time'] > (summ['end_time']+0.5) or  (narrations[0]['time']) < (summ['start_time']-0.5): 
        print('narration end time:', narrations[-1]['time'], 'summ end time:', summ['end_time'])
        print(narrations[-1]['time'] > (summ['end_time']+0.5))
        print('narration start time:', narrations[0]['time'], 'summ start time:', summ['start_time'])
        print(narrations[0]['time'] < summ['start_time'])
        return True
    if summ_len < 250  or narration_len / summ_len < 0.8:
        return True
    return False


filter = videoClipFilter()
filter.add_filter_func(filter_func_aspect_ratio)
# filter.add_filter_func(filter_func_unsure_tags)
# filter.add_filter_func(filter_func_less)
filter.add_filter_func(filter_func_beta)
filter.add_filter_func(filter_func_loose)

filtered_data = filter.filter(data, origin_data, meta_data, beta_map, alpha)

narration end time: 569.9863905333334 summ end time: 1919.954361933333
False
narration start time: 270.72471053333334 summ start time: 1619.954361933333
True
narration end time: 1109.9740571999998 summ end time: 1919.954361933333
False
narration start time: 810.1254705333334 summ start time: 1619.954361933333
True
narration end time: 3539.8740571999997 summ end time: 1919.954361933333
True
narration start time: 3240.2800571999996 summ start time: 1619.954361933333
False
narration end time: 569.9960624 summ end time: 839.9320312
False
narration start time: 269.9973957333333 summ start time: 539.9653645333333
True
narration end time: 1330.9774186458335 summ end time: 300.0210286458333
True
narration start time: 1080.0210286458334 summ start time: 0.021028645833333335
False
narration end time: 300.07605720000004 summ end time: 1380.0577025908854
False
narration start time: 0.2968672 summ start time: 1080.0210286
True
narration end time: 570.0800572 summ end time: 1380.0577025908854
False


In [3]:
json.dump(filtered_data, open('filtered_data_v2.json', 'w'),indent=4)