# Musical Source Separation with Limited Data
## LAMIR 2024 Hackathon

Authored by Richa Namballa

Based on the late-breaking demo:
> Namballa, R., Morais, G., \& Fuentes, M. (2024). "Musical Source Separation of Brazilian Percussion." In _Extended Abstracts for the Late-Breaking Demo Session of the 25th International Society for Music Information Retrieval Conference_.

**Musical source separation (MSS)** is a central task of music information retrieval (MIR) which aims to “de-mix” audio into its corresponding instrument stems. It has applications in both the research and production of music by allowing the analysis and reuse of the stems.

For a more detailed introduction on the task of the source separation itself, please refer to the [**Open Source Tools & Data for Music Source Separation**](https://source-separation.github.io/tutorial/landing.html) tutorial writted by Ethan Manilow, Prem Seetharaman, and Justin Salamon.

Some source separation models, such as [**Demucs**](https://github.com/adefossez/demucs), have reached a state-of-the-art level in their ability to celebrate musical mixtures into four stems: _drums_, _bass_, _vocals_ and _other_. However, most source separation systems are trained to process Western instruments only, precluding their application to more culturally-diverse music.

### Datasets

There are many source separation datasets available to use for modeling training such as, [**Slakh2100**](http://www.slakh.com/). One of the most popular 4-stem MSS datasets is [**MUSDB18**](https://sigsep.github.io/datasets/musdb.html#musdb18-compressed-stems), which contains 150 full length audio tracks. Even those datasets which advertise a larger variety of stems, such as [**MoisesDB**](https://github.com/moises-ai/moises-db), are focused on Eurocentric instruments. Creating new MSS datasets is challenging due to the time and monetary cost required to record and mix high-quality stems, thus the lack of diversity in instrumentation is expected. Prior to investing significant resources into constructing new datasets, we investigate the feasibility of building an MSS system by artificially creating mixtures featuring an existing non-Western dataset.

We choose to use the [**Brazilian Rhythmic Instruments Dataset**](https://zenodo.org/records/14051323), a dataset typically used in the context of beat tracking. For this demo, we elected to set the _surdo_ as our target source to separate from the mixture. The surdo is a large tom-like drum which plays a distinctive pattern repeated throughout the piece. This trait, plus its distinctive low-pitched timbre, makes it an easier target compared to the other percussion instruments.

#### Libraries

In [None]:
import os
import numpy as np
from tqdm import tqdm
from datetime import datetime
import pickle
import random

import torch
from torch.utils.data import DataLoader

from spectrogram import generate_spectrograms
from unet import UNet
from dataset import SeparationDataset
from utils import plot_loss, Spec2Audio
from separate import separate

from IPython.display import Audio

### SynBRID Dataset

For this task, we artificially generated our own mixtures by combining BRID solo tracks to create `syn_brid`. In total, we generated 100 mixtures for training, 10 for validation, and 30 for testing. For each song, we provide the mix (`mixture.wav`) and the surdo stem (`surdo.wav`).

In [None]:
import tarfile

# extract .wav files from syn_brid.tar.gz
# open file
if not os.path.isdir('syn_brid/'):
    with tarfile.open('syn_brid.tar.gz') as f:
        # extract compressed files
        f.extractall('./')

### Data Preprocessing

The source separation model is built using the magnitude spectrograms of the mixture and stem. For computational efficiency, we use a low sample rate.

In [None]:
SAMPLE_RATE = 8192
FFT_SIZE = 1024
HOP_SIZE = 768
PATCH_SIZE = 128

TARGET_SOURCE = "surdo"

In [None]:
# generate the spectrograms for each fold of the dataset
print("\n>>> TRAINING DATA <<<")
generate_spectrograms('./syn_brid/train', './spec/train', TARGET_SOURCE, SAMPLE_RATE, FFT_SIZE, HOP_SIZE)
print("\n>>> VALIDATION DATA <<<")
generate_spectrograms('./syn_brid/val', './spec/val', TARGET_SOURCE, SAMPLE_RATE, FFT_SIZE, HOP_SIZE)
print("\n>>> TESTING DATA <<<")
generate_spectrograms('./syn_brid/test', './spec/test', TARGET_SOURCE, SAMPLE_RATE, FFT_SIZE, HOP_SIZE)

### Test the Base Model

We pretrained the source separation model on the _bass_ stem from the MUSDB dataset with a learning rate of `1e-4` and 1000 epochs. Let's listen to what it sounds like if we use the bass (base) model directly on our SynBRID mixtures to try and separate the surdo.

In [None]:
DEVICE_TYPE = "cuda"
IN_CHANNELS = 1

In [None]:
DEVICE = torch.device(DEVICE_TYPE)

In [None]:
print("Loading base model...")
# initialize model
base_model = UNet(IN_CHANNELS)
# load weights
bass_weights = torch.load('best_weights_bass.pth', weights_only=True)
base_model.load_state_dict(bass_weights)
print("Base model loaded succesfully!")

In [None]:
# set a seed to choose a random mixture from the test set
test_seed = 14
random.seed(test_seed)
files = [f for f in os.listdir('./syn_brid/test') if not f.startswith('.')]
test_mixture = os.path.join('syn_brid', 'test', random.choice(files), 'mixture.wav')

In [None]:
stem = separate(test_mixture, base_model, DEVICE, FFT_SIZE, HOP_SIZE, SAMPLE_RATE, PATCH_SIZE)

In [None]:
Audio(stem, rate=SAMPLE_RATE)

Hmmm not too great... can we do better?

### Fine-tuning the Base Model

_Fine-tuning_ is a _transfer learning_ method where "tune" (or adjust) the weights of the pretrained model to work on new data. When you fine-tune, you have the option of continuing to train all of the parameters on new data or "freeze" the earlier layers so that the initial feature extraction remains the same.

We have provided some basic code that you use as a starting point to improve the surdo separation model. In the dataloader, we have provided an argument `pct_files` which represents the percentage of each data subset to use. For example, if you set `pct_files=0.5` in the `train_dataset`, you will use 50 spectrograms from the training set to fine-tune the model.

**Task**: Your task is to see how much you can improve the surdo separation model with the smallest amount of data (i.e., the lowest value of `pct_files` for `train_dataset`). Feel free to experiment and make modificiations to the training code and other scripts.

Some suggestions to improve your model:
* Test different values of the hyperparameters (the learning rate, number of epochs, etc.)
* Add more fine-tuning layers.
* Try freezing the earlier layers of the model to only fine-tune the parameters of the later layers.
* Use different `random_state` seeds.
* Be creative!



In [None]:
BATCH_SIZE = 4
NUM_WORKERS = 2
LEARNING_RATE = 1e-6
NUM_EPOCHS = 500
PERCENT_TRAIN = 1.
PERCENT_VAL = 1.

#### Set-Up Data Loaders

In [None]:
RANDOM_STATE = 42

In [None]:
train_dataset = SeparationDataset('./spec/train/', TARGET_SOURCE,
                                  pct_files=PERCENT_TRAIN,
                                  patch_size=PATCH_SIZE, random_state=RANDOM_STATE)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                              num_workers=NUM_WORKERS, shuffle=True)

val_dataset = SeparationDataset('./spec/val/', TARGET_SOURCE,
                                pct_files=PERCENT_VAL,
                                patch_size=PATCH_SIZE)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
                            num_workers=NUM_WORKERS, shuffle=True)

### Training

Run this code to fine-tune the model. Make sure that you are connected to a GPU runtime to increase the speed of your training.

In [None]:
# get timestamp for saving checkpoints and history
t_stamp = datetime.now().strftime("%y%m%d_%I%M%S%p")

# create checkpoint directory
os.makedirs('./checkpoint', exist_ok=True)

model_name = f"{t_stamp}_FT_{TARGET_SOURCE}.pth"
history_name = f"{t_stamp}_FT_history.pkl"

In [None]:
# load the bass base model weights to have a starting point for training
print("Loading base model for fine-tuning...")
ft_model = UNet(IN_CHANNELS)
ft_model.load_state_dict(torch.load('best_weights_bass.pth', weights_only=True))
print("Base model ready for fine-tuning!")

In [None]:
# fine-tune all model parameters
for param in ft_model.parameters():
    param.requires_grad = True

In [None]:
# send model to GPU if available
ft_model = ft_model.to(DEVICE)

# save best model
best_model = None
best_val_loss = 1e6

# initialize training
# mean absolute error loss
criterion = torch.nn.L1Loss()
optimizer = torch.optim.Adam(ft_model.parameters(), lr=LEARNING_RATE)

# enable mixed precision
scaler = torch.GradScaler(DEVICE_TYPE)

# save history of metrics
history = {'train_loss': [], 'val_loss': []}

In [None]:
# training loop
for epoch in range(NUM_EPOCHS):
    # loss values within epoch
    train_loss_epoch, val_loss_epoch = [], []

    # TRAINING
    # enable training
    ft_model.train()

    # progress bar
    pbar = tqdm(train_dataloader)
    pbar.set_description("Training")

    for idx, (mix, stem, phase) in enumerate(pbar):
        # send data to device
        mix = mix.to(DEVICE)
        stem = stem.to(DEVICE)

        # autocast data type
        with torch.autocast(device_type=DEVICE_TYPE, dtype=torch.float32):
            output = ft_model(mix)
            loss = criterion(output, stem)

        train_loss_epoch.append(loss.item())

        pbar.set_postfix({"Loss": loss.item()}, refresh=True)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    # compute avg training loss for this epoch
    history['train_loss'].append(np.mean(train_loss_epoch))

    # VALIDATION
    ft_model.eval()

    with torch.no_grad():

        for mix, stem, phase in val_dataloader:
            # send data to device
            mix = mix.to(DEVICE)
            stem = stem.to(DEVICE)

            # forward
            preds = ft_model(mix)
            loss = criterion(preds, stem)

            val_loss_epoch.append(loss.item())

    # compute avg validation loss for this epoch
    history['val_loss'].append(np.mean(val_loss_epoch))

    # log summary for epoch
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}: " +
             f"Training Loss: {history['train_loss'][-1]:.6f}, " +
             f"Validation Loss: {history['val_loss'][-1]:.6f}\n")

    # check if the model improved on the validation dataset
    if history['val_loss'][-1] < best_val_loss:
        best_model = ft_model
        torch.save(best_model.state_dict(),
                   os.path.join('checkpoint/', model_name))
        best_val_loss = history['val_loss'][-1]

