In [1]:
# Import relevant packages
import numpy as np
import json, pickle, os
import pandas as pd

In [2]:
banned_qtypes = {"Object", "Action After", "Action Before", "Action While",                                                                               
                 "Action Between", "First Action", "Last Action",                                                                                         
                 "Object After", "Object Before", "Object While", "Object Between", "Action"}

subq_type_ordering = ["Object Exists", "Relation Exists", "Interaction",                                                                        
                      "Interaction Temporal Localization", "Exists Temporal Localization",                                                                                            
                      "First/Last", "Longest/Shortest Action", "Choose", "Equals"]   

collapsed_qtypes = {"Object Exists" : "Object Exists",                                                                                                    
                    "Relation Exists" : "Relation Exists",                                                                                                
                    "Interaction" : "Interaction",                                                                                                        
                    "Object" : "Object",                                                                                                                  
                    "Action" : "Action",                                                                                                                  
                    "Interaction After" : "Interaction Temporal Localization",                                                                            
                    "Interaction Before" : "Interaction Temporal Localization",                                                                           
                    "Interaction While" : "Interaction Temporal Localization",                                                                            
                    "Interaction Between" : "Interaction Temporal Localization",                                                                          
                    "Exists After" : "Exists Temporal Localization",                                                                                      
                    "Exists Before" : "Exists Temporal Localization",                                                                                     
                    "Exists While" : "Exists Temporal Localization",                                                                                      
                    "Exists Between" : "Exists Temporal Localization",                                                                                    
                    "Object After" : "Object Temporal Localization",                                                                                      
                    "Object Before" : "Object Temporal Localization",                                                                                     
                    "Object While" : "Object Temporal Localization",                                                                                      
                    "Object Between" : "Object Temporal Localization",                                                                                    
                    "Action After" : "Action Temporal Localization",                                                                                      
                    "Action Before" : "Action Temporal Localization",                                                                                     
                    "Action While" : "Action Temporal Localization",                                                                                      
                    "Action Between" : "Action Temporal Localization",                                                                                    
                    "And" : "Conjunction",                                                                                                                
                    "Xor" : "Conjunction",                                                                                                                
                    "Choose" : "Choose",      
                    "Object Equals" : "Equals",                                                                                                           
                    "Action Equals" : "Equals",                                                                                                           
                    "First Object" : "First/Last",                                                                                                        
                    "Last Object" : "First/Last",                                                                                                         
                    "First Action" : "First/Last",                                                                                                        
                    "Last Action" : "First/Last",                                                                                                         
                    "Longest Action" : "Longest/Shortest Action",                                                                                         
                    "Shortest Action" : "Longest/Shortest Action",                                                                                        
                    "Longer Choose" : "Choose",                                                                                                           
                    "Shorter Choose" : "Choose"}               

In [5]:
# First get a list of legal root types and print them out
if False:
    model_name = 'hcrn'
    folder = f'final_balanced_test_{model_name}'
    files = os.listdir(folder)

    root_types = set()
    for file in files:
        with open(f'{folder}/{file}', 'r') as f:
            data = json.load(f)

        if file.split(".")[1] == "txt":
            continue

        for key, value in data.items():
            subquestions = value['subquestion']
            top_level_q = list(subquestions.keys())[-1]
            top_info = subquestions[top_level_q]
            top_answer = top_info["answer"]

            if top_answer is None:
                continue
            if top_info["type"] in banned_qtypes:
                continue

            root_types.add(collapsed_qtypes[top_info['type']])
    root_types = list(root_types)
    print(root_types)

['Longest/Shortest Action', 'Interaction Temporal Localization', 'Conjunction', 'Interaction', 'Equals', 'Choose', 'First/Last']


In [7]:
from collections import Counter

# First get a list of legal root types and print them out
model_name = 'hcrn'
folder = f'final_balanced_test_{model_name}'
files = os.listdir(folder)

root_to_subqtypes = {root_type : Counter() for root_type in root_types}
for file in files:
    with open(f'{folder}/{file}', 'r') as f:
        data = json.load(f)
    
    if file.split(".")[1] == "txt":
        continue
        
    for key, value in data.items():
        subquestions = value['subquestion']
        top_level_q = list(subquestions.keys())[-1]
        top_info = subquestions[top_level_q]
        top_answer = top_info["answer"]
        
        if top_answer is None:
            continue
        if top_info["type"] in banned_qtypes:
            continue

        root_type = collapsed_qtypes[top_info['type']]
        root_to_subqtypes[root_type]['Overall'] += 1
        
        seen_types = set()
        for question, info in subquestions.items():
            if question == top_level_q:
                continue            
            if info["type"] in banned_qtypes:
                continue  
            answer = info["answer"] 
            if answer is None:
                continue
                
            qtype = collapsed_qtypes[info["type"]]
            if qtype in seen_types:
                continue
            seen_types.add(qtype)
            
            root_to_subqtypes[root_type][qtype] += 1

root_to_percentages = {}
for root_type, counts in root_to_subqtypes.items():
    root_to_percentages[root_type] = {}
    total_hierarchies = counts["Overall"]
    for qtype, count in counts.items():
        if qtype not in subq_type_ordering:
            continue
        root_to_percentages[root_type][qtype] = count / total_hierarchies
    
print(root_to_percentages)

