# Finetuning Transformer embeddings
- Create Siamese network with transformer archetecture
- Finetune pretrained models to determine song similarity
- The pretrained models considered are
    - Wav2VecBertModel from https://huggingface.co/docs/transformers/v4.44.1/en/model_doc/wav2vec2-bert
    - Wav2Vec2BertForSequenceClassification model from https://huggingface.co/docs/transformers/v4.44.1/en/model_doc/wav2vec2-bert#transformers.Wav2Vec2BertForSequenceClassification
    - HubertForSequenceClassification model from https://huggingface.co/docs/transformers/v4.44.1/en/model_doc/hubert#transformers.HubertForSequenceClassification
    - ASTModel from https://huggingface.co/docs/transformers/en/model_doc/audio-spectrogram-transformer#transformers.ASTModel

In [1]:
import pandas as pd
import numpy as np
import librosa
import librosa.display
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
#from dotenv import dotenv_values 
#import spotipy
#from spotipy.oauth2 import SpotifyClientCredentials
import pickle as pkl

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import train_test_split

from transformers import  HubertForSequenceClassification, Wav2Vec2BertModel, Wav2Vec2BertForSequenceClassification, AutoProcessor, AutoFeatureExtractor, ASTModel

2024-08-27 19:17:18.633850: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-27 19:17:18.633970: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-27 19:17:18.940433: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Define Dataset classes for each model

In [2]:
# Dataset class for the Wav2Vec2BertModel embedding
class W2VBertModelDataset(Dataset):

    def __init__(self, audio, sr=16000):

        self.audio = audio
        self.sr = sr
        self.feature_extractor = AutoProcessor.from_pretrained("hf-audio/wav2vec2-bert-CV16-en")
    
    def __len__(self):
        return self.audio.shape[0]
    
    def __getitem__(self, idx):

        # Load audio data and select idx'th example and get [0] to get audio from (y, sr) tuple
        anchor = self.audio['processed_audio'].values[idx][0]#.astype(np.float16)
        positive = self.audio['augmented_audio'].values[idx][0]#.astype(np.float16)
        negative = self.audio['diff_processed_audio'].values[idx][0]#.astype(np.float16)

        # Preprocess data
        anchor_mel = self.feature_extractor(anchor, sampling_rate=self.sr, return_tensors="pt").input_features
        positive_mel = self.feature_extractor(positive, sampling_rate=self.sr, return_tensors="pt").input_features
        negative_mel = self.feature_extractor(negative, sampling_rate=self.sr, return_tensors="pt").input_features

        return anchor_mel, positive_mel, negative_mel

In [3]:
# Dataset class for the Wav2Vec2BertSequence model
class W2VBSeqDataset(Dataset):

    def __init__(self, audio, sr=16000):

        self.audio = audio
        self.sr = sr
        self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
    
    def __len__(self):
        return self.audio.shape[0]
    
    def __getitem__(self, idx):

        # Load audio data and select idx'th example and get [0] to get audio from (y, sr) tuple
        anchor = self.audio['processed_audio'].values[idx][0]#.astype(np.float16)
        positive = self.audio['augmented_audio'].values[idx][0]#.astype(np.float16)
        negative = self.audio['diff_processed_audio'].values[idx][0]#.astype(np.float16)

        # Preprocess data
        anchor_mel = self.feature_extractor(anchor, sampling_rate=self.sr, return_tensors="pt").input_features
        positive_mel = self.feature_extractor(positive, sampling_rate=self.sr, return_tensors="pt").input_features
        negative_mel = self.feature_extractor(negative, sampling_rate=self.sr, return_tensors="pt").input_features

        return anchor_mel, positive_mel, negative_mel

In [4]:
# Dataset class for the Hubert model
class HubertDataset(Dataset):

    def __init__(self, audio):

        self.audio = audio
        self.feature_extractor = AutoFeatureExtractor.from_pretrained("superb/hubert-base-superb-ks")
    
    def __len__(self):
        return self.audio.shape[0]
    
    def __getitem__(self, idx):

        # Load audio data and select idx'th example and get [0] to get audio from (y, sr) tuple
        anchor = self.audio['processed_audio'].values[idx][0]#.astype(np.float16)
        positive = self.audio['augmented_audio'].values[idx][0]#.astype(np.float16)
        negative = self.audio['diff_processed_audio'].values[idx][0]#.astype(np.float16)
    
        anchor = self.feature_extractor(anchor, sampling_rate=16000, return_tensors="pt").input_values
        positive = self.feature_extractor(positive, sampling_rate=16000, return_tensors="pt").input_values
        negative = self.feature_extractor(negative, sampling_rate=16000, return_tensors="pt").input_values

        return anchor, positive, negative



