In [18]:
# !pip install -U yt-dlp==2023.1.6 matplotlib==3.6.0 datasets[audio] rich

# install newest transformers build to be able to pass `inputs_embeds` through generate()
# !pip install --upgrade git+https://github.com/huggingface/transformers.git

**Relevant huggingface gpt2 code**

- https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
- https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
- https://github.com/huggingface/transformers/issues/6535

# Load musiccaps

In [1]:
from musiccaps import load_musiccaps
import numpy as np
from rich import print as printr
from torch.utils.data import DataLoader, Dataset, random_split
import torch
from tqdm.auto import tqdm, trange
import matplotlib.pyplot as plt
import itertools
import math
from rich import print as printr
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
def filter_muscaps_with_embeddings(ds, embeddings):
    '''Some clips weren't downloaded so we couldn't embed them, get rid of that'''
    exclude_ids = set()
    for i in range(len(ds)):
        if ds[i]['ytid'] not in embeddings.keys():
            exclude_ids.add(i)
    ds = ds.select(
        (
            i for i in range(len(ds)) 
            if i not in set(exclude_ids)
        )
    )
    assert len(ds) == len(embeddings)
    return ds

In [3]:
ds = load_musiccaps(
    './music_data',
    sampling_rate=16000,
    limit=None,
    num_proc=8,
    writer_batch_size=1000,
    return_without_audio=True
)
embeddings = np.load('embeddings.npy', allow_pickle=True).item()

Using custom data configuration google--MusicCaps-7925612b943f961b
Found cached dataset csv (/home/dominik/.cache/huggingface/datasets/google___csv/google--MusicCaps-7925612b943f961b/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


In [4]:
class CaptionEmbedding(Dataset):
    '''Returns a torch Dataset of paired captions and embeddings'''
    def __init__(self, muscaps_ds, embeddings):
        ds = filter_muscaps_with_embeddings(muscaps_ds, embeddings)
        self.captions = ds.sort(column='ytid')['caption']
        sorted_embs = [value for _, value in sorted(embeddings.items())]
        self.embeddings = torch.from_numpy(np.stack(sorted_embs)).to(device)

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        return self.captions[idx], self.embeddings[idx]

In [5]:
dataset = CaptionEmbedding(muscaps_ds=ds, embeddings=embeddings)

# quick check did not mess up ordering of caption-embedding pairs
# for cap, emb in tqdm(dataset):
#     for i in range(len(ds)):
#         if cap == ds[i]['caption']:
#             assert torch.allclose(emb,torch.from_numpy(embeddings[ds[i]['ytid']]).to(device))



In [6]:
aspects = []
for x in ds:
    aspect_str = x['aspect_list']
    for t in ('[]"\''):
        aspect_str = aspect_str.replace(t, '')
    aspects.extend(aspect_str.split(', '))
    
from collections import Counter
aspects = {s for s, count in Counter(aspects).most_common() if count >= 20}
len(aspects)

471

# Training

### Tokenization

target should be:

`"<bos> <mask> caption <eos> <mask...>"` (first element is dropped in transformer.forward)

input should be:

`"<bos> <music-emb> caption <eos> <pad...>"`

where

- `<bos>` = `<eos>` (for gpt2)
- `<mask>` is -100
- `<pad>` is arbitrary
- `<music-emb>` is the encoded music

In [19]:
def tokenize(captions_batch):
    input_ids = tokenizer(captions_batch)['input_ids']
    eos = tokenizer.eos_token_id
     # wrap in eos token (see https://github.com/huggingface/transformers/issues/2026)
    input_ids = [torch.tensor([eos] + x + [eos]) for x in input_ids]
    # pad with -100, this index is masked in the cross-entropy loss
    # (see https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)
    input_ids_target = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=-100).to(device)
    
    # index -100 isn't valid as model input however, since the token embedding lookup fails
    # so we need a second version as input with -100 replaced with another token (shouldn't matter which)
    input_ids = input_ids_target.clone()
    input_ids[input_ids==-100] = eos
    
    # the model input will be prefixed with the music embedding, 
    # so we need to prefix the target too with some token to get the shapes to match
    # (maybe masked -100 is better)
    input_ids_target = torch.cat([
        torch.full((len(input_ids_target), 1), fill_value=-100, device=device),
        input_ids_target,
    ], dim=1)
    
    return input_ids, input_ids_target

In [20]:
class B2T(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.main = nn.Sequential(
            nn.Linear(512, 768),
            nn.ReLU(),
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, 768),
        )
        
    def forward(self, x):
        return self.main(x)

In [21]:
model_name = 'gpt2' # gpt2, gpt2-medium, gpt2-large, gpt2-xl
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
b2t = B2T().cuda()

opt = torch.optim.AdamW([
    {'params': b2t.parameters(), 'lr': 0.0005},
])