print("Training completed.")

In [None]:
# save training history
with open(os.path.join('./checkpoint/', history_name), 'wb') as f:
    pickle.dump(history, f)

In [None]:
# plot loss curve
plot_loss(history, save_path=f'./checkpoint/{t_stamp}_loss.png')

### Test the Fine-tuned Model

Load the fine-tuned model and run it on the same test mixture from before.

We also recommend that you run the model on the entire test set and compute traditional MSS performance metrics, such as [Source-to-Distortion Ratio (SDR)](https://lightning.ai/docs/torchmetrics/stable/audio/signal_distortion_ratio.html). Make sure that you compute the SDR for the base model on the surdo test set as well. A higher SDR is interpreted as better separation.

In [None]:
ckpt_path = f'./checkpoint/{model_name}'

print("Loading fine-tuned model...")
# initialize model
ft_model = UNet(IN_CHANNELS)
# load weights
surdo_weights = torch.load(ckpt_path, weights_only=True)
base_model.load_state_dict(surdo_weights)
print("Fine-tuned model loaded succesfully!")

In [None]:
surdo_stem = separate(test_mixture, ft_model, DEVICE, FFT_SIZE, HOP_SIZE, SAMPLE_RATE, PATCH_SIZE)

In [None]:
Audio(surdo_stem, rate=SAMPLE_RATE)

In [None]:
# !pip install torchmetrics
# from torchmetrics.audio import SignalDistortionRatio
# ...