Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make the output of "feature_window" function chronological #68

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
allosaurus/pm/test.py

# Translations
*.mo
Expand Down Expand Up @@ -125,9 +126,12 @@ venv.bak/
.dmypy.json
dmypy.json

# audio files
allosaurus/*.wav

# Pyre type checker
.pyre/
.idea/
allosaurus/pretrained/*
allosaurus.egg-info
test_model/*
test_model/*
2 changes: 1 addition & 1 deletion allosaurus/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,4 @@ def recognize(self, filename, lang_id='ipa', topk=1, emit=1.0, timestamp=False):
batch_lprobs = tensor_batch_lprobs.detach().numpy()

token = self.lm.compute(batch_lprobs[0], lang_id, topk, emit=emit, timestamp=timestamp)
return token
return token
5 changes: 2 additions & 3 deletions allosaurus/pm/mfcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def compute(self, audio):

# make sample rate consistent
audio = resample_audio(audio, self.sample_rate)

# validate sample rate
assert self.config.sample_rate == audio.sample_rate, " sample rate of audio is "+str(audio.sample_rate)+" , but model is "+str(self.config.sample_rate)

Expand All @@ -70,6 +70,5 @@ def compute(self, audio):

# subsampling and windowing
if self.feature_window == 3:
feat = feature_window(feat)

feat = feature_window_ordered(feat)
return feat
35 changes: 35 additions & 0 deletions allosaurus/pm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,45 @@ def feature_cmvn(feature):


def feature_window(feature, window_size=3):
"""
chunks a given array based on the window_size (3) so the length of the 2nd dimensions is 3x the original.
given [[1 2 3]
[3 4 5]
[6 7 8]]
it turns into
[[6 7 8 1 2 3 3 4 5]]
the function rolls the array so that the last is at the start and the first is at the end. it concatonates them and then removes the repeated elements. This creates and offset and aligns the audio data so that it is not out of time with the phones
"""

assert window_size == 3, "only window size 3 is supported"

feature = np.concatenate((np.roll(feature, 1, axis=0), feature, np.roll(feature, -1, axis=0)), axis=1)
feature = feature[::3, ]

return feature

def feature_window_ordered(feature, window_size=3):
"""
chunks a given 2D array (feature) into a different 2D array of with a shfted array where the 2nd dimension is 3x the original length
e.g. given
[[1, 2, 3],
[3, 4, 5],
[6, 7, 8]]
to
[[1, 2, 3, 1, 2, 3, 3, 4, 5],
[6, 7, 8, 6, 7, 8, 6, 7, 8]]

it repeats the first element (in this case 1, 2, 3) in order to shift the remaining elements so that it lines up the timing for the phones to be decoded
"""
assert window_size == 3, "Window_size must equall 3"

shape = feature.shape

trailing_els = (3-(shape[0] + 1)%3)%3

windowed = np.full((shape[0] + 1 + trailing_els, shape[1]), feature[-1])
windowed[0] = feature[0]
windowed[1:shape[0] + 1] = feature

windowed.shape = (windowed.size // (shape[1] * 3), shape[1] * 3 )
return windowed