# Preds analysis

In [13]:
import json
import evaluate
from collections import defaultdict, Counter
from tqdm import tqdm as tqdm
from musiccaps import load_musiccaps
import string
import numpy as np
import random

## Helper functions

In [2]:
meteor = evaluate.load('meteor')
google_bleu = evaluate.load('google_bleu')

ds = load_musiccaps(
    "./music_data",
    sampling_rate=16000,
    limit=None,
    num_proc=8,
    writer_batch_size=1000,
    return_without_audio=True,
)

def clean_text_for_aspect_metrics(caption):
    table = str.maketrans('','', string.punctuation)
    caption.replace("-"," ")
    # split the sentences into words
    desc = caption.split()
    #converts to lower case
    desc = [word.lower() for word in desc]
    #remove punctuation from each token
    desc = [word.translate(table) for word in desc]
    #remove hanging 's and a 
    #desc = [word for word in desc if(len(word)>1)]
    #remove tokens with numbers in them
    #desc = [word for word in desc if(word.isalpha())]
    #convert back to string
    caption = ' '.join(desc)
    return caption

# get a list of music-related words to use for evaluation
aspects = set()
for x in ds:
    aspect_str = x["aspect_list"]
    for t in "[]\"'":
        aspect_str = aspect_str.replace(t, "")
    aspects.update(aspect_str.split(", "))
# clean aspects
aspects = {clean_text_for_aspect_metrics(a) for a in aspects if len(a) > 2}
    
def wrap_in_space(s):
    return ' ' + s + ' '
    
# filter
all_captions = clean_text_for_aspect_metrics(' '.join(ds[i]['caption'] for i in range(len(ds))))
aspect_counts = {a: all_captions.count(wrap_in_space(a)) for a in aspects}
aspects = {a for a in aspects if aspect_counts[a] > 10}
aspects -= {'the'}

def compute_aspects_metric(true, pred):
    true = wrap_in_space(clean_text_for_aspect_metrics(true))
    pred = wrap_in_space(clean_text_for_aspect_metrics(pred))
    
    aspects_in_true = {a for a in aspects if wrap_in_space(a) in true}
    aspects_in_pred = {a for a in aspects if wrap_in_space(a) in pred}
    
    #print(aspects_in_true)
    #print(aspects_in_pred)
    
    precision = len(aspects_in_true&aspects_in_pred)/np.maximum(len(aspects_in_pred),1)
    recall = len(aspects_in_true&aspects_in_pred)/np.maximum(len(aspects_in_true), 1)
    
    return precision, recall

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
Using custom data configuration google--MusicCaps-bedc2a0fd7888f2f
Reusing dataset csv (/root/.cache/huggingface/datasets/google___csv/google--MusicCaps-bedc2a0fd7888f2f/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a)