In [5]:
# Dataset class for the Wav2Vec2BertModel embedding
class ASTModelDataset(Dataset):

    def __init__(self, audio, sr=16000):

        self.audio = audio
        self.sr = sr
        self.feature_extractor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
    
    def __len__(self):
        return self.audio.shape[0]
    
    def __getitem__(self, idx):

        # Load audio data and select idx'th example and get [0] to get audio from (y, sr) tuple
        anchor = self.audio['processed_audio'].values[idx][0]#.astype(np.float16)
        positive = self.audio['augmented_audio'].values[idx][0]#.astype(np.float16)
        negative = self.audio['diff_processed_audio'].values[idx][0]#.astype(np.float16)

        # Compute mel spectrograms
        anchor_mel = self.feature_extractor(anchor, sampling_rate=self.sr, return_tensors="pt").input_values
        positive_mel = self.feature_extractor(positive, sampling_rate=self.sr, return_tensors="pt").input_values
        negative_mel = self.feature_extractor(negative, sampling_rate=self.sr, return_tensors="pt").input_values

        return anchor_mel, positive_mel, negative_mel

## Define model classes

In [7]:
# Use the Wav2VecBertModel from https://huggingface.co/docs/transformers/v4.44.1/en/model_doc/wav2vec2-bert
# with an added fully connected layer to create embedding for Siamese network

class W2VBertModelEmbedding(nn.Module):
    def __init__(self, input_len = 160, embedding_dim=128):
        super(W2VBertModelEmbedding, self).__init__()

        # Load pre-trained Wav2Vec model
        self.model = Wav2Vec2BertModel.from_pretrained("hf-audio/wav2vec2-bert-CV16-en")
        self.model.feature_projection.layer_norm = nn.LayerNorm((input_len,), eps=1e-05, elementwise_affine=True)
        self.model.feature_projection.projection = nn.Linear(in_features=input_len, out_features=1024, bias=True)
        
        # Add a fully connected layer to project the hidden states to the desired embedding dimension
        self.fc = nn.Linear(self.model.config.hidden_size, embedding_dim)
        
        #self.relu = nn.ReLU()

    def forward(self, x):
        # Extract the last hidden state from the model
        x = self.model(x).last_hidden_state
        
        x = self.fc(x[:,0,:])
        
        #x = self.relu(x)
        
        # Normalize the output embeddings
        return F.normalize(x, p=2, dim=1)

In [8]:
# Use the Wav2Vec2BertForSequenceClassification model from https://huggingface.co/docs/transformers/v4.44.1/en/model_doc/wav2vec2-bert#transformers.Wav2Vec2BertForSequenceClassification
# to create embedding for Siamese network
class W2VBSeqEmbedding(nn.Module):
    def __init__(self, input_len = 160, embedding_dim=128):
        super(W2VBSeqEmbedding, self).__init__()

        # Load pre-trained Wav2Vec model
        self.model = Wav2Vec2BertForSequenceClassification.from_pretrained("facebook/w2v-bert-2.0")
        self.model.wav2vec2_bert.feature_projection.layer_norm = nn.LayerNorm((input_len,), eps=1e-05, elementwise_affine=True)
        self.model.wav2vec2_bert.feature_projection.projection = nn.Linear(in_features=input_len, out_features=1024, bias=True)
        
        # Add a fully connected layer to project the hidden states to the desired embedding dimension
        self.model.classifier = nn.Linear(self.model.classifier.in_features, embedding_dim)
        
    def forward(self, x):
        
        x = self.model(x).logits
        
        # Normalize the output embeddings
        return F.normalize(x, p=2, dim=1)

In [9]:
# Use the HubertForSequenceClassification model from https://huggingface.co/docs/transformers/v4.44.1/en/model_doc/hubert#transformers.HubertForSequenceClassification
# to create embedding for Siamese network

