-
Notifications
You must be signed in to change notification settings - Fork 174
/
test_model.py
64 lines (45 loc) · 1.42 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import pytest
import torch
from openunmix import model
from openunmix import umxse
from openunmix import umxhq
from openunmix import umx
from openunmix import umxl
@pytest.fixture(params=[10, 100])
def nb_frames(request):
return int(request.param)
@pytest.fixture(params=[1, 2, 3])
def nb_channels(request):
return request.param
@pytest.fixture(params=[1, 5])
def nb_samples(request):
return request.param
@pytest.fixture(params=[111, 1024])
def nb_bins(request):
return request.param
@pytest.fixture
def spectrogram(request, nb_samples, nb_channels, nb_bins, nb_frames):
return torch.rand((nb_samples, nb_channels, nb_bins, nb_frames))
@pytest.fixture(params=[True, False])
def unidirectional(request):
return request.param
@pytest.fixture(params=[32])
def hidden_size(request):
return request.param
def test_shape(spectrogram, nb_bins, nb_channels, unidirectional, hidden_size):
unmix = model.OpenUnmix(
nb_bins=nb_bins,
nb_channels=nb_channels,
unidirectional=unidirectional,
nb_layers=1, # speed up training
hidden_size=hidden_size,
)
unmix.eval()
Y = unmix(spectrogram)
assert spectrogram.shape == Y.shape
@pytest.mark.parametrize("model_fn", [umx, umxhq, umxse, umxl])
def test_model_loading(model_fn):
X = torch.rand((1, 2, 4096))
model = model_fn(niter=0, pretrained=True)
Y = model(X)
assert Y[:, 0, ...].shape == X.shape