In [27]:
def print_stats(data_true, data_pred):

    N = len(data_true)
    
    assert len(data_pred) == N

    average_len_true, average_len_pred = 0, 0
    word_count_true, word_count_pred = defaultdict(int), defaultdict(int)
    low_quality_true, low_quality_pred = 0, 0
    total_male, total_female = 0, 0
    male_female_confound, female_male_confound  = 0, 0
    max_score, min_score, max_id, min_id = 0, 2, 0, 0
    f2w_true, f2w_pred = [], []
    total_gleu_score, total_meteor_score = 0, 0
    aspect_precision, aspect_recall = [], []

    for i, (true, pred) in tqdm(enumerate(zip(data_true, data_pred))):

        # vocabulary
        sptrue, sppred = true.split(), pred.split()
        for w in sptrue:
            word_count_true[w] += 1
        for w in sppred:
            word_count_pred[w] += 1

        f2w_true.append(" ".join(sptrue[:2]))
        f2w_pred.append(" ".join(sppred[:2]))

        # average length
        average_len_true += (1./N)*len(true)
        average_len_pred += (1./N)*len(pred)

        # count captions that start with "low quality recording" 
        low_quality_true += (true[1:25]=="he low quality recording")
        low_quality_pred += (pred[1:25]=="he low quality recording")

        # male / female
        male_in_true = any("male"==word for word in sptrue)
        female_in_true = any("female"==word for word in sptrue)
        total_male += male_in_true
        total_female += female_in_true
        if (not male_in_true) and female_in_true:
            male_female_confound += any("male"==word for word in sppred)
        elif (not female_in_true) and male_in_true:
            female_male_confound += any("female"==word for word in sppred)

        # metrics
        gleu_score = google_bleu.compute(predictions=[pred], references=[true])['google_bleu']
        meteor_score = meteor.compute(predictions=[pred], references=[true])['meteor']

        if gleu_score+meteor_score < min_score:
            min_score = gleu_score+meteor_score
            min_id = i

        if gleu_score+meteor_score > max_score:
            max_score = gleu_score+meteor_score
            max_id = i
            
        precision, recall = compute_aspects_metric(true, pred)
        aspect_precision.append(precision)
        aspect_recall.append(recall)
        
    aspect_precision, aspect_recall = np.array(aspect_precision), np.array(aspect_recall)

    top_n = 10
    most_common_true = {k: v for k, v in sorted(word_count_true.items(), key=lambda item: -item[1])[:top_n]}
    most_common_true_string = ", ".join([f"{key}: {value}" for key, value in most_common_true.items()])
    most_common_pred = {k: v for k, v in sorted(word_count_pred.items(), key=lambda item: -item[1])[:top_n]}
    most_common_pred_string = ", ".join([f"{key}: {value}" for key, value in most_common_pred.items()])
    
    total_gleu_score = google_bleu.compute(predictions=data_pred, references=data_true)['google_bleu']
    total_meteor_score = meteor.compute(predictions=data_pred, references=data_true)['meteor']

    data_true_shuffled = sorted(data_true, key=lambda k: random.random())
    
    n_shuffles = 10
    shuffled_gleu_score, shuffled_meteor_score = 0, 0
    for _ in tqdm(range(n_shuffles)):
        shuffled_gleu_score += 1./n_shuffles * google_bleu.compute(predictions=data_pred, references=data_true_shuffled)['google_bleu']
        shuffled_meteor_score += 1./n_shuffles * meteor.compute(predictions=data_pred, references=data_true_shuffled)['meteor']
    spec_meteor = total_meteor_score-shuffled_meteor_score
    spec_gleu = total_gleu_score-shuffled_gleu_score

    print("\n Pred vs. true stats\n","-"*50,"\n")
    
    print(f"Test GLEU score: {total_gleu_score:.4f}")
    print(f"Test METEOR score: {total_meteor_score:.4f}")
    print(f"Shuffled test GLEU score: {shuffled_gleu_score:.4f}")
    print(f"Shuffled test METEOR score: {shuffled_meteor_score:.4f}")
    print(f"Test Spec-GLEU score: {spec_gleu:.4f}")
    print(f"Test Spec-METEOR score: {spec_meteor:.4f}")
    
    print(f"Aspect precision: {aspect_precision.mean():.3f}±{aspect_precision.std():.3f}")
    print(f"Aspect recall: {aspect_recall.mean():.3f}±{aspect_recall.std():.3f}")

    print(f"Average length true captions: {average_len_true:.3f}")
    print(f"Average length pred captions: {average_len_pred:.3f}\n")

    print(f"Vocabulary true captions: {len(word_count_true)}")
    print(f"Vocabulary pred captions: {len(word_count_pred)}\n")

    print(f"{top_n} most common words true:\n {most_common_true_string}")
    print(f"{top_n} most common words pred:\n {most_common_pred_string}\n")

    f2wc_true, f2wc_pred  = Counter(f2w_true), Counter(f2w_pred)
    print(f"{top_n} most common first two words true:\n {f2wc_true.most_common(top_n)}")
    print(f"{top_n} most common first two words pred:\n {f2wc_pred.most_common(top_n)}\n")

    print(f"{low_quality_true} true captions start with 'the low quality recording'")
    print(f"{low_quality_pred} predicted captions start with 'the low quality recording'\n")

    print(f"Captions with 'male': {total_male}")
    print(f"Captions with 'female': {total_female}")
    print(f"Captions where true was male but predicted female: {female_male_confound}")
    print(f"Captions where true was female but predicted male: {male_female_confound}\n")

    print(f"Best prediction (score sum {max_score:.2f})")
    print(f"TRUE: {data_true[max_id]}")
    print(f"PRED: {data_pred[max_id]}\n")

    print(f"Worst prediction (score sum {min_score:.2f})")
    print(f"TRUE: {data_true[min_id]}")
    print(f"PRED: {data_pred[min_id]}\n")


## Results

In [28]:
data_path = 'outputs/preds_gpt2_notag_chataug.json'

with open(data_path) as f:
    data = json.load(f)

print_stats(data['eval_true_captions'], data['eval_pred_captions'])

550it [00:05, 107.43it/s]
100%|██████████| 10/10 [00:11<00:00,  1.18s/it]


 Pred vs. true stats
 -------------------------------------------------- 

Test GLEU score: 0.0915
Test METEOR score: 0.2123
Shuffled test GLEU score: 0.0769
Shuffled test METEOR score: 0.1849
Test Spec-GLEU score: 0.0147
Test Spec-METEOR score: 0.0274
Aspect precision: 0.148±0.152
Aspect recall: 0.166±0.157
Average length true captions: 287.156
Average length pred captions: 248.364

Vocabulary true captions: 3140
Vocabulary pred captions: 1363

10 most common words true:
 a: 1576, is: 1216, the: 926, and: 882, The: 832, song: 550, This: 536, in: 530, of: 496, with: 422
10 most common words pred:
 and: 1214, a: 1027, The: 861, is: 620, recording: 619, of: 514, song: 498, features: 485, quality: 479, the: 466

10 most common first two words true:
 [('The low', 115), ('This is', 113), ('This song', 46), ('A male', 33), ('The song', 31), ('This music', 24), ('This audio', 22), ('A female', 19), ('Someone is', 14), ('This clip', 10)]
