-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathtest_training.py
82 lines (61 loc) · 1.99 KB
/
test_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import itertools
import pandas
import numpy
import pytest
from microesc import train, features, urbansound8k
@pytest.mark.skip("fails")
def test_generator_fake_loader():
dataset_path = 'data/UrbanSound8K/'
urbansound8k.default_path = dataset_path
data = urbansound8k.load_dataset()
folds, test = urbansound8k.folds(data)
data_length = 16
batch_size = 8
frames = 72
bands = 32
n_classes = 10
def zero_loader(s):
#assert
return numpy.zeros((bands, frames, 1))
fold = folds[0][0]
X = fold[0:data_length]
Y = fold.classID[0:data_length]
g = train.dataframe_generator(X, Y, loader=zero_loader,
batchsize=batch_size, n_classes=n_classes)
n_batches = 3
batches = list(itertools.islice(g, n_batches))
assert len(batches) == n_batches
assert len(batches[0]) == 2 # X,y
assert batches[0][0].shape == (batch_size, bands, frames, 1)
assert batches[0][1].shape == (batch_size, n_classes)
def test_windows_shorter_than_window():
frame_samples=256
window_frames=64
fs=16000
length = 0.4*fs
w = list(features.sample_windows(int(length), frame_samples, window_frames))
assert len(w) == 1, len(w)
assert w[-1][1] == length
def test_window_typical():
frame_samples=256
window_frames=64
fs=16000
length = 4.0*fs
w = list(features.sample_windows(int(length), frame_samples, window_frames))
assert len(w) == 8, len(w)
assert w[-1][1] == length
def _test_predict_windowed():
t = test[0:10]
sbcnn16k32_settings = dict(
feature='mels',
samplerate=16000,
n_mels=32,
fmin=0,
fmax=8000,
n_fft=512,
hop_length=256,
augmentations=5,
)
def load_sample32(sample):
return features.load_sample(sample, sbcnn16k32_settings, window_frames=72, feature_dir='../../scratch/aug')
mean_m = features.predict_voted(sbcnn16k32_settings, model, t, loader=load_sample32, method='mean')