class HubertEmbedding(nn.Module):
    def __init__(self, embedding_dim=128):
        super(HubertEmbedding, self).__init__()
        
        # Load pre-trained Hubert model
        self.model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks")
        
        # Add a fully connected layer to project the hidden states to the desired embedding dimension
        self.model.classifier = nn.Linear(self.model.classifier.in_features, embedding_dim)

    def forward(self, x):
        # Extract the last hidden state from the DistilHuBERT model
        x = self.model(x.squeeze(1)).logits#.last_hidden_state
        
        # Apply the fully connected layer to reduce the dimension
        #x = self.fc(x.mean(dim=1))  # Take mean over the time dimension
        
        # Normalize the output embeddings
        return F.normalize(x, p=2, dim=1)

In [10]:
class AudioSpecTransformerModel(torch.nn.Module):
    def __init__(self, embedding_dim=128):
        super(AudioSpecTransformerModel, self).__init__()
        
        self.model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

        self.fc = torch.nn.Linear(self.model.config.hidden_size, embedding_dim)

    def forward(self, x):
        outputs = self.model(x).last_hidden_state
        
        # Get [cls] token embedding for classification/summary of embedding
        x = self.fc(outputs[:, 0, :])
        
        return F.normalize(x, p=2, dim=1)

## Make helper functions

In [12]:

def make_train_loaders(model, dir_path = '/kaggle/input/augmented-music/combined_batch_augmented.pkl', dataset_size = 10000, train_batch_size = 16, val_batch_size = 16):

    '''
        Creates DataLoaders to be used in training and validation

        Arguments: 
            model -- model to be trained
            dataset_size -- int, size of dataset to be imported
            dir_path -- path to dataset
            train(val)_batch_size -- size of training (validation) batch sizes
        Returns:
            train(val)_loader -- DataLoader of training (validation) samples
    '''
    

    audio_data = pd.read_pickle(dir_path).iloc[:dataset_size]
    
    # Make sure the requested dataset size is no bigger than the given dataset
    assert dataset_size <= len(audio_data), 'dataset_size is larger than the imported dataset'
    
    # Make sure the dataset has 'processed_audio', 'augmented_audio', and 'diff_processed_audio' columns
    assert 'processed_audio' and 'augmented_audio' and 'diff_processed_audio' in audio_data.columns, 'DataFrame must contain columns \'processed_audio\', \'augmented_audio\', and \'diff_processed_audio\''

    # Split the data into training and validation sets
    train_data, val_data = train_test_split(audio_data, test_size=0.2, random_state=123)

    try:
        if isinstance(model, W2VBSeqEmbedding):
            train_dataset = W2VBSeqDataset(train_data)
            val_dataset = W2VBSeqDataset(val_data)
        elif isinstance(model, W2VBertModelEmbedding):
            train_dataset = W2VBertModelDataset(train_data)
            val_dataset = W2VBertModelDataset(val_data)            
        elif isinstance(model, HubertEmbedding):
            train_dataset = HubertDataset(train_data)
            val_dataset = HubertDataset(val_data)
        elif isinstance(model, AudioSpecTransformerModel):
            train_dataset = ASTModelDataset(train_data)
            val_dataset = ASTModelDataset(val_data)
    except NameError:
        None


    # Define dataloaders
    train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False)

    return train_loader, val_loader

## Set up the model by freezing all but the deepest layers

