# Open-Unmix Separation

Note: this notebook must be run on a CUDA-enabled device.

In [None]:
import os
import time
import torch
from tqdm import tqdm
import openunmix
import numpy as np
import soundfile as sf
from IPython.display import display, Audio

### Separation Function

In [None]:
def separate(in_file, model, gpu=True):
    """
    Run Open-Unmix source separation model on WAV files to isolate stems.

    :param in_file: (str) path of mixture WAV file to separate
    :param model: (Separator) open unmix model object
    :param gpu: (bool) whether a gpu is available for use
    """
    # only process wav files
    if in_file.endswith(".wav"):
        # read the soundfile
        y, sr = sf.read(in_file)

        num_samples, num_channels = y.shape
      
        if gpu:
            # send model to gpu
            model.to('cuda')
            
            # convert to 3 dimensional tensor (1, num_channels, num_samples)
            x = torch.from_numpy(y.T.reshape(1, num_channels, num_samples).astype(np.float32)).to('cuda')
            with torch.no_grad():
                out = model(x)

            out_stems = []
            for i in range(len(STEMS)):
                est_stem = out[0][i].cpu().detach().numpy()  # convert to numpy array
                est_stem[np.isnan(est_stem)] = 0  # convert nan to 0
                out_stems.append(est_stem)
        else:
            print("Please enable CUDA for inference.")

    else:
        out_stems = None
        print("Invalid input file type. Please use WAV files only.")

    return out_stems

### Separate

In [None]:
STEMS = ["vocals", "drums", "bass", "other"]
SAMPLE_RATE = 44100

In [None]:
# load the model
MODEL = openunmix.umxhq() # torch.hub.load('sigsep/open-unmix-pytorch', 'umxhq')

In [None]:
assert MODEL.sample_rate == SAMPLE_RATE

#### Stereo Mixtures

In [None]:
# set input and output directories
INPUT_DIR = "../data/musdb18hq/test/"
OUTPUT_DIR = "../data/output/umxhq/stereo/test/"

In [None]:
# get all of the files in the input directory
print("Loading list of files...")
file_list = [f for f in os.listdir(INPUT_DIR) if os.path.isdir(os.path.join(INPUT_DIR, f))]
print(f"There are {len(file_list)} files in the input directory.")

In [None]:
# create the output directory if it does not already exist
print("Creating output directory, if it does not already exist...")
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
# iterate through each file
print("Beginning to process files...")
for file in tqdm(file_list):
    out_dir = os.path.join(OUTPUT_DIR, file)
    os.makedirs(out_dir, exist_ok=True)
    out = separate(os.path.join(INPUT_DIR, file, 'mixture.wav'), MODEL)
    for i in range(len(STEMS)):
        out_path = os.path.join(out_dir, STEMS[i] + '.wav')
        sf.write(out_path, out[i].T, SAMPLE_RATE)  # transpose so it is (num_samples, num_channels)
print("Processing complete!")

#### Binaural Mixtures

In [None]:
# set input and output directories
INPUT_DIR = "../data/binaural_musdb18/test/"
OUTPUT_DIR = "../data/output/umxhq/binaural/test/"

In [None]:
# get all of the files in the input directory
print("Loading list of files...")
file_list = [f for f in os.listdir(INPUT_DIR) if os.path.isdir(os.path.join(INPUT_DIR, f))]
print(f"There are {len(file_list)} files in the input directory.")

In [None]:
# create the output directory if it does not already exist
print("Creating output directory, if it does not already exist...")
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
# iterate through each file
print("Beginning to process files...")
for file in tqdm(file_list):
    out_dir = os.path.join(OUTPUT_DIR, file)
    os.makedirs(out_dir, exist_ok=True)
    out = separate(os.path.join(INPUT_DIR, file, 'mixture.wav'), MODEL)
    for i in range(len(STEMS)):
        out_path = os.path.join(out_dir, STEMS[i] + '.wav')
        sf.write(out_path, out[i].T, SAMPLE_RATE)  # transpose so it is (num_samples, num_channels)
print("Processing complete!")