**Relevant huggingface gpt2 code**

- https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
- https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
- https://github.com/huggingface/transformers/issues/6535
- bos/eos discussion: https://github.com/huggingface/transformers/issues/3311

**Some options for our main model**

- different gpt2 sizes
- gpt2 self-att vs gpt-2 cross-att (image captioning)
- which gpt2 layers to finetune?
- first pretrain on labels, then captions? or at the same time with different prompt/`<bos>` token?
- make b2t output a bunch of 768 dimensional vectors that gpt2 self-att attends to

In [1]:
#!pip install -U yt-dlp==2023.1.6 matplotlib==3.6.0 datasets[audio] rich

# install newest transformers build to be able to pass `inputs_embeds` through generate()
!pip install --upgrade git+https://github.com/huggingface/transformers.git

Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-vxl0pxdr
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-vxl0pxdr
  Resolved https://github.com/huggingface/transformers.git to commit a9bd5df16a46356463f2712dd8f6c109fa83d6f9
  Installing build dependencies ... [?25l\^C
[?25canceled
[31mERROR: Operation cancelled by user[0m[31m
[0m

In [2]:
!pip install torch==1.13

Collecting torch==1.13
^C
[31mERROR: Operation cancelled by user[0m[31m
[0m

In [3]:
from musiccaps import load_musiccaps

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split, Subset

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer

from rich import print as printr
from tqdm.auto import tqdm, trange
import matplotlib.pyplot as plt

import itertools
import math
import json
import random
from collections import defaultdict
from pathlib import Path

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
f = open('logs.txt','a')

In [4]:
torch.__version__

'1.12.1+cu116'

# Load musiccaps

In [5]:
def filter_muscaps_with_embeddings(ds, embeddings):
    """Some clips weren't downloaded so we couldn't embed them, get rid of that"""
    exclude_ids = set()
    for i in range(len(ds)):
        if ds[i]["ytid"] not in embeddings.keys():
            exclude_ids.add(i)
    ds = ds.select((i for i in range(len(ds)) if i not in set(exclude_ids)))
    assert len(ds) == len(embeddings)
    return ds


In [6]:
ls /datasets/beat2tweet/

embeddings.npy     embeddings_75.npy  tag_scores_51.npy  tags_26.json
embeddings_00.npy  embeddings_76.npy  tag_scores_52.npy  tags_27.json
embeddings_01.npy  embeddings_77.npy  tag_scores_53.npy  tags_28.json
embeddings_02.npy  embeddings_78.npy  tag_scores_54.npy  tags_29.json
embeddings_03.npy  embeddings_79.npy  tag_scores_55.npy  tags_30.json
embeddings_04.npy  embeddings_80.npy  tag_scores_56.npy  tags_31.json
embeddings_05.npy  embeddings_81.npy  tag_scores_57.npy  tags_32.json
embeddings_06.npy  embeddings_82.npy  tag_scores_58.npy  tags_33.json
embeddings_07.npy  embeddings_83.npy  tag_scores_59.npy  tags_34.json
embeddings_08.npy  embeddings_84.npy  tag_scores_60.npy  tags_36.json
embeddings_09.npy  embeddings_85.npy  tag_scores_61.npy  tags_37.json
embeddings_10.npy  embeddings_86.npy  tag_scores_62.npy  tags_38.json
embeddings_11.npy  embeddings_87.npy  tag_scores_63.npy  tags_39.json
embeddings_12.npy  embeddings_88.npy  tag_scores_64.npy  tags_40.json
embeddings_13.npy  e

In [7]:
ds = load_musiccaps(
    "./music_data",
    sampling_rate=16000,
    limit=None,
    num_proc=8,
    writer_batch_size=1000,
    return_without_audio=True,
)
embeddings = np.load("embeddings.npy", allow_pickle=True).item()

Using custom data configuration google--MusicCaps-bedc2a0fd7888f2f


Downloading and preparing dataset csv/google--MusicCaps to /root/.cache/huggingface/datasets/google___csv/google--MusicCaps-bedc2a0fd7888f2f/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/2.94M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/google___csv/google--MusicCaps-bedc2a0fd7888f2f/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a. Subsequent calls will reuse this data.


  return pd.read_csv(xopen(filepath_or_buffer, "rb", use_auth_token=use_auth_token), **kwargs)


In [10]:
# get a list of music-related words to use for evaluation
aspects = []
for x in ds:
    aspect_str = x["aspect_list"]
    for t in "[]\"'":
        aspect_str = aspect_str.replace(t, "")
    aspects.extend(aspect_str.split(", "))

from collections import Counter

# only pick aspects that show up somewhat frequently
aspects = {s for s, count in Counter(aspects).most_common() if count >= 25}
len(aspects)

378

In [14]:
class CaptionEmbedding(Dataset):
    """Returns a torch Dataset of paired captions and embeddings"""

    def __init__(self, muscaps_ds, embeddings):
        ds = filter_muscaps_with_embeddings(muscaps_ds, embeddings)
        self.captions = ds.sort("ytid")["caption"]
        sorted_embs = [value for _, value in sorted(embeddings.items())]
        self.embeddings = torch.from_numpy(np.stack(sorted_embs)).to(device)

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        emb = self.embeddings[idx]
        assert len(emb) == 512
        emb = (emb[:256] + emb[256:]) / 2

        return self.captions[idx], emb

# Load Jamendo tag/embedding dataset

In [15]:
JAMENDO_TAGS = np.array(['genre---alternative','genre---ambient','genre---atmospheric','genre---chillout','genre---classical','genre---dance','genre---downtempo','genre---easylistening','genre---electronic','genre---experimental','genre---folk','genre---funk','genre---hiphop','genre---house','genre---indie','genre---instrumentalpop','genre---jazz','genre---lounge','genre---metal','genre---newage','genre---orchestral','genre---pop','genre---popfolk','genre---poprock','genre---reggae','genre---rock','genre---soundtrack','genre---techno','genre---trance','genre---triphop','genre---world','instrument---acousticguitar','instrument---bass','instrument---computer','instrument---drummachine','instrument---drums','instrument---electricguitar','instrument---electricpiano','instrument---guitar','instrument---keyboard','instrument---piano','instrument---strings','instrument---synthesizer','instrument---violin','instrument---voice','mood/theme---emotional','mood/theme---energetic','mood/theme---film','mood/theme---happy','mood/theme---relaxing'])

def get_top_tags(scores, k=3, threshold=.4):
    assert scores.shape == (2, 50)
    scores = (scores[0]+scores[1])/2
    indices = np.where(scores>threshold)[0]
    sorted_indices = indices[np.argsort(-scores[indices])[:k]]
    return JAMENDO_TAGS[sorted_indices]

In [13]:
jam_tags = {}
jam_pred_tags = {}
jam_embeddings = {}
jam_scores = {}

jam_embeddings_dir = Path("jam_embeddings/")

for i in range(100):
    try:
        with open(jam_embeddings_dir / f"tags_{i:02d}.json") as f:
            jam_tags.update(json.load(f))
    except Exception as e:
        print(e)
        continue
    data_dict = np.load(jam_embeddings_dir / f"embeddings_{i:02d}.npy", allow_pickle=True)
    jam_embeddings.update(data_dict.item())
    data_dict = np.load(jam_embeddings_dir / f"tag_scores_{i:02d}.npy", allow_pickle=True)
    jam_scores.update(data_dict.item())

for k, v in jam_scores.items():
    jam_pred_tags[k] = get_top_tags(v, k=3, threshold=.3)

pred_tag_counts = np.array([len(v) for v in jam_pred_tags.values()])
print(f"avg number of pred tags = {(pred_tag_counts).mean()}, fraction of samples with 0 pred tags = {(pred_tag_counts == 0).mean()}")

[Errno 2] No such file or directory: 'jam_embeddings/tags_35.json'
avg number of pred tags = 1.7163339214411737, fraction of samples with 0 pred tags = 0.06175875197325657


In [13]:
pred_tag_counts = np.array([len(v) for v in jam_tags.values()])
print(f'avg number of jamendo tags = {(pred_tag_counts).mean()}, fraction of samples with 0 jamendo tags = {(pred_tag_counts==0).mean()}')

avg number of jamendo tags = 4.659095552047543, fraction of samples with 0 jamendo tags = 0.0


In [14]:
class JamendoTagDataset(Dataset):
    def __init__(self, key_subset, jam_tags, jam_pred_tags, jam_embeddings, shuffle_order):
        
        self.keys = key_subset
        self.jam_tags = jam_tags
        self.jam_pred_tags = jam_pred_tags
        self.jam_embeddings = jam_embeddings
        self.shuffle_order = shuffle_order

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        id = self.keys[idx]
        
        #tags = self.jam_tags[id]
        tags = list(set(self.jam_tags[id] + self.jam_pred_tags[id]))
        categories = defaultdict(set)
        for t in tags:
            assert '---' in t
            categories[t[:t.find('---')]].add(t[t.find('---')+3:])
            
        result = []
        for k in sorted(categories.keys()):
            cat_tags = list(categories[k])
            
            if self.shuffle_order:
                ts = random.sample(cat_tags, len(cat_tags))
            else:
                ts = sorted(cat_tags)
            result.append(k + ': ' + ', '.join(ts))
        tags_cap = '; '.join(result)
        
        emb = self.jam_embeddings[id]
        assert emb.shape == (2, 256)
        emb = (emb[0]+emb[1])/2
        
        return tags_cap, torch.from_numpy(emb).to(device)

# Select dataset

In [15]:
use_jam_tag_dataset = True

if use_jam_tag_dataset:
    train_frac = 0.9
    
    all_keys = sorted(jam_tags.keys())
    keys_with_many_tags = [k for k, v in jam_tags.items() if len(v)>3]
    all_keys = keys_with_many_tags # USE ONLY SONGS WITH MORE THAN 3 TAGS
    
    song_ids = list(set([x.split("_")[0] for x in all_keys]))
    train_song_ids = set(random.sample(song_ids, k=int(train_frac*len(song_ids))))
    
    train_keys = {k for k in all_keys if k.split("_")[0] in train_song_ids}
    test_keys = set(all_keys) - train_keys
    
    training_data = JamendoTagDataset(sorted(train_keys), jam_tags, jam_pred_tags, jam_embeddings, shuffle_order=False)
    test_data = JamendoTagDataset(sorted(test_keys), jam_tags, jam_pred_tags, jam_embeddings, shuffle_order=False)
else:
    with open('musiccaps_split.json', 'r') as fp:
        musiccaps_split = json.load(fp)

    train_ytids, valid_ytids, test_ytids = musiccaps_split['train'], musiccaps_split['valid'], musiccaps_split['test']

    train_ds = ds.filter(lambda x: x['ytid'] in train_ytids)
    valid_ds = ds.filter(lambda x: x['ytid'] in valid_ytids)
    test_ds = ds.filter(lambda x: x['ytid'] in test_ytids)

    train_embeddings = {ytid: e for ytid, e in embeddings.items() if ytid in train_ytids}
    valid_embeddings = {ytid: e for ytid, e in embeddings.items() if ytid in valid_ytids}
    test_embeddings = {ytid: e for ytid, e in embeddings.items() if ytid in test_ytids}

    training_data = CaptionEmbedding(muscaps_ds=train_ds, embeddings=train_embeddings)
    valid_data = CaptionEmbedding(muscaps_ds=valid_ds, embeddings=valid_embeddings)
    test_data = CaptionEmbedding(muscaps_ds=test_ds, embeddings=test_embeddings) 

# Training

### Tokenization

target should be:

`"<bos> <pad> caption <eos> <mask...>"` (first element is dropped in transformer.forward)

input should be:

`"<bos> <music-emb> caption <eos> <pad...>"` (last element is dropped in transformer.forward)

where

- `<bos>` = `<eos>` (for gpt2, see https://github.com/huggingface/transformers/issues/2026)
- `<mask>` is -100 (masked in cross-entropy, see https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)
- `<pad>` is arbitrary (but needs to be valid embedding index)
- `<music-emb>` is the encoded music

to use a `<bos>` token, prepend it in `tokenize()`, set `music_emb_ind = 1`, and update the caption slicing in `eval`

In [16]:
class ResidualLinear(nn.Module):
    
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, dim)
        
    def forward(self, x):
        return x + self.fc(torch.nn.functional.relu(x))
    
def B2T():
    return nn.Sequential(
        nn.Linear(256, 768),
        ResidualLinear(768),
        nn.Dropout(0.6),
        ResidualLinear(768),
        nn.Dropout(0.4),
        ResidualLinear(768),
    )

In [17]:
def tokenize(captions_batch):
    input_ids = tokenizer(captions_batch)["input_ids"]

    # wrap in eos and add placeholder for music embedding/mask
    input_ids = [torch.tensor([placeholder_id] + x + [eos_id]) for x in input_ids]
    # pad with -100, this index is masked in the cross-entropy loss
    input_ids_target = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=mask_id
    ).to(device)

    # index -100 isn't valid as model input however, since the token embedding lookup fails,
    # so we need a second version as model input, with -100 replaced with another token (shouldn't matter which)
    input_ids = input_ids_target.clone()
    input_ids[input_ids == mask_id] = eos_id

    return input_ids, input_ids_target


def transform_input_ids(music_embedding, input_ids, input_ids_target):
    music_emb_ind = 0  # 1 if using <bos>, otherwise 0
    assert (input_ids[:, music_emb_ind] == placeholder_id).all()
    assert (input_ids_target[:, music_emb_ind] == placeholder_id).all()

    input_ids_target[:, music_emb_ind] = eos_id  # mask_id
    input_ids[:, music_emb_ind] = eos_id  # temp placeholder to make the embedding lookup work
    inputs_embeds = model.transformer.wte(input_ids)

    inputs_embeds[:, music_emb_ind] = b2t(music_embedding)  # insert music embedding

    return inputs_embeds, input_ids_target


def strip_eos(pred):
    """ 
    remove eos tokens from predicted captions 
    discards everything after the first <eos> that isn't the very first token
    (the hf can only skip eos but not stop at eos) 
    """
    pred = [p.removeprefix("<|endoftext|>") for p in pred]
    pred = [p[: p.find("<|endoftext|>")] if "<|endoftext|>" in p else p for p in pred]
    return pred


def eval(caption_batch, embedding_batch, rm_eos=False, **kwargs):
    model.eval()
    input_ids, input_ids_target = tokenize(caption_batch)
    inputs_embeds, _ = transform_input_ids(embedding_batch, input_ids, input_ids_target)

    # only include <bos> (optional) and music_embedding, don't include true caption
    inputs_embeds = inputs_embeds[:, :1]
    output_ids = model.generate(inputs_embeds=inputs_embeds, **kwargs)
    pred = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
    pred = [p.replace("\n", "").strip() for p in pred]
    return strip_eos(pred) if rm_eos else pred

In [18]:
def update_step(inputs_embeds, input_ids_target, apply_grad):
    model.train()
    loss = model.forward(inputs_embeds=inputs_embeds, labels=input_ids_target).loss
    loss.backward()
    
    if apply_grad:
        opt.step()
        opt.zero_grad()
    
    return loss.item()

def eval_step(string_info=""):
    
    caption_batch, embedding_batch = next(iter(eval_train_dataloader))
    pred = eval(caption_batch, embedding_batch, **generation_params)
    printr('[green bold]TRAIN TRUE: ' + caption_batch[0])
    printr('[yellow]TRAIN PRED: ' + pred[0])
    wlog('TRAIN TRUE: ' + caption_batch[0])
    wlog('TRAIN PRED: ' + pred[0])
    caption_batch, embedding_batch = next(iter(eval_test_dataloader))
    pred = eval(caption_batch, embedding_batch, **generation_params)
    printr('[green bold]TEST TRUE: ' + caption_batch[0])
    printr('[yellow]TEST PRED: ' + pred[0])
    wlog('TEST TRUE: ' + caption_batch[0])
    wlog('TEST PRED: ' + pred[0] +'\n')
    print()

In [19]:
model_name = 'gpt2' # gpt2, gpt2-medium, gpt2-large, gpt2-xl
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.config.is_decoder = True # not sure if necessary

mask_id = -100 # don't change, this is fixed in torch cross-entropy loss!
eos_id = tokenizer.eos_token_id
placeholder_id = -200
b2t = B2T().cuda()
b2t_lr = 3e-4

opt = torch.optim.AdamW([
    # b2t needs to be the first parameter group
    {'params': b2t.parameters(), 'lr': b2t_lr},
    
    # disable AdamW weight decay for gpt2 layer finetuning
    {'params': model.transformer.h[1].parameters(), 'lr': 0, 'weight_decay': 0},
    {'params': model.transformer.h[2].parameters(), 'lr': 0, 'weight_decay': 0},
])


gpt2_finetune_lr = 1e-5
batch_size = 32
num_epochs = 5
epoch = 0
gradient_acc_fact = 4
gpt2_finetune_start_epoch = 1
load_pretrain = False

generation_params = dict(
    max_new_tokens=64,
    num_beams=4,
    do_sample=True,
    temperature=1.0,
    bos_token_id=eos_id,
    eos_token_id=eos_id,
    pad_token_id=mask_id,
    early_stopping=True,
    rm_eos=True,
)

losses = []
train_dataloader = DataLoader(training_data, batch_size, shuffle=True)
eval_train_dataloader = DataLoader(training_data, 1, shuffle=True)
eval_test_dataloader = DataLoader(test_data, 1, shuffle=True)

checkpoint_dir = Path('checkpoints')
checkpoint_dir.mkdir(exist_ok=True)

model_name_info = "small" if model_name == 'gpt2' else model_name
pretraining_info = "tags" if use_jam_tag_dataset else ("yes" if load_pretrain else "no")
string_info = f"{model_name_info}_{pretraining_info}_{gpt2_finetune_lr}_{gpt2_finetune_start_epoch}_{b2t_lr}"

def wlog(s):
    f = open(f"logs_{string_info}.txt",'a')
    f.write(s+"\n")
    f.close()

string_info

'small_tags_1e-05_1_0.0003'

In [None]:
for epoch in tqdm(range(epoch, num_epochs)):
    wlog(f"\nEpoch {epoch}")

    torch.save({
        "model": model.state_dict(), 
        "b2t": b2t.state_dict(), 
        "opt": opt
    }, checkpoint_dir / f"chkp_{string_info}.pt")
    #}, checkpoint_dir / f"chkp_{epoch}.pt")

    if epoch == gpt2_finetune_start_epoch:
        print("Started finetuning gpt2")
        wlog("Started finetuning gpt2")
        torch.save({
        "model": model.state_dict(), 
        "b2t": b2t.state_dict(), 
        "opt": opt
        }, checkpoint_dir / f"chkp_{string_info}.pt")
        for pg in opt.param_groups[1:]:
            pg["lr"] = gpt2_finetune_lr

    for step, (caption_batch, embedding_batch) in enumerate(tqdm(train_dataloader)):
        # tokenize and prepare inputs for forward
        input_ids, input_ids_target = tokenize(list(caption_batch))
        inputs_embeds, input_ids_target = transform_input_ids(
            embedding_batch, input_ids, input_ids_target
        )

        apply_grad_cond = step % gradient_acc_fact == 0
        losses.append(
            update_step(inputs_embeds, input_ids_target, apply_grad=apply_grad_cond)
        )

        if step % 1000 == 0:
            eval_step(string_info=string_info)
            wlog(f"Loss {np.mean(losses[-1000:])}\n")
            plt.plot(losses)
            plt.savefig(f"plot_{string_info}.png")
            plt.show()

# Generate eval captions

In [None]:
eval_true_captions = []
eval_pred_captions = []

for j in range(1000):
    # load epoch checkpoint
    data_dict = torch.load(f'./checkpoints/chkp_{j}.pt')
    model.load_state_dict(data_dict['model'])
    b2t.load_state_dict(data_dict['b2t'])

    # generate a bunch of captions with this checkpoint
    # for some reason hf generate() breaks atm when using batched captions, idk why
    true, preds = [], []
    for i, (caption_batch, embedding_batch) in enumerate(tqdm(eval_test_dataloader)):
        pred = eval(caption_batch, embedding_batch, **generation_params_hf)
        true.append(caption_batch[0])
        preds.append(pred[0])
        if i >= 300:
            break
        
    eval_true_captions.append(true)
    eval_pred_captions.append(preds)

In [None]:
json.dump(dict(
    eval_true_captions=eval_true_captions,
    eval_pred_captions=eval_pred_captions
), open('preds.json', 'w'))

In [None]:
import evaluate
google_bleu = evaluate.load("google_bleu")

gleu = []
for p, t in zip(eval_pred_captions, eval_true_captions):
    gleu.append(google_bleu.compute(predictions=p, references=t)['google_bleu'])
    
plt.plot(gleu)