Analyses how many conflicting combinations from combinations to examples exist

In [1]:
import os
from typing import List, Dict

import numpy as np
import pandas as pd

project_dir = '/home/rp218/luke-for-roko'

# ## Load Data

# In[3]:


# Parse the wrong format to the correct one
# path = os.path.join(project_dir, "Extracted_Concepts/final_dict_new_codex_no_pruning.pkl")
# path = os.path.join(project_dir, "cmd/final_dict_old_codex.pkl")
path = os.path.join(project_dir, "cmd/birds_flowers_final_dict_new_codex.pkl")
luke_output = pd.read_pickle(path)

labels_keys = ['id', 'label', 'concepts']
labels_dict = {key: luke_output[key] for key in labels_keys}
labels_df_filtered = pd.DataFrame.from_dict(labels_dict)

labels = labels_df_filtered.copy()
print(labels_df_filtered)

concepts_text = pd.DataFrame.from_dict({"explanations": luke_output["explanations"]})

                                    id   label  \
0               Artic_Tern_0024_143268    bird   
1     Northern_Waterthrush_0084_177239    bird   
2              Common_Tern_0067_149540    bird   
3               Horned_Lark_0012_74511    bird   
4          Northern_Flicker_0006_28290    bird   
...                                ...     ...   
1995                       image_00479  flower   
1996                       image_08137  flower   
1997                       image_04682  flower   
1998                       image_00690  flower   
1999                       image_02314  flower   

                                               concepts  
0     [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...  
1     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...  
2     [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, ...  
3     [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...  
4     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...  
...                                                

In [2]:
print(len(np.unique(luke_output['id'])))

2000


In [3]:
class ConceptCounter:
    def __init__(self, labels: List[str]):
        self.store = {}
        self.labels = labels

    def add_val(self, concepts: np.ndarray, label: str):
        concepts = tuple(concepts)

        curr = self.store.get(concepts, [0] * len(self.labels))
        curr[self.labels.index(label)] += 1
        self.store[concepts] = curr

    def get(self) -> Dict[int, int]:
        return self.store

In [4]:
def pretty_print(cnt_arr: List[int], str_labels: List[str]):
    vals = np.array(cnt_arr)
    indices, = vals.nonzero()

    print("Values: ", end="")
    for idx in indices:
        print(f"{str_labels[idx]}: {vals[idx]},", end="")
    print()
    print(f"Results in at least {sum(vals) - max(vals)} incorrectly classified examples")
    print()

In [7]:
test_labels = labels.iloc[:]

str_labels = list(np.unique(test_labels.label))
c_counter = ConceptCounter(str_labels)
for label in str_labels:
    print(f"Running for {label}")
    out_df = test_labels[test_labels.label == label]
    for concepts in out_df["concepts"]:
        c_counter.add_val(concepts, label)

results = c_counter.get()
conflicts = []
for k, v in results.items():
    vals = np.array(v)
    indices, = vals.nonzero()
    if len(indices) > 1:
        incorrectly_classified_cnt = sum(vals) - max(vals)
        conflicts.append(
            {"concept_combination": k, "count_arr": vals, "incorrectly_classified_cnt": incorrectly_classified_cnt})

for conflict in conflicts:
    print(f"Conflict with combination {conflict['concept_combination']}")
    pretty_print(conflict["count_arr"], str_labels)

print("Total nubmer of examples which must be incorrectly classified")
print(sum([conflict["incorrectly_classified_cnt"] for conflict in conflicts]))

Running for bird
Running for flower
Conflict with combination (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
Values: bird: 254,flower: 233,
Results in at least 233 incorrectly classified examples

Conflict with combination (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
Values: bird: 11,flower: 2,
Results in at least 2 incorrectly classified examples

Conflict with combination (0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
Values: bi