In [None]:
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.9.0-py3-none-any.whl (462 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m462.8/462.8 KB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting huggingface-hub<1.0.0,>=0.2.0
  Downloading huggingface_hub-0.12.1-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.3/190.3 KB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess
  Downloading multiprocess-0.70.14-py38-none-any.whl (132 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.0/132.0 KB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash
  Downloading xxhash-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (213 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m213.0/213.0 KB[0m

In [None]:
import numpy as np
import string
import os
import torch
import torchvision
import pandas as pd
from datasets import load_dataset
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import StepLR

import warnings
# UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() 
# or sourceTensor.clone().detach().requires_grad_(True) rather than torch.tensor(sourceTensor)
warnings.simplefilter("ignore", category=UserWarning)

## Data loading and cleanup

In [None]:
# Load embeddings, which is a dictionary (ytid -> 512-dim feature vector)
try:
    embeddings = np.load("embeddings.npy", allow_pickle=True).item()
    ids = sorted(list(embeddings.keys()))
    print(f"Loaded {len(ids)} feature vectors.")
except:
    print("Error loading embeddings")

# Load MusicCaps captions
try:
  ds = load_dataset('google/MusicCaps', split='train')
  df = pd.DataFrame({'ytid':ds['ytid'], 'caption':ds['caption'], 'is_eval': ds['is_audioset_eval']})
  # discard entries without embedding
  df_captions = df[df.ytid.isin(ids)] 
  print(f"Loaded {df.shape[0]} captions, using {df_captions.shape[0]} ")
except:
    print("Error loading captions")

def cleaning_text(caption):
    table = str.maketrans('','',string.punctuation)
    caption.replace("-"," ")
    # split the sentences into words
    desc = caption.split()
    #converts to lower case
    desc = [word.lower() for word in desc]
    #remove punctuation from each token
    desc = [word.translate(table) for word in desc]
    #remove hanging 's and a 
    desc = [word for word in desc if(len(word)>1)]
    #remove tokens with numbers in them
    desc = [word for word in desc if(word.isalpha())]
    #convert back to string
    caption = ' '.join(desc)

    return caption

def preprocess_captions(captions):
    for audio_file,caption in captions.items():
        caption = cleaning_text(caption)
        captions[audio_file]= caption
            
    return captions

captions = dict(zip(df.ytid,df.caption))
captions = preprocess_captions(captions)

# define a vocabulary
def text_vocabulary(descriptions):
    captions = list(descriptions.values())
    vocab = set(['<start>', '<end>', ':', ',', ';'])
    for caption in captions:
        for token in caption.strip().split():
            vocab.add(token)
    return vocab

# force <pad> to have idx 0 (convention)
vocab = ['<pad>'] + list(text_vocabulary(captions))

word_to_idx = {word: idx for idx, word in enumerate(vocab)}
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

def tokenize(caption):
    caption = cleaning_text(caption)
    token_list = []
    # Add <start> to the beginning and <end> to the end of each caption
    caption_list = ["<start>"] + caption.split() + ["<end>"]
    token_list = [word_to_idx[word] for word in caption_list]
    return token_list 

df_captions['tokenized_caption'] = df_captions['caption'].apply(tokenize)
train_df = df_captions[~df_captions.is_eval]
eval_df = df_captions[df_captions.is_eval]
df_captions.head()

Loaded 5493 feature vectors.




Loaded 5521 captions, using 5493 


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_captions['tokenized_caption'] = df_captions['caption'].apply(tokenize)


Unnamed: 0,ytid,caption,is_eval,tokenized_caption
0,-0Gj8-vB1q4,The low quality recording features a ballad so...,True,"[343, 5267, 3761, 311, 2156, 1632, 5558, 3133,..."
1,-0SdAVK79lg,This song features an electric guitar as the m...,False,"[343, 602, 3133, 1632, 1171, 3100, 941, 665, 5..."
2,-0vPFx-wRRI,a male voice is singing a melody with changing...,True,"[343, 4470, 1719, 2579, 2369, 910, 804, 5582, ..."
3,-0xzrMun0Rs,This song contains digital drums playing a sim...,True,"[343, 602, 3133, 860, 1544, 4688, 2165, 4729, ..."
4,-1LrH01Ei1w,This song features a rubber instrument being p...,False,"[343, 602, 3133, 1632, 834, 3105, 2250, 4711, ..."


In [None]:
# Define the audio captioning dataset
class AudioCaptionDataset(Dataset):
    def __init__(self, captions, embeddings):
        self.captions = captions
        self.embeddings = embeddings

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        caption = self.captions[idx]
        ytid = caption['ytid']
        embedding = self.embeddings[ytid]
        tokenized_caption = torch.LongTensor(caption['tokenized_caption'])
        return {"embedding": embedding, "tokenized_caption": tokenized_caption}

# Define the collate function for the audio captioning dataset
def collate_fn_try(batch):
    embeddings = []
    captions = []
    for b in batch:
        embeddings.append(torch.from_numpy(b['embedding']))
        captions.append(torch.tensor(b['tokenized_caption']))
    padded_embeddings = nn.utils.rnn.pad_sequence(embeddings, batch_first=True)
    padded_captions = nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=word_to_idx["<pad>"])  # Use the <pad> index for padding
    return padded_embeddings, padded_captions

In [None]:
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1, dropout=0.4):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = nn.Dropout(p=dropout)
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)

    def forward(self, input, features):
        embedded = self.embedding(input)
        embedded = self.dropout(embedded)
        # concatenate audio features and input embedding
        inputs = torch.cat((features.unsqueeze(1), embedded), dim=1)
        output, hidden = self.lstm(inputs)
        output = self.fc(output)[:,1:,:]

        return output, hidden


```
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, input, hidden=None):
        c0 = torch.zeros(self.num_layers, input.size(0), self.hidden_size).to(input.device)
        if hidden is None:
            h0 = torch.zeros(self.num_layers, input.size(0), self.hidden_size).to(input.device)
        else:
            if num_layers == 1: h0 = hidden.unsqueeze(0)
            else: h0 = torch.cat((hidden.unsqueeze(0),
                            torch.zeros(num_layers-1, input.size(0), hidden_size).to(device)), dim=0)
        hidden = (h0, c0)

        embedded = self.embedding(input)
        output, hidden = self.lstm(embedded, hidden)
        output = self.fc(output)

        return output, hidden
```

## Model training

In [None]:
# Instantiate the model
vocab_size = len(vocab)
input_size = vocab_size
hidden_size = 512
output_size = vocab_size
num_layers = 2
model = LSTMModel(input_size, hidden_size, output_size, num_layers)

# Instantiate the dataset
train_dataset = AudioCaptionDataset(train_df.to_dict('records'),
                {id: embeddings[id] for id in train_df.ytid.unique()})
eval_dataset = AudioCaptionDataset(eval_df.to_dict('records'),
                {id: embeddings[id] for id in eval_df.ytid.unique()})

# Train the model
lr = 5e-4
batch_size = 64
num_epochs = 50
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=word_to_idx['<pad>'])
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
loss_fn = nn.CrossEntropyLoss(ignore_index=word_to_idx['<pad>'])

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_try)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_try)

