# Finetuning Transformer embeddings
- Create Siamese network with transformer archetecture
- Finetune pretrained models to determine song similarity

In [1]:
import pandas as pd
import numpy as np
import librosa
import librosa.display
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
#from dotenv import dotenv_values 
#import spotipy
#from spotipy.oauth2 import SpotifyClientCredentials
import pickle as pkl

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import train_test_split

from transformers import  HubertForSequenceClassification, Wav2Vec2BertModel, Wav2Vec2BertForSequenceClassification, AutoProcessor, AutoFeatureExtractor

2024-08-24 17:55:55.739185: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-24 17:55:55.739305: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-24 17:55:55.908299: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# Dataset class for the Wav2Vec2Bert models
class SpectrogramDataset(Dataset):

    def __init__(self, audio, sr=16000, n_mels=128):

        self.audio = audio
        self.sr = sr
        self.n_mels = n_mels
    
    
    def __len__(self):
        return self.audio.shape[0]
    
    def __getitem__(self, idx):

        # Load audio data and select idx'th example and get [0] to get audio from (y, sr) tuple
        anchor = self.audio['processed_audio'].values[idx][0]#.astype(np.float16)
        positive = self.audio['augmented_audio'].values[idx][0]#.astype(np.float16)
        negative = self.audio['diff_processed_audio'].values[idx][0]#.astype(np.float16)

        # Compute mel spectrograms
        anchor_mel = self._process_audio(anchor)
        positive_mel = self._process_audio(positive)
        negative_mel = self._process_audio(negative)

        return anchor_mel, positive_mel, negative_mel
    
    # Convert raw audio to mel spectrogram
    def _process_audio(self, y):
        mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=self.sr, n_mels=self.n_mels)
        mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
        mel_spectrogram_db = torch.tensor(mel_spectrogram_db, dtype=torch.float32).unsqueeze(0)

        return mel_spectrogram_db

In [3]:
# Dataset class for the Hubert model
class RawAudioDataset(Dataset):

    def __init__(self, audio):

        self.audio = audio
        self.feature_extractor = AutoFeatureExtractor.from_pretrained("superb/hubert-base-superb-ks")
    
    def __len__(self):
        return self.audio.shape[0]
    
    def __getitem__(self, idx):

        # Load audio data and select idx'th example and get [0] to get audio from (y, sr) tuple
        anchor = self.audio['processed_audio'].values[idx][0]#.astype(np.float16)
        positive = self.audio['augmented_audio'].values[idx][0]#.astype(np.float16)
        negative = self.audio['diff_processed_audio'].values[idx][0]#.astype(np.float16)
    
        anchor = self.feature_extractor(anchor, sampling_rate=16000, return_tensors="pt")
        positive = self.feature_extractor(positive, sampling_rate=16000, return_tensors="pt")
        negative = self.feature_extractor(negative, sampling_rate=16000, return_tensors="pt")

        return anchor, positive, negative



In [4]:
# Use the Wav2VecBertModel from https://huggingface.co/docs/transformers/v4.44.1/en/model_doc/wav2vec2-bert
# with an added fully connected layer to create embedding for Siamese network

class W2VBertModelEmbedding(nn.Module):
    def __init__(self, input_len = 313, embedding_dim=128):
        super(W2VBertModelEmbedding, self).__init__()

        # Load pre-trained Wav2Vec model
        self.model = Wav2Vec2BertModel.from_pretrained("hf-audio/wav2vec2-bert-CV16-en")
        self.model.feature_projection.layer_norm = nn.LayerNorm((input_len,), eps=1e-05, elementwise_affine=True)
        self.model.feature_projection.projection = nn.Linear(in_features=input_len, out_features=1024, bias=True)
        
        # Add a fully connected layer to project the hidden states to the desired embedding dimension
        self.fc = nn.Linear(self.model.config.hidden_size, embedding_dim)

    def forward(self, x):
        # Extract the last hidden state from the DistilHuBERT model
        x = self.model(x).last_hidden_state
        
        # Apply the fully connected layer to reduce the dimension
        x = self.fc(x.mean(dim=1))  # Take mean over the time dimension
        
        # Normalize the output embeddings
        return F.normalize(x, p=2, dim=1)

