# Fine Tuning Resnet
- Prepare data for training/validation. Create dataloader
- Load in resnet model
- Create architecture for fine tuning including pytorch/tensorflow boilerplate

In [None]:
#!pip install -q --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu118

In [1]:
!pip install -q spotipy

In [1]:
import pandas as pd
import numpy as np
import librosa
import librosa.display
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from dotenv import dotenv_values 
#import spotipy
#from spotipy.oauth2 import SpotifyClientCredentials
import pickle as pkl

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

from scipy.spatial.distance import euclidean
from sklearn.model_selection import train_test_split
from transformers import Wav2Vec2Model

## Data Preparation
- Import data
- Calculate mel specs
- Create data set class for spectrograms/chromagrams/tempograms

In [None]:
explore_df = pd.read_pickle('/Users/reggiebain/erdos/song-similarity-erdos-old/data/augmented_audio/batch_1_augmented.pkl')
explore_df.head(2)

In [None]:
def make_mel_spectrogram(y, sr, n_mels=128, fmax=8000):
    S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels, fmax=fmax)
    S_dB = librosa.power_to_db(S, ref=np.max)
    return S_dB

In [None]:
# Apply the function to the DataFrame
tqdm.pandas(desc=f"Making Anchor Mel Spectrograms...")
working_df['anchors'] = working_df['processed_audio'].progress_apply(lambda row: make_mel_spectrogram(row[0], row[1]))
tqdm.pandas(desc=f"Making Similar Mel Spectrograms...")
working_df['positives'] = working_df['augmented_audio'].progress_apply(lambda row: make_mel_spectrogram(row[0], row[1]))
tqdm.pandas(desc=f"Making Differet Mel Spectrograms...")
working_df['negatives'] = working_df['diff_processed_audio'].progress_apply(lambda row: make_mel_spectrogram(row[0], row[1]))

# Create new dataframe
dataset_df = working_df[['anchors', 'positives', 'negatives']]
dataset_df.head(2)

In [None]:
# Plot sample just to check
test_mel_spec = dataset_df['anchors'].iloc[0]

plt.figure(figsize=(10, 4))
librosa.display.specshow(test_mel_spec, sr=22050, x_axis='time', y_axis='mel', fmax=8000)
plt.colorbar(format='%+2.0f dB')
plt.title('Mel Spectrogram')
plt.tight_layout()
plt.show()

#### Create Dataset Class

In [41]:
# Dataset class that does include batching
class SpectrogramDataset(Dataset):

    def __init__(self, file_paths, sr=22050, n_mels=128):
        self.file_paths = file_paths
        self.current_file_index = 0
        self.current_file_length = 0
        self.current_data = None
        self.sr = sr
        self.n_mels = n_mels
        self.load_current_file()      
    
    def load_current_file(self):
        # Load data from the current .pkl file if we have any files left
        if self.current_file_index < len(self.file_paths):
            print(f"Loading file {self.file_paths[self.current_file_index]}")
            self.current_data = pd.read_pickle(self.file_paths[self.current_file_index])
            self.current_anchors = self.current_data['processed_audio'].values
            self.current_positives = self.current_data['augmented_audio'].values
            self.current_negatives = self.current_data['diff_processed_audio'].values
            self.current_file_length = len(self.current_anchors)
            print(f"Loaded {self.current_file_length} samples from file {self.file_paths[self.current_file_index]}")
            self.current_file_index += 1
        # If no remaining files, set to zero so we don't do anything
        else:
            self.current_data = None
            self.current_file_length = 0
            print('No more files to load.')
    
    def __len__(self):
        # Return the total length of all datasets in all pickle files
        total_length = sum(pd.read_pickle(file).shape[0] for file in self.file_paths)
        return total_length
    
    def __getitem__(self, idx):
        # Go through files until you find the index that you need to add more data
        #while self.current_data is not None and idx >= self.current_file_length:
        #    print(f"Index {idx} exceeds current file length {self.current_file_length}. Loading next file.")
        #    idx -= self.current_file_length
        #    self.load_current_file()
        #if self.current_data is None:
        #    raise IndexError('Index out of range as no more data is available.')
        print(f"Starting __getitem__ with idx: {idx}")
        while idx >= self.current_file_length:
            print(f"Index {idx} exceeds current file length {self.current_file_length}. Loading next file.")
            idx -= self.current_file_length
            if self.current_file_index < len(self.file_paths):
                self.load_current_file()
            else:
                raise IndexError("Index out of range inside while loop")
        if self.current_data is None:
            raise IndexError("index out of range outside while loop")
        print(f"Accessing index {idx} in current file.")
        # Load audio data and select idx'th example and get [0] to get audio from (y, sr) tuple
        anchor = self.current_anchors[idx][0]
        positive = self.current_positives[idx][0]
        negative = self.current_negatives[idx][0]

        # Compute mel spectrograms
        anchor_mel = self._process_audio(anchor)
        positive_mel = self._process_audio(positive)
        negative_mel = self._process_audio(negative)

        return anchor_mel, positive_mel, negative_mel
    
    # Convert raw audio to mel spectrogram
    def _process_audio(self, y):
        mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=self.sr, n_mels=self.n_mels)
        mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
        mel_spectrogram_db = torch.tensor(mel_spectrogram_db, dtype=torch.float32).unsqueeze(0)
        return mel_spectrogram_db

