In [None]:
# !pip install -U yt-dlp==2023.1.6 matplotlib==3.6.0 datasets[audio]
# !pip install transformers
# !pip install rich

In [1]:
from musiccaps import load_musiccaps
import numpy as np
from rich import print as printr
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import torch
from torch.utils.data import DataLoader, Dataset, random_split
import math
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn as nn
from tqdm.auto import tqdm
import itertools

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

We filter the dataset for the ones we have embeddings for. These are the harmonic CNN embeddings from feature_extraction.ipynb

In [2]:
def filter_muscaps_with_embeddings(ds, embeddings):
    '''Some clips weren't downloaded so we couldn't embed them, get rid of that'''
    exclude_ids = set()
    for i in range(len(ds)):
        if ds[i]['ytid'] not in embeddings.keys():
            exclude_ids.add(i)
    ds = ds.select(
        (
            i for i in range(len(ds)) 
            if i not in set(exclude_ids)
        )
    )
    assert len(ds) == len(embeddings)
    return ds

In [3]:
ds = load_musiccaps(
    './music_data',
    sampling_rate=16000,
    limit=None,
    num_proc=8,
    writer_batch_size=1000,
    return_without_audio=True
)
embeddings = np.load('embeddings.npy', allow_pickle=True).item()

Using custom data configuration google--MusicCaps-7925612b943f961b
Found cached dataset csv (/Users/alexandrasouly/.cache/huggingface/datasets/google___csv/google--MusicCaps-7925612b943f961b/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


We create a pytorch Dataset that yields captions and embeddings. This will make it easier to switch embeddings when we want to, create train-test splits and batch with Dataloaders

In [4]:
class CaptionEmbedding(Dataset):
    '''Returns a torch Dataset of paired captions and embeddings'''
    def __init__(self, muscaps_ds, embeddings):
        ds = filter_muscaps_with_embeddings(muscaps_ds, embeddings)
        self.captions = ds.sort(column='ytid')['caption']
        sorted_embs = [ value for _, value in sorted(embeddings.items())]
        self.embeddings = torch.from_numpy(np.stack(sorted_embs)).to(device)

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        return self.captions[idx], self.embeddings[idx]



In [5]:
class B2T(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.main = nn.Sequential(
            nn.Linear(512, 768),
            nn.ReLU(),
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, 768)
        )
        
    def forward(self, x):
        return self.main(x)

In [6]:
dataset = CaptionEmbedding(muscaps_ds=ds, embeddings=embeddings)
# quick check did not mess up ordering of caption-embedding pairs
# for cap, emb in dataset:
#     for i in range(len(ds)):
#         if cap == ds[i]['caption']:
#             assert torch.allclose(emb,torch.from_numpy(embeddings[ds[i]['ytid']]))
        

train_size = math.floor(0.8*len(dataset))
test_size = len(dataset) - train_size
training_data, test_data = random_split(dataset, [train_size, test_size])
batch_size = 3
train_dataloader = DataLoader(training_data, batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size, shuffle=True)



In [7]:
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
b2t = B2T()

b2t.to(device)
model.to(device)


# either update the b2t layers only, or the whole model depending on which opt you
# uncomment

# opt = torch.optim.Adam([*b2t.parameters()], lr=0.0001) # , *model.decoder.parameters()]
opt = torch.optim.Adam([
    {'params': b2t.parameters(), 'lr': 0.0001},
    {'params': model.parameters(), 'lr': 0.000001}
])

losses = []
fake_pixel_values = torch.zeros((batch_size, 3, 224, 224)).to(device)



In [None]:
encoder_forward = model.encoder.forward

In [8]:

def patched_forward(*args, **kwargs):
    result = encoder_forward(*args, **kwargs) # this is just to appease the HuggingFace gods
    result.last_hidden_state = b2t(EMBS).repeat(1, 197, 1) # overwrite with actual embedding we use
    return result

# the original model uses a vision transformer in the encoder forward, so we get rid of that 
# and use the embeddings we have for the music

model.encoder.forward = patched_forward

In [10]:
num_epochs = 5
for epoch in tqdm(range(num_epochs)):
    for step in tqdm(range(len(train_dataloader))):


        captions, EMBS = next(iter(train_dataloader)) # patched forward is using this EMBS
        captions_tok = tokenizer(captions, padding='longest', return_tensors='pt')['input_ids'].to(device)
        loss = model(fake_pixel_values, labels=captions_tok).loss
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.item())
        
        if step % 20 == 0:
            EMBS = EMBS[0:1]
            fake_eval_pixel_values = torch.zeros((1,3, 224, 224)).to(device)
            output_ids = model.generate(fake_eval_pixel_values, max_length=128, num_beams=2)
            printr('[blue bold] PREDICTION1: ' + tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip())
            output_ids = model.generate(fake_eval_pixel_values, max_length=128, num_beams=4, do_sample=True, temperature=0.8)
            printr('[blue bold] PREDICTION2: ' + tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip())
            printr('[green bold] TRUE CAPTION: ' + captions[0])
            print()
        

        if step % 200 == 199:
            plt.plot(losses)
            plt.show()

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/1465 [00:00<?, ?it/s]










KeyboardInterrupt: 