In [None]:
# load a batch
batch = next(iter(train_dataloader))
print(f"Embeddings batch, shape={batch[0].shape}")
print(f"Captions batch, shape={batch[1].shape}")

Embeddings batch, shape=torch.Size([64, 512])
Captions batch, shape=torch.Size([64, 133])


In [None]:
def generate_caption(i, max_caption_length=64, show_true_caption=True, show_ytid=True):

  ytid = df_captions.iloc[i]["ytid"]
  true_caption = df_captions.iloc[i]["caption"]
  embedding = torch.from_numpy(embeddings[ytid])

  x = embedding.unsqueeze(0).to(device, dtype=torch.float)
  model.eval()
  # breaks if starting sequence is only one token (?)
  caption = torch.tensor([word_to_idx[word] for word in ['<pad>', '<start>']]).unsqueeze(0).to(device)

  # Generate the caption word by word
  with torch.no_grad():
      while caption[0][-1] != word_to_idx['<end>'] and len(caption[0]) < max_caption_length:
          logits, hidden = model(caption[:, :-1], x)
          predicted_word_index = logits.argmax(-1)[:, -1].item()
          predicted_word = idx_to_word[predicted_word_index]
          caption = torch.cat([caption, torch.tensor([[predicted_word_index]], dtype=torch.long).to(device)], dim=1)

  predicted_caption = ' '.join([idx_to_word[word_idx] for word_idx in caption[0].tolist()][2:-1])

  if show_ytid: print(f"https://www.youtube.com/watch?v={ytid}")
  if show_true_caption: print(f"True caption: {true_caption}")
  print(f"Predicted caption: {predicted_caption}")
  

