-
Notifications
You must be signed in to change notification settings - Fork 10
/
mit.py
288 lines (235 loc) · 9.29 KB
/
mit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
from collections import namedtuple
import re
import glob
import os.path
import numpy as np
import scipy.io.wavfile as wavfile
import scipy.signal as signal
import math
import paths
from collections import namedtuple
from minimum_phase import minimum_phase
SR = 44100
NYQUIST = SR // 2
def to_coords(f):
try:
bn = os.path.basename(f)
x = re.match("L(.*)e(.*)a.wav", bn)
az, elev = int(x[2]), int(x[1])
wav = wavfile.read(f)
assert wav[0] == 44100
data = wav[1]
data = data / 32768.0
return az, elev, data
except:
print(f)
raise
def load_dataset():
"""Returns the dataset as a dict elev -> azimuth -> f32 array."""
d = {}
files = glob.glob(os.path.join(paths.data_path, "elev*", "L*.wav"), recursive=True)
for f in files:
az, elev, data = to_coords(f)
x = d.get(elev, {})
x[az] = data
d[elev] = x
assert len(d) > 2, "Must have at least 2 elevations"
print(f"Have {len(d)} elevations")
return d
ElevationMetadata = namedtuple(
"ElevationMetadata",
["elev_angles", "degs_per_elev", "elev_min", "elev_max", "num_elevs"],
)
def extract_elevation_metadata(dataset):
elev_angles = sorted(dataset.keys())
degs_per_elev = elev_angles[1] - elev_angles[0]
print("Checking that elevations are equidistant")
for i in range(len(elev_angles) - 1):
assert (
elev_angles[i + 1] - elev_angles[i] == degs_per_elev
), f"Elevation must be equal to {degs_per_elev}"
elev_min = elev_angles[0]
elev_max = elev_angles[-1]
return ElevationMetadata(
elev_angles=elev_angles,
degs_per_elev=degs_per_elev,
elev_min=elev_min,
elev_max=elev_max,
num_elevs=len(dataset),
)
def unfold_azimuths(dataset):
"""Returns a [[azimuths], [azimuths]...] list, where the outer list is per election.
This is useful for processing, which doesn't have to worry about iterating the dict right."""
print("Unfolding azimuths")
azimuths = []
for a, e in sorted(dataset.items()):
azs = sorted(e.keys())
azs = [e[i] for i in azs]
azimuths.append(azs)
return azimuths
def map_azimuths(azimuths, fn):
out = []
for e in azimuths:
az = []
for a in e:
az.append(fn(a))
out.append(az)
return out
def magnitude_response(array):
# Make sure to really overdo this; we want more than 256 bins.
# Match the samplerate of the dataset, and each bin is then 1 hz, which is
# convenient.
return np.abs(np.fft.fft(array, n=SR))
def convert_to_magnitude_responses(azimuths):
print("Building magnitude responses")
return map_azimuths(azimuths, magnitude_response)
def equalize_power(azimuths):
print("Equalizing power response")
presp = np.zeros(len(azimuths[0][0]), dtype=np.float64)
indices = [(i, j) for i in range(len(azimuths)) for j in range(len(azimuths[i]))]
new_az = [[None] * len(i) for i in azimuths]
c = sum([len(i) for i in azimuths])
for i, j in indices:
presp += azimuths[i][j] ** 2
average_power = presp / c
# for frequencies abovethis value, set the average to 1; this allows
# the dataset to emphasize ahead/behind
stop_at = 5000
# We apply this twice, once before averaging and once after. This is because it makes the averaging step easy,
# since any value set to 1.0 before going to log-magnitude comes out to 0.
average_power[stop_at : len(average_power) - stop_at] = 1.0
# Now, we want to limit this filter so that it's not doing insane things. Following logic was roughly borroed
# from the matlab scripts distributed with the MIT kemar dataset.
avg_power_log = 20 * np.log10(np.abs(average_power))
db_range = 20
# Recall that everything after stop_at is 0.0 db. We don't want to count it.
offset = sum(avg_power_log) / stop_at
avg_power_log -= offset
# avg_power_log = np.minimum(np.maximum(avg_power_log, -db_range), db_range)
filter = np.sqrt(10 ** (-avg_power_log / 20))
for i, j in indices:
new_az[i][j] = azimuths[i][j] * filter
new_az[i][j][0] = 1.0
return new_az
def clamp_responses(azimuths):
print("Clamping responses to be between -60 db and 3 db")
min_gain = 0.0 # 10**(-40/20)
max_gain = 10 ** (6 / 20)
print(f"Min {min_gain} max {max_gain}")
def normalize(a):
return np.maximum(np.minimum(a, max_gain), min_gain)
return map_azimuths(azimuths, normalize)
def conv_minimum_phase(azimuths):
print("Converting to minimum phase")
def minphase(a):
a = np.array(a)
return minimum_phase(a)
return map_azimuths(azimuths, minphase)
def truncate_hrir(azimuths, hrir_length):
print(f"Windowing to {hrir_length} points")
# We use blackman-harris because the WDL likes it for its resampler, so proceeding under the assumption that it's good enough for us too.
blackman = signal.blackmanharris(hrir_length * 2 - 1)[-hrir_length:]
assert len(blackman) == hrir_length
assert blackman[0] == 1.0
return map_azimuths(azimuths, lambda a: a[:hrir_length] * blackman)
def compute_dc(azimuths):
dcs = map_azimuths(azimuths, lambda a: np.sum(a))
dcs_flat = [j for i in dcs for j in i]
return max(dcs_flat)
base_butter = signal.butter(N=2, btype="highpass", fs=SR, Wn=100)
def run_base_filter(a):
return signal.lfilter(b=base_butter[0], a=base_butter[1], x=a)
def remove_base(azimuths):
return map_azimuths(azimuths, run_base_filter)
def emphasize_behind(azimuths):
"""Adds a lowpass filter which emphasises when sounds are behind the listener. The values here were determined through listening to it,
and have no real logic beyond that"""
lp_freq = 2000 # frequency for the lowpass stopband's start.
lp_stopfreq = NYQUIST # The frequency that the filter's stopband fully starts at.
lp_stopband_db = (
-1.5
) # maximum rolloff for the stopband, when the source is straight behind.
print("Emphasizing sources behind the listener with a lowpass filter")
az_out = [[None] * len(i) for i in azimuths]
for i, elev in enumerate(azimuths):
for j, entry in enumerate(elev):
# find percent around the circle
percent = j / len(elev)
# back_relative is 0 for straight behind, 0.5 for the side, 1 for straight in front.
# It doesn't encode which side.
back_relative = 2 * abs(0.5 - percent)
# Leave a little bit of slack in this if statement for floating point error.
if back_relative >= 0.49:
# It's in front; don't do anything.
az_out[i][j] = np.array(entry)
continue
# Convert to a scale factor for the stopband: 1.0 for straight to the side, 0.0 for behind.
# The exponent here is to make it less abrupt at the side transition points
scale_factor = 1.0 - back_relative / 0.5
# scale_factor = scale_factor**2
# Then, take that to astopband
# Note that numpy will choke if this is too low.
stopband_gain = -0.1 + scale_factor * lp_stopband_db
# Now design and run the filter:
# Note that gstop is actually positive, so we have to flip it.
b, a = signal.iirdesign(
wp=lp_freq,
ws=lp_stopfreq,
fs=SR,
gpass=0.1,
gstop=-stopband_gain,
ftype="butterworth",
)
filtered = signal.lfilter(b, a, entry)
az_out[i][j] = filtered
return az_out
# this is the data that we need to write out.
HrtfData = namedtuple(
"HrtfData",
[
# Number of elevations in the dataset.
"num_elevs",
# Increment of the elevation in degrees.
"elev_increment",
# Min elevation angle in degrees.
"elev_min",
# num_elevs-length list.
# Holds the azimuth count for each elevation.
# For now, we assume azimuths are equally distributed.
"num_azimuths",
# The azimuths themselves as an array of arrays of arrays.
"azimuths",
# Number of data points in the set.
"impulse_length",
],
)
def compute_hrtf_data():
hrir_length_final = 32
np.seterr(all="raise")
dataset = load_dataset()
elev_meta = extract_elevation_metadata(dataset)
azimuths = unfold_azimuths(dataset)
print("Initial dc", compute_dc(azimuths))
azimuths = convert_to_magnitude_responses(azimuths)
azimuths = equalize_power(azimuths)
azimuths = clamp_responses(azimuths)
azimuths = conv_minimum_phase(azimuths)
print("Minimum phase dc", compute_dc(azimuths))
azimuths = remove_base(azimuths)
azimuths = truncate_hrir(azimuths, hrir_length_final)
print("Truncated dc", compute_dc(azimuths))
azimuths = emphasize_behind(azimuths)
num_azs = [len(i) for i in azimuths]
impulse_length = len(azimuths[0][0])
assert impulse_length == hrir_length_final
ret = HrtfData(
num_elevs=elev_meta.num_elevs,
elev_min=elev_meta.elev_min,
elev_increment=elev_meta.degs_per_elev,
num_azimuths=num_azs,
azimuths=azimuths,
impulse_length=impulse_length,
)
print(ret._replace(azimuths=None))
print(azimuths[0][0])
return ret