Skip to content

Commit

Permalink
prepare umxl release (#87)
Browse files Browse the repository at this point in the history
* prepare umxl release

* add urls

* add test for pretrained models

* black

* streamline test

* rephrase news
  • Loading branch information
faroit committed Jul 5, 2021
1 parent b3d2afb commit 01b6f6b
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_unittests.yml
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8]
python-version: [3.7]
pytorch-version: ["1.9.0"]

# Timeout: https://stackoverflow.com/a/59076067/4521646
Expand Down
19 changes: 12 additions & 7 deletions README.md
Expand Up @@ -12,6 +12,7 @@ This repository contains the PyTorch (1.8+) implementation of __Open-Unmix__, a

## ⭐️ News

- 03/07/2021: We added `umxl`, a model that was trained on extra data which significantly improves the performance, especially generalization.
- 14/02/2021: We released the new version of open-unmix as a python package. This comes with: a fully differentiable version of [norbert](https://github.com/sigsep/norbert), improved audio loading pipeline and large number of bug fixes. See [release notes](https://github.com/sigsep/open-unmix-pytorch/releases/) for further info.

- 06/05/2020: We added a pre-trained speech enhancement model `umxse` provided by Sony.
Expand Down Expand Up @@ -82,6 +83,10 @@ docker run -v ~/Music/:/data -it faroit/open-unmix-pytorch umx "/data/track1.wav

We provide three core pre-trained music separation models. All three models are end-to-end models that take waveform inputs and output the separated waveforms.

* __`umxl`__ trained on private stems dataset of compressed stems. __Note, that the weights are only licensed for non-commercial use (CC BY-NC-SA 4.0).__

[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5069601.svg)](https://doi.org/10.5281/zenodo.5069601)

* __`umxhq` (default)__ trained on [MUSDB18-HQ](https://sigsep.github.io/datasets/musdb.html#uncompressed-wav) which comprises the same tracks as in MUSDB18 but un-compressed which yield in a full bandwidth of 22050 Hz.

[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3370489.svg)](https://doi.org/10.5281/zenodo.3370489)
Expand Down Expand Up @@ -185,13 +190,13 @@ Note that

#### Scores (Median of frames, Median of tracks)

|target|SDR |SIR | SAR | ISR | SDR | SIR | SAR | ISR |
|------|-----|-----|-----|-----|-----|-----|-----|-----|
|`model`|UMX |UMX |UMX |UMX |UMXHQ|UMXHQ|UMXHQ|UMXHQ|
|vocals|6.32 |13.33| 6.52|11.93| 6.25|12.95| 6.50|12.70|
|bass |5.23 |10.93| 6.34| 9.23| 5.07|10.35| 6.02| 9.71|
|drums |5.73 |11.12| 6.02|10.51| 6.04|11.65| 5.93|11.17|
|other |4.02 |6.59 | 4.74| 9.31| 4.28| 7.10| 4.62| 8.78|
|target|SDR |SIR | SAR | ISR | SDR | SIR | SAR | ISR | SDR |SIR | SAR | ISR |
|------|-----|-----|-----|-----|-----|-----|-----|-----|-----|-----|-----|-----|
|`model`|UMX |UMX |UMX |UMX |UMXHQ|UMXHQ|UMXHQ|UMXHQ|UMXL |UMXL |UMXL |UMXL |
|vocals|6.32 |13.33| 6.52|11.93| 6.25|12.95| 6.50|12.70|__7.21__ |14.65|7.38 |13.01|
|bass |5.23 |10.93| 6.34| 9.23| 5.07|10.35| 6.02| 9.71|__6.02__ |11.44|6.52 |9.05 |
|drums |5.73 |11.12| 6.02|10.51| 6.04|11.65| 5.93|11.17|__7.15__ |11.24|7.07 |12.26|
|other |4.02 |6.59 | 4.74| 9.31| 4.28| 7.10| 4.62| 8.78|__4.89__ |8.20 |5.88 |10.42|

## Training

Expand Down
3 changes: 3 additions & 0 deletions hubconf.py
Expand Up @@ -14,3 +14,6 @@

from openunmix import umx_spec
from openunmix import umx

from openunmix import umxl_spec
from openunmix import umxl
85 changes: 85 additions & 0 deletions openunmix/__init__.py
Expand Up @@ -259,3 +259,88 @@ def umx(
).to(device)

return separator


def umxl_spec(targets=None, device="cpu", pretrained=True):
from .model import OpenUnmix

# set urls for weights
target_urls = {
"bass": "https://zenodo.org/api/files/f8209c3e-ba60-48cf-8e79-71ae65beca61/bass-2ca1ce51.pth",
"drums": "https://zenodo.org/api/files/f8209c3e-ba60-48cf-8e79-71ae65beca61/drums-69e0ebd4.pth",
"other": "https://zenodo.org/api/files/f8209c3e-ba60-48cf-8e79-71ae65beca61/other-c8c5b3e6.pth",
"vocals": "https://zenodo.org/api/files/f8209c3e-ba60-48cf-8e79-71ae65beca61/vocals-bccbd9aa.pth",
}

if targets is None:
targets = ["vocals", "drums", "bass", "other"]

# determine the maximum bin count for a 16khz bandwidth model
max_bin = utils.bandwidth_to_max_bin(rate=44100.0, n_fft=4096, bandwidth=16000)

target_models = {}
for target in targets:
# load open unmix model
target_unmix = OpenUnmix(
nb_bins=4096 // 2 + 1, nb_channels=2, hidden_size=1024, max_bin=max_bin
)

# enable centering of stft to minimize reconstruction error
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
target_urls[target], map_location=device
)
target_unmix.load_state_dict(state_dict, strict=False)
target_unmix.eval()

target_unmix.to(device)
target_models[target] = target_unmix
return target_models


def umxl(
targets=None,
residual=False,
niter=1,
device="cpu",
pretrained=True,
filterbank="torch",
):
"""
Open Unmix Extra (UMX-L), 2-channel/stereo BLSTM Model trained on a private dataset
of ~400h of multi-track audio.
Args:
targets (str): select the targets for the source to be separated.
a list including: ['vocals', 'drums', 'bass', 'other'].
If you don't pick them all, you probably want to
activate the `residual=True` option.
Defaults to all available targets per model.
pretrained (bool): If True, returns a model pre-trained on MUSDB18-HQ
residual (bool): if True, a "garbage" target is created
niter (int): the number of post-processingiterations, defaults to 0
device (str): selects device to be used for inference
filterbank (str): filterbank implementation method.
Supported are `['torch', 'asteroid']`. `torch` is about 30% faster
compared to `asteroid` on large FFT sizes such as 4096. However,
asteroids stft can be exported to onnx, which makes is practical
for deployment.
"""

from .model import Separator

target_models = umxl_spec(targets=targets, device=device, pretrained=pretrained)
separator = Separator(
target_models=target_models,
niter=niter,
residual=residual,
n_fft=4096,
n_hop=1024,
nb_channels=2,
sample_rate=44100.0,
filterbank=filterbank,
).to(device)

return separator
2 changes: 1 addition & 1 deletion openunmix/cli.py
Expand Up @@ -25,7 +25,7 @@ def separate():
"--model",
default="umxhq",
type=str,
help="path to mode base directory of pretrained models",
help="path to mode base directory of pretrained models, defaults to UMX-HQ",
)

parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

umx_version = "1.1.2"
umx_version = "1.2.0"

with open("README.md", encoding="utf-8") as fh:
long_description = fh.read()
Expand Down
12 changes: 12 additions & 0 deletions tests/test_model.py
Expand Up @@ -2,6 +2,10 @@
import torch

from openunmix import model
from openunmix import umxse
from openunmix import umxhq
from openunmix import umx
from openunmix import umxl


@pytest.fixture(params=[10, 100])
Expand Down Expand Up @@ -50,3 +54,11 @@ def test_shape(spectrogram, nb_bins, nb_channels, unidirectional, hidden_size):
unmix.eval()
Y = unmix(spectrogram)
assert spectrogram.shape == Y.shape


@pytest.mark.parametrize("model_fn", [umx, umxhq, umxse, umxl])
def test_model_loading(model_fn):
X = torch.rand((1, 2, 4096))
model = model_fn(niter=0, pretrained=True)
Y = model(X)
assert Y[:, 0, ...].shape == X.shape

0 comments on commit 01b6f6b

Please sign in to comment.