# Preds analysis

In [89]:
import json
import evaluate
from collections import defaultdict

In [90]:
data_path = 'outputs/preds_notag_noaug.json'
with open('outputs/preds_notag_noaug.json') as f:
    data = json.load(f)
    
meteor = evaluate.load('meteor')
google_bleu = evaluate.load('google_bleu')

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [97]:
N = len(data['eval_true_captions'])

average_len_true, average_len_pred = 0, 0
word_count_true, word_count_pred = defaultdict(int), defaultdict(int)
low_quality_true, low_quality_pred = 0, 0
total_male, total_female = 0, 0
male_female_confound, female_male_confound  = 0, 0
max_score, min_score, max_id, min_id = 0, 2, 0, 0


for i, (true, pred) in enumerate(zip(data['eval_true_captions'], data['eval_pred_captions'])):
    
    # vocabulary
    sptrue, sppred = true.split(), pred.split()
    for w in sptrue:
        word_count_true[w] += 1
    for w in sppred:
        word_count_pred[w] += 1
    
    # average length
    average_len_true += (1./N)*len(true)
    average_len_pred += (1./N)*len(pred)
    
    # count captions that start with "low quality recording" 
    low_quality_true += (true[:25]=="The low quality recording")
    low_quality_pred += (pred[:25]=="The low quality recording")
    
    # male / female
    male_in_true = any("male"==word for word in sptrue)
    female_in_true = any("female"==word for word in sptrue)
    total_male += male_in_true
    total_female += female_in_true
    if (not male_in_true) and female_in_true:
        male_female_confound += any("male"==word for word in sppred)
    elif (not female_in_true) and male_in_true:
        female_male_confound += any("female"==word for word in sppred)
        
    # metrics
    gleu_score = google_bleu.compute(predictions=[pred], references=[true])['google_bleu']
    meteor_score = meteor.compute(predictions=[pred], references=[true])['meteor']
    
    if gleu_score+meteor_score < min_score:
        min_score = gleu_score+meteor_score
        min_id = i
        
    if gleu_score+meteor_score > max_score:
        max_score = gleu_score+meteor_score
        max_id = i
        
top_n = 8
most_common_true = {k: v for k, v in sorted(word_count_true.items(), key=lambda item: -item[1])[:top_n]}
most_common_true_string = ", ".join([f"{key}: {value}" for key, value in most_common_true.items()])
most_common_pred = {k: v for k, v in sorted(word_count_pred.items(), key=lambda item: -item[1])[:top_n]}
most_common_pred_string = ", ".join([f"{key}: {value}" for key, value in most_common_pred.items()])
        
print("\n Pred vs. true stats\n","-"*50,"\n")
        
print(f"Average length true captions: {average_len_true:.3f}")
print(f"Average length pred captions: {average_len_pred:.3f}\n")

print(f"Vocabulary true captions: {len(word_count_true)}")
print(f"Vocabulary pred captions: {len(word_count_pred)}\n")


print(f"Most common words true:\n {most_common_true_string}")
print(f"Most common words pred:\n {most_common_pred_string}\n")

print(f"{low_quality_true} true captions start with 'the low quality recording'")
print(f"{low_quality_pred} predicted captions start with 'the low quality recording'\n")

print(f"Captions with 'male': {total_male}")
print(f"Captions with 'female': {total_female}")
print(f"Captions where true was male but predicted female: {female_male_confound}")
print(f"Captions where true was female but predicted male: {male_female_confound}\n")

print(f"Best prediction (score sum {max_score:.2f})")
print(f"TRUE: {data['eval_true_captions'][max_id]}")
print(f"PRED: {data['eval_pred_captions'][max_id]}\n")

print(f"Worst prediction (score sum {min_score:.2f})")
print(f"TRUE: {data['eval_true_captions'][min_id]}")
print(f"PRED: {data['eval_pred_captions'][min_id]}\n")



 Pred vs. true stats
 -------------------------------------------------- 

Average length true captions: 287.156
Average length pred captions: 254.927

Vocabulary true captions: 3140
Vocabulary pred captions: 1794

Most common words true:
 a: 1576, is: 1216, the: 926, and: 882, The: 832, song: 550, This: 536, in: 530
Most common words pred:
 a: 1282, and: 1007, is: 797, The: 781, the: 663, of: 515, recording: 472, song: 387

114 true captions start with 'the low quality recording'
321 predicted captions start with 'the low quality recording'

Captions with 'male': 182
Captions with 'female': 86
Captions where true was male but predicted female: 25
Captions where true was female but predicted male: 28

Best prediction (score sum 1.25)
TRUE: The low quality recording features a live performance of a traditional song that consists of harmonizing male vocals singing over plucked strings melody and wooden percussion. It sounds passionate and soulful, but the recording is very noisy and in 