In [2]:
class SpectrogramDataset(Dataset):
    def __init__(self, file_paths, transform=False, sr=22050, n_mels=128):
        self.file_paths = file_paths
        self.data_index = self._build_index()
        self.sr = sr
        self.n_mels = n_mels
        self.transform = transform
        
    def _build_index(self):
        index = []
        for file_idx, file_path in enumerate(self.file_paths):
            with open(file_path, 'rb') as f:
                data = pkl.load(f)
                for i in range(len(data)):
                    index.append((file_idx, i))
        return index

    def _get_log_mel_spectrogram(self, y):
        # Convert to mel spectrogram
        mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=self.sr, n_mels=self.n_mels)
        # Convert to log scale
        log_mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
        log_mel_spectrogram = torch.tensor(log_mel_spectrogram, dtype=torch.float32).unsqueeze(0)
        return log_mel_spectrogram

    def __len__(self):
        return len(self.data_index)

    def __getitem__(self, idx):
        file_idx, data_idx = self.data_index[idx]
        file_path = self.file_paths[file_idx]

        with open(file_path, 'rb') as f:
            data = pkl.load(f)
        
        row = data.iloc[data_idx]
        anchor = row['processed_audio'][0]  # (y, sr)
        positive = row['augmented_audio'][0]
        negative = row['diff_processed_audio'][0]

        # Convert to log mel spectrograms
        anchor_mel = self._get_log_mel_spectrogram(anchor)
        positive_mel = self._get_log_mel_spectrogram(positive)
        negative_mel = self._get_log_mel_spectrogram(negative)
        
        # Apply any transformations
        if self.transform:
            anchors = self.transform(anchors)
            positives = self.transform(positives)
            negatives = self.transform(negatives)

        return anchor_mel, positive_mel, negative_mel


### Split Data and Instantiate Dataset Class

In [7]:
file_paths = [f'/Users/reggiebain/erdos/song-similarity-erdos-old/data/augmented_audio/batch_{i}_augmented.pkl' for i in range(1,10,1)]

# Split the files instaed of actual data into training/val
train_files, val_files = train_test_split(file_paths, test_size=0.2, random_state=123)

# Instantiate Dataset Classes
train_dataset = SpectrogramDataset(train_files)
val_dataset = SpectrogramDataset(val_files)

# Declare dataloaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True,  pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False,pin_memory=True)

### Delcaring the Model
- Define architecture: default resnet with adjusted first conv layer and final FC layer to set num params
- Choose loss function, optimizer, device, etc.

In [8]:
# Define pretrained resnet from Torch Vision resnet 18
class ResNetEmbedding(nn.Module):
    def __init__(self, embedding_dim=128, dropout_rate=0.5):
        # get resnet super class
        super(ResNetEmbedding, self).__init__()
        self.resnet = models.resnet18(weights='DEFAULT')
        # Change structure of first layer to take non RGB images, rest of params same as default
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        #self.dropout = nn.Dropout(p=dropout_rate)
        # Set the last fully connected to a set dimension "embedding_dim" instead of default 1000
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embedding_dim)

    def forward(self, x):
        x = self.resnet(x)
        return F.normalize(x, p=2, dim=1)

In [None]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
import torch.nn as nn

class WhisperEmbedding(nn.Module):
    def __init__(self, embedding_dim=128, pretrained_model_name="openai/whisper-tiny"):
        super(WhisperEmbedding, self).__init__()
        self.whisper = WhisperForConditionalGeneration.from_pretrained(pretrained_model_name)
        self.fc = nn.Linear(self.whisper.config.hidden_size, embedding_dim)

    def forward(self, x):
        # Pass through Whisper's encoder
        outputs = self.whisper.encoder(x)
        # Get the mean of the hidden states to create an embedding
        embedding = outputs.last_hidden_state.mean(dim=1)
        # Pass through a linear layer to get the final embedding
        embedding = self.fc(embedding)
        # Normalize the embedding
        return torch.nn.functional.normalize(embedding, p=2, dim=1)

