#### Introduction

This notebook demonstrates the usage of class `MisalignSRL`. Given one narration clip (A) from fho_main.json (a subset of Ego4D), this class is mainly used to select other narration clips (A_mis_v, A_mis_n, A_mis_vn) that have certain semantic roles misaligning with the given clip. 

Practically, the main method `get_misaligned_samples` can be invoked at https://github.com/shanestorks/TRAVEl/blob/8fad8f174a4c43b712b6c25dff33019790361ea2/travel/data/ego4d/__init__.py#L552 to generate A_mis_v, A_mis_n, A_mis_vn.

#### Usage
- download index files from [https://prism.eecs.umich.edu/yayuanli/z_web/dataset/Ego4D_Mistake/v1/] and set path to variables: 
    - `fho_main_path`: "fho_main.json"
    - `narration_mapping_fho2srl_df_path`: "narration_mapping_fho2srl_df.csv"
    - `narration_df_path`: "egoclip_narrations_exploed_groupby_no,txt.csv"
    - `fho_narration_df_rows_path`: "fho_narration_df_rows.json"
    - `group_df_path`: "egoclip_groups_groupby_no,txt.csv"
- run this notebook

#### TODOs
- [ ] end-to-end test in TRAVEl
    - [ ] Need to mount the dataset and annotation used in TRAVEl
- [ ] load a from the sampled narration clips
    - [ ] discuss what images are appropriate
    - [ ] Need to random access a clip given `video_uid` and `narration_timestamp_sec` in `Ego4dFHOMainDataset`



#### Limitations
- some (~20%) narration clips in egoclip.csv are wasted due to imperfect SRL (AllenNLP) and grouping algorithm (string matching)
- some misaligned samples may be too straightforward and simple. For example, for a given clip A with narration "pour sauce", there could be a A_mis_n with narration "pour detergent". Not sure if these sampels could be helpful for specific tasks.


In [None]:
import pprint
import json
import ast
import time
from tqdm import tqdm
import pandas as pd
import numpy as np
import random

# utils.py
def convert_to_list(string):
    try:
        return ast.literal_eval(string)
    except:
        return []  # Return an empty list in case of error


def read_large_csv(file_path, columns_str2list=[], nrows=None, chunk_size=100, ):
    """
    Read a large CSV file in chunks.
    
    Args:
    """
    
    # Define the file path
    # file_path = f'../dataset/egoclip_groups_groupby_{grouping_type}.csv'

    # Determine the number of rows in the CSV
    total_rows = sum(1 for _ in open(file_path, 'r'))

    # Read the CSV with a progress bar
    # chunk_size = 100  # Adjust chunk size based on your needs
    chunks = []
    converters = None
    if columns_str2list != []:
        converters = {cname: convert_to_list for cname in columns_str2list}
    # Use tqdm to show progress
    for chunk in tqdm(pd.read_csv(file_path, 
                                chunksize=chunk_size,
                                index_col=0,
                                converters=converters,
                                nrows=nrows), 
                    total=total_rows/chunk_size,
                    unit=f'chunk ({chunk_size} videos per chunk)'):
        chunks.append(chunk)

    # Concatenate all chunks into a single DataFrame
    df = pd.concat(chunks, axis=0)

    # Now you can use df as a normal DataFrame
    return df    



class MisalignSRL:
    def __init__(self, fho_main_path, narration_mapping_fho2srl_df_path, narration_df_path, fho_narration_df_rows_path, group_df_path):
        '''
        fho_main_path: the path to json file (fho_main) with structure
            {"videos": 
                [{"annotated_intervals": [{"clip_uid": str, ... , 
                                            "narrated_actions": [{"narration_text": str, 
                                                                "narration_timestamp_sec"}, ...]
                                            }, ...
                ], 
                "video_metadata": {}, 
                "video_uid": str}]},
            "date": xx, 
            "description": xx, 
            "metadata": xx}
        
        narration_mapping_fho2srl_df_path: the path to csv file with columns ["video_uid", "narration_text", "narration_timestamp_sec", "srl_index", "fho_index"]
        
        narration_df_path: the path to csv file with example row:
            video_uid                 26202090-684d-4be8-b3cc-de04da827e91
            video_dur                                          3127.233333
            narration_source                              narration_pass_1
            narration_ind                                                7
            narration_time                                         40.8375
            clip_start                                           40.583986
            clip_end                                             41.090933
            clip_text                C takes his hand out of the paper bag
            tag_verb                                                  [93]
            tag_noun                                        [321, 12, 349]
            ARG0                                                         C
            V                                                        takes
            ARG1                                                  his hand
            valid_tag_noun                                           [321]
            valid_txt_noun                                        ['hand']
            valid_tag_verb                                            [93]
            valid_txt_verb                                        ['take']
            index                                                   322348
            valid_txt_verb_single                                     take
            valid_txt_noun_single                                     hand        
        
        fho_narration_df_rows_path: the path to json file with structure: a list of dict. An example dict:
                    {
                      'video_index': 1, # the index of the video in the fho_main.json
                      'interval_index': 0, # the index of the interval in the fho_main.json
                      'action_index': 3, # the index of the action in the fho_main.json                      
                      'narration_timestamp_sec': 35.4317396,
                      'video_uid': '26202090-684d-4be8-b3cc-de04da827e91',
                      'narration_text': '#C C takes a steel bowl out of the '
                                        'paper bag',}
        
        group_df_path: the path to csv/parquet file with columns ["txt_verb"(str), "txt_noun"(str), "narration_index"(list), "narration_indices"(list), "mismatch_noun"(list), "mismatch_verb"(list), "mismatch_verb_noun"(list)]. 
        '''
        
        print("Loading fho_main_json ...")
        start_time = time.time()
        # file_path = "/home/yayuanli/fun/mistake_detection/fine_grained_action_mistake_detection/dataset/fho_main.json"
        with open(fho_main_path, "r") as file:
            fho_main_json = json.load(file)
        print(f"Loading fho_main.json took {time.time() - start_time} seconds.")
        
        print("Loading narration_df ...")
        start_time = time.time()
        # narration_df = "/z/home/yayuanli/dat/Ego4D_Mistake/v1/egoclip_narrations_exploed_groupby_no,txt.csv"
        narration_df = pd.read_csv(narration_df_path, index_col=0)        
        print(f"Loading narration_df took {time.time() - start_time} seconds.")
        
        print("Loading narration_mapping_fho2srl_df ...")  
        start_time = time.time()
        # 'narration_mapping_fho2srl_df.csv'
        narration_mapping_fho2srl_df = pd.read_csv(narration_mapping_fho2srl_df_path, index_col=0)
        narration_mapping_fho2srl_df["srl_index"] = narration_mapping_fho2srl_df["srl_index"].apply(ast.literal_eval)
        print(f"Loading narration_mapping_fho2srl_df took {time.time() - start_time} seconds.")
        
        print("Loading fho_narration_df_rows ...")
        start_time = time.time()
        # fho_narration_df_rows.json
        with open(fho_narration_df_rows_path, 'r') as f:
            fho_narration_df_rows = json.load(f)
        print(f"Loading fho_narration_df_rows took {time.time() - start_time} seconds.")
        
        print("Loading group_df ...")
        start_time = time.time()
        # f'/z/home/yayuanli/dat/Ego4D_Mistake/v1/egoclip_groups_groupby_no,txt.csv'
        group_df = read_large_csv(group_df_path, 
                                #   columns_str2list=["narration_index", "narration_indices", "mismatch_noun", "mismatch_verb", "mismatch_verb_noun"], 
                                nrows=None, # None
                                )
        print(f"Loading group_df took {time.time() - start_time} seconds.")

        self.fho_main_json = fho_main_json
        self.narration_mapping_fho2srl_df = narration_mapping_fho2srl_df
        self.narration_df = narration_df
        self.fho_narration_df_rows = fho_narration_df_rows
        self.group_df = group_df


        self.type_name_col_name_map = {"MisalignSRL_V": "mismatch_verb", "MisalignSRL_ARG1": "mismatch_noun", "MisalignSRL_V_ARG1": "mismatch_verb_noun"} # human readable name -> column name in group_df

           
    def get_misaligned_samples(self, clip):
        '''
        clip: (obj?). Corresponds to one action clip in fho_main.json. The distinguishing information of this clip is `video_uid` and `narration_timestamp_sec`.
        
        return:
            mistake_example_meta_dict: {misalignsrl_type -> one action clip in fho_main.json (fho_main_json["videos"][video_index]["annotated_intervals"][big_clip_index]["narrated_actions"][narration_clip_index])}
        '''
        # for each misalignsrl_type, sample one srl_index in the group, and return the index of that narration clip in `fho_main.json`
        # mistake_example_meta_dict: misalignsrl_type -> fho_narration_df_rows. To be returned.
        mistake_example_meta_dict = {_: None for _ in self.type_name_col_name_map}

        
        video_uid = clip["video_uid"]
        narration_timestamp_sec = clip["narration_timestamp_sec"]
        
        # find the map_row (the row in `narration_mapping_fho2srl_df (pd.DataFrame)`) by matching `video_uid` and `narration_timestamp_sec`
        match_video_uid = self.narration_mapping_fho2srl_df["video_uid"] == video_uid
        match_narration_timestamp_sec = self.narration_mapping_fho2srl_df["narration_timestamp_sec"] == narration_timestamp_sec
        map_row = self.narration_mapping_fho2srl_df[match_video_uid & match_narration_timestamp_sec]
        
        # return if no map_row found. Meaning, this clip does not have misalignsrl sample in current group_df.
        if len(map_row) == 0:
            return mistake_example_meta_dict
        
        # get the `fho_index` (index in `fho_narration_df_rows`) and `srl_index` (index in `narration_df`) from the map_row (a row in `narration_mapping_fho2srl_df (pd.DataFrame)`)
        fho_index = map_row["fho_index"].values[0]
        
        srl_index = random.choice(map_row["srl_index"].values[0]) # randomly pick one. one narration could be in multiple groups. E.g., "pick up a bag of clothes" could be in "pick up bag" and "pick up cloth"
        
        # find the group in `group_df` to which the `srl_narration_row` belongs
        srl_narration_row = self.narration_df.iloc[srl_index]
        match_txt_verb = self.group_df["txt_verb"] == srl_narration_row["valid_txt_verb_single"]
        match_txt_noun = self.group_df["txt_noun"] == srl_narration_row["valid_txt_noun_single"]
        target_group = self.group_df[match_txt_verb & match_txt_noun] # a row in `group_df` (pd.Series)
        
        # fill in the mistake_example_meta_dict   
        for misalignsrl_type in self.type_name_col_name_map:
            # index_list: list of int. the srl index in the srl narration df.
            index_list = target_group[self.type_name_col_name_map[misalignsrl_type]].iloc[0] 
            # TODO: this row can be optmized by saving group_df as parquet file instead of csv
            if not isinstance(index_list, list):
                index_list = ast.literal_eval(index_list)
                
            # shuffle the index_list so that different given `clip` is less likely to get the same sample for a misalignsrl_type. For example, for given `clip`s "cut carrot" and "pour sauce", the first sample for "MisalignSRL_VN" could be the same -- "pick up cap"
            index_list = np.random.permutation(index_list)
            # find fho_index (for fho_narration_df_rows)        
            for srl_index in index_list:
                # narration_mapping_fho2srl_df["srl_index"] is a column where each row is a list of int. srl_index is a int. 
                # Find the row where the srl_index is in the list of the row narration_mapping_fho2srl_df["srl_index"]
                # it may not be since current group_df is made from `egoclip.csv` (refer to: https://github.com/facebookresearch/EgoVLPv2), which contains narration clips over whole Ego4D while the fho_main.json is a subset of Ego4D (with bbox annotation).
                match_misalign_sample_in_fho_main = self.narration_mapping_fho2srl_df["srl_index"].apply(lambda x: srl_index in x)
                fho_index_row = self.narration_mapping_fho2srl_df[match_misalign_sample_in_fho_main] 
                if len(fho_index_row) > 0:
                    fho_index = fho_index_row["fho_index"].values[0]
                    mistake_example_meta_dict[misalignsrl_type] = self.fho_narration_df_rows[fho_index]
                    break

        # return mistake_example_meta_dict. For each misalignsrl_type, if None, it means no sample is found in the group. (This could be improved making group_df.parquet from fho_main.json instead of from egoclip)
        return mistake_example_meta_dict    
        

    def get_clip_info_from_fho_main_index(self, video_index, big_clip_index, narration_clip_index):
        narration_clip_info = self.fho_main_json["videos"][video_index]["annotated_intervals"][big_clip_index]["narrated_actions"][narration_clip_index]
        
        return narration_clip_info
    
        # (need Peter's input.) for each misalignsrl_type, given the information of the narration clip in fho_main.json, we can load visual media (image). Such information of a narration clip should have been loaded the same way as one item (`clip`) in `ego4d` above. I.E., need to identify the index of this clip in `ego4d` so that the media can be loaded efficiently using exisiting code. 

        



In [None]:
fho_main_path = "/home/yayuanli/fun/mistake_detection/fine_grained_action_mistake_detection/dataset/fho_main.json"
narration_mapping_fho2srl_df_path = 'narration_mapping_fho2srl_df.csv'
narration_df_path = "/z/home/yayuanli/dat/Ego4D_Mistake/v1/egoclip_narrations_exploed_groupby_no,txt.csv"
fho_narration_df_rows_path = "fho_narration_df_rows.json"
group_df_path = "/z/home/yayuanli/dat/Ego4D_Mistake/v1/egoclip_groups_groupby_no,txt.csv"
misalignsrl = MisalignSRL(fho_main_path, narration_mapping_fho2srl_df_path, narration_df_path, fho_narration_df_rows_path, group_df_path)

Loading fho_main_json ...
Loading fho_main.json took 21.797735691070557 seconds.
Loading narration_df ...
Loading narration_df took 5.178931713104248 seconds.
Loading narration_mapping_fho2srl_df ...
Loading narration_mapping_fho2srl_df took 0.24354338645935059 seconds.
Loading fho_narration_df_rows ...
Loading fho_narration_df_rows took 0.19900894165039062 seconds.
Loading group_df ...


  full_bar = Bar(frac,
100%|██████████| 450/449.88 [01:49<00:00,  4.12chunk (100 videos / chunk)/s]


Loading group_df took 116.41300749778748 seconds.


In [None]:
# usage example 
video_index = 1
big_clip_index = 0
narration_clip_index = 10
clip = misalignsrl.fho_main_json["videos"][video_index]["annotated_intervals"][big_clip_index]["narrated_actions"][narration_clip_index]
clip.update({"video_uid": misalignsrl.fho_main_json["videos"][video_index]["video_uid"]})
pprint.pprint(f"example clip: {clip['narration_text']}")

misaligned_sample = misalignsrl.get_misaligned_samples(clip)
print(misaligned_sample)

# misalignsrl.get_clip_info_from_fho_main_index(misaligned_sample["video_index"], misaligned_sample["interval_index"], misaligned_sample["action_index"])

'example clip: #C C takes off the lid of the plastic container'
{'MisalignSRL_V': {'video_uid': '4abd8edc-4751-4a47-9808-696d960b7557', 'narration_text': '#C C fills the mould container with sand', 'narration_timestamp_sec': 5732.351891933333, 'video_index': 507, 'interval_index': 1, 'action_index': 9}, 'MisalignSRL_ARG1': {'video_uid': 'ac259c29-f40c-4afb-a4b3-b910dcff46ff', 'narration_text': '#C C takes a pinch of clay mold on the table', 'narration_timestamp_sec': 1481.92476762133, 'video_index': 809, 'interval_index': 0, 'action_index': 26}, 'MisalignSRL_V_ARG1': {'video_uid': '8e58a7b3-43ef-406d-ad5d-901f83418261', 'narration_text': '#C C dips the paint brush into a cup of water on the  table with her right hand.', 'narration_timestamp_sec': 6831.789344600002, 'video_index': 202, 'interval_index': 1, 'action_index': 64}}


In [None]:
# usage example 
found_sample = False
for video_index, video_dict in enumerate(misalignsrl.fho_main_json["videos"]):
    for big_clip_index, big_clip_dict in enumerate(video_dict["annotated_intervals"]):
        for narration_clip_index, action_dict in enumerate(big_clip_dict["narrated_actions"]):
            clip = action_dict.copy()
            clip.update({"video_uid": video_dict["video_uid"]})
            misaligned_samples = misalignsrl.get_misaligned_samples(clip)
            
            # check any useful samples
            if not list(misaligned_samples.values())==[None]*len(misaligned_samples):
                print(f"video_index: {video_index}, big_clip_index: {big_clip_index}, narration_clip_index: {narration_clip_index}")            
                pprint.pprint(f"{clip=}")
                print()
                pprint.pprint(f"{misaligned_samples=}")
                print()
            for misalignsrl_type, misaligned_sample in misaligned_samples.items():
                if misaligned_sample is not None:
                    narration_clip_info = misalignsrl.get_clip_info_from_fho_main_index(misaligned_sample["video_index"], misaligned_sample["interval_index"], misaligned_sample["action_index"])
                    
                    pprint.pprint(f"{misalignsrl_type=}")
                    pprint.pprint(f"{narration_clip_info=}")
                    print(f"===============\n")
                    found_sample = True
            
            if found_sample:
                break
        if found_sample:
            break
    if found_sample:
        break        

video_index: 1, big_clip_index: 0, narration_clip_index: 2
 "'start_sec': 30.654361933333334, 'end_sec': 38.654361933333334, "
 "'start_frame': 919, 'end_frame': 1159, 'is_valid_action': True, "
 "'is_partial': False, 'clip_start_sec': 30.654361933333334, 'clip_end_sec': "
 "38.654361933333334, 'clip_start_frame': 919, 'clip_end_frame': 1159, "
 "'narration_timestamp_sec': 34.666419600000005, "
 "'clip_narration_timestamp_sec': 34.666419600000005, 'narration_text': '#C C "
 "opens a paper bag', 'narration_annotation_uid': "
 "'e5866a1d-dbad-4da9-8776-0cb23082d744', 'structured_verb': 'open', "
 "'freeform_verb': None, 'state_transition': "
 "'activate_-_[object_is_transformed_to_enable_access_to_one_or_more_objects_inside_(e.g.,_open),_or_enables_an_object’s_function_(e.g.,_turn_power_on)]', "
 "'critical_frames': {'pre_45': 984, 'pre_30': 999, 'pre_15': 1014, "
 "'post_frame': 1063, 'contact_frame': 1037, 'pre_frame': 1029, 'pnr_frame': "
 "1049}, 'clip_critical_frames': {'pre_45': 98