In [1]:
%%capture
#!pip install -U yt-dlp==2023.1.6 matplotlib==3.6.0 datasets[audio] rich
!pip install --upgrade git+https://github.com/huggingface/transformers.git
!pip install torch==1.13
!pip install evaluate

In [2]:
from musiccaps import load_musiccaps

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split, Subset

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer

from rich import print as printr
from tqdm.auto import tqdm, trange
import matplotlib.pyplot as plt

import itertools
import math
import json
import random
from collections import defaultdict
from pathlib import Path
import evaluate
import itertools

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
f = open('logs.txt','a')

# Load musiccaps

In [3]:
def filter_muscaps_with_embeddings(ds, embeddings):
    """Some clips weren't downloaded so we couldn't embed them, get rid of that"""
    exclude_ids = set()
    for i in range(len(ds)):
        if ds[i]["ytid"] not in embeddings.keys():
            exclude_ids.add(i)
    ds = ds.select((i for i in range(len(ds)) if i not in set(exclude_ids)))
    assert len(ds) == len(embeddings)
    return ds

ds = load_musiccaps(
    "./music_data",
    sampling_rate=16000,
    limit=None,
    num_proc=8,
    writer_batch_size=1000,
    return_without_audio=True,
)
embeddings = np.load("embeddings.npy", allow_pickle=True).item()

Using custom data configuration google--MusicCaps-bedc2a0fd7888f2f
Reusing dataset csv (/root/.cache/huggingface/datasets/google___csv/google--MusicCaps-bedc2a0fd7888f2f/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a)


In [4]:
# get a list of music-related words to use for evaluation
aspects = []
for x in ds:
    aspect_str = x["aspect_list"]
    for t in "[]\"'":
        aspect_str = aspect_str.replace(t, "")
    aspects.extend(aspect_str.split(", "))

from collections import Counter

# only pick aspects that show up somewhat frequently
aspects = {s for s, count in Counter(aspects).most_common() if count >= 25}
len(aspects)

378

In [5]:
class CaptionEmbedding(Dataset):
    """Returns a torch Dataset of paired captions and embeddings"""

    def __init__(self, muscaps_ds, embeddings):
        ds = filter_muscaps_with_embeddings(muscaps_ds, embeddings)
        self.captions = ds.sort("ytid")["caption"]
        sorted_embs = [value for _, value in sorted(embeddings.items())]
        self.embeddings = torch.from_numpy(np.stack(sorted_embs)).to(device)

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        emb = self.embeddings[idx]
        assert len(emb) == 512
        emb = (emb[:256] + emb[256:]) / 2

        return self.captions[idx], emb

In [6]:
with open('musiccaps_split.json', 'r') as fp:
    musiccaps_split = json.load(fp)

train_ytids, valid_ytids, test_ytids = musiccaps_split['train'], musiccaps_split['valid'], musiccaps_split['test']

train_ds = ds.filter(lambda x: x['ytid'] in train_ytids)
valid_ds = ds.filter(lambda x: x['ytid'] in valid_ytids)
test_ds = ds.filter(lambda x: x['ytid'] in test_ytids)

train_embeddings = {ytid: e for ytid, e in embeddings.items() if ytid in train_ytids}
valid_embeddings = {ytid: e for ytid, e in embeddings.items() if ytid in valid_ytids}
test_embeddings = {ytid: e for ytid, e in embeddings.items() if ytid in test_ytids}

training_data = CaptionEmbedding(muscaps_ds=train_ds, embeddings=train_embeddings)
valid_data = CaptionEmbedding(muscaps_ds=valid_ds, embeddings=valid_embeddings)
test_data = CaptionEmbedding(muscaps_ds=test_ds, embeddings=test_embeddings) 
    

Loading cached processed dataset at /root/.cache/huggingface/datasets/google___csv/google--MusicCaps-bedc2a0fd7888f2f/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-65e36bf881a776b6.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/google___csv/google--MusicCaps-bedc2a0fd7888f2f/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-b43b658bbefe199e.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/google___csv/google--MusicCaps-bedc2a0fd7888f2f/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-f4182c0212b9d53a.arrow


# Define model

In [7]:
class ResidualLinear(nn.Module):
    
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, dim)
        
    def forward(self, x):
        return x + self.fc(torch.nn.functional.relu(x))
    
def B2T():
    return nn.Sequential(
        nn.Linear(256, 768),
        ResidualLinear(768),
        nn.Dropout(0.6),
        ResidualLinear(768),
        nn.Dropout(0.4),
        ResidualLinear(768),
    )

In [8]:
def tokenize(captions_batch):
    input_ids = tokenizer(captions_batch)["input_ids"]

    # wrap in eos and add placeholder for music embedding/mask
    input_ids = [torch.tensor([placeholder_id] + x + [eos_id]) for x in input_ids]
    # pad with -100, this index is masked in the cross-entropy loss
    input_ids_target = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=mask_id
    ).to(device)

    # index -100 isn't valid as model input however, since the token embedding lookup fails,
    # so we need a second version as model input, with -100 replaced with another token (shouldn't matter which)
    input_ids = input_ids_target.clone()
    input_ids[input_ids == mask_id] = eos_id

    return input_ids, input_ids_target


def transform_input_ids(music_embedding, input_ids, input_ids_target):
    music_emb_ind = 0  # 1 if using <bos>, otherwise 0
    assert (input_ids[:, music_emb_ind] == placeholder_id).all()
    assert (input_ids_target[:, music_emb_ind] == placeholder_id).all()

    input_ids_target[:, music_emb_ind] = eos_id  # mask_id
    input_ids[:, music_emb_ind] = eos_id  # temp placeholder to make the embedding lookup work
    inputs_embeds = model.transformer.wte(input_ids)

    inputs_embeds[:, music_emb_ind] = b2t(music_embedding)  # insert music embedding

    return inputs_embeds, input_ids_target


