# Preds analysis

In [1]:
import json
import evaluate
from collections import defaultdict, Counter
from tqdm import tqdm as tqdm

In [2]:
data_path = 'outputs/preds_notag_noaug.json'
aug_data_path = 'chataug.json'

with open(data_path) as f:
    data = json.load(f)
    
with open(aug_data_path) as f:
    aug_data = json.load(f)
    
meteor = evaluate.load('meteor')
google_bleu = evaluate.load('google_bleu')

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [3]:
N = len(data['eval_true_captions'])

average_len_true, average_len_pred = 0, 0
word_count_true, word_count_pred = defaultdict(int), defaultdict(int)
low_quality_true, low_quality_pred = 0, 0
total_male, total_female = 0, 0
male_female_confound, female_male_confound  = 0, 0
max_score, min_score, max_id, min_id = 0, 2, 0, 0
f2w_true, f2w_pred = [], []

for i, (true, pred) in tqdm(enumerate(zip(data['eval_true_captions'], data['eval_pred_captions']))):
    
    # vocabulary
    sptrue, sppred = true.split(), pred.split()
    for w in sptrue:
        word_count_true[w] += 1
    for w in sppred:
        word_count_pred[w] += 1
        
    f2w_true.append(" ".join(sptrue[:2]))
    f2w_pred.append(" ".join(sppred[:2]))
    
    # average length
    average_len_true += (1./N)*len(true)
    average_len_pred += (1./N)*len(pred)
    
    # count captions that start with "low quality recording" 
    low_quality_true += (true[:25]=="The low quality recording")
    low_quality_pred += (pred[:25]=="The low quality recording")
    
    # male / female
    male_in_true = any("male"==word for word in sptrue)
    female_in_true = any("female"==word for word in sptrue)
    total_male += male_in_true
    total_female += female_in_true
    if (not male_in_true) and female_in_true:
        male_female_confound += any("male"==word for word in sppred)
    elif (not female_in_true) and male_in_true:
        female_male_confound += any("female"==word for word in sppred)
        
    # metrics
    gleu_score = google_bleu.compute(predictions=[pred], references=[true])['google_bleu']
    meteor_score = meteor.compute(predictions=[pred], references=[true])['meteor']
    
    if gleu_score+meteor_score < min_score:
        min_score = gleu_score+meteor_score
        min_id = i
        
    if gleu_score+meteor_score > max_score:
        max_score = gleu_score+meteor_score
        max_id = i
        
top_n = 10
most_common_true = {k: v for k, v in sorted(word_count_true.items(), key=lambda item: -item[1])[:top_n]}
most_common_true_string = ", ".join([f"{key}: {value}" for key, value in most_common_true.items()])
most_common_pred = {k: v for k, v in sorted(word_count_pred.items(), key=lambda item: -item[1])[:top_n]}
most_common_pred_string = ", ".join([f"{key}: {value}" for key, value in most_common_pred.items()])
    
        
print("\n Pred vs. true stats\n","-"*50,"\n")
        
print(f"Average length true captions: {average_len_true:.3f}")
print(f"Average length pred captions: {average_len_pred:.3f}\n")


print(f"Vocabulary true captions: {len(word_count_true)}")
print(f"Vocabulary pred captions: {len(word_count_pred)}\n")

print(f"{top_n} most common words true:\n {most_common_true_string}")
print(f"{top_n} most common words pred:\n {most_common_pred_string}\n")

f2wc_true, f2wc_pred  = Counter(f2w_true), Counter(f2w_pred)
print(f"{top_n} most common first two words true:\n {f2wc_true.most_common(top_n)}")
print(f"{top_n} most common first two words pred:\n {f2wc_pred.most_common(top_n)}\n")

print(f"{low_quality_true} true captions start with 'the low quality recording'")
print(f"{low_quality_pred} predicted captions start with 'the low quality recording'\n")

print(f"Captions with 'male': {total_male}")
print(f"Captions with 'female': {total_female}")
print(f"Captions where true was male but predicted female: {female_male_confound}")
print(f"Captions where true was female but predicted male: {male_female_confound}\n")

print(f"Best prediction (score sum {max_score:.2f})")
print(f"TRUE: {data['eval_true_captions'][max_id]}")
print(f"PRED: {data['eval_pred_captions'][max_id]}\n")

print(f"Worst prediction (score sum {min_score:.2f})")
print(f"TRUE: {data['eval_true_captions'][min_id]}")
print(f"PRED: {data['eval_pred_captions'][min_id]}\n")


550it [00:08, 61.60it/s]


 Pred vs. true stats
 -------------------------------------------------- 

Average length true captions: 287.156
Average length pred captions: 254.927

Vocabulary true captions: 3140
Vocabulary pred captions: 1794

10 most common words true:
 a: 1576, is: 1216, the: 926, and: 882, The: 832, song: 550, This: 536, in: 530, of: 496, with: 422
10 most common words pred:
 a: 1282, and: 1007, is: 797, The: 781, the: 663, of: 515, recording: 472, song: 387, features: 384, in: 372

10 most common first two words true:
 [('The low', 115), ('This is', 113), ('This song', 46), ('A male', 33), ('The song', 31), ('This music', 24), ('This audio', 22), ('A female', 19), ('Someone is', 14), ('This clip', 10)]
10 most common first two words pred:
 [('The low', 321), ('This is', 93), ('This song', 29), ('This music', 28), ('The Electro', 18), ('A male', 16), ('The song', 16), ('This audio', 12), ('The Disco', 2), ('The Ambient', 2)]

114 true captions start with 'the low quality recording'
321 predict




In [10]:
with open('musiccaps_split.json', 'r') as fp:
    musiccaps_split = json.load(fp)
    
aug_data = {k: v for k, v in aug_data.items() if k in musiccaps_split['test']}
N = len(aug_data)

average_len_aug = 0
word_count_aug = defaultdict(int)
low_quality_aug = 0
f2w_aug = []
top_n = 10

for aug in tqdm(aug_data.values()):
    
    # vocabulary
    spaug = aug.split()
    for w in spaug:
        word_count_aug[w] += 1
        
    f2w_aug.append(" ".join(spaug[:2]))
    
    # average length
    average_len_aug += (1./N)*len(aug)
    
most_common_aug = {k: v for k, v in sorted(word_count_aug.items(), key=lambda item: -item[1])[:top_n]}
most_common_aug_string = ", ".join([f"{key}: {value}" for key, value in most_common_aug.items()])
f2wc_aug = Counter(f2w_aug)

print(f"Average length ChatAug captions: {average_len_aug:.3f}\n")
print(f"Vocabulary ChatAug captions: {len(word_count_aug)}\n")
print(f"{top_n} most common words pred:\n {most_common_aug_string}\n")
print(f"{top_n} most common first two words ChatAug:\n {f2wc_aug.most_common(top_n)}\n")

100%|██████████| 552/552 [00:00<00:00, 65774.31it/s]

Average length ChatAug captions: 279.203

Vocabulary ChatAug captions: 3679

10 most common words pred:
 a: 1673, and: 1046, the: 903, The: 780, is: 662, of: 408, with: 380, in: 374, by: 329, an: 276

10 most common first two words ChatAug:
 [('The music', 35), ('The track', 32), ('The recording', 23), ('In this', 19), ('The song', 18), ('A male', 15), ('The piece', 13), ('The instrumental', 13), ('This instrumental', 11), ('This is', 9)]




