# Preds analysis

In [2]:
!pip install datasets
!pip install evaluate

Collecting evaluate
  Using cached evaluate-0.4.0-py3-none-any.whl (81 kB)
Installing collected packages: evaluate
Successfully installed evaluate-0.4.0


In [21]:
!pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting absl-py
  Using cached absl_py-1.4.0-py3-none-any.whl (126 kB)
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25ldone
[?25h  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24936 sha256=916455fe14c84c8b5b5217b4edf2041c22209e93f15ac491a1a3e5763bb2c049
  Stored in directory: /Users/corinacaraconcea/Library/Caches/pip/wheels/b0/3f/ac/cc3bc304f50c77ef38d79d8e4e2684313de39af543cb4eb3da
Successfully built rouge_score
Installing collected packages: absl-py, rouge_score
Successfully installed absl-py-1.4.0 rouge_score-0.1.2


In [107]:
import json
import evaluate
from collections import defaultdict, Counter
from tqdm import tqdm as tqdm
from musiccaps import load_musiccaps
import string
import numpy as np
import random
import pandas as pd
import re
import string
import itertools

## Helper functions

In [142]:
meteor = evaluate.load('meteor')
google_bleu = evaluate.load('google_bleu')
rouge = evaluate.load('rouge')

ds = load_musiccaps(
    "./music_data",
    sampling_rate=16000,
    limit=None,
    num_proc=8,
    writer_batch_size=1000,
    return_without_audio=True,
)

def clean_text_for_aspect_metrics(caption):
    table = str.maketrans('','', string.punctuation)
    caption.replace("-"," ")
    # split the sentences into words
    desc = caption.split()
    #converts to lower case
    desc = [word.lower() for word in desc]
    #remove punctuation from each token
    desc = [word.translate(table) for word in desc]
    #remove hanging 's and a 
    #desc = [word for word in desc if(len(word)>1)]
    #remove tokens with numbers in them
    #desc = [word for word in desc if(word.isalpha())]
    #convert back to string
    caption = ' '.join(desc)
    return caption

def preprocessing_remove_unk(text_input):
    # remove punctuations
    desc = re.sub(r'[^\w\s]',' ',text_input)
    table = str.maketrans('','',string.punctuation)

    # turn uppercase letters into lowercase ones
    desc = text_input.lower()

    # split into words
    desc = desc.split(' ')
    
    try: 
        # remove <unk> tokens
        desc.remove("<unk>")
    except ValueError:
        desc = desc
    try: 
        # remove <unk> tokens
        desc.remove("unk")
    except ValueError:
        desc = desc
        
    # remove the punctuations
    text_no_punctuation = [word.translate(table) for word in desc]

    # join the caption words
    caption = ' '.join(desc)
    
    return caption

# get a list of music-related words to use for evaluation
aspects = set()
for x in ds:
    aspect_str = x["aspect_list"]
    for t in "[]\"'":
        aspect_str = aspect_str.replace(t, "")
    aspects.update(aspect_str.split(", "))
# clean aspects
aspects = {clean_text_for_aspect_metrics(a) for a in aspects if len(a) > 2}
    
def wrap_in_space(s):
    return ' ' + s + ' '
    
# filter
all_captions = clean_text_for_aspect_metrics(' '.join(ds[i]['caption'] for i in range(len(ds))))
aspect_counts = {a: all_captions.count(wrap_in_space(a)) for a in aspects}
aspects = {a for a in aspects if aspect_counts[a] > 10}
aspects -= {'the'}

def compute_aspects_metric(true, pred):
    true = wrap_in_space(clean_text_for_aspect_metrics(true))
    pred = wrap_in_space(clean_text_for_aspect_metrics(pred))
    
    aspects_in_true = {a for a in aspects if wrap_in_space(a) in true}
    aspects_in_pred = {a for a in aspects if wrap_in_space(a) in pred}
    
    #print(aspects_in_true)
    #print(aspects_in_pred)
    
    precision = len(aspects_in_true&aspects_in_pred)/np.maximum(len(aspects_in_pred),1)
    recall = len(aspects_in_true&aspects_in_pred)/np.maximum(len(aspects_in_true), 1)
    
    return precision, recall

[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/corinacaraconcea/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     /Users/corinacaraconcea/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /Users/corinacaraconcea/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [143]:
def import_outputs(data_path):
    
    with open(data_path) as f:
        data = json.load(f)

    # multiple references for one captions
    true_captions = data['eval_true_captions']
    if isinstance(true_captions, list) and any(isinstance(item, list) for item in true_captions):
        true_captions = list(itertools.chain(*true_captions))
    # single prediction
    predicted_captions = data['eval_pred_captions']
    if isinstance(predicted_captions, list) and any(isinstance(item, list) for item in predicted_captions):
        predicted_captions = list(itertools.chain(*predicted_captions))


    return true_captions,predicted_captions,

In [144]:
def compute_metrics(true_captions,predicted_captions):
    true_clean = []
    pred_clean = []
    for i, (true, pred) in tqdm(enumerate(zip(true_captions, predicted_captions))):
        # preprocess captions and predictions to remove punctuations and <unk> tokens
        true = preprocessing_remove_unk(true)
        true_clean.append(true)
        pred = preprocessing_remove_unk(pred)
        pred_clean.append(pred)
        with open('clean_pred.txt','w') as f:
            for i in range(len(pred_clean)):
                f.write(pred_clean[i]+'\n')
        with open('clean_caption.txt','w') as f:
            for i in range(len(true_clean)):
                f.write(true_clean[i]+'\n')

    total_google_bleu = google_bleu.compute(predictions = pred_clean,references = true_clean)
    total_rouge = rouge.compute(predictions = pred_clean,references = true_clean)
    total_meteor = meteor.compute(predictions = pred_clean,references = true_clean)
    
    return total_google_bleu, total_rouge, total_meteor


In [145]:
def print_metrics(data_paths,methods):
    for i, (data_path, method) in enumerate(zip(data_paths,methods)):
       true_captions,predicted_captions = import_outputs(data_path)
       google_bleu_score, rouge_score, meteor_score = compute_metrics(true_captions,predicted_captions)

       print(method,"test google BLEU score:",str(np.round(google_bleu_score["google_bleu"],3)))
       print(method,"test google ROUGE score:",str(np.round(rouge_score["rouge1"],3)))
       print(method,"test google METEOR score:",str(np.round(meteor_score["meteor"],3)))

In [148]:
data_paths = ["outputs/preds_lstm_attn_summaries.json",
              "outputs/preds_lstm_no_attn_summaries.json",
              "outputs/preds_lstm_attn_no_tag_no_aug.json",
              "outputs/preds_lstm_no_attn_no_tag_no_aug.json",
              "outputs/preds_gpt2_summarized_notag_noaug.json",
              "outputs/preds_gpt2_summarized.json",
              "outputs/preds_gpt2_notag_chataug.json",
              "outputs/preds_gpt2_notag_noaug.json"]

methods = ["LSTM with attention and summarized dataset",
           "LSTM without attention and summarized dataset",
           "LSTM with attention no ChatAug",
           "LSTM no attention no ChatAug",
           "GPT-2 no tag no aug summarized dataset",
           "GPT-2 ?? summarized",
           "GPT-2 no tag ChatAug",
           "GPT-2 no tag no ChatAug"]

In [149]:
print_metrics(data_paths,methods)

1647it [00:00, 1810.24it/s]


LSTM with attention and summarized dataset test google BLEU score: 0.089
LSTM with attention and summarized dataset google ROUGE score: 0.305
LSTM with attention and summarized dataset google METEOR score: 0.205


1647it [00:00, 1895.88it/s]


LSTM without attention and summarized dataset test google BLEU score: 0.092
LSTM without attention and summarized dataset google ROUGE score: 0.312
LSTM without attention and summarized dataset google METEOR score: 0.207


549it [00:00, 3212.08it/s]


LSTM with attention no ChatAug test google BLEU score: 0.083
LSTM with attention no ChatAug google ROUGE score: 0.275
LSTM with attention no ChatAug google METEOR score: 0.183


549it [00:00, 3139.35it/s]


LSTM no attention no ChatAug test google BLEU score: 0.084
LSTM no attention no ChatAug google ROUGE score: 0.276
LSTM no attention no ChatAug google METEOR score: 0.185


550it [00:00, 3473.61it/s]


GPT-2 no tag no aug summarized dataset test google BLEU score: 0.087
GPT-2 no tag no aug summarized dataset google ROUGE score: 0.267
GPT-2 no tag no aug summarized dataset google METEOR score: 0.205


550it [00:00, 3365.30it/s]


GPT-2 ?? summarized test google BLEU score: 0.086
GPT-2 ?? summarized google ROUGE score: 0.267
GPT-2 ?? summarized google METEOR score: 0.202


550it [00:00, 3072.73it/s]


GPT-2 no tag ChatAug test google BLEU score: 0.097
GPT-2 no tag ChatAug google ROUGE score: 0.282
GPT-2 no tag ChatAug google METEOR score: 0.221


550it [00:00, 3019.70it/s]


GPT-2 no tag no ChatAug test google BLEU score: 0.091
GPT-2 no tag no ChatAug google ROUGE score: 0.273
GPT-2 no tag no ChatAug google METEOR score: 0.21


In [22]:
def print_stats(data_true, data_pred):

    N = len(data_true)
    
    assert len(data_pred) == N

    average_len_true, average_len_pred = 0, 0
    word_count_true, word_count_pred = defaultdict(int), defaultdict(int)
    low_quality_true, low_quality_pred = 0, 0
    total_male, total_female = 0, 0
    male_female_confound, female_male_confound  = 0, 0
    max_score, min_score, max_id, min_id = 0, 2, 0, 0
    f2w_true, f2w_pred = [], []
    total_gleu_score, total_meteor_score = 0, 0
    aspect_precision, aspect_recall = [], []

    for i, (true, pred) in tqdm(enumerate(zip(data_true, data_pred))):

        # vocabulary
        sptrue, sppred = true.split(), pred.split()
        for w in sptrue:
            word_count_true[w] += 1
        for w in sppred:
            word_count_pred[w] += 1

        f2w_true.append(" ".join(sptrue[:2]))
        f2w_pred.append(" ".join(sppred[:2]))

        # average length
        average_len_true += (1./N)*len(true)
        average_len_pred += (1./N)*len(pred)

        # count captions that start with "low quality recording" 
        low_quality_true += (true[1:25]=="he low quality recording")
        low_quality_pred += (pred[1:25]=="he low quality recording")

        # male / female
        male_in_true = any("male"==word for word in sptrue)
        female_in_true = any("female"==word for word in sptrue)
        total_male += male_in_true
        total_female += female_in_true
        if (not male_in_true) and female_in_true:
            male_female_confound += any("male"==word for word in sppred)
        elif (not female_in_true) and male_in_true:
            female_male_confound += any("female"==word for word in sppred)

        # metrics
        gleu_score = google_bleu.compute(predictions=[pred], references=[true])['google_bleu']
        meteor_score = meteor.compute(predictions=[pred], references=[true])['meteor']

        if gleu_score+meteor_score < min_score:
            min_score = gleu_score+meteor_score
            min_id = i

        if gleu_score+meteor_score > max_score:
            max_score = gleu_score+meteor_score
            max_id = i
            
        precision, recall = compute_aspects_metric(true, pred)
        aspect_precision.append(precision)
        aspect_recall.append(recall)
        
    aspect_precision, aspect_recall = np.array(aspect_precision), np.array(aspect_recall)

    top_n = 10
    most_common_true = {k: v for k, v in sorted(word_count_true.items(), key=lambda item: -item[1])[:top_n]}
    most_common_true_string = ", ".join([f"{key}: {value}" for key, value in most_common_true.items()])
    most_common_pred = {k: v for k, v in sorted(word_count_pred.items(), key=lambda item: -item[1])[:top_n]}
    most_common_pred_string = ", ".join([f"{key}: {value}" for key, value in most_common_pred.items()])
    
    total_gleu_score = google_bleu.compute(predictions=data_pred, references=data_true)['google_bleu']
    total_meteor_score = meteor.compute(predictions=data_pred, references=data_true)['meteor']

    n_shuffles = 10
    shuffled_gleu_score, shuffled_meteor_score = 0, 0
    for _ in tqdm(range(n_shuffles)):
        data_true_shuffled = sorted(data_true, key=lambda k: random.random())
        shuffled_gleu_score += 1./n_shuffles * google_bleu.compute(predictions=data_pred, references=data_true_shuffled)['google_bleu']
        shuffled_meteor_score += 1./n_shuffles * meteor.compute(predictions=data_pred, references=data_true_shuffled)['meteor']
    spec_meteor = total_meteor_score-shuffled_meteor_score
    spec_gleu = total_gleu_score-shuffled_gleu_score

    print("\n Pred vs. true stats\n","-"*50,"\n")
    
    print(f"Test GLEU score: {total_gleu_score:.4f}")
    print(f"Test METEOR score: {total_meteor_score:.4f}")
    print(f"Shuffled test GLEU score: {shuffled_gleu_score:.4f}")
    print(f"Shuffled test METEOR score: {shuffled_meteor_score:.4f}")
    print(f"Test Spec-GLEU score: {spec_gleu:.4f}")
    print(f"Test Spec-METEOR score: {spec_meteor:.4f}")
    
    print(f"Aspect precision: {aspect_precision.mean():.3f}±{aspect_precision.std():.3f}")
    print(f"Aspect recall: {aspect_recall.mean():.3f}±{aspect_recall.std():.3f}")

    print(f"Average length true captions: {average_len_true:.3f}")
    print(f"Average length pred captions: {average_len_pred:.3f}\n")

    print(f"Vocabulary true captions: {len(word_count_true)}")
    print(f"Vocabulary pred captions: {len(word_count_pred)}\n")

    print(f"{top_n} most common words true:\n {most_common_true_string}")
    print(f"{top_n} most common words pred:\n {most_common_pred_string}\n")

    f2wc_true, f2wc_pred  = Counter(f2w_true), Counter(f2w_pred)
    print(f"{top_n} most common first two words true:\n {f2wc_true.most_common(top_n)}")
    print(f"{top_n} most common first two words pred:\n {f2wc_pred.most_common(top_n)}\n")

    print(f"{low_quality_true} true captions start with 'the low quality recording'")
    print(f"{low_quality_pred} predicted captions start with 'the low quality recording'\n")

    print(f"Captions with 'male': {total_male}")
    print(f"Captions with 'female': {total_female}")
    print(f"Captions where true was male but predicted female: {female_male_confound}")
    print(f"Captions where true was female but predicted male: {male_female_confound}\n")

    print(f"Best prediction (score sum {max_score:.2f})")
    print(f"TRUE: {data_true[max_id]}")
    print(f"PRED: {data_pred[max_id]}\n")

    print(f"Worst prediction (score sum {min_score:.2f})")
    print(f"TRUE: {data_true[min_id]}")
    print(f"PRED: {data_pred[min_id]}\n")


## Results

In [15]:
data_path = 'outputs/preds_gpt2_notag_chataug.json'

with open(data_path) as f:
    data = json.load(f)

print_stats(data['eval_true_captions'], data['eval_pred_captions'])

550it [00:04, 137.49it/s]
100%|██████████| 25/25 [00:32<00:00,  1.28s/it]


 Pred vs. true stats
 -------------------------------------------------- 

Test GLEU score: 0.0930
Test METEOR score: 0.2209
Shuffled test GLEU score: 0.0773
Shuffled test METEOR score: 0.1898
Test Spec-GLEU score: 0.0157
Test Spec-METEOR score: 0.0311
Aspect precision: 0.154±0.152
Aspect recall: 0.181±0.165
Average length true captions: 287.156
Average length pred captions: 259.371

Vocabulary true captions: 3140
Vocabulary pred captions: 1457

10 most common words true:
 a: 1576, is: 1216, the: 926, and: 882, The: 832, song: 550, This: 536, in: 530, of: 496, with: 422
10 most common words pred:
 and: 1161, a: 1148, The: 927, is: 745, recording: 612, song: 577, the: 536, of: 509, quality: 460, features: 445

10 most common first two words true:
 [('The low', 115), ('This is', 113), ('This song', 46), ('A male', 33), ('The song', 31), ('This music', 24), ('This audio', 22), ('A female', 19), ('Someone is', 14), ('This clip', 10)]
10 most common first two words pred:
 [('The low', 412)




In [16]:
data_path = 'outputs/preds_gpt2_notag_noaug.json'

with open(data_path) as f:
    data = json.load(f)

print_stats(data['eval_true_captions'], data['eval_pred_captions'])

550it [00:03, 139.40it/s]
100%|██████████| 25/25 [00:31<00:00,  1.28s/it]


 Pred vs. true stats
 -------------------------------------------------- 

Test GLEU score: 0.0874
Test METEOR score: 0.2095
Shuffled test GLEU score: 0.0760
Shuffled test METEOR score: 0.1821
Test Spec-GLEU score: 0.0114
Test Spec-METEOR score: 0.0274
Aspect precision: 0.147±0.148
Aspect recall: 0.164±0.155
Average length true captions: 287.156
Average length pred captions: 256.922

Vocabulary true captions: 3140
Vocabulary pred captions: 1911

10 most common words true:
 a: 1576, is: 1216, the: 926, and: 882, The: 832, song: 550, This: 536, in: 530, of: 496, with: 422
10 most common words pred:
 a: 1257, and: 1001, is: 804, The: 770, the: 641, of: 537, recording: 469, song: 391, features: 382, quality: 352

10 most common first two words true:
 [('The low', 115), ('This is', 113), ('This song', 46), ('A male', 33), ('The song', 31), ('This music', 24), ('This audio', 22), ('A female', 19), ('Someone is', 14), ('This clip', 10)]
10 most common first two words pred:
 [('The low', 320)




In [24]:
import re

In [50]:
data_path = 'outputs/preds_gpt2_notag_chataug.json'

with open(data_path) as f:
    data = json.load(f)
    
data['tracks_ids'] = [x[1:] for x in data['tracks_ids']]

json.dump(dict(
    data
), open(data_path, 'w'))

In [42]:
lstm_data.keys()

dict_keys(['predictions', 'true_captions', 'audio_paths'])

In [26]:
data_path = 'outputs/preds_lstm_notag_noaug.json'

with open(data_path) as f:
    lstm_data = json.load(f)

# 340 is missing in lstm and also cleaning for <sos> and <eos>
print_stats(data['eval_true_captions'][:339]+data['eval_true_captions'][340:], 
            [re.sub(r'[^\w\s]','',c[6:-6]).lower() for c in lstm_data['predictions']])

549it [00:03, 144.04it/s]
100%|██████████| 10/10 [00:10<00:00,  1.01s/it]


 Pred vs. true stats
 -------------------------------------------------- 

Test GLEU score: 0.0649
Test METEOR score: 0.1638
Shuffled test GLEU score: 0.0592
Shuffled test METEOR score: 0.1518
Test Spec-GLEU score: 0.0056
Test Spec-METEOR score: 0.0120
Aspect precision: 0.152±0.152
Aspect recall: 0.162±0.145
Average length true captions: 287.148
Average length pred captions: 210.716

Vocabulary true captions: 3139
Vocabulary pred captions: 216

10 most common words true:
 a: 1574, is: 1215, the: 925, and: 879, The: 831, song: 549, This: 536, in: 529, of: 494, with: 422
10 most common words pred:
 the: 1395, and: 1231, is: 1047, a: 1012, song: 671, of: 551, guitar: 417, bass: 397, quality: 392, recording: 386

10 most common first two words true:
 [('The low', 114), ('This is', 113), ('This song', 46), ('A male', 33), ('The song', 31), ('This music', 24), ('This audio', 22), ('A female', 19), ('Someone is', 14), ('This clip', 10)]
10 most common first two words pred:
 [('the low', 284)




In [31]:
true_nounk, pred_nounk = [], []
for true, pred in zip(data['eval_true_captions'][:339]+data['eval_true_captions'][340:], 
            [c[6:-6] for c in lstm_data['predictions']]):
    if not "<unk>" in pred:
        true_nounk.append(re.sub(r'[^\w\s]','',true).lower())
        pred_nounk.append(pred)

In [35]:
len(pred_nounk)

483

In [33]:
print_stats(true_nounk, pred_nounk)

483it [00:03, 146.93it/s]
100%|██████████| 10/10 [00:07<00:00,  1.27it/s]


 Pred vs. true stats
 -------------------------------------------------- 

Test GLEU score: 0.0841
Test METEOR score: 0.1865
Shuffled test GLEU score: 0.0758
Shuffled test METEOR score: 0.1702
Test Spec-GLEU score: 0.0084
Test Spec-METEOR score: 0.0163
Aspect precision: 0.155±0.154
Aspect recall: 0.167±0.148
Average length true captions: 281.199
Average length pred captions: 210.284

Vocabulary true captions: 2083
Vocabulary pred captions: 205

10 most common words true:
 a: 1545, the: 1534, is: 1076, and: 787, song: 576, this: 573, in: 473, of: 429, with: 386, playing: 360
10 most common words pred:
 the: 1123, and: 1122, a: 888, is: 794, song: 612, of: 478, quality: 384, recording: 383, passionate: 374, bass: 368

10 most common first two words true:
 [('the low', 103), ('this is', 97), ('this song', 41), ('a male', 32), ('the song', 28), ('this music', 20), ('a female', 19), ('this audio', 17), ('someone is', 12), ('this clip', 9)]
10 most common first two words pred:
 [('the low',




## ChatAug data

In [7]:
aug_data_path = 'chataug.json'
with open(aug_data_path) as f:
    aug_data = json.load(f)

In [8]:
with open('musiccaps_split.json', 'r') as fp:
    musiccaps_split = json.load(fp)
    
aug_data = {k: v for k, v in aug_data.items() if k in musiccaps_split['train']}
N = len(aug_data)

average_len_aug = 0
word_count_aug = defaultdict(int)
low_quality_aug = 0
f2w_aug = []
top_n = 10

for aug in tqdm(aug_data.values()):
    
    # vocabulary
    spaug = aug.split()
    for w in spaug:
        word_count_aug[w] += 1
        
    f2w_aug.append(" ".join(spaug[:2]))
    
    # average length
    average_len_aug += (1./N)*len(aug)
    
most_common_aug = {k: v for k, v in sorted(word_count_aug.items(), key=lambda item: -item[1])[:top_n]}
most_common_aug_string = ", ".join([f"{key}: {value}" for key, value in most_common_aug.items()])
f2wc_aug = Counter(f2w_aug)

print(f"Average length ChatAug captions: {average_len_aug:.3f}\n")
print(f"Vocabulary ChatAug captions: {len(word_count_aug)}\n")
print(f"{top_n} most common words pred:\n {most_common_aug_string}\n")
print(f"{top_n} most common first two words ChatAug:\n {f2wc_aug.most_common(top_n)}\n")

  0%|          | 0/4417 [00:00<?, ?it/s]


AttributeError: 'list' object has no attribute 'split'