# Fine Tuning Transformer for Audio Similarity Analysis
- Load and fine-tune [Audio Spectrogram Transformer ](https://arxiv.org/abs/2104.01778) from [HuggingFace](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593) for use in song similarity. 
- Create dataset class for feeding spectrograms into Transformer architecture
- Train on the triplet loss to
    1. Minimize euclidean distance between anchor positive
    2. Maximize euclidean distance between anchor negative
- Output model weights for deployment on songs with real covers     

#### Import Dependencies

In [1]:
import pandas as pd
import numpy as np
import librosa
import librosa.display
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from dotenv import dotenv_values 
import pickle as pkl

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

from scipy.spatial.distance import euclidean
from sklearn.model_selection import train_test_split
from transformers import Wav2Vec2Model
from transformers import AutoProcessor, AutoModel, ASTModel, AutoFeatureExtractor, AutoModelForAudioClassification

#### Create Dataset Class

In [11]:
class SpectrogramDataset(Dataset):
    def __init__(self, file_paths, transform=False, sr=22050, target_sr=16000, n_mels=128):
        self.file_paths = file_paths
        self.data_index = self._build_index()
        self.sr = sr
        self.target_sr=target_sr
        self.n_mels = n_mels
        self.transform = transform
        self.feature_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

    def _build_index(self):
        index = []
        for file_idx, file_path in enumerate(self.file_paths):
            with open(file_path, 'rb') as f:
                data = pkl.load(f)
                for i in range(len(data)):
                    index.append((file_idx, i))
        return index

    def _get_log_mel_spectrogram(self, y):
        # Convert to mel spectrogram
        mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=self.sr, n_mels=self.n_mels)
        # Convert to log scale
        log_mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
        log_mel_spectrogram = torch.tensor(log_mel_spectrogram, dtype=torch.float32).unsqueeze(0)
        return log_mel_spectrogram

    def __len__(self):
        return len(self.data_index)

    def __getitem__(self, idx):
        file_idx, data_idx = self.data_index[idx]
        file_path = self.file_paths[file_idx]

        with open(file_path, 'rb') as f:
            data = pkl.load(f)
        
        row = data.iloc[data_idx]
        anchor = row['processed_audio'][0]  # (y, sr)
        positive = row['augmented_audio'][0]
        negative = row['diff_processed_audio'][0]
        
        anchor = librosa.resample(anchor, orig_sr=self.sr, target_sr=self.target_sr)
        positive = librosa.resample(positive, orig_sr=self.sr, target_sr=self.target_sr)
        negative = librosa.resample(negative, orig_sr=self.sr, target_sr=self.target_sr)
        
        #inputs = self.feature_extractor(audio_data, sampling_rate=self.sampling_rate, return_tensors="pt")
        anchor = self.feature_extractor(anchor, sampling_rate=self.target_sr, return_tensors='pt')
        positive = self.feature_extractor(positive, sampling_rate=self.target_sr, return_tensors='pt')
        negative = self.feature_extractor(negative, sampling_rate=self.target_sr, return_tensors='pt')
        
        anchor_mel = anchor["input_values"].squeeze(0)
        positive_mel = positive['input_values'].squeeze(0)
        negative_mel = negative['input_values'].squeeze(0)
        
        #print(f"{anchor=}")
        # Convert to log mel spectrograms
        #anchor_mel = self._get_log_mel_spectrogram(anchor)
        #positive_mel = self._get_log_mel_spectrogram(positive)
        #negative_mel = self._get_log_mel_spectrogram(negative)
        # Pass in audio from librosa 
        #anchor_mel = torch.tensor(anchor, dtype=torch.float32).unsqueeze(0)
        #positive_mel = torch.tensor(positive, dtype=torch.float32).unsqueeze(0)
        #negative_mel = torch.tensor(negative, dtype=torch.float32).unsqueeze(0)
        
        # Apply any transformations
        if self.transform:
            anchors = self.transform(anchors)
            positives = self.transform(positives)
            negatives = self.transform(negatives)

        return anchor_mel, positive_mel, negative_mel


### Split Data and Instantiate Dataset Class and DataLoaders

In [12]:
file_paths = [f'/kaggle/input/augmented-audio-10k/batch_{i}_augmented.pkl' for i in range(1,10,1)]

# Split the files instaed of actual data into training/val
train_files, val_files = train_test_split(file_paths, test_size=0.2, random_state=123)

# Instantiate Dataset Classes
train_dataset = SpectrogramDataset(train_files)
val_dataset = SpectrogramDataset(val_files)

# Declare dataloaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)#, drop_last=True)

In [13]:
#test = train_loader.dataset[0][0]

### Delcaring the Model
- Define architecture: standard CNN with batch norms and pooling to create 128 dim embeddings
- Choose loss function, optimizer, device, etc.
**Note:**
- Wave2Vec2Model expects input in shape [batch size, sequence length] or [batch size, channels, sequence length]