In [13]:
#Freeze layers for finetuning
def setup_model(embedding):
    '''
    Freeze all but the last few layers of the model.
    Argument:
        embedding -- model, must be one of the embedding types defined above
    Returns:
        None
    '''

    try:
        if isinstance(embedding, W2VBertModelEmbedding):
            # Freeze all the layers
            for param in embedding.model.parameters():
                param.requires_grad = False

            # Turn back on adapter
            for param in embedding.model.adapter.parameters():
                param.requires_grad = True
                
            # Turn back on last fully connected layer
            for param in embedding.fc.parameters():
                param.requires_grad = True
    except NameError:
        None

    try:
        if isinstance(embedding, W2VBSeqEmbedding):

            # Freeze all the layers
            for param in embedding.model.parameters():
                param.requires_grad = False

            for i in range(21,23):
                # Turn back on ith encoder layer
                for param in embedding.model.wav2vec2_bert.encoder.layers[i].parameters():
                    param.requires_grad = True
                
            # Turn back on projector
            for param in embedding.model.projector.parameters():
                param.requires_grad = True
                
            # Turn back on classifier
            for param in embedding.model.classifier.parameters():
                param.requires_grad = True
    except NameError:
         None

    try:
        if isinstance(embedding, HubertEmbedding):
            # Freeze all the layers
            for param in embedding.model.parameters():
                param.requires_grad = False

            for param in embedding.model.hubert.encoder.pos_conv_embed.conv.parametrizations.weight:
                param.requires_grad = False
                
            # Turn back on last layer
            for param in embedding.model.hubert.encoder.layers[11].parameters():
                param.requires_grad = True
                
            # Turn back on projector
            for param in embedding.model.projector.parameters():
                param.requires_grad = True
                
            # Turn back on classifier
            for param in embedding.model.classifier.parameters():
                param.requires_grad = True
    except NameError:
        None
        
    try:
        if isinstance(embedding, AudioSpecTransformerModel):
            # Freeze all the layers
            for param in embedding.model.parameters():
                param.requires_grad = False

            for i in range(7,11):
                # Turn back on ith encoder layer
                for param in embedding.model.encoder.layer[i].parameters():
                    param.requires_grad = True
                
            # Turn back on projector
            for param in embedding.model.layernorm.parameters():
                param.requires_grad = True
                
            # Turn back on last fully connected layer
            for param in embedding.fc.parameters():
                param.requires_grad = True
    except NameError:
        None
    


## Choose a model, define DataLoaders, choose optimizer, etc

In [14]:
# Choose a model from one of the classes defined above

model = AudioSpecTransformerModel()
#model = W2VBertModelEmbedding()
#model = W2VBSeqEmbedding()
#model = HubertEmbedding()

config.json:   0%|          | 0.00/1.87k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.32G [00:00<?, ?B/s]

Some weights of Wav2Vec2BertForSequenceClassification were not initialized from the model checkpoint at facebook/w2v-bert-2.0 and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
train_loader, val_loader = make_train_loaders(model, 
                                              '/kaggle/input/augmented-music/batch_1_augmented_16000Hz.pkl',
                                              dataset_size = 1000,
                                              train_batch_size = 16,
                                              val_batch_size = 16)

preprocessor_config.json:   0%|          | 0.00/275 [00:00<?, ?B/s]

In [16]:
print('Before set up, this model has', str(sum(p.numel() for p in model.parameters() if p.requires_grad)), 'trainable parameters.')
setup_model(model)
print('After set up, this model now has', str(sum(p.numel() for p in model.parameters() if p.requires_grad)), 'trainable parameters.')

Before set up, this model has 581378752 trainable parameters.
After set up, this model now has 49246208 trainable parameters.


In [17]:
criterion = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)
train_losses = []
val_losses = []
baseline_losses = []

num_epochs = 3

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