10 most common first two words pred:
 [('The low', 432)




In [29]:
data_path = 'outputs/preds_gpt2_notag_noaug.json'

with open(data_path) as f:
    data = json.load(f)

print_stats(data['eval_true_captions'], data['eval_pred_captions'])

550it [00:04, 132.76it/s]
100%|██████████| 10/10 [00:12<00:00,  1.26s/it]


 Pred vs. true stats
 -------------------------------------------------- 

Test GLEU score: 0.0915
Test METEOR score: 0.2155
Shuffled test GLEU score: 0.0756
Shuffled test METEOR score: 0.1812
Test Spec-GLEU score: 0.0159
Test Spec-METEOR score: 0.0343
Aspect precision: 0.154±0.163
Aspect recall: 0.167±0.166
Average length true captions: 287.156
Average length pred captions: 254.927

Vocabulary true captions: 3140
Vocabulary pred captions: 1794

10 most common words true:
 a: 1576, is: 1216, the: 926, and: 882, The: 832, song: 550, This: 536, in: 530, of: 496, with: 422
10 most common words pred:
 a: 1282, and: 1007, is: 797, The: 781, the: 663, of: 515, recording: 472, song: 387, features: 384, in: 372

10 most common first two words true:
 [('The low', 115), ('This is', 113), ('This song', 46), ('A male', 33), ('The song', 31), ('This music', 24), ('This audio', 22), ('A female', 19), ('Someone is', 14), ('This clip', 10)]
10 most common first two words pred:
 [('The low', 321), ('T




In [30]:
data_path = 'outputs/preds_lstm_notag_noaug.json'

with open(data_path) as f:
    data = json.load(f)

# [6:-6] to strip <sos> and <eos>
# x[0] for true captions because each element is ['caption'] instead of 'caption'
print_stats([x[0][6:-6] for x in data['true_captions']], [x[6:-6] for x in data['predictions']])

549it [00:04, 136.38it/s]
100%|██████████| 10/10 [00:08<00:00,  1.13it/s]


 Pred vs. true stats
 -------------------------------------------------- 

Test GLEU score: 0.0836
Test METEOR score: 0.1860
Shuffled test GLEU score: 0.0794
Shuffled test METEOR score: 0.1746
Test Spec-GLEU score: 0.0042
Test Spec-METEOR score: 0.0114
Aspect precision: 0.152±0.152
Aspect recall: 0.162±0.145
Average length true captions: 278.344
Average length pred captions: 211.441

Vocabulary true captions: 1768
Vocabulary pred captions: 216

10 most common words true:
 the: 1755, a: 1739, is: 1215, and: 881, this: 667, song: 639, in: 558, <unk>: 515, of: 496, with: 423
10 most common words pred:
 the: 1395, and: 1231, is: 1047, a: 1012, song: 671, of: 551, guitar: 417, bass: 397, quality: 392, recording: 386

10 most common first two words true:
 [('the low', 118), ('this is', 112), ('this song', 46), ('a male', 33), ('the song', 31), ('this music', 24), ('this audio', 22), ('a female', 19), ('someone is', 15), ('this clip', 10)]
10 most common first two words pred:
 [('the low', 2




## ChatAug data

In [8]:
aug_data_path = 'chataug.json'
with open(aug_data_path) as f:
    aug_data = json.load(f)

In [9]:
with open('musiccaps_split.json', 'r') as fp:
    musiccaps_split = json.load(fp)
    
aug_data = {k: v for k, v in aug_data.items() if k in musiccaps_split['train']}
N = len(aug_data)

average_len_aug = 0
word_count_aug = defaultdict(int)
low_quality_aug = 0
f2w_aug = []
top_n = 10

for aug in tqdm(aug_data.values()):
    
    # vocabulary
    spaug = aug.split()
    for w in spaug:
        word_count_aug[w] += 1
        
    f2w_aug.append(" ".join(spaug[:2]))
    
    # average length
    average_len_aug += (1./N)*len(aug)
    
most_common_aug = {k: v for k, v in sorted(word_count_aug.items(), key=lambda item: -item[1])[:top_n]}
most_common_aug_string = ", ".join([f"{key}: {value}" for key, value in most_common_aug.items()])
f2wc_aug = Counter(f2w_aug)

print(f"Average length ChatAug captions: {average_len_aug:.3f}\n")
print(f"Vocabulary ChatAug captions: {len(word_count_aug)}\n")
print(f"{top_n} most common words pred:\n {most_common_aug_string}\n")
print(f"{top_n} most common first two words ChatAug:\n {f2wc_aug.most_common(top_n)}\n")

100%|██████████| 4417/4417 [00:00<00:00, 88489.03it/s]

Average length ChatAug captions: 277.987

Vocabulary ChatAug captions: 10392

10 most common words pred:
 a: 13110, and: 8359, the: 7485, The: 6265, is: 5313, of: 3510, with: 3224, in: 3156, by: 2473, an: 2205

10 most common first two words ChatAug:
 [('The recording', 243), ('The music', 229), ('The track', 189), ('The instrumental', 165), ('A male', 142), ('In this', 140), ('The song', 125), ('The piece', 94), ('This instrumental', 93), ('The composition', 68)]




