# Demucs Separation

In [None]:
import os
import time
import torch
import numpy as np
import soundfile as sf
from demucs import pretrained
from demucs.apply import apply_model
from IPython.display import display, Audio

### Separation Function

In [None]:
def separate(in_file, model, gpu=None):
    """
    Run Demucs source separation model on WAV files to isolate stems.

    :param in_file: (str) path of mixture WAV file to separate
    :param model: (demucs.BagOfModels) demucs model object
    :param gpu: (torch.device) if a gpu is available for use, pass in the device
    """
    # only process wav files
    if in_file.endswith(".wav"):
        # read the soundfile
        y, sr = sf.read(in_file)

        # check if audio is in mono
        if len(y.shape) == 1:
            # if the audio is in mono,
            # duplicate channels to create a stereo track
            # demucs network expects two channels of audio
            y = np.vstack([y, y])

        # get dimensions
        num_samples, num_channels = y.shape

        # convert to 3 dimensional tensor (1, num_channels, num_samples)
        x = torch.from_numpy(y.T.reshape(1, num_channels, num_samples).astype(np.float32))

        # output is [1, S, C, T] where S is the number of sources
        if gpu:
            # use current gpu device
            out = apply_model(model, x, progress=True, device=gpu)
        else:
            # use cpu
            out = apply_model(model, x, progress=True)

        # vocals are the 4th source
        # drums.wav, bass.wav, other.wav, vocals.wav
        out_stems = []
        for i in range(4):
            est_stem = out[0][i]
            out_stems.append(np.array(est_stem).T)

    else:
        out_stems = None
        print("Invalid input file type. Please use WAV files only")

    return out_stems

### Separation Standard Binaural Mixtures

In [None]:
# set input and output directories
INPUT_DIR = "../data/binaural_musdb18/standard/test/"
OUTPUT_DIR = "../data/output/htdemucs_ft/standard/test/"

In [None]:
STEMS = ["drums", "bass", "other", "vocals"]
SAMPLE_RATE = 44100

In [None]:
# mps check
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
else:
    print("MPS device not found.")

In [None]:
# load the model
print("Loading pretrained model...")
MODEL = pretrained.get_model('htdemucs_ft')
print("Model loaded successfully.")

In [None]:
# get all of the files in the input directory
print("Loading list of files...")
file_list = [f for f in os.listdir(INPUT_DIR) if os.path.isdir(os.path.join(INPUT_DIR, f))]
print(f"There are {len(file_list)} files in the input directory.")

In [None]:
# create the output directory if it does not already exist
print("Creating output directory, if it does not already exist...")
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
# iterate through each file
print("Beginning to process files...")
for file in file_list:
    print(f"\n{file}")
    out_dir = os.path.join(OUTPUT_DIR, file)
    os.makedirs(out_dir, exist_ok=True)
    out = separate(os.path.join(INPUT_DIR, file, 'mixture.wav'), MODEL, mps_device)
    for i in range(len(STEMS)):
        out_path = os.path.join(out_dir, STEMS[i] + '.wav')
        sf.write(out_path, out[i], SAMPLE_RATE)
print("Processing complete!")

### Separation Random Binaural Mixtures

In [None]:
# set input and output directories
INPUT_DIR = "../data/binaural_musdb18/random/test/"
OUTPUT_DIR = "../data/output/htdemucs_ft/random/test/"

In [None]:
STEMS = ["drums", "bass", "other", "vocals"]
SAMPLE_RATE = 44100

In [None]:
# mps check
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
else:
    print("MPS device not found.")

In [None]:
# load the model
print("Loading pretrained model...")
MODEL = pretrained.get_model('htdemucs_ft')
print("Model loaded successfully.")

In [None]:
# get all of the files in the input directory
print("Loading list of files...")
file_list = [f for f in os.listdir(INPUT_DIR) if os.path.isdir(os.path.join(INPUT_DIR, f))]
print(f"There are {len(file_list)} files in the input directory.")

In [None]:
# create the output directory if it does not already exist
print("Creating output directory, if it does not already exist...")
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
# iterate through each file
print("Beginning to process files...")
start_time = time.time()
for file in file_list:
    print(f"\n{file}")
    out_dir = os.path.join(OUTPUT_DIR, file)
    os.makedirs(out_dir, exist_ok=True)
    out = separate(os.path.join(INPUT_DIR, file, 'mixture.wav'), MODEL, mps_device)
    for i in range(len(STEMS)):
        out_path = os.path.join(out_dir, STEMS[i] + '.wav')
        sf.write(out_path, out[i], SAMPLE_RATE)
end_time = time.time()
print("Processing complete!")
print(f"Time: {end_time - start_time} seconds.")