In [None]:
# Train the model
for epoch in range(num_epochs):
    model.train()  # set model to train mode
    for i, (x, captions) in enumerate(train_dataloader):
        x = x.to(device, dtype=torch.float)
        captions = captions.to(device, dtype=torch.long)
        optimizer.zero_grad()
        logits, hidden = model(captions[:, :-1], x)
        loss = loss_fn(logits.reshape(-1, vocab_size), captions[:, 1:].reshape(-1))
        loss.backward()
        optimizer.step()
    scheduler.step()

    model.eval()  # set model to eval mode
    eval_loss = 0
    with torch.no_grad():
        for x, captions in eval_dataloader:
            x = x.to(device, dtype=torch.float)
            captions = captions.to(device, dtype=torch.long)
            logits, hidden = model(captions[:, :-1], x)
            loss = loss_fn(logits.reshape(-1, vocab_size), captions[:, 1:].reshape(-1))
            eval_loss += loss.item() * x.size(0)  # accumulate loss for entire eval dataset
        eval_loss /= len(eval_dataloader.dataset)  # compute average eval loss
        
    
    print(f"Epoch {epoch}, train loss {loss.item():.4f}, eval loss {eval_loss:.4f}")
    generate_caption(0, show_ytid=(epoch==0), show_true_caption=(epoch==0)) # eval song
    generate_caption(1, show_ytid=(epoch==0), show_true_caption=(epoch==0)) # train song
    print("")

Epoch 0, train loss 5.7788, eval loss 5.8476
https://www.youtube.com/watch?v=-0Gj8-vB1q4
True caption: The low quality recording features a ballad song that contains sustained strings, mellow piano melody and soft female vocal singing over it. It sounds sad and soulful, like something you would hear at Sunday services.
Predicted caption: this this is is is is is is the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the
https://www.youtube.com/watch?v=-0SdAVK79lg
True caption: This song features an electric guitar as the main instrument. The guitar plays a descending run in the beginning then plays an arpeggiated chord followed by a double stop hammer on to a higher note and a descending slide followed by a descending chord run. The percussion plays a simple beat using rim shots. The percussion plays in common time. The bass plays o

In [None]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

10547246

### Model assessment
Predict with the trained LSTM model

In [None]:
i = 3
generate_caption(i)

https://www.youtube.com/watch?v=-0xzrMun0Rs
True caption: This song contains digital drums playing a simple groove along with two guitars. One strumming chords along with the snare the other one playing a melody on top. An e-bass is playing the footnote while a piano is playing a major and minor chord progression. A trumpet is playing a loud melody alongside the guitar. All the instruments sound flat and are being played by a keyboard. There are little bongo hits in the background panned to the left side of the speakers. Apart from the music you can hear eating sounds and a stomach rumbling. This song may be playing for an advertisement.
Predicted caption: the the song tempo is is medium fast tempo with with steady keyboard drumming accompaniment rhythm and and other other percussion instruments playing the the song song is is emotional spirited and and romantic the the audio audio quality quality is is poor poor