def strip_eos(pred):
    """ 
    remove eos tokens from predicted captions 
    discards everything after the first <eos> that isn't the very first token
    (the hf can only skip eos but not stop at eos) 
    """
    pred = [p.removeprefix("<|endoftext|>") for p in pred]
    pred = [p[: p.find("<|endoftext|>")] if "<|endoftext|>" in p else p for p in pred]
    return pred


def eval(caption_batch, embedding_batch, rm_eos=False, **kwargs):
    model.eval()
    input_ids, input_ids_target = tokenize(caption_batch)
    inputs_embeds, _ = transform_input_ids(embedding_batch, input_ids, input_ids_target)

    # only include <bos> (optional) and music_embedding, don't include true caption
    inputs_embeds = inputs_embeds[:, :1]
    output_ids = model.generate(inputs_embeds=inputs_embeds, **kwargs)
    pred = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
    pred = [p.replace("\n", "").strip() for p in pred]
    return strip_eos(pred) if rm_eos else pred

In [9]:
model_name = 'gpt2' # gpt2, gpt2-medium, gpt2-large, gpt2-xl
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.config.is_decoder = True # not sure if necessary

mask_id = -100 # don't change, this is fixed in torch cross-entropy loss!
eos_id = tokenizer.eos_token_id
placeholder_id = -200
b2t = B2T().cuda()

# Generate eval captions

In [10]:
MODEL_PATH = "saved-models/best_notag_noaug.pt"

In [11]:
eval_train_dataloader = DataLoader(training_data, 1, shuffle=False)
eval_valid_dataloader = DataLoader(valid_data, 1, shuffle=False)
eval_test_dataloader = DataLoader(test_data, 1, shuffle=False)

generation_params = dict(
    max_new_tokens=64,
    num_beams=4,
    do_sample=True,
    temperature=0.9,
    bos_token_id=eos_id,
    eos_token_id=eos_id,
    pad_token_id=mask_id,
    early_stopping=True,
    rm_eos=True,
)

meteor = evaluate.load('meteor')
google_bleu = evaluate.load("google_bleu")

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [20]:
for caption in test_ds['caption'][:10]:
    print("\n", caption)


 This song features an electric guitar as the main instrument. The guitar plays a descending run in the beginning then plays an arpeggiated chord followed by a double stop hammer on to a higher note and a descending slide followed by a descending chord run. The percussion plays a simple beat using rim shots. The percussion plays in common time. The bass plays only one note on the first count of each bar. The piano plays backing chords. There are no voices in this song. The mood of this song is relaxing. This song can be played in a coffee shop.

 low fidelity audio from a live performance featuring a solo direct input acoustic guitar strumming airy, suspended open chords. Also present are occasional ambient sounds, perhaps papers being shuffled.

 This middle eastern folk song features a male voice. This is accompanied by a string instrument called the oud playing the melody in between lines. A variety of middle-eastern percussion instruments are played in the background. A tambourine

In [None]:
eval_true_captions, eval_pred_captions = [], []

data_dict = torch.load(MODEL_PATH)
model.load_state_dict(data_dict['model'])
b2t.load_state_dict(data_dict['b2t'])

for i, (caption_batch, embedding_batch) in enumerate(tqdm(eval_test_dataloader)):
    pred = eval(list(caption_batch), embedding_batch, **generation_params)
    eval_true_captions.append(caption_batch[0])
    eval_pred_captions.append(pred[0])
    if i < 10: print("\n", pred)

  0%|          | 0/550 [00:00<?, ?it/s]


 ['This is an instrumental rock music piece. The piece is being performed on a clean sounding electric guitar. There is an easygoing guitar solo being played. The atmosphere is groovy. This piece could be used as an advertisement jingle.']

 ['This song features an electric guitar playing arpeggiated chords. There are open strings being played in the arpeggios. A chorus effect is added to the guitar sound. At the end of the song, one guitar chord is played. There are no voices in this song. This is an instrumental song. There']

 ['The low quality recording features a traditional song that consists of passionate female vocal singing over wooden percussive elements, strings and flute melody. It sounds passionate, emotional and soulful and the recording is noisy.']

 ['The low quality recording features a lullaby sung by passionate male vocalists over an arpeggiated acoustic guitar melody. It sounds mellow, soft, sad and so emotional that you forget about how noisy the recording actuall

In [None]:
json.dump(dict(
    eval_true_captions=eval_true_captions,
    eval_pred_captions=eval_pred_captions
), open('outputs/preds.json', 'w'))

In [None]:
gleu_score = google_bleu.compute(predictions=eval_pred_captions, references=eval_true_captions)['google_bleu']
meteor_score = meteor.compute(predictions=eval_pred_captions, references=eval_true_captions)['meteor']

gleu_score, meteor_score

In [None]:
for _ in range(50):
    eval_true_shuffled = sorted(eval_true_captions, key=lambda k: random.random())

    shuffled_gleu_score = google_bleu.compute(predictions=eval_pred_captions, references=eval_true_shuffled)['google_bleu']
    shuffled_meteor_score = meteor.compute(predictions=eval_pred_captions, references=eval_true_shuffled)['meteor']

    shuffled_gleu_score-gleu_score, shuffled_meteor_score-meteor_score