In [5]:
# Use only the encoder from Wav2VecBertModel from https://huggingface.co/docs/transformers/v4.44.1/en/model_doc/wav2vec2-bert
# with an added fully connected layer to create embedding for Siamese network

class W2VBertEncoderEmbedding(nn.Module):
    def __init__(self, input_len = 313, embedding_dim=128):
        super(W2VBertEncoderEmbedding, self).__init__()

        # Load pre-trained Wav2Vec model
        self.feature_proj = Wav2Vec2BertModel.from_pretrained("hf-audio/wav2vec2-bert-CV16-en").feature_projection
        self.encoder = Wav2Vec2BertModel.from_pretrained("hf-audio/wav2vec2-bert-CV16-en").encoder
        self.feature_proj.layer_norm = nn.LayerNorm((input_len,), eps=1e-05, elementwise_affine=True)
        self.feature_proj.projection = nn.Linear(in_features=input_len, out_features=1024, bias=True)
        
        # Add a fully connected layer to project the hidden states to the desired embedding dimension
        self.fc = nn.Linear(self.encoder.config.hidden_size, embedding_dim)

    def forward(self, x):
        # Extract the last hidden state from the DistilHuBERT model
        x = self.feature_proj(x)[0]
        x = self.encoder(x).last_hidden_state
        
        # Apply the fully connected layer to reduce the dimension
        x = self.fc(x.mean(dim=1))  # Take mean over the time dimension
        
        # Normalize the output embeddings
        return F.normalize(x, p=2, dim=1)

In [6]:
# Use the Wav2Vec2BertForSequenceClassification model from https://huggingface.co/docs/transformers/v4.44.1/en/model_doc/wav2vec2-bert#transformers.Wav2Vec2BertForSequenceClassification
# to create embedding for Siamese network
class W2VBSeqEmbedding(nn.Module):
    def __init__(self, input_len = 313, embedding_dim=128):
        super(W2VBSeqEmbedding, self).__init__()

        # Load pre-trained Wav2Vec model
        self.model = Wav2Vec2BertForSequenceClassification.from_pretrained("facebook/w2v-bert-2.0")
        self.model.wav2vec2_bert.feature_projection.layer_norm = nn.LayerNorm((input_len,), eps=1e-05, elementwise_affine=True)
        self.model.wav2vec2_bert.feature_projection.projection = nn.Linear(in_features=input_len, out_features=1024, bias=True)
        
        # Add a fully connected layer to project the hidden states to the desired embedding dimension
        self.model.classifier = nn.Linear(self.model.classifier.in_features, embedding_dim)

    def forward(self, x):
        # Extract the last hidden state from the DistilHuBERT model
        x = self.model(x).logits#.last_hidden_state
        
        # Apply the fully connected layer to reduce the dimension
        #x = self.fc(x.mean(dim=1))  # Take mean over the time dimension
        
        # Normalize the output embeddings
        return F.normalize(x, p=2, dim=1)

In [7]:
# Use the HubertForSequenceClassification model from https://huggingface.co/docs/transformers/v4.44.1/en/model_doc/hubert#transformers.HubertForSequenceClassification
# to create embedding for Siamese network

class HubertEmbedding(nn.Module):
    def __init__(self, input_len = 313, embedding_dim=128):
        super(HubertEmbedding, self).__init__()
        
        # Load pre-trained Hubert model
        self.model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks")
        
        # Add a fully connected layer to project the hidden states to the desired embedding dimension
        self.model.classifier = nn.Linear(self.model.classifier.in_features, embedding_dim)

    def forward(self, x):
        # Extract the last hidden state from the DistilHuBERT model
        x = self.model(x.input_values.squeeze(1)).logits#.last_hidden_state
        
        # Apply the fully connected layer to reduce the dimension
        #x = self.fc(x.mean(dim=1))  # Take mean over the time dimension
        
        # Normalize the output embeddings
        return F.normalize(x, p=2, dim=1)

