# Music Source Separation by Pretrained Models
If you want to separate your own music files, see
- `egs/tutorials/conv-tasnet/separate_music.ipynb` for Conv-TasNet
- `egs/tutorials/mm-dense-lstm/separate_music.ipynb` for MMDenseLSTM
- `egs/tutorials/umx/separate_music.ipynb` for Open-Unmix
- `egs/tutorials/xumx/separate_music.ipynb` for CrossNet-Open-Unmix
- `egs/tutorials/d3net/separate_music.ipynb` for D3Net

In [None]:
%%bash
git clone https://github.com/tky823/DNN-based_source_separation.git

cd "./DNN-based_source_separation/egs/tutorials/"

# To install torch & torchaudio
pip install -r requirements.txt

In [None]:
%%bash
# Download music datset
wget "https://zenodo.org/api/files/1ff52183-071a-4a59-923f-7a31c4762d43/MUSDB18-7-STEMS.zip"
unzip "./MUSDB18-7-STEMS.zip"

# Convert .mp4 to .wav
cd "./train"

for stem in *.stem.mp4 ; do
    name=`echo $stem | awk -F".stem.mp4" '{$0=$1}1'`;
    echo "$stem"
    mkdir "$name"
    cd "$name"
    ffmpeg -loglevel panic -i "../${stem}" -map 0:0 -vn mixture.wav
    ffmpeg -loglevel panic -i "../${stem}" -map 0:1 -vn drums.wav
    ffmpeg -loglevel panic -i "../${stem}" -map 0:2 -vn bass.wav
    ffmpeg -loglevel panic -i "../${stem}" -map 0:3 -vn other.wav
    ffmpeg -loglevel panic -i "../${stem}" -map 0:4 -vn vocals.wav
    cd "../"
done

In [None]:
import sys
sys.path.append("/content/DNN-based_source_separation/src")

In [None]:
import IPython.display as ipd
import torch
import torchaudio

In [None]:
from models.conv_tasnet import ConvTasNet
from models.mm_dense_lstm import MMDenseLSTM, ParallelMMDenseLSTM
from models.umx import OpenUnmix, ParallelOpenUnmix
from models.xumx import CrossNetOpenUnmix
from models.d3net import D3Net, ParallelD3Net

In [None]:
name = "ANiMAL - Rockshow"

In [None]:
waveform_bass, sample_rate = torchaudio.load("/content/train/{}/bass.wav".format(name))
waveform_drums, sample_rate = torchaudio.load("/content/train/{}/drums.wav".format(name))
waveform_other, sample_rate = torchaudio.load("/content/train/{}/other.wav".format(name))
waveform_vocals, sample_rate = torchaudio.load("/content/train/{}/vocals.wav".format(name))

print("bass")
display(ipd.Audio(waveform_bass, rate=sample_rate))
print("drums")
display(ipd.Audio(waveform_drums, rate=sample_rate))
print("other")
display(ipd.Audio(waveform_other, rate=sample_rate))
print("vocals")
display(ipd.Audio(waveform_vocals, rate=sample_rate))

In [None]:
mixture = waveform_bass + waveform_drums + waveform_other + waveform_vocals
display(ipd.Audio(mixture, rate=sample_rate))

In [None]:
model = ConvTasNet.build_from_pretrained(task="musdb18", sample_rate=sample_rate)
model.eval()

input = mixture.unsqueeze(dim=0).unsqueeze(dim=1)
with torch.no_grad():
    mean = input.mean(dim=3, keepdim=True)
    std = input.std(dim=3, keepdim=True)
    input = (input - mean) / std
    output = model(input)
    output = std * output + mean

output = output.squeeze(dim=0)
estimated = torch.unbind(output, dim=0)

for idx, target in enumerate(model.sources):
    print(target)
    display(ipd.Audio(estimated[idx], rate=sample_rate))

In [None]:
model = ParallelMMDenseLSTM.build_from_pretrained(task="musdb18", sample_rate=sample_rate)
wrapper_model = ParallelMMDenseLSTM.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = mixture.unsqueeze(dim=0).unsqueeze(dim=1)
with torch.no_grad():
    output = wrapper_model(input)

output = output.squeeze(dim=0)
estimated = torch.unbind(output, dim=0)

for idx, target in enumerate(wrapper_model.sources):
    print(target)
    display(ipd.Audio(estimated[idx], rate=sample_rate))

In [None]:
model = ParallelOpenUnmix.build_from_pretrained(task="musdb18", sample_rate=sample_rate)
wrapper_model = ParallelOpenUnmix.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = mixture.unsqueeze(dim=0).unsqueeze(dim=1)
with torch.no_grad():
    output = wrapper_model(input)

output = output.squeeze(dim=0)
estimated = torch.unbind(output, dim=0)

for idx, target in enumerate(wrapper_model.sources):
    print(target)
    display(ipd.Audio(estimated[idx], rate=sample_rate))

In [None]:
model = CrossNetOpenUnmix.build_from_pretrained(task="musdb18", sample_rate=sample_rate)
wrapper_model = CrossNetOpenUnmix.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = mixture.unsqueeze(dim=0).unsqueeze(dim=1)
with torch.no_grad():
    output = wrapper_model(input)

output = output.squeeze(dim=0)
estimated = torch.unbind(output, dim=0)

for idx, target in enumerate(wrapper_model.sources):
    print(target)
    display(ipd.Audio(estimated[idx], rate=sample_rate))

In [None]:
model = ParallelD3Net.build_from_pretrained(task="musdb18", sample_rate=sample_rate)
wrapper_model = ParallelD3Net.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = mixture.unsqueeze(dim=0).unsqueeze(dim=1)
with torch.no_grad():
    output = wrapper_model(input)

output = output.squeeze(dim=0)
estimated = torch.unbind(output, dim=0)

for idx, target in enumerate(wrapper_model.sources):
    print(target)
    display(ipd.Audio(estimated[idx], rate=sample_rate))