W2VBSeqEmbedding(
  (model): Wav2Vec2BertForSequenceClassification(
    (wav2vec2_bert): Wav2Vec2BertModel(
      (feature_projection): Wav2Vec2BertFeatureProjection(
        (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
        (projection): Linear(in_features=160, out_features=1024, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): Wav2Vec2BertEncoder(
        (dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-23): 24 x Wav2Vec2BertEncoderLayer(
            (ffn1_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (ffn1): Wav2Vec2BertFeedForward(
              (intermediate_dropout): Dropout(p=0.0, inplace=False)
              (intermediate_dense): Linear(in_features=1024, out_features=4096, bias=True)
              (intermediate_act_fn): SiLU()
              (output_dense): Linear(in_features=4096, out_features=1024, bias=True)
              (output_dropout): Dropout

## Train model

In [None]:

scaler = torch.cuda.amp.GradScaler()
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    running_train_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Training {epoch+1}/{num_epochs}", unit="sample", position=0, leave=True)
    #train_loader = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
    # Loop over batches using dataloaders
    for anchors, positives, negatives in train_loader:
        anchors, positives, negatives = anchors.to(device), positives.to(device), negatives.to(device)

        optimizer.zero_grad()
            
        with torch.autocast(device.type):
            try:
                if isinstance(model, (W2VBertModelEmbedding, W2VBSeqEmbedding, AudioSpecTransformerModel)):
                    anchor_embeddings = model(anchors.squeeze())
                    positive_embeddings = model(positives.squeeze())
                    negative_embeddings = model(negatives.squeeze())
                else:
                    anchor_embeddings = model(anchors)
                    positive_embeddings = model(positives)
                    negative_embeddings = model(negatives)
            except NameError:
                None

            loss = criterion(anchor_embeddings, positive_embeddings, negative_embeddings)
    
        del anchor_embeddings, positive_embeddings, negative_embeddings
    
        # Scales the loss, and calls backward() to create scaled gradients
        scaler.scale(loss).backward()
            
        # Unscales gradients and calls or skips optimizer.step()
        scaler.step(optimizer)
        
        # Updates the scale for next iteration
        scaler.update()
        #loss.backward()
        #optimizer.step()
        
        running_train_loss += loss.item() * anchors.size(0)
        
        # Update the progress bar by the current batch size
        pbar.set_postfix({'loss': loss.item()})
        pbar.update(1)  # Increment the progress bar
        #del loss
          
    train_loss = running_train_loss / len(train_loader.dataset)
    train_losses.append(train_loss)

    # Turn on validation/eval mode
    model.eval()
    running_val_loss = 0.0 
    running_baseline_loss = 0.0   
    val_pbar = tqdm(val_loader, desc=f"Validation {epoch+1}/{num_epochs}", unit="batch", position=0, leave=True)
    # Turn off gradient updates since we're in validation
    with torch.no_grad():
        # Batch loop 
        for anchors, positives, negatives in val_loader:
            anchors, positives, negatives = anchors.to(device), positives.to(device), negatives.to(device)
            
            try:
                if isinstance(model, (W2VBertModelEmbedding, W2VBSeqEmbedding, AudioSpecTransformerModel)):
                    anchor_embeddings = model(anchors.squeeze())
                    positive_embeddings = model(positives.squeeze())
                    negative_embeddings = model(negatives.squeeze())
                else:
                    anchor_embeddings = model(anchors)
                    positive_embeddings = model(positives)
                    negative_embeddings = model(negatives)
            except NameError:
                None
            
            loss = criterion(anchor_embeddings, positive_embeddings, negative_embeddings)

            # Add to running val loss
            running_val_loss += loss.item() * anchors.size(0)
            
            # baseline loss
            baseline_loss = criterion(F.normalize(anchors), 
                                               F.normalize(positives), 
                                               F.normalize(negatives)).item()
            running_baseline_loss += baseline_loss*anchors.size(0)
            
            # Update the validation progress bar
            val_pbar.update(1)  # Increment the progress bar
            #val_pbar.update(anchors.size(0))
            #val_pbar.set_postfix(loss=loss.item())
    
    # Calculate average validation loss over the entire dataset
    val_loss = running_val_loss / len(val_loader.dataset)
    val_losses.append(val_loss)
    # Do the same for the baseline
    baseline_avg_loss = running_baseline_loss / len(val_loader.dataset)
    baseline_losses.append(baseline_avg_loss)

    
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Baseline Loss: {baseline_avg_loss:.4f}")
    
    with open('training_logs.pkl', 'wb') as f:
        pkl.dump((train_losses, val_losses, baseline_losses), f)

Validation 1/3: 100%|██████████| 13/13 [02:43<00:00, 11.03s/batch]loss=0.436]

Epoch [1/3], Train Loss: 0.9927, Val Loss: 0.9815, Baseline Loss: 0.7368


Training 1/3: 100%|██████████| 50/50 [08:29<00:00, 10.19s/sample, loss=0.436]
Validation 1/3: 100%|██████████| 13/13 [08:35<00:00, 39.66s/batch]loss=1.13] 
Validation 2/3:  23%|██▎       | 3/13 [00:38<02:08, 12.83s/batch]

In [None]:
# Plot loss curves for training
epochs = range(1, num_epochs + 1)

plt.figure(figsize=(10, 5))
plt.plot(epochs, train_losses, label='Training Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.plot(epochs, baseline_losses, label='Baseline Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.tight_layout()
plt.savefig('audio-spec-loss-plot.png')
plt.show()

In [None]:
model.eval()

a,p,n = next(iter(train_loader))

with torch.no_grad():
    out1 = model(a[0].to(device))
    out2 = model(p[0].to(device))
    out3 = model(n[0].to(device))

In [None]:
torch.nn.functional.cosine_similarity(out1,out2)

In [None]:
torch.nn.functional.cosine_similarity(out1,out3)