In [8]:
# Define pretrained resnet from Torch Vision resnet 18
class ResNetEmbedding(nn.Module):
    def __init__(self, embedding_dim=128, dropout_rate=0.8):
        # get resnet super class
        super(ResNetEmbedding, self).__init__()
        self.resnet = models.resnet18(weights='DEFAULT')
        # Change structure of first layer to take non RGB images, rest of params same as default
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.dropout = nn.Dropout(p=dropout_rate)
        # Set the last fully connected to a set dimension "embedding_dim" instead of default 1000
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embedding_dim)

    def forward(self, x):
        x = self.resnet(x)
        return F.normalize(x, p=2, dim=1)

In [12]:
# Make sure the dataset has 'processed_audio', 'augmented_audio', and 'diff_processed_audio' columns
def make_train_loaders(model, dir_path = '/kaggle/input/augmented-music/combined_batch_augmented.pkl', train_batch_size = 16, val_batch_size = 16):

    '''
        Creates DataLoaders to be used in training and validation

        Arguments: 
            model -- model to be trained
            dir_path -- path to dataset
            spec_or_raw -- string, must be 'spec' or 'raw' depending on wether the data should be prepared as a spectrogram or as raw audio data
            train(val)_batch_size -- size of training (validation) batch sizes
        Returns:
            train(val)_loader -- DataLoader of training (validation) samples
    '''
    assert isinstance(model, (ResNetEmbedding, W2VBertModelEmbedding,W2VBSeqEmbedding, HubertEmbedding)), 'model must be of class type ResNetEmbedding, W2VBertModelEmbedding,W2VBSeqEmbedding, or HubertEmbedding'

    audio_data = pd.read_pickle(dir_path).iloc[:1000]

    assert 'processed_audio' and 'augmented_audio' and 'diff_processed_audio' in audio_data.columns, 'DataFrame must contain columns \'processed_audio\', \'augmented_audio\', and \'diff_processed_audio\''

    # Split the data into training and validation sets
    train_data, val_data = train_test_split(audio_data, test_size=0.2, random_state=123)

    try:
        if isinstance(model, (ResNetEmbedding, W2VBertModelEmbedding,W2VBSeqEmbedding)):
            train_dataset = SpectrogramDataset(train_data)
            val_dataset = SpectrogramDataset(val_data)
        elif isinstance(model, HubertEmbedding):
            train_dataset = RawAudioDataset(train_data)
            val_dataset = RawAudioDataset(val_data)
    except NameError:
        None


    # Define dataloaders
    train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False)

    return train_loader, val_loader

