Skip to content

Commit

Permalink
Merge 6baf160 into c789bb2
Browse files Browse the repository at this point in the history
  • Loading branch information
rbroc committed Feb 26, 2020
2 parents c789bb2 + 6baf160 commit 3cd2715
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 9 deletions.
7 changes: 7 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,18 @@ before_install:
- conda create -q -n test-env python=$PYTHON_VERSION
- source activate test-env
install:
- python -m pip install --upgrade pip wheel
- sudo apt-get install libboost-python-dev
- pip install --upgrade --ignore-installed setuptools
- pip install -r requirements.txt --upgrade
- pip install -r optional-dependencies.txt --upgrade
- pip install --upgrade coveralls pytest-cov
- git clone --depth 1 -b master https://github.com/tensorflow/models
- cd models/research/audioset/yamnet
- curl -O https://storage.googleapis.com/audioset/yamnet.h5
- python yamnet_test.py
- export PYTHONPATH=$PYTHONPATH:$PWD
- cd /home/travis/build/tyarkoni/pliers

before_script:
- python -m pliers.support.download
Expand Down
2 changes: 2 additions & 0 deletions optional-dependencies.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ pygraphviz
pysrt
pytesseract
python-twitter
resampy
scikit-learn
seaborn
soundfile
spacy
SpeechRecognition>=3.6.0
tensorflow>=1.0.0
Expand Down
3 changes: 2 additions & 1 deletion pliers/extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
TempoExtractor,
BeatTrackExtractor,
HarmonicExtractor,
PercussiveExtractor)
PercussiveExtractor,
AudiosetLabelExtractor)
from .image import (BrightnessExtractor, SaliencyExtractor, SharpnessExtractor,
VibranceExtractor, FaceRecognitionFaceEncodingsExtractor,
FaceRecognitionFaceLandmarksExtractor,
Expand Down
143 changes: 142 additions & 1 deletion pliers/extractors/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,38 @@
from pliers.stimuli.text import ComplexTextStim
from pliers.extractors.base import Extractor, ExtractorResult
from pliers.utils import attempt_to_import, verify_dependencies, listify
from pliers.support.exceptions import MissingDependencyError
import numpy as np
from scipy import fft
import pandas as pd
import soundfile as sf
from abc import ABCMeta
from os import path
import logging

librosa = attempt_to_import('librosa')

yamnet = attempt_to_import('yamnet')
tf = attempt_to_import('tensorflow')

YAMNET_INSTALL_MESSAGE = '''
yamnet cannot be imported. To download and set up yamnet, open a terminal
window and do the following
- cd DOWNLOAD_PATH (path where you want yamnet to be downloaded)
- git clone --depth 1 -b master https://github.com/tensorflow/models
- cd models/research/audioset/yamnet
- curl -O https://storage.googleapis.com/audioset/yamnet.h5
- python yamnet_test.py
If you're a Mac or Linux user, do:
- open ~/.bash_profile
- add "export PYTHONPATH=$PYTHONPATH:DOWNLOAD_PATH/models/research/audioset/yamnet"
to the end of the file
- save and close
If you're a Windows user, do:
- set PYTHONPATH=%PYTHONPATH%;DOWNLOAD_PATH/models/research/audioset/yamnet
- To make the change permanent, you have to add this line to your autoexec.bat
'''

class AudioExtractor(Extractor):

Expand Down Expand Up @@ -488,4 +514,119 @@ class PercussiveExtractor(LibrosaFeatureExtractor):
For details on argument specification visit:
https://librosa.github.io/librosa/effect.html.'''


_feature = 'percussive'


class AudiosetLabelExtractor(AudioExtractor):

''' Extract probability of 521 audio event classes based on AudioSet
corpus using a YAMNet architecture. Code available at:
https://github.com/tensorflow/models/tree/master/research/audioset/yamnet
Args:
hop_size (float): size of the audio segment (in seconds) on which label
extraction is performed.
top_n (int): specifies how many of the highest label probabilities are
returned. If not defined, returns probabilities for all 521 labels.
label_subset (list): specifies subset of labels for which probabilities
are to be returned. A comprehensive list of labels is available in the
audioset/yamnet repository (see yamnet_class_map.csv).
weights_path (optional): full path to model weights file. If not provided,
weights from pretrained YAMNet module are used.
yamnet_kwargs (optional): Optional named arguments that modify input
parameters for the model (see params.py file in yamnet repository)
'''

_log_attributes = ('hop_size', 'top_n', 'label_subset', 'weights_path',
'yamnet_kwargs')

def __init__(self, hop_size=0.1, top_n=None, label_subset=None,
weights_path=None, **yamnet_kwargs):
try:
verify_dependencies(['yamnet'])
except MissingDependencyError:
raise MissingDependencyError(dependencies=None,
custom_message=YAMNET_INSTALL_MESSAGE)
verify_dependencies(['tensorflow'])

MODULE_PATH = path.dirname(yamnet.__file__)
LABELS_PATH = path.join(MODULE_PATH, 'yamnet_class_map.csv')
self.weights_path = weights_path or path.join(MODULE_PATH, 'yamnet.h5')
self.hop_size = hop_size
self.yamnet_kwargs = yamnet_kwargs
self.params = yamnet.params
self.params.PATCH_HOP_SECONDS = hop_size
for par, v in self.yamnet_kwargs.items():
if par in self.params.__dict__:
setattr(self.params, par, v)
self.labels = pd.read_csv(LABELS_PATH)['display_name'].tolist()
self.label_subset = label_subset
self.top_n = top_n
if self.label_subset:
for l in self.label_subset:
if l not in self.labels:
logging.warning('''Label {} does not exist.
Dropping.'''.format(l))
super(AudiosetLabelExtractor, self).__init__()

def _extract(self, stim):
params = self.params
params.SAMPLE_RATE = stim.sampling_rate

if params.SAMPLE_RATE >= 2 * params.MEL_MAX_HZ:
if params.SAMPLE_RATE != 16000:
logging.warning(
'The sampling rate of the stimulus is '
f'{params.SAMPLE_RATE}Hz. '
'YAMNet was trained on audio sampled at 16000Hz. '
'This should not impact predictions, but you can resample '
'the input using AudioResamplingFilter for full conformity '
'to training.')
if params.MEL_MIN_HZ != 125 or params.MEL_MAX_HZ != 7500:
logging.warning(
'Custom values for MEL_MIN_HZ and MEL_MAX_HZ '
'were passed. Changing these defaults might affect '
'model performance.')
else:
raise ValueError(
f'The sampling rate of your stimulus ({params.SAMPLE_RATE}Hz)'
' must be at least twice the value of MEL_MAX_HZ '
f'({params.MEL_MAX_HZ}Hz). '
'Upsample your audio stimulus (recommended) or pass a lower '
'value of MEL_MAX_HZ when initializing this extractor.')

labels = self.labels
model = yamnet.yamnet_frames_model(params)
model.load_weights(self.weights_path)
preds, _ = model.predict_on_batch(np.reshape(stim.data, [1,-1]))
preds = preds.numpy()

if self.label_subset:
for l in self.label_subset:
print(l)
print(np.where(labels == l))
label_subset_idx = [idx for idx, lab in enumerate(labels)
if lab in self.label_subset]
preds = preds[:,label_subset_idx]
labels = self.label_subset

nr_lab = self.top_n or len(labels)
if nr_lab > len(labels):
raise ValueError(
f'Value of top_n ({self.top_n}) exceeds number of '
f'labels ({len(labels)}). Reinstantiate this extractor using '
'suitable parameters.''')
idx = np.mean(preds,axis=0).argsort()
preds = np.fliplr(preds[:,idx][:,-nr_lab:])
labels = [labels[i] for i in idx][-nr_lab:][::-1]

durations = params.PATCH_HOP_SECONDS
onsets = np.arange(start=0, stop=stim.duration, step=durations)
onsets = onsets[onsets + params.PATCH_WINDOW_SECONDS < stim.duration]

return ExtractorResult(preds, stim, self, features=labels,
onsets=onsets, durations=durations,
orders=list(range(len(onsets))))

# Add pointer to installation instructions (install models, run test install - link, add to pythonpath)
9 changes: 5 additions & 4 deletions pliers/filters/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pliers.stimuli import AudioStim
from pliers.utils import attempt_to_import, verify_dependencies
from .base import Filter, TemporalTrimmingFilter
from copy import deepcopy

librosa = attempt_to_import('librosa')

Expand Down Expand Up @@ -43,11 +44,11 @@ def __init__(self, target_sr=44100, resample_type='kaiser_best',
super(AudioResamplingFilter, self).__init__()

def _filter(self, stim):
stim.data = librosa.core.resample(y=stim.data,
resampled_stim = deepcopy(stim)
resampled_stim.data = librosa.core.resample(y=stim.data,
orig_sr=stim.sampling_rate,
target_sr=self.target_sr,
resample_type=self.resample_type,
**self.librosa_kwargs)
stim.sampling_rate = self.target_sr

return stim
resampled_stim.sampling_rate = self.target_sr
return resampled_stim
7 changes: 5 additions & 2 deletions pliers/support/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class MissingDependencyError(PliersError):
"""

def __init__(self, dependencies, message=MISSING_DEPENDENCY_MESSAGE,
*args, **kwargs):
msg = message % ', '.join(dependencies)
custom_message=None, *args, **kwargs):
if custom_message:
msg = custom_message
else:
msg = message % ', '.join(dependencies)
super(MissingDependencyError, self).__init__(msg, *args, **kwargs)
66 changes: 65 additions & 1 deletion pliers/tests/extractors/test_audio_extractors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from os.path import join
from ..utils import get_test_data_path
from pliers import set_option
from pliers.extractors import (LibrosaFeatureExtractor,
STFTAudioExtractor,
MeanAmplitudeExtractor,
Expand All @@ -23,10 +24,13 @@
TempoExtractor,
BeatTrackExtractor,
HarmonicExtractor,
PercussiveExtractor)
PercussiveExtractor,
AudiosetLabelExtractor)
from pliers.stimuli import (ComplexTextStim, AudioStim,
TranscribedAudioCompoundStim)
from pliers.filters import AudioResamplingFilter
import numpy as np
import pytest

AUDIO_DIR = join(get_test_data_path(), 'audio')

Expand Down Expand Up @@ -352,3 +356,63 @@ def test_percussion_extractor():
assert np.isclose(df['onset'][29], 1.346757)
assert np.isclose(df['duration'][29], 0.04644)
assert np.isclose(df['percussive'][29], 0.004497, rtol=1e-4)


@pytest.mark.parametrize('hop_size', [0.1, 1])
@pytest.mark.parametrize('top_n', [5, 10])
@pytest.mark.parametrize('target_sr', [22000, 14000])
def test_audioset_extractor(hop_size, top_n, target_sr):

def compute_expected_length(stim, ext):
bins = int(stim.duration / ext.hop_size)
nr_incomplete = ext.params.PATCH_WINDOW_SECONDS / ext.hop_size - 1
exp_length = int(bins - nr_incomplete)
return exp_length

audio_stim = AudioStim(join(AUDIO_DIR, 'crowd.mp3'))
audio_filter = AudioResamplingFilter(target_sr=target_sr)
audio_resampled = audio_filter.transform(audio_stim)

# test with defaults and 44100 stimulus
ext = AudiosetLabelExtractor(hop_size=hop_size)
r_orig = ext.transform(audio_stim).to_df()
assert r_orig.shape[0] == compute_expected_length(audio_stim, ext)
assert r_orig.shape[1] == 525
assert np.argmax(r_orig.to_numpy()[:,4:].mean(axis=0)) == 0
assert r_orig['duration'][0] == ext.hop_size
assert all([np.isclose(r_orig['onset'][i] - r_orig['onset'][i-1], hop_size)
for i in range(1,r_orig.shape[0])])

# test resampled audio length and errors
if target_sr >= 14500:
r_resampled = ext.transform(audio_resampled).to_df()
assert r_orig.shape[0] == r_resampled.shape[0]
else:
with pytest.raises(ValueError) as sr_error:
ext.transform(audio_resampled)
assert all([substr in str(sr_error.value)
for substr in ['Upsample' , str(target_sr)]])

# test top_n option
ext_top_n = AudiosetLabelExtractor(top_n=top_n)
r_top_n = ext_top_n.transform(audio_stim).to_df()
assert r_top_n.shape[1] == ext_top_n.top_n + 4
assert np.argmax(r_top_n.to_numpy()[:,4:].mean(axis=0)) == 0

# test label subset
labels = ['Speech', 'Silence', 'Harmonic', 'Bark', 'Music', 'Bell',
'Steam', 'Rain']
ext_labels_only = AudiosetLabelExtractor(label_subset=labels)
r_labels_only = ext_labels_only.transform(audio_stim).to_df()
assert r_labels_only.shape[1] == len(labels) + 4

# test top_n/labels combination
ext_labels_top_n = AudiosetLabelExtractor(top_n=top_n, label_subset=labels)
if top_n > len(labels):
with pytest.raises(ValueError) as labels_error:
ext_labels_top_n.transform(audio_stim)
assert all([val in str(labels_error.value)
for val in [str(top_n), str(len(labels))]])
else:
r_labels_top_n = ext_labels_top_n.transform(audio_stim).to_df()
assert r_labels_top_n.shape[1] == top_n + 4

0 comments on commit 3cd2715

Please sign in to comment.