In [18]:
# !pip install -U yt-dlp==2023.1.6 matplotlib==3.6.0 datasets[audio] rich

# install newest transformers build to be able to pass `inputs_embeds` through generate()
# !pip install --upgrade git+https://github.com/huggingface/transformers.git

**Relevant huggingface gpt2 code**

- https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
- https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
- https://github.com/huggingface/transformers/issues/6535

# Load musiccaps

In [1]:
from musiccaps import load_musiccaps
import numpy as np
from rich import print as printr
from torch.utils.data import DataLoader, Dataset, random_split
import torch
from tqdm.auto import tqdm, trange
import matplotlib.pyplot as plt
import itertools
import math
from rich import print as printr
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
def filter_muscaps_with_embeddings(ds, embeddings):
    '''Some clips weren't downloaded so we couldn't embed them, get rid of that'''
    exclude_ids = set()
    for i in range(len(ds)):
        if ds[i]['ytid'] not in embeddings.keys():
            exclude_ids.add(i)
    ds = ds.select(
        (
            i for i in range(len(ds)) 
            if i not in set(exclude_ids)
        )
    )
    assert len(ds) == len(embeddings)
    return ds

In [3]:
ds = load_musiccaps(
    './music_data',
    sampling_rate=16000,
    limit=None,
    num_proc=8,
    writer_batch_size=1000,
    return_without_audio=True
)
embeddings = np.load('embeddings.npy', allow_pickle=True).item()