In [10]:
#Freeze layers for finetuning
def setup_model(embedding):
    '''
    Freeze all but the last few layers of the model.
    Argument:
        embedding -- model, must be one of the embedding types defined above
    Returns:
        None
    '''
    try:
        if isinstance(embedding, ResNetEmbedding):
            # Freeze all the layers
            for param in embedding.resnet.parameters():
                param.requires_grad = False

            # Turn back on last residual block
            for param in embedding.resnet.layer4.parameters():
                param.requires_grad = True

            # Turn back on fully connected layer
            for param in embedding.resnet.fc.parameters():
                param.requires_grad = True
    except NameError:
        None

    try:
        if isinstance(embedding, W2VBertModelEmbedding):
            # Freeze all the layers
            for param in embedding.model.parameters():
                param.requires_grad = False

            # Turn back on adapter
            for param in embedding.model.adapter.parameters():
                param.requires_grad = True
    except NameError:
        None

    try:
        if isinstance(embedding, W2VBSeqEmbedding):

            # Freeze all the layers
            for param in embedding.model.parameters():
                param.requires_grad = False

            # Turn back on last feed forward layer
            for param in embedding.model.wav2vec2_bert.encoder.layers[23].ffn2.parameters():
                param.requires_grad = True

            # Turn back on final layer norm
            for param in embedding.model.wav2vec2_bert.encoder.layers[23].final_layer_norm.parameters():
                param.requires_grad = True
                
            # Turn back on projector
            for param in embedding.model.projector.parameters():
                param.requires_grad = True
                
            # Turn back on classifier
            for param in embedding.model.classifier.parameters():
                param.requires_grad = True
    except NameError:
         None

    try:
        if isinstance(embedding, HubertEmbedding):
            # Freeze all the layers
            for param in embedding.model.parameters():
                param.requires_grad = False

            # Turn back on last feed forward layer
            for param in embedding.model.hubert.encoder.layers[11].feed_forward.parameters():
                param.requires_grad = True

            # Turn back on final layer norm
            for param in embedding.model.hubert.encoder.layers[11].final_layer_norm.parameters():
                param.requires_grad = True
                
            # Turn back on projector
            for param in embedding.model.projector.parameters():
                param.requires_grad = True
                
            # Turn back on classifier
            for param in embedding.model.classifier.parameters():
                param.requires_grad = True
    except NameError:
        None

In [11]:
# Choose a model from one of the classes defined above

#model = W2VBertModelEmbedding()
#model = W2VBSeqEmbedding()
#model = HubertEmbedding()
model = ResNetEmbedding()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 134MB/s] 


In [18]:
train_loader, val_loader = make_train_loaders(model, 
                                              '/kaggle/input/augmented-music/batch_1_augmented_16000Hz.pkl', 
                                              train_batch_size = 16)

In [14]:
print('This model has', str(sum(p.numel() for p in model.parameters() if p.requires_grad)), 'trainable parameters.')

This model has 11235904 trainable parameters.


In [15]:
setup_model(model)

In [16]:
print('This model now has', str(sum(p.numel() for p in model.parameters() if p.requires_grad)), 'trainable parameters.')

This model now has 8459392 trainable parameters.


In [19]:
criterion = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)
train_losses = []
val_losses = []
baseline_losses = []

num_epochs = 1

In [20]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

