diff --git a/.github/workflows/test_unittests.yml b/.github/workflows/test_unittests.yml index efcbae05..3542d57a 100644 --- a/.github/workflows/test_unittests.yml +++ b/.github/workflows/test_unittests.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7, 3.8] + python-version: [3.7] pytorch-version: ["1.9.0"] # Timeout: https://stackoverflow.com/a/59076067/4521646 diff --git a/README.md b/README.md index c72dfac0..7e18b277 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ This repository contains the PyTorch (1.8+) implementation of __Open-Unmix__, a ## ⭐️ News +- 03/07/2021: We added `umxl`, a model that was trained on extra data which significantly improves the performance, especially generalization. - 14/02/2021: We released the new version of open-unmix as a python package. This comes with: a fully differentiable version of [norbert](https://github.com/sigsep/norbert), improved audio loading pipeline and large number of bug fixes. See [release notes](https://github.com/sigsep/open-unmix-pytorch/releases/) for further info. - 06/05/2020: We added a pre-trained speech enhancement model `umxse` provided by Sony. @@ -82,6 +83,10 @@ docker run -v ~/Music/:/data -it faroit/open-unmix-pytorch umx "/data/track1.wav We provide three core pre-trained music separation models. All three models are end-to-end models that take waveform inputs and output the separated waveforms. +* __`umxl`__ trained on private stems dataset of compressed stems. __Note, that the weights are only licensed for non-commercial use (CC BY-NC-SA 4.0).__ + + [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5069601.svg)](https://doi.org/10.5281/zenodo.5069601) + * __`umxhq` (default)__ trained on [MUSDB18-HQ](https://sigsep.github.io/datasets/musdb.html#uncompressed-wav) which comprises the same tracks as in MUSDB18 but un-compressed which yield in a full bandwidth of 22050 Hz. [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3370489.svg)](https://doi.org/10.5281/zenodo.3370489) @@ -185,13 +190,13 @@ Note that #### Scores (Median of frames, Median of tracks) -|target|SDR |SIR | SAR | ISR | SDR | SIR | SAR | ISR | -|------|-----|-----|-----|-----|-----|-----|-----|-----| -|`model`|UMX |UMX |UMX |UMX |UMXHQ|UMXHQ|UMXHQ|UMXHQ| -|vocals|6.32 |13.33| 6.52|11.93| 6.25|12.95| 6.50|12.70| -|bass |5.23 |10.93| 6.34| 9.23| 5.07|10.35| 6.02| 9.71| -|drums |5.73 |11.12| 6.02|10.51| 6.04|11.65| 5.93|11.17| -|other |4.02 |6.59 | 4.74| 9.31| 4.28| 7.10| 4.62| 8.78| +|target|SDR |SIR | SAR | ISR | SDR | SIR | SAR | ISR | SDR |SIR | SAR | ISR | +|------|-----|-----|-----|-----|-----|-----|-----|-----|-----|-----|-----|-----| +|`model`|UMX |UMX |UMX |UMX |UMXHQ|UMXHQ|UMXHQ|UMXHQ|UMXL |UMXL |UMXL |UMXL | +|vocals|6.32 |13.33| 6.52|11.93| 6.25|12.95| 6.50|12.70|__7.21__ |14.65|7.38 |13.01| +|bass |5.23 |10.93| 6.34| 9.23| 5.07|10.35| 6.02| 9.71|__6.02__ |11.44|6.52 |9.05 | +|drums |5.73 |11.12| 6.02|10.51| 6.04|11.65| 5.93|11.17|__7.15__ |11.24|7.07 |12.26| +|other |4.02 |6.59 | 4.74| 9.31| 4.28| 7.10| 4.62| 8.78|__4.89__ |8.20 |5.88 |10.42| ## Training diff --git a/hubconf.py b/hubconf.py index 64ffc552..669017fd 100644 --- a/hubconf.py +++ b/hubconf.py @@ -14,3 +14,6 @@ from openunmix import umx_spec from openunmix import umx + +from openunmix import umxl_spec +from openunmix import umxl diff --git a/openunmix/__init__.py b/openunmix/__init__.py index ddd6be99..dc3fbb8a 100644 --- a/openunmix/__init__.py +++ b/openunmix/__init__.py @@ -259,3 +259,88 @@ def umx( ).to(device) return separator + + +def umxl_spec(targets=None, device="cpu", pretrained=True): + from .model import OpenUnmix + + # set urls for weights + target_urls = { + "bass": "https://zenodo.org/api/files/f8209c3e-ba60-48cf-8e79-71ae65beca61/bass-2ca1ce51.pth", + "drums": "https://zenodo.org/api/files/f8209c3e-ba60-48cf-8e79-71ae65beca61/drums-69e0ebd4.pth", + "other": "https://zenodo.org/api/files/f8209c3e-ba60-48cf-8e79-71ae65beca61/other-c8c5b3e6.pth", + "vocals": "https://zenodo.org/api/files/f8209c3e-ba60-48cf-8e79-71ae65beca61/vocals-bccbd9aa.pth", + } + + if targets is None: + targets = ["vocals", "drums", "bass", "other"] + + # determine the maximum bin count for a 16khz bandwidth model + max_bin = utils.bandwidth_to_max_bin(rate=44100.0, n_fft=4096, bandwidth=16000) + + target_models = {} + for target in targets: + # load open unmix model + target_unmix = OpenUnmix( + nb_bins=4096 // 2 + 1, nb_channels=2, hidden_size=1024, max_bin=max_bin + ) + + # enable centering of stft to minimize reconstruction error + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + target_urls[target], map_location=device + ) + target_unmix.load_state_dict(state_dict, strict=False) + target_unmix.eval() + + target_unmix.to(device) + target_models[target] = target_unmix + return target_models + + +def umxl( + targets=None, + residual=False, + niter=1, + device="cpu", + pretrained=True, + filterbank="torch", +): + """ + Open Unmix Extra (UMX-L), 2-channel/stereo BLSTM Model trained on a private dataset + of ~400h of multi-track audio. + + + Args: + targets (str): select the targets for the source to be separated. + a list including: ['vocals', 'drums', 'bass', 'other']. + If you don't pick them all, you probably want to + activate the `residual=True` option. + Defaults to all available targets per model. + pretrained (bool): If True, returns a model pre-trained on MUSDB18-HQ + residual (bool): if True, a "garbage" target is created + niter (int): the number of post-processingiterations, defaults to 0 + device (str): selects device to be used for inference + filterbank (str): filterbank implementation method. + Supported are `['torch', 'asteroid']`. `torch` is about 30% faster + compared to `asteroid` on large FFT sizes such as 4096. However, + asteroids stft can be exported to onnx, which makes is practical + for deployment. + + """ + + from .model import Separator + + target_models = umxl_spec(targets=targets, device=device, pretrained=pretrained) + separator = Separator( + target_models=target_models, + niter=niter, + residual=residual, + n_fft=4096, + n_hop=1024, + nb_channels=2, + sample_rate=44100.0, + filterbank=filterbank, + ).to(device) + + return separator diff --git a/openunmix/cli.py b/openunmix/cli.py index 08d10b5c..da8822ca 100644 --- a/openunmix/cli.py +++ b/openunmix/cli.py @@ -25,7 +25,7 @@ def separate(): "--model", default="umxhq", type=str, - help="path to mode base directory of pretrained models", + help="path to mode base directory of pretrained models, defaults to UMX-HQ", ) parser.add_argument( diff --git a/setup.py b/setup.py index 5a8a5ac8..e05fbe52 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, find_packages -umx_version = "1.1.2" +umx_version = "1.2.0" with open("README.md", encoding="utf-8") as fh: long_description = fh.read() diff --git a/tests/test_model.py b/tests/test_model.py index ed56f991..c7f931e2 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -2,6 +2,10 @@ import torch from openunmix import model +from openunmix import umxse +from openunmix import umxhq +from openunmix import umx +from openunmix import umxl @pytest.fixture(params=[10, 100]) @@ -50,3 +54,11 @@ def test_shape(spectrogram, nb_bins, nb_channels, unidirectional, hidden_size): unmix.eval() Y = unmix(spectrogram) assert spectrogram.shape == Y.shape + + +@pytest.mark.parametrize("model_fn", [umx, umxhq, umxse, umxl]) +def test_model_loading(model_fn): + X = torch.rand((1, 2, 4096)) + model = model_fn(niter=0, pretrained=True) + Y = model(X) + assert Y[:, 0, ...].shape == X.shape