Using custom data configuration google--MusicCaps-7925612b943f961b
Found cached dataset csv (/home/dominik/.cache/huggingface/datasets/google___csv/google--MusicCaps-7925612b943f961b/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


In [4]:
class CaptionEmbedding(Dataset):
    '''Returns a torch Dataset of paired captions and embeddings'''
    def __init__(self, muscaps_ds, embeddings):
        ds = filter_muscaps_with_embeddings(muscaps_ds, embeddings)
        self.captions = ds.sort(column='ytid')['caption']
        sorted_embs = [value for _, value in sorted(embeddings.items())]
        self.embeddings = torch.from_numpy(np.stack(sorted_embs)).to(device)

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        return self.captions[idx], self.embeddings[idx]

In [5]:
dataset = CaptionEmbedding(muscaps_ds=ds, embeddings=embeddings)

# quick check did not mess up ordering of caption-embedding pairs
# for cap, emb in tqdm(dataset):
#     for i in range(len(ds)):
#         if cap == ds[i]['caption']:
#             assert torch.allclose(emb,torch.from_numpy(embeddings[ds[i]['ytid']]).to(device))



In [6]:
# get a list of music-related words to use for evaluation
aspects = []
for x in ds:
    aspect_str = x['aspect_list']
    for t in ('[]"\''):
        aspect_str = aspect_str.replace(t, '')
    aspects.extend(aspect_str.split(', '))
    
from collections import Counter
# only pick aspects that show up somewhat frequently
aspects = {s for s, count in Counter(aspects).most_common() if count >= 25}
len(aspects)

378

# Training

### Tokenization

target should be:

`"<bos> <mask> caption <eos> <mask...>"` (first element is dropped in transformer.forward)

input should be:

`"<bos> <music-emb> caption <eos> <pad...>"`

where

- `<bos>` = `<eos>` (for gpt2, see https://github.com/huggingface/transformers/issues/2026)
- `<mask>` is -100 (masked in cross-entropy, see https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)
- `<pad>` is arbitrary
- `<music-emb>` is the encoded music

In [21]:
class B2T(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.main = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 768),
            nn.ReLU(),
            nn.Linear(768, 768),
        )
        
    def forward(self, x):
        return self.main(x)

In [40]:
model_name = 'gpt2' # gpt2, gpt2-medium, gpt2-large, gpt2-xl
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
b2t = B2T().cuda()

opt = torch.optim.AdamW([
    {'params': b2t.parameters(), 'lr': 0.00025},
    # disable AdamW weight decay for gpt2 layer finetuning!
    {'params': model.transformer.h[0].attn.parameters(), 'lr': 0, 'weight_decay': 0},
])

train_frac = 0.8
batch_size = 32
num_epochs = 50

losses = []
dataset = CaptionEmbedding(muscaps_ds=ds, embeddings=embeddings)
training_data, test_data = random_split(dataset, [train_frac, 1-train_frac])
train_dataloader = DataLoader(training_data, batch_size, shuffle=True)
eval_train_dataloader = DataLoader(training_data, 1, shuffle=True)
eval_test_dataloader = DataLoader(test_data, 1, shuffle=True)

In [41]:
mask_id = -100 # don't change, this is fixed in torch cross-entropy loss!
eos_id = tokenizer.eos_token_id
placeholder_id = -200

def tokenize(captions_batch):
    input_ids = tokenizer(captions_batch)['input_ids']
    
     # wrap in eos and add placeholder for music embedding/mask
    input_ids = [torch.tensor([eos_id, placeholder_id] + x + [eos_id]) for x in input_ids]
    # pad with -100, this index is masked in the cross-entropy loss
    input_ids_target = torch.nn.utils.rnn.pad_sequence(
        input_ids,
        batch_first=True,
        padding_value=mask_id
    ).to(device)
    
    # index -100 isn't valid as model input however, since the token embedding lookup fails
    # so we need a second version as model input, with -100 replaced with another token (shouldn't matter which)
    input_ids = input_ids_target.clone()
    input_ids[input_ids==mask_id] = eos_id
    
    return input_ids, input_ids_target

In [42]:
def transform_input_ids(music_embedding, input_ids, input_ids_target):
    assert (input_ids[:, 1]==placeholder_id).all()
    assert (input_ids_target[:, 1]==placeholder_id).all()

    input_ids_target[:, 1] = mask_id
    input_ids[:, 1] = eos_id # temp placeholder to make the embedding lookup work
    inputs_embeds = model.transformer.wte(input_ids)
    
    music_embedding = b2t(music_embedding)
    inputs_embeds[:, 1] = music_embedding
    
    return inputs_embeds, input_ids_target

In [43]:
@torch.no_grad()
def manual_generate_single(inputs_embeds, max_length, do_sample):
    """ Autoregressively generate max_len tokens based on the embedded prompt. """
    result = []
    log_probs = []

    for i in range(max_length):
        logits = model.forward(inputs_embeds=inputs_embeds).logits[:, -1]

        distr = torch.distributions.Categorical(logits=logits)
        token_inds = distr.sample() if do_sample else logits.argmax(-1)
        log_probs.append(distr.log_prob(token_inds))
        
        result.append(token_inds)

        inputs_embeds = torch.cat([
            inputs_embeds,
            model.transformer.wte(token_inds).unsqueeze(1)
        ], dim=1)
        
    log_probs = torch.stack(log_probs, dim=1)
    ppl = 2**(-(1/len(log_probs))*log_probs.sum(-1))
        
    return torch.stack(result, dim=1), ppl

@torch.no_grad()
def manual_generate(inputs_embeds, iters, max_length, do_sample):
    """ Repeatedly generate samples using manual_generate_single and return the ones with the highest perplexity. """
    preds = []
    ppls = []
    
    for i in range(iters):
        pred, ppl = manual_generate_single(inputs_embeds, max_length, do_sample)
        preds.append(pred)
        ppls.append(ppl)
    
    preds = torch.stack(preds)
    ppls = torch.stack(ppls)
    
    max_ppl_inds = ppls.argmax(0)
    best_preds = preds[max_ppl_inds, np.arange(preds.shape[1])]
    
    return best_preds, ppls.max(0)[0]

In [44]:
def eval(caption_batch, embedding_batch, use_manual_generation=False, **kwargs):
    model.eval()
    input_ids, input_ids_target = tokenize(caption_batch)
    inputs_embeds, input_ids_target = transform_input_ids(
        embedding_batch,
        input_ids,
        input_ids_target
    )
    
    if use_manual_generation:
        output_ids, ppl = manual_generate(inputs_embeds, **kwargs)
    else:
        output_ids = model.generate(inputs_embeds=inputs_embeds, **kwargs)
    pred = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
    
    pred = [p.replace('\n', '').strip() for p in pred]
    return pred

In [45]:
generation_params_hf = dict(
    max_length=48,
    num_beams=4,
    do_sample=True,
    temperature=0.95,
    pad_token_id=tokenizer.eos_token_id,
    use_manual_generation=False
)

generation_params_ours = dict(
    max_length=48,
    iters=1,
    do_sample=True,
    use_manual_generation=True
)

In [None]:
for epoch in tqdm(range(num_epochs)):
    
    if epoch > 1:
        opt.param_groups[1]['lr'] = 5e-5
    
    for step, (caption_batch, embedding_batch) in enumerate(tqdm(train_dataloader)):
        # tokenize and prepare inputs for forward
        input_ids, input_ids_target = tokenize(caption_batch)
        inputs_embeds, input_ids_target = transform_input_ids(
            embedding_batch,
            input_ids,
            input_ids_target
        )
        
        model.train()
        loss = model.forward(inputs_embeds=inputs_embeds, labels=input_ids_target).loss
        loss.backward()
        opt.step()
        losses.append(loss.item())
        
        if step % 100 == 0:
            
            caption_batch, embedding_batch = next(iter(eval_train_dataloader))
            pred = eval(caption_batch, embedding_batch, **generation_params_hf)
            printr('[green bold]TRAIN TRUE: ' + caption_batch[0])
            printr('[blue]TRAIN PRED-A: ' + pred[0])
            pred = eval(caption_batch, embedding_batch, **generation_params_ours)
            printr('[blue]TRAIN PRED-B: ' + pred[0])
            
            caption_batch, embedding_batch = next(iter(eval_test_dataloader))
            pred = eval(caption_batch, embedding_batch, **generation_params_hf)
            printr('[green bold]TEST TRUE: ' + caption_batch[0])
            printr('[blue]TEST PRED-A: ' + pred[0])
            pred = eval(caption_batch, embedding_batch, **generation_params_ours)
            printr('[blue]TEST PRED-B: ' + pred[0])
            
            print('\n')

        if step % 200 == 199:
            plt.plot(losses)
            plt.show()

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/138 [00:00<?, ?it/s]





In [None]:
caption_batch, embedding_batch = next(iter(eval_test_dataloader))
pred = eval(caption_batch, embedding_batch, use_manual_generation=True)
printr('[green bold]TEST TRUE: ' + caption_batch[0])
printr('[blue]TEST PRED-A: ' + pred[0])