ResNetEmbedding(
  (resnet): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tr

In [25]:

scaler = torch.cuda.amp.GradScaler()
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    running_train_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Training {epoch+1}/{num_epochs}", unit="sample", position=0, leave=True)
    #train_loader = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
    # Loop over batches using dataloaders
    for anchors, positives, negatives in train_loader:
        anchors, positives, negatives = anchors.to(device), positives.to(device), negatives.to(device)

        optimizer.zero_grad()
            
        with torch.autocast(device.type):
            try:
                if isinstance(model, (W2VBertModelEmbedding,W2VBSeqEmbedding)):
                    anchor_embeddings = model(anchors.squeeze())
                    positive_embeddings = model(positives.squeeze())
                    negative_embeddings = model(negatives.squeeze())
                elif isinstance(model, HubertEmbedding):
                    anchor_embeddings = model(anchors)
                    positive_embeddings = model(positives)
                    negative_embeddings = model(negatives)
                elif isinstance(model, ResNetEmbedding):
                    anchor_embeddings = model(anchors)
                    positive_embeddings = model(positives)
                    negative_embeddings = model(negatives)
            except NameError:
                None

            loss = criterion(anchor_embeddings, positive_embeddings, negative_embeddings)
    
        del anchor_embeddings, positive_embeddings, negative_embeddings
    
        # Scales the loss, and calls backward() to create scaled gradients
        scaler.scale(loss).backward()
            
        # Unscales gradients and calls or skips optimizer.step()
        scaler.step(optimizer)
        
        # Updates the scale for next iteration
        scaler.update()
        #loss.backward()
        #optimizer.step()
        
        try:
            if isinstance(model, (W2VBertModelEmbedding,W2VBSeqEmbedding, ResNetEmbedding)):
                running_train_loss += loss.item() * anchors.size(0)
            elif isinstance(model, HubertEmbedding):
                running_train_loss += loss.item() * anchors.input_values.size(0)
        except NameError:
            None
        # Update the progress bar by the current batch size
        pbar.set_postfix({'loss': loss.item()})
        pbar.update(1)  # Increment the progress bar
        #del loss
          
    train_loss = running_train_loss / len(train_loader.dataset)
    train_losses.append(train_loss)

    # Turn on validation/eval mode
    model.eval()
    running_val_loss = 0.0 
    running_baseline_loss = 0.0   
    val_pbar = tqdm(val_loader, desc=f"Validation {epoch+1}/{num_epochs}", unit="batch", position=0, leave=True)
    # Turn off gradient updates since we're in validation
    with torch.no_grad():
        # Batch loop 
        for anchors, positives, negatives in val_loader:
            anchors, positives, negatives = anchors.to(device), positives.to(device), negatives.to(device)
            
            try:
                if isinstance(model, (W2VBertModelEmbedding,W2VBSeqEmbedding)):
                    anchor_embeddings = model(anchors.squeeze())
                    positive_embeddings = model(positives.squeeze())
                    negative_embeddings = model(negatives.squeeze())
                elif isinstance(model, HubertEmbedding):
                    anchor_embeddings = model(anchors)
                    positive_embeddings = model(positives)
                    negative_embeddings = model(negatives)
                elif isinstance(model, ResNetEmbedding):
                    anchor_embeddings = model(anchors)
                    positive_embeddings = model(positives)
                    negative_embeddings = model(negatives)
            except NameError:
                None
            
            loss = criterion(anchor_embeddings, positive_embeddings, negative_embeddings)

            # Add to running val loss
            running_val_loss += loss.item() * anchors.size(0)
            
            # baseline loss
            baseline_loss = criterion(F.normalize(anchors), 
                                               F.normalize(positives), 
                                               F.normalize(negatives)).item()
            running_baseline_loss += baseline_loss*anchors.size(0)
            
            # Update the validation progress bar
            val_pbar.update(1)  # Increment the progress bar
            #val_pbar.update(anchors.size(0))
            #val_pbar.set_postfix(loss=loss.item())
    
    # Calculate average validation loss over the entire dataset
    val_loss = running_val_loss / len(val_loader.dataset)
    val_losses.append(val_loss)
    # Do the same for the baseline
    baseline_avg_loss = running_baseline_loss / len(val_loader.dataset)
    baseline_losses.append(baseline_avg_loss)

    
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Baseline Loss: {baseline_avg_loss:.4f}")
    
    with open('training_logs.pkl', 'wb') as f:
        pkl.dump((train_losses, val_losses, baseline_losses), f)

Training 1/1: 100%|██████████| 50/50 [04:05<00:00,  4.91s/sample, loss=0.625]
Validation 1/1: 100%|██████████| 13/13 [04:07<00:00, 19.01s/batch]loss=0.378]
Validation 1/1: 100%|██████████| 13/13 [00:16<00:00,  1.05s/batch]

Epoch [1/1], Train Loss: 0.5380, Val Loss: 0.7557, Baseline Loss: 0.9980


In [15]:
model.eval()

with torch.no_grad():
    out1 = model(train_dataset[0][0].to(device))
    out2 = model(train_dataset[0][1].to(device))
    out3 = model(train_dataset[0][2].to(device))

In [16]:
torch.nn.functional.cosine_similarity(out1,out2)

tensor([0.9950], device='cuda:0')

In [17]:
torch.nn.functional.cosine_similarity(out1,out3)

tensor([0.9946], device='cuda:0')