# Fine Tuning DistilHuBERT
- Fine tune transformer architecture
- Make binary classification with similar, dissimilar pairs.

In [None]:
import pandas as pd
import numpy as np
import librosa
import librosa.display
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from dotenv import dotenv_values 
import spotipy
from spotipy.oauth2 import SpotifyClientCredentials
import pickle as pkl

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset

from scipy.spatial.distance import euclidean
from sklearn.model_selection import train_test_split

from transformers import Wav2Vec2Processor, DistilHuBERTForSequenceClassification, Wav2Vec2Model

In [None]:
# Dataset class that does include batching same as for other model.
class SpectrogramDataset(Dataset):

    def __init__(self, file_paths, sr=22050, n_mels=128):
        self.file_paths = file_paths
        self.current_file_index = 0

        self.sr = sr
        self.n_mels = n_mels
        self.load_current_file()      
    
    def load_current_file(self):
        # Load data from the current .pkl file
        self.current_data = pd.read_pickle(self.file_paths[self.current_file_index])
        self.current_anchors = self.current_data['processed_audio'].values
        self.current_positives = self.current_data['augmented_audio'].values
        self.current_negatives = self.current_data['diff_processed_audio'].values
        self.current_file_length = len(self.current_anchors)
        self.current_file_index += 1
    
    def __len__(self):
        total_length = sum(pd.read_pickle(file).shape[0] for file in self.file_paths)
        return total_length
    
    def __getitem__(self, idx):
        while idx >= self.current_file_length:
            idx -= self.current_file_length
            if self.current_file_index < len(self.file_paths):
                self.load_current_file()
            else:
                raise IndexError("Index out of range")

        # Load audio data and select idx'th example and get [0] to get audio from (y, sr) tuple
        anchor = self.current_anchors[idx][0]
        positive = self.current_positives[idx][0]
        negative = self.current_negatives[idx][0]

        # Compute mel spectrograms
        anchor_mel = self._process_audio(anchor)
        positive_mel = self._process_audio(positive)
        negative_mel = self._process_audio(negative)

        return anchor_mel, positive_mel, negative_mel
    
    # Convert raw audio to mel spectrogram
    def _process_audio(self, y):
        mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=self.sr, n_mels=self.n_mels)
        mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
        mel_spectrogram_db = torch.tensor(mel_spectrogram_db, dtype=torch.float32).unsqueeze(0)
        return mel_spectrogram_db

In [None]:
class SimilarityModel(nn.Module):
    def __init__(self, pretrained_model_name='facebook/distilhubert'):
        super(SimilarityModel, self).__init__()
        self.model = Wav2Vec2Model.from_pretrained(pretrained_model_name)
        self.fc = nn.Linear(self.model.config.hidden_size, 128)

    def forward(self, anchor, positive, negative):
        anchor_output = self.model(anchor).last_hidden_state
        positive_output = self.model(positive).last_hidden_state
        negative_output = self.model(negative).last_hidden_state

        anchor_output = self.fc(anchor_output.mean(dim=1))
        positive_output = self.fc(positive_output.mean(dim=1))
        negative_output = self.fc(negative_output.mean(dim=1))

        return anchor_output, positive_output, negative_output

# Example initialization
model = SimilarityModel()

In [None]:
file_paths = [f'../data/augmented_audio/batch_{i}_augmented.pkl' for i in range(1,3,1)]

# Split the files instaed of actual data into training/val
train_files, val_files = train_test_split(file_paths, test_size=0.2, random_state=123)

# Instantiate Dataset Classes
train_dataset = SpectrogramDataset(train_files)
val_dataset = SpectrogramDataset(val_files)

# Declare dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [None]:
# Choose model, loss, and optimizer
model = ResNetEmbedding()
criterion = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Declare losses/accuracies
train_losses = []
val_losses = []

num_epochs = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

In [None]:
# Merge datasets to get spotify links for most similar and most different song

In [None]:
from transformers import Wav2Vec2Processor, DistilHuBERTForSequenceClassification
from datasets import load_dataset
import torch

processor = Wav2Vec2Processor.from_pretrained("facebook/distilhubert-base")
model = DistilHuBERTForSequenceClassification.from_pretrained("facebook/distilhubert-base", num_labels=2)

In [None]:
def preprocess_function(examples):
    audio = examples["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt", padding=True)
    return inputs

dataset = load_dataset("path/to/your/dataset")
dataset = dataset.map(preprocess_function)


In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=processor.feature_extractor,
)

trainer.train()


In [None]:
def measure_similarity(file1, file2, model, processor):
    y1, sr1 = load_audio(file1)
    y2, sr2 = load_audio(file2)
    
    inputs1 = processor(y1, sampling_rate=sr1, return_tensors="pt", padding=True)
    inputs2 = processor(y2, sampling_rate=sr2, return_tensors="pt", padding=True)
    
    with torch.no_grad():
        outputs1 = model(**inputs1).logits
        outputs2 = model(**inputs2).logits
        
    euclidean_distance = torch.nn.functional.pairwise_distance(outputs1, outputs2)
    return euclidean_distance.item()

similarity_score = measure_similarity('audio1.wav', 'audio2.wav', model, processor)
print(f"Similarity score: {similarity_score}")
