In [2]:
import json
import random

with open('../data/important_words_SHAP.json', 'r') as f:
    important_words_data = json.load(f)

with open('../data/rationale_tokens_list.json', 'r') as f:
    rationale_tokens_data = json.load(f)

with open("../data/annotated_labels.json", "r") as f:
    annotated_labels = json.load(f)

with open("../data/classified_labels.json", "r") as f:
    classified_labels = json.load(f)

In [5]:
high_overlap_correct_samples = []
high_overlap_misclassified_samples = []
low_overlap_correct_samples = []
low_overlap_misclassified_samples = []

total_count = 0

for i in range(len(annotated_labels)):
    if annotated_labels[i] == 1:
        rationale_tokens = set(rationale_tokens_data[i])
        important_tokens = set(important_words_data[i])
        
        intersection = rationale_tokens & important_tokens
        overlap_rate = len(intersection) / len(rationale_tokens) if rationale_tokens else 0

        correct_classification = annotated_labels[i] == classified_labels[i]
        
        sample = {
            "index": i,
            "annotated_label": annotated_labels[i],
            "classified_label": classified_labels[i],
            "overlap_rate": overlap_rate,
            "rationale_tokens": list(rationale_tokens),
            "shap_tokens": list(important_tokens),
            "intersection": list(intersection)
        }

        if overlap_rate > 0.7:
            if correct_classification:
                high_overlap_correct_samples.append(sample)
            else:
                high_overlap_misclassified_samples.append(sample)

        elif overlap_rate < 0.3:
            if correct_classification:
                low_overlap_correct_samples.append(sample)
            else:
                low_overlap_misclassified_samples.append(sample)

        total_count += 1

def sample_cases(samples_list, n=5):
    return random.sample(samples_list, min(n, len(samples_list)))

high_correct_sampled = sample_cases(high_overlap_correct_samples)
high_misclassified_sampled = sample_cases(high_overlap_misclassified_samples)
low_correct_sampled = sample_cases(low_overlap_correct_samples)
low_misclassified_sampled = sample_cases(low_overlap_misclassified_samples)

output_data = {
    "High Overlap Correct": high_correct_sampled,
    "High Overlap Misclassified": high_misclassified_sampled,
    "Low Overlap Correct": low_correct_sampled,
    "Low Overlap Misclassified": low_misclassified_sampled
}

# Save to JSON file
# with open("sampled_cases.json", "w") as f:
#     json.dump(output_data, f, indent=4)

# print("\nSamples saved to 'sampled_cases.json'")

# Print results category by category
def print_samples(samples, name):
    print(f"\nCategory: {name}")
    for sample in samples:
        print(sample)

# Print each category
print_samples(high_correct_sampled, "High Overlap Correct")
print_samples(high_misclassified_sampled, "High Overlap Misclassified")
print_samples(low_correct_sampled, "Low Overlap Correct")
print_samples(low_misclassified_sampled, "Low Overlap Misclassified")



Category: High Overlap Correct
{'index': 1270, 'annotated_label': 1, 'classified_label': 1, 'overlap_rate': 0.75, 'rationale_tokens': ['graves', 'mass', 'moslem', 'unmarked'], 'shap_tokens': ['would', 'graves', 'moslem', 'unmarked'], 'intersection': ['graves', 'moslem', 'unmarked']}
{'index': 9277, 'annotated_label': 1, 'classified_label': 1, 'overlap_rate': 1.0, 'rationale_tokens': ['nigger'], 'shap_tokens': ['the', 'punch', 'nigger', 'deserves'], 'intersection': ['nigger']}
{'index': 273, 'annotated_label': 1, 'classified_label': 1, 'overlap_rate': 1.0, 'rationale_tokens': ['muzrat'], 'shap_tokens': ['muzrat', 'brexit', 'will'], 'intersection': ['muzrat']}
{'index': 4930, 'annotated_label': 1, 'classified_label': 1, 'overlap_rate': 1.0, 'rationale_tokens': ['faggots', 'lover', 'nigger'], 'shap_tokens': ['us', 'lover', 'nigger', 'faggots'], 'intersection': ['faggots', 'nigger', 'lover']}
{'index': 4261, 'annotated_label': 1, 'classified_label': 1, 'overlap_rate': 1.0, 'rationale_toke