{'Longest/Shortest Action': {}, 'Interaction Temporal Localization': {'Object Exists': 1.0, 'Relation Exists': 0.9833204034134988, 'Interaction': 1.0, 'Exists Temporal Localization': 0.9985259891388674, 'Interaction Temporal Localization': 0.24092319627618308, 'First/Last': 0.22787044220325833, 'Longest/Shortest Action': 0.006904577191621412}, 'Conjunction': {'Object Exists': 1.0, 'Relation Exists': 0.9999048645906006, 'Interaction': 1.0, 'First/Last': 0.46305574934990806, 'Exists Temporal Localization': 0.7997716750174415, 'Interaction Temporal Localization': 0.7874357835986554, 'Longest/Shortest Action': 0.004122534407306399}, 'Interaction': {'Object Exists': 1.0, 'Relation Exists': 0.6915880377418839, 'First/Last': 0.07108587877818646, 'Interaction': 0.01639213177674716, 'Exists Temporal Localization': 0.016552055013593477, 'Longest/Shortest Action': 0.00015992323684631377}, 'Equals': {'Object Exists': 1.0, 'Relation Exists': 0.9943244050582495, 'Interaction': 0.8002091008662751, 'E

In [8]:
# Produce the dataframe for the chosen root type
#model_name = 'hme'
#folder = f'final_balanced_test_{model_name}'
#files = os.listdir(folder)

df_list = {root_type : [] for root_type in root_to_percentages}
for file in files:
    with open(f'{folder}/{file}', 'r') as f:
        data = json.load(f)
        
    if file.split(".")[1] == "txt":
        continue
        
    for key, value in data.items():
        subquestions = value["subquestion"]
        hierarchy_dict = {}

        # First check if the top-level
        top_level_q = list(subquestions.keys())[-1]
        top_info = subquestions[top_level_q]
        top_answer = top_info["answer"]
        
        if top_answer is None:
            continue
        if "prediction" not in top_info:
            continue
        if top_info["type"] in banned_qtypes:
            continue

        root_type = collapsed_qtypes[top_info['type']]            
        top_pred = top_info["prediction"]
        top_correct = 1 if top_pred == top_answer else 0

        hierarchy_dict = {"top_correct":top_correct}
        
        subquestion_counts = {}
        for qtype in subq_type_ordering:
            if qtype in root_to_subqtypes[root_type]:
                subquestion_counts[qtype] = {"total":0, "correct":0}
        
        for question, info in subquestions.items():
            if question == top_level_q:
                continue            
            if info["type"] in banned_qtypes:
                continue  
            if "prediction" not in info:
                continue
            answer = info["answer"] 
            if answer is None:
                continue
                
            pred = info["prediction"]                                                                                                                        
            qtype = collapsed_qtypes[info["type"]]
            subquestion_counts[qtype]['total'] += 1
            
            if answer == pred:                                                                                                                               
                subquestion_counts[qtype]['correct'] += 1

        total = 0
        for qtype, counts in subquestion_counts.items():
            total += counts["total"]
        
        for qtype, counts in subquestion_counts.items():                        
            if counts['total'] > 0:
                hierarchy_dict[qtype] = (counts['total'] - counts['correct']) / counts['total']
            else:
                hierarchy_dict[qtype] = 0

        df_list[root_type].append(hierarchy_dict)                            

In [8]:
print(root_types)

root_type = 'First/Last'
df = pd.DataFrame(df_list[root_type])
print("Original shape: ", df.shape)
df = df.drop_duplicates()
print("Non-duplicate: ", df.shape)

print(root_to_percentages[root_type])

cols = []
for qtype in root_to_percentages[root_type]:
    cols.append(qtype)

X = df[cols]
y = df['top_correct']

['Interaction Temporal Localization', 'Conjunction', 'Longest/Shortest Action', 'Interaction', 'Choose', 'First/Last', 'Equals']
Original shape:  (62405, 7)
Non-duplicate:  (428, 7)
{'Object Exists': 1.0, 'Relation Exists': 0.9994071082908694, 'Interaction': 0.9009710604749543, 'Exists Temporal Localization': 0.9118193763420184, 'First/Last': 0.24803704771977053, 'Longest/Shortest Action': 0.012514822292728264}


In [9]:
import statsmodels.api as sm
logit_model=sm.Logit(y,X)
result=logit_model.fit()
print(result.summary2())

Optimization terminated successfully.
         Current function value: 0.670105
         Iterations 4
                               Results: Logit
Model:                   Logit                Pseudo R-squared:     0.012   
Dependent Variable:      top_correct          AIC:                  585.6098
Date:                    2021-11-18 07:43     BIC:                  609.9645
No. Observations:        428                  Log-Likelihood:       -286.80 
Df Model:                5                    LL-Null:              -290.24 
Df Residuals:            422                  LLR p-value:          0.23085 
Converged:               1.0000               Scale:                1.0000  
No. Iterations:          4.0000                                             
----------------------------------------------------------------------------
                              Coef.  Std.Err.    z    P>|z|   [0.025  0.975]
----------------------------------------------------------------------------
Objec

In [22]:
df.head()

Unnamed: 0,top_correct,Object Exists Accuracy,Relation Exists Accuracy,Interaction Accuracy,Interaction Temporal Localization Accuracy,Exists Temporal Localization Accuracy,First/Last Accuracy,Longest/Shortest Action Accuracy,Choose Accuracy,Equals Accuracy
0,0,1.0,1.0,0.5,1.0,0.666667,1.0,1.0,1.0,1.0
2,0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
3,0,1.0,0.5,0.666667,1.0,0.777778,1.0,1.0,1.0,1.0
4,1,1.0,1.0,0.5,1.0,0.0,1.0,1.0,1.0,1.0
5,1,1.0,1.0,0.5,1.0,0.333333,1.0,1.0,1.0,1.0