In [9]:
# Choose model, loss, and optimizer
model = ResNetEmbedding()
criterion = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)

# Freeze all layers except the last block and the fully connected layer
for param in model.resnet.parameters():
    param.requires_grad = False

for param in model.resnet.layer4.parameters():
    param.requires_grad = True

for param in model.resnet.fc.parameters():
    param.requires_grad = True

# Use a smaller learning rate for fine-tuning
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
#optimizer = optim.Adam(lr=1e-5, weight_decay=1e-4)

# Declare losses/accuracies
train_losses = []
val_losses = []
baseline_losses = []

num_epochs = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

ResNetEmbedding(
  (resnet): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tr

### Training Model
- Go in training mode, zero gradients with 

In [10]:
# Loop over epochs
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    running_train_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Training {epoch+1}/{num_epochs}", unit="batch")

    # Loop over batches using dataloaders
    for anchors, positives, negatives in train_loader:
        anchors, positives, negatives = anchors.to(device), positives.to(device), negatives.to(device)
        optimizer.zero_grad()
        
        anchor_embeddings = model(anchors)
        positive_embeddings = model(positives)
        negative_embeddings = model(negatives)
        
        loss = criterion(anchor_embeddings, positive_embeddings, negative_embeddings)
        loss.backward()
        optimizer.step()
        
        running_train_loss += loss.item() * anchors.size(0)
        # Update the progress bar by the current batch size
        pbar.update(1)  # Increment the progress bar
        #pbar.update(anchors.size(0))
        #pbar.set_postfix(loss=loss.item())
    
          
    train_loss = running_train_loss / len(train_loader.dataset)
    train_losses.append(train_loss)
        #train_loss += loss.item()

    # Get avg train loss over batches
    #train_loss /= len(train_loader)
    
    # Turn on validation/eval mode
    model.eval()
    running_val_loss = 0.0 
    baseline_loss = 0.0   
    val_pbar = tqdm(val_loader, desc=f"Validation {epoch+1}/{num_epochs}", unit="batch")
    # Turn off gradient updates since we're in validation
    with torch.no_grad():
        # Batch loop 
        for anchors, positives, negatives in val_loader:
            anchors, positives, negatives = anchors.to(device), positives.to(device), negatives.to(device)
            
            anchor_embeddings = model(anchors)
            positive_embeddings = model(positives)
            negative_embeddings = model(negatives)
            
            loss = criterion(anchor_embeddings, positive_embeddings, negative_embeddings)
            #val_loss += loss.item()

            running_val_loss += loss.item() * anchors.size(0)
            # baseline loss
            baseline_loss += criterion(anchors, positives, negatives).item()
            
            # Update the validation progress bar
            val_pbar.update(1)  # Increment the progress bar
            #val_pbar.update(anchors.size(0))
            #val_pbar.set_postfix(loss=loss.item())
    
    # Calculate average validation loss over the entire dataset
    val_loss = running_val_loss / len(val_loader.dataset)
    val_losses.append(val_loss)  

    baseline_avg_loss = baseline_loss / len(val_loader.dataset)
    baseline_losses.append(baseline_avg_loss)
    # Get average val loss over batches
    #val_loss /= len(val_loader)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Baseline Loss: {baseline_avg_loss:.4f}")
    
    with open('training_logs.pkl', 'wb') as f:
        pkl.dump((train_losses, val_losses), f)

Training 1/5:   0%|          | 0/110 [02:02<?, ?batch/s]


KeyboardInterrupt: 

In [None]:
# Plot loss curves for training
epochs = range(1, num_epochs + 1)

plt.figure(figsize=(10, 5))
plt.plot(epochs, train_losses, label='Training Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.plot(epochs, baseline_losses, label='Baseline Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.tight_layout()
plt.savefig('resnet-loss-plot.png')
plt.show()


In [None]:
# Save just the model weights (recommended apparently for portability/compatibility)
torch.save(model.state_dict(), 'resnet18_model_weights.pth')

In [20]:
# Save the entire model so we can use it for deployment
torch.save(model, 'resnet18_model.pth')

### Deploy on any 2 songs
- Use model to calculate embeddings (using eval mode specifically and with no gradient updating)
- NOTE: Right now, we have to have the input data in the correct format: a spectrogram/chromagram/tempogram (generically called "gram"). So for any deployment, we'll have to do preprocessing in the streamlit app for example. OR we can have a set of say 10-15 sample songs you can compare where we've already done all of the calculations.
- **Similarity values key:**
    - 0.5 to 1: Very similar. Perhaps the same song.
    - 0 to 0.5: Somewhat similar. Share some key characteristics
    - -1 to 0: Low to no similarity. Different songs.


In [None]:
# How to load the model later using just the state dictionary
model = ResNetEmbedding()  # Make sure this matches the architecture you used
model.load_state_dict(torch.load('resnet18_model_weights.pth'))

# If using a GPU
model.to(device)

In [None]:
def extract_embedding(model, audio_data_clip, sr=22050, use_model=True):
    y = audio_data_clip
    #y, sr = librosa.load(audio_data_clip, sr=sr)
    mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr)
    mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
    
    # Convert to tensor and move to the appropriate device
    mel_tensor = torch.tensor(mel_spectrogram_db, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
    
    if use_model:
        # Get the embedding from the model
        with torch.no_grad():
            embedding = model(mel_tensor)
        
        # Normalize the embedding
        #embedding = F.normalize(embedding, p=2, dim=1)
        return embedding
    else:
        return mel_tensor


def compute_cosine_similarity(embedding1, embedding2):
    # Compute cosine similarity
    cosine_sim = F.cosine_similarity(embedding1, embedding2)
    return cosine_sim.item()  # Convert to a Python float


In [None]:
explore_df = pd.read_pickle('/kaggle/input/augmented-audio-10k/batch_1_augmented.pkl')
explore_df.head(2)

In [None]:
# Test on some random training/validation data as sanity check
y1, y2 = explore_df['processed_audio'][10000][0], explore_df['diff_processed_audio'][10000][0]
mel1, mel2 = extract_embedding(model, y1), extract_embedding(model, y2)
compute_cosine_similarity(mel1, mel2)

In [None]:
# Test on some random training/validation data as sanity check
y1, y3 = explore_df['processed_audio'][10000][0], explore_df['augmented_audio'][10000][0]
mel1, mel3 = extract_embedding(model, y1), extract_embedding(model, y3)
compute_cosine_similarity(mel1, mel3)

In [None]:
criterion(mel1, mel2, mel3)

In [None]:
# Paths to the audio files (these should be the paths to your actual song files)
anchor = '/Users/reggiebain/erdos/song-similarity-erdos-old/data/coversongs/covers32k/A_Whiter_Shade_Of_Pale/annie_lennox+Medusa+03-A_Whiter_Shade_Of_Pale.mp3'
positive = '/Users/reggiebain/erdos/song-similarity-erdos-old/data/coversongs/covers32k/A_Whiter_Shade_Of_Pale/procol_harum+Greatest_Hits+2-A_Whiter_Shade_Of_Pale.mp3'
negative = '/Users/reggiebain/erdos/song-similarity-erdos-old/data/coversongs/covers32k/Abracadabra/steve_miller_band+Steve_Miller_Band_Live_+09-Abracadabra.mp3'

# Extract embeddings for both songs
embedding1 = extract_embedding(model, anchor)
embedding2 = extract_embedding(model, positive)
embedding3 = extract_embedding(model, negative)

In [None]:
# Calculate cosine similarity between the two embeddings
print(f"Cosine Similarity between song and known cover: {compute_cosine_similarity(embedding1, embedding2):.4f}")
print(f"Cosine Similarity between song and random other song: {compute_cosine_similarity(embedding1, embedding3):.4f}")

In [None]:
compute_cosine_similarity(embedding2, embedding3)

In [None]:
embed1 = extract_embedding(model, explore_df.processed_audio[10000][0])

In [None]:
print(f"Cosine Similarity between song and known cover (raw): {compute_cosine_similarity(embedding1, embedding2, use_model=False):.4f}")
print(f"Cosine Similarity between song and random other song (raw): {compute_cosine_similarity(embedding1, embedding3, use_model=False):.4f}")

In [None]:
# Try triplet loss between these 3 songs
print(f"Triplet Loss between songs: {criterion(embedding1, embedding2, embedding3)}")

In [None]:
# try using baseline
baseline_anchor = extract_embedding(model, anchor, use_model=False)
baseline_positive = extract_embedding(model, positive, use_model=False)
baseline_negative = extract_embedding(model, negative, use_model=False)
#print(f"Triplet loss with baseline: {criterion(baseline_anchor, baseline_positive, baseline_negative)}")

In [None]:
baseline_anchor.shape

In [None]:
baseline_positive.shape

In [None]:
F.normalize(baseline_positive, p=2, dim=1).shape