In [23]:
class AudioSpecTransformerModel(torch.nn.Module):
    def __init__(self, pretrained_model_name="MIT/ast-finetuned-audioset-10-10-0.4593", embedding_dim=128, dropout_rate=0.5):
        super(AudioSpecTransformerModel, self).__init__()
        
        #self.model = AutoModel.from_pretrained(pretrained_model_name)
        #self.model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
        self.model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
        #self.model = Wav2Vec2Model.from_pretrained(pretrained_model_name)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.fc = torch.nn.Linear(self.model.config.hidden_size, embedding_dim)

    def forward(self, x):
        #print(f"Input shape to model in forward: {x.shape}")
        #print(self.model)
        #print(f"{x.view(x.size(0), -1).shape=}")
        #print(f"{x.view(x.size(0), -1)=}")
        #print(f"{self.model(x).shape=}")
        #print(f"{self.model(x)=}")
        #output = self.model(x).last_hidden_state
        #print(f"Output Shape After Last Hidden state: {output.shape}")
        #print(f"{output=}")
        #x = x.mean(dim=1)
        #embeddings = self.fc(x)
        #print(f"Embeddings Shape: {embeddings.shape}")
        #return F.normalize(embeddings, p=2, dim=1)
        # Pass through the transformer model
        outputs = self.model(x).last_hidden_state
        
        # Apply the final fully connected layer to get the embeddings
        x = self.fc(outputs[:, 0, :])  # Use the [CLS] token embedding
        
        return F.normalize(x, p=2, dim=1)

In [None]:
# Choose model, loss, and optimizer
model = AudioSpecTransformerModel()
criterion = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-4)

# Declare losses/accuracies
train_losses = []
val_losses = []
baseline_losses = []

num_epochs = 15
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

### Training Model

In [17]:
def save_checkpoint(model, optimizer, epoch, file_path="distilhubert_checkpoint.pth"):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(checkpoint, file_path)
    print(f"Checkpoint saved at epoch {epoch} to {file_path}")

def load_checkpoint(model, optimizer, file_path="/kaggle/working/distilhubert_checkpoint.pth"):
    if os.path.isfile(file_path):
        checkpoint = torch.load(file_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
        return start_epoch
    else:
        print("No checkpoint found. Starting from scratch.")
        return 0

In [None]:
checkpoint_path = "/kaggle/working/distilhubert_checkpoint.pth"
save_frequency = 5  # Save every 5 epochs

# Load checkpoint if available
start_epoch = load_checkpoint(model, optimizer, checkpoint_path)

# Loop over epochs
for epoch in range(start_epoch, num_epochs):
    model.train()
    train_loss = 0.0
    running_train_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Training {epoch+1}/{num_epochs}", unit="batch")

    # Loop over batches using dataloaders
    for anchors, positives, negatives in train_loader:
        anchors, positives, negatives = anchors.to(device), positives.to(device), negatives.to(device)
        # current shape is [Batch size, channels, height, width]
        optimizer.zero_grad()

        anchor_embeddings = model(anchors)
        positive_embeddings = model(positives)
        negative_embeddings = model(negatives)
        
        loss = criterion(anchor_embeddings, positive_embeddings, negative_embeddings)
        loss.backward()
        optimizer.step()
        
        running_train_loss += loss.item() * anchors.size(0)
        pbar.update(1)
          
    train_loss = running_train_loss / len(train_loader.dataset)
    train_losses.append(train_loss)

    # Turn on validation/eval mode
    model.eval()
    running_val_loss = 0.0 
    running_baseline_loss = 0.0   
    val_pbar = tqdm(val_loader, desc=f"Validation {epoch+1}/{num_epochs}", unit="batch")
    
    # Turn off gradient updates since we're in validation
    with torch.no_grad():
        # Batch loop 
        for anchors, positives, negatives in tqdm(val_loader):
            anchors, positives, negatives = anchors.to(device), positives.to(device), negatives.to(device)
            
            anchor_embeddings = model(anchors)
            positive_embeddings = model(positives)
            negative_embeddings = model(negatives)
            
            loss = criterion(anchor_embeddings, positive_embeddings, negative_embeddings)

            # Add to running val loss
            running_val_loss += loss.item() * anchors.size(0)
            
            # baseline loss
            baseline_loss = criterion(F.normalize(anchors), 
                                               F.normalize(positives), 
                                               F.normalize(negatives)).item()
            running_baseline_loss += baseline_loss*anchors.size(0)
            
            # Update the validation progress bar
            val_pbar.update(1)
    
    # Calculate average validation loss over the entire dataset
    val_loss = running_val_loss / len(val_loader.dataset)
    val_losses.append(val_loss)
    # Do the same for the baseline
    baseline_avg_loss = running_baseline_loss / len(val_loader.dataset)
    baseline_losses.append(baseline_avg_loss)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Baseline Loss: {baseline_avg_loss:.4f}")
    
    if (epoch+1)% save_frequency == 0:
        save_checkpoint(model, optimizer, epoch, checkpoint_path)
    
    with open('training_logs.pkl', 'wb') as f:
        pkl.dump((train_losses, val_losses), f)

In [None]:
# Plot loss curves for training
epochs = range(1, num_epochs + 1)

plt.figure(figsize=(10, 5))
plt.plot(epochs, train_losses, label='Training Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.plot(epochs, baseline_losses, label='Baseline Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.tight_layout()
plt.savefig('distilhubert-loss-plot.png')
plt.show()


In [None]:
# Save just the model weights (recommended apparently for portability/compatibility)
torch.save(model.state_dict(), 'distilhubert_weights.pth')

In [None]:
# Save the entire model so we can use it for deployment
torch.save(model, 'distilhubert_model.pth')