# Preds analysis

In [2]:
!pip install datasets
!pip install evaluate

Collecting evaluate
  Using cached evaluate-0.4.0-py3-none-any.whl (81 kB)
Installing collected packages: evaluate
Successfully installed evaluate-0.4.0


In [21]:
!pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting absl-py
  Using cached absl_py-1.4.0-py3-none-any.whl (126 kB)
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25ldone
[?25h  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24936 sha256=916455fe14c84c8b5b5217b4edf2041c22209e93f15ac491a1a3e5763bb2c049
  Stored in directory: /Users/corinacaraconcea/Library/Caches/pip/wheels/b0/3f/ac/cc3bc304f50c77ef38d79d8e4e2684313de39af543cb4eb3da
Successfully built rouge_score
Installing collected packages: absl-py, rouge_score
Successfully installed absl-py-1.4.0 rouge_score-0.1.2


In [107]:
import json
import evaluate
from collections import defaultdict, Counter
from tqdm import tqdm as tqdm
from musiccaps import load_musiccaps
import string
import numpy as np
import random
import pandas as pd
import re
import string
import itertools

## Helper functions

In [142]:
meteor = evaluate.load('meteor')
google_bleu = evaluate.load('google_bleu')
rouge = evaluate.load('rouge')

ds = load_musiccaps(
    "./music_data",
    sampling_rate=16000,
    limit=None,
    num_proc=8,
    writer_batch_size=1000,
    return_without_audio=True,
)

def clean_text_for_aspect_metrics(caption):
    table = str.maketrans('','', string.punctuation)
    caption.replace("-"," ")
    # split the sentences into words
    desc = caption.split()
    #converts to lower case
    desc = [word.lower() for word in desc]
    #remove punctuation from each token
    desc = [word.translate(table) for word in desc]
    #remove hanging 's and a 
    #desc = [word for word in desc if(len(word)>1)]
    #remove tokens with numbers in them
    #desc = [word for word in desc if(word.isalpha())]
    #convert back to string
    caption = ' '.join(desc)
    return caption

def preprocessing_remove_unk(text_input):
    # remove punctuations
    desc = re.sub(r'[^\w\s]',' ',text_input)
    table = str.maketrans('','',string.punctuation)

    # turn uppercase letters into lowercase ones
    desc = text_input.lower()

    # split into words
    desc = desc.split(' ')
    
    try: 
        # remove <unk> tokens
        desc.remove("<unk>")
    except ValueError:
        desc = desc
    try: 
        # remove <unk> tokens
        desc.remove("unk")
    except ValueError:
        desc = desc
        
    # remove the punctuations
    text_no_punctuation = [word.translate(table) for word in desc]

    # join the caption words
    caption = ' '.join(desc)
    
    return caption

# get a list of music-related words to use for evaluation
aspects = set()
for x in ds:
    aspect_str = x["aspect_list"]
    for t in "[]\"'":
        aspect_str = aspect_str.replace(t, "")
    aspects.update(aspect_str.split(", "))
# clean aspects
aspects = {clean_text_for_aspect_metrics(a) for a in aspects if len(a) > 2}
    
def wrap_in_space(s):
    return ' ' + s + ' '
    
# filter
all_captions = clean_text_for_aspect_metrics(' '.join(ds[i]['caption'] for i in range(len(ds))))
aspect_counts = {a: all_captions.count(wrap_in_space(a)) for a in aspects}
aspects = {a for a in aspects if aspect_counts[a] > 10}
aspects -= {'the'}

def compute_aspects_metric(true, pred):
    true = wrap_in_space(clean_text_for_aspect_metrics(true))
    pred = wrap_in_space(clean_text_for_aspect_metrics(pred))
    
    aspects_in_true = {a for a in aspects if wrap_in_space(a) in true}
    aspects_in_pred = {a for a in aspects if wrap_in_space(a) in pred}
    
    #print(aspects_in_true)
    #print(aspects_in_pred)
    
    precision = len(aspects_in_true&aspects_in_pred)/np.maximum(len(aspects_in_pred),1)
    recall = len(aspects_in_true&aspects_in_pred)/np.maximum(len(aspects_in_true), 1)
    
    return precision, recall

[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/corinacaraconcea/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     /Users/corinacaraconcea/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /Users/corinacaraconcea/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [143]:
def import_outputs(data_path):
    
    with open(data_path) as f:
        data = json.load(f)

    # multiple references for one captions
    true_captions = data['eval_true_captions']
    if isinstance(true_captions, list) and any(isinstance(item, list) for item in true_captions):
        true_captions = list(itertools.chain(*true_captions))
    # single prediction
    predicted_captions = data['eval_pred_captions']
    if isinstance(predicted_captions, list) and any(isinstance(item, list) for item in predicted_captions):
        predicted_captions = list(itertools.chain(*predicted_captions))


    return true_captions,predicted_captions,

In [144]:
def compute_metrics(true_captions,predicted_captions):
    true_clean = []
    pred_clean = []
    for i, (true, pred) in tqdm(enumerate(zip(true_captions, predicted_captions))):
        # preprocess captions and predictions to remove punctuations and <unk> tokens
        true = preprocessing_remove_unk(true)
        true_clean.append(true)
        pred = preprocessing_remove_unk(pred)
        pred_clean.append(pred)
        with open('clean_pred.txt','w') as f:
            for i in range(len(pred_clean)):
                f.write(pred_clean[i]+'\n')
        with open('clean_caption.txt','w') as f:
            for i in range(len(true_clean)):
                f.write(true_clean[i]+'\n')

    total_google_bleu = google_bleu.compute(predictions = pred_clean,references = true_clean)
    total_rouge = rouge.compute(predictions = pred_clean,references = true_clean)
    total_meteor = meteor.compute(predictions = pred_clean,references = true_clean)
    
    return total_google_bleu, total_rouge, total_meteor


In [145]:
def print_metrics(data_paths,methods):
    for i, (data_path, method) in enumerate(zip(data_paths,methods)):
       true_captions,predicted_captions = import_outputs(data_path)
       google_bleu_score, rouge_score, meteor_score = compute_metrics(true_captions,predicted_captions)

       print(method,"test google BLEU score:",str(np.round(google_bleu_score["google_bleu"],3)))
       print(method,"test google ROUGE score:",str(np.round(rouge_score["rouge1"],3)))
       print(method,"test google METEOR score:",str(np.round(meteor_score["meteor"],3)))

In [148]:
data_paths = ["outputs/preds_lstm_attn_summaries.json",
              "outputs/preds_lstm_no_attn_summaries.json",
              "outputs/preds_lstm_attn_no_tag_no_aug.json",
              "outputs/preds_lstm_no_attn_no_tag_no_aug.json",
              "outputs/preds_gpt2_summarized_notag_noaug.json",
              "outputs/preds_gpt2_summarized.json",
              "outputs/preds_gpt2_notag_chataug.json",
              "outputs/preds_gpt2_notag_noaug.json"]

methods = ["LSTM with attention and summarized dataset",
           "LSTM without attention and summarized dataset",
           "LSTM with attention no ChatAug",
           "LSTM no attention no ChatAug",
           "GPT-2 no tag no aug summarized dataset",
           "GPT-2 ?? summarized",
           "GPT-2 no tag ChatAug",
           "GPT-2 no tag no ChatAug"]

In [149]:
print_metrics(data_paths,methods)

1647it [00:00, 1810.24it/s]


LSTM with attention and summarized dataset test google BLEU score: 0.089
LSTM with attention and summarized dataset google ROUGE score: 0.305
LSTM with attention and summarized dataset google METEOR score: 0.205


1647it [00:00, 1895.88it/s]


LSTM without attention and summarized dataset test google BLEU score: 0.092
LSTM without attention and summarized dataset google ROUGE score: 0.312
LSTM without attention and summarized dataset google METEOR score: 0.207


549it [00:00, 3212.08it/s]


LSTM with attention no ChatAug test google BLEU score: 0.083
LSTM with attention no ChatAug google ROUGE score: 0.275
LSTM with attention no ChatAug google METEOR score: 0.183


549it [00:00, 3139.35it/s]


LSTM no attention no ChatAug test google BLEU score: 0.084
LSTM no attention no ChatAug google ROUGE score: 0.276
LSTM no attention no ChatAug google METEOR score: 0.185


550it [00:00, 3473.61it/s]


GPT-2 no tag no aug summarized dataset test google BLEU score: 0.087
GPT-2 no tag no aug summarized dataset google ROUGE score: 0.267
GPT-2 no tag no aug summarized dataset google METEOR score: 0.205


550it [00:00, 3365.30it/s]


GPT-2 ?? summarized test google BLEU score: 0.086
GPT-2 ?? summarized google ROUGE score: 0.267
GPT-2 ?? summarized google METEOR score: 0.202


550it [00:00, 3072.73it/s]


GPT-2 no tag ChatAug test google BLEU score: 0.097
GPT-2 no tag ChatAug google ROUGE score: 0.282
GPT-2 no tag ChatAug google METEOR score: 0.221


550it [00:00, 3019.70it/s]


GPT-2 no tag no ChatAug test google BLEU score: 0.091
GPT-2 no tag no ChatAug google ROUGE score: 0.273
GPT-2 no tag no ChatAug google METEOR score: 0.21