generation_params = dict(
    max_length=32,
    num_beams=4,
    do_sample=True,
    temperature=0.95,
    pad_token_id=tokenizer.eos_token_id,
)

losses = []

dataset = CaptionEmbedding(muscaps_ds=ds, embeddings=embeddings)

train_frac = 0.8
training_data, test_data = random_split(dataset, [train_frac, 1-train_frac])
batch_size = 8
num_epochs = 50
train_dataloader = DataLoader(training_data, batch_size, shuffle=True)
eval_train_dataloader = DataLoader(training_data, 1, shuffle=True)
eval_test_dataloader = DataLoader(test_data, 1, shuffle=True)

In [84]:
@torch.no_grad()
def manual_generate_single(inputs_embeds, max_len, sample):
    result = []
    log_probs = []

    for i in range(max_len):
        logits = model.forward(inputs_embeds=inputs_embeds).logits[:, -1]

        distr = torch.distributions.Categorical(logits=logits)
        token_inds = distr.sample() if sample else logits.argmax(-1)
        log_probs.append(distr.log_prob(token_inds))
        
        result.append(token_inds)

        inputs_embeds = torch.cat([
            inputs_embeds,
            model.transformer.wte(token_inds).unsqueeze(1)
        ], dim=1)
        
    log_probs = torch.stack(log_probs, dim=1)
    ppl = 2**(-(1/len(log_probs))*log_probs.sum(-1))
        
    return torch.stack(result, dim=1), ppl

@torch.no_grad()
def manual_generate(iters, *args, **kwargs):
    preds = []
    ppls = []
    
    for i in range(iters):
        pred, ppl = manual_generate_single(*args, **kwargs)
        preds.append(pred)
        ppls.append(ppl)
    
    preds = torch.stack(preds)
    ppls = torch.stack(ppls)
    
    max_ppl_inds = ppls.argmax(0)
    best_preds = preds[max_ppl_inds, np.arange(preds.shape[1])]
    
    return ppls.max(0), best_preds

In [28]:
def eval(data_loader, use_manual_generation=False):
    model.eval()
    caption_batch, embedding_batch = next(iter(eval_train_dataloader))
    input_ids, input_ids_target = tokenize(caption_batch)
    
    inputs_embeds = torch.cat([
        b2t(embedding_batch).unsqueeze(1),
        model.transformer.wte(input_ids),
    ], dim=1)
    
    if use_manual_generation:
        output_ids, ppl = manual_generate(1, inputs_embeds, 64, True)
    else:
        output_ids = model.generate(inputs_embeds=inputs_embeds, **generation_params)
    pred = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
    printr('[blue bold]PRED: ' + pred.replace('\n', ' '))
    printr('[green bold]TRUE: ' + caption_batch[0])

In [29]:
for epoch in tqdm(range(num_epochs)):
    for step, (caption_batch, embedding_batch) in enumerate(tqdm(train_dataloader)):
        input_ids, input_ids_target = tokenize(caption_batch)
        
        inputs_embeds = torch.cat([
            b2t(embedding_batch).unsqueeze(1),
            model.transformer.wte(input_ids),
        ], dim=1)
        
        model.train()
        loss = model.forward(inputs_embeds=inputs_embeds, labels=input_ids_target).loss
        loss.backward()
        opt.step()
        losses.append(loss.item())
        
        if step % 100 == 0:
            eval(eval_train_dataloader, use_manual_generation=True)
            eval(eval_train_dataloader)
            eval(eval_test_dataloader, use_manual_generation=True)
            eval(eval_test_dataloader)
            print('\n')

        if step % 200 == 199:
            plt.plot(losses)
            plt.show()

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/550 [00:00<?, ?it/s]

MarkupError: closing tag '[/note]' at position 313 doesn't match any open tag

---

# random experiments

In [29]:
seq = torch.tensor(tokenizer.encode('List of 10 cute animals: dog, cat, '))
inputs_embeds = model.transformer.wte(seq.cuda()).unsqueeze(0)
#inputs_embeds = torch.cat([weird_word_embedding, inputs_embeds], dim=1)

In [9]:
encoded = tokenizer.batch_encode_plus(['List of 10 cute animals: dog, cat, '], return_tensors='pt')
input_ids = encoded['input_ids'].cuda()
attention_mask = encoded['attention_mask'].cuda()
inputs_embeds = model.transformer.wte(input_ids)

In [42]:
# if you don't pass input_ids here, the output is the same, minus the prompt
outputs = model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_params)
tokenizer.batch_decode(outputs, skip_special_tokens=True)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['ursine, horse, elephant, goat, fish, sheep, pig, chicken, piglet, mouse, rat, rat-tat-tat']

In [34]:
outputs = model.generate(input_ids, **generation_params)
tokenizer.batch_decode(outputs, skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['List of 10 cute animals: dog, cat, urchin, fish, frog, rabbit, snake, squid, and bird.\n\n1. Dog']