Skip to content

Commit

Permalink
Merge afbe485 into 7e92ad5
Browse files Browse the repository at this point in the history
  • Loading branch information
rbroc committed Mar 18, 2020
2 parents 7e92ad5 + afbe485 commit af1675c
Show file tree
Hide file tree
Showing 10 changed files with 250 additions and 13 deletions.
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ before_install:
- conda create -q -n test-env python=$PYTHON_VERSION
- source activate test-env
install:
- python -m pip install --upgrade pip wheel
- sudo apt-get install libboost-python-dev
- pip install --upgrade --ignore-installed setuptools
- pip install -r requirements.txt --upgrade
- pip install -r optional-dependencies.txt --upgrade
- pip install --upgrade coveralls pytest-cov

before_script:
- python -m pliers.support.setup_yamnet
- python -m pliers.support.download
- python -m spacy download en_core_web_sm
script:
Expand Down
1 change: 1 addition & 0 deletions optional-dependencies.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pytesseract
python-twitter
scikit-learn
seaborn
soundfile
spacy
SpeechRecognition>=3.6.0
tensorflow>=1.0.0
Expand Down
4 changes: 3 additions & 1 deletion pliers/extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
TempoExtractor,
BeatTrackExtractor,
HarmonicExtractor,
PercussiveExtractor)
PercussiveExtractor,
AudiosetLabelExtractor)
from .image import (BrightnessExtractor, SaliencyExtractor, SharpnessExtractor,
VibranceExtractor, FaceRecognitionFaceEncodingsExtractor,
FaceRecognitionFaceLandmarksExtractor,
Expand Down Expand Up @@ -138,6 +139,7 @@
'BeatTrackExtractor',
'HarmonicExtractor',
'PercussiveExtractor',
'AudiosetLabelExtractor',
'PretrainedBertEncodingExtractor',
'WordCounterExtractor'
]
129 changes: 128 additions & 1 deletion pliers/extractors/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
from pliers.stimuli.text import ComplexTextStim
from pliers.extractors.base import Extractor, ExtractorResult
from pliers.utils import attempt_to_import, verify_dependencies, listify
from pliers.support.exceptions import MissingDependencyError
from pliers.support.setup_yamnet import YAMNET_PATH
import numpy as np
from scipy import fft
import pandas as pd
import soundfile as sf
from abc import ABCMeta
from os import path
import sys
import logging

librosa = attempt_to_import('librosa')

tf = attempt_to_import('tensorflow')

class AudioExtractor(Extractor):

Expand Down Expand Up @@ -488,4 +495,124 @@ class PercussiveExtractor(LibrosaFeatureExtractor):
For details on argument specification visit:
https://librosa.github.io/librosa/effect.html.'''


_feature = 'percussive'


class AudiosetLabelExtractor(AudioExtractor):

''' Extract probability of 521 audio event classes based on AudioSet
corpus using a YAMNet architecture. Code available at:
https://github.com/tensorflow/models/tree/master/research/audioset/yamnet
Args:
hop_size (float): size of the audio segment (in seconds) on which label
extraction is performed.
top_n (int): specifies how many of the highest label probabilities are
returned. If None, all labels (or all in labels) are returned.
Top_n and labels are mutually exclusive arguments.
labels (list): specifies subset of labels for which probabilities
are to be returned. If None, all labels (or top_n) are returned.
The full list of labels is available in the audioset/yamnet
repository (see yamnet_class_map.csv).
weights_path (optional): full path to model weights file. If not provided,
weights from pretrained YAMNet module are used.
yamnet_kwargs (optional): Optional named arguments that modify input
parameters for the model (see params.py file in yamnet repository)
'''

_log_attributes = ('hop_size', 'top_n', 'labels', 'weights_path',
'yamnet_kwargs')

def __init__(self, hop_size=0.1, top_n=None, labels=None,
weights_path=None, **yamnet_kwargs):
verify_dependencies(['tensorflow'])
try:
sys.path.insert(0, str(YAMNET_PATH))
self.yamnet = attempt_to_import('yamnet')
verify_dependencies(['yamnet'])
except MissingDependencyError:
msg = ('Yamnet could not be imported. To download and set up '
'yamnet, run:\n\tpython -m pliers.support.setup_yamnet')
raise MissingDependencyError(dependencies=None,
custom_message=msg)
if top_n and labels:
raise ValueError('Top_n and labels are mutually exclusive '
'arguments. Reinstantiate the extractor setting '
'top_n or labels to None (or leaving it '
'unspecified).')

MODULE_PATH = path.dirname(self.yamnet.__file__)
LABELS_PATH = path.join(MODULE_PATH, 'yamnet_class_map.csv')
self.weights_path = weights_path or path.join(MODULE_PATH, 'yamnet.h5')
self.hop_size = hop_size
self.yamnet_kwargs = yamnet_kwargs or {}
self.params = self.yamnet.params
self.params.PATCH_HOP_SECONDS = hop_size
for par, v in self.yamnet_kwargs.items():
setattr(self.params, par, v)
if self.params.PATCH_WINDOW_SECONDS != 0.96:
logging.warning('Custom values for PATCH_WINDOW_SECONDS were '
'passed. YAMNet was trained on windows of 0.96s. Different '
'values might yield unreliable results.')

self.top_n = top_n
all_labels = pd.read_csv(LABELS_PATH)['display_name'].tolist()
if labels is not None:
missing = list(set(labels) - set(all_labels))
labels = list(set(labels) & set(all_labels))
if missing:
logging.warning(f'Labels {missing} do not exist. Dropping.')
self.labels = labels
self.label_idx = [i for i, l in enumerate(all_labels)
if l in labels]
else:
self.labels = all_labels
self.label_idx = list(range(len(all_labels)))
super(AudiosetLabelExtractor, self).__init__()

def _extract(self, stim):
self.params.SAMPLE_RATE = stim.sampling_rate

if self.params.SAMPLE_RATE >= 2 * self.params.MEL_MAX_HZ:
if self.params.SAMPLE_RATE != 16000:
logging.warning(
'The sampling rate of the stimulus is '
f'{self.params.SAMPLE_RATE}Hz. YAMNet was trained on '
' audio sampled at 16000Hz. This should not impact '
'predictions, but you can resample the input using '
'AudioResamplingFilter for full conformity '
'to training.')
if self.params.MEL_MIN_HZ != 125 or self.params.MEL_MAX_HZ != 7500:
logging.warning(
'Custom values for MEL_MIN_HZ and MEL_MAX_HZ '
'were passed. Changing these defaults might affect '
'model performance.')
else:
raise ValueError(
'The sampling rate of your stimulus '
f'({self.params.SAMPLE_RATE}Hz) must be at least twice the '
f'value of MEL_MAX_HZ ({self.params.MEL_MAX_HZ}Hz). Upsample'
' your audio stimulus (recommended) or pass a lower value of '
'MEL_MAX_HZ when initializing the extractor.')

model = self.yamnet.yamnet_frames_model(self.params)
model.load_weights(self.weights_path)
preds, _ = model.predict_on_batch(np.reshape(stim.data, [1,-1]))
preds = preds.numpy()[:,self.label_idx]

nr_lab = self.top_n or len(self.labels)
idx = np.mean(preds,axis=0).argsort()
preds = np.fliplr(preds[:,idx][:,-nr_lab:])
labels = [self.labels[i] for i in idx][-nr_lab:][::-1]

hop = self.params.PATCH_HOP_SECONDS
window = self.params.PATCH_WINDOW_SECONDS
stft_window = self.params.STFT_WINDOW_SECONDS
stft_hop = self.params.STFT_HOP_SECONDS
dur = window + stft_window - stft_hop
onsets = np.arange(start=0, stop=stim.duration - dur, step=hop)

return ExtractorResult(preds, stim, self, features=labels,
onsets=onsets, durations=[dur]*len(onsets),
orders=list(range(len(onsets))))
2 changes: 1 addition & 1 deletion pliers/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas as pd
import numpy as np
from pliers.transformers import Transformer
from pliers.utils import isgenerator, flatten
from pliers.utils import isgenerator, flatten, listify
from pandas.api.types import is_numeric_dtype


Expand Down
9 changes: 5 additions & 4 deletions pliers/filters/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pliers.stimuli import AudioStim
from pliers.utils import attempt_to_import, verify_dependencies
from .base import Filter, TemporalTrimmingFilter
from copy import deepcopy

librosa = attempt_to_import('librosa')

Expand Down Expand Up @@ -43,11 +44,11 @@ def __init__(self, target_sr=44100, resample_type='kaiser_best',
super(AudioResamplingFilter, self).__init__()

def _filter(self, stim):
stim.data = librosa.core.resample(y=stim.data,
resampled_stim = deepcopy(stim)
resampled_stim.data = librosa.core.resample(y=stim.data,
orig_sr=stim.sampling_rate,
target_sr=self.target_sr,
resample_type=self.resample_type,
**self.librosa_kwargs)
stim.sampling_rate = self.target_sr

return stim
resampled_stim.sampling_rate = self.target_sr
return resampled_stim
4 changes: 3 additions & 1 deletion pliers/support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

from .decorators import requires_nltk_corpus
from .download import download_nltk_data
from .setup_yamnet import setup_yamnet

__all__ = [
'requires_nltk_corpus',
'download_nltk_data'
'download_nltk_data',
'setup_yamnet'
]
4 changes: 2 additions & 2 deletions pliers/support/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ class MissingDependencyError(PliersError):
"""

def __init__(self, dependencies, message=MISSING_DEPENDENCY_MESSAGE,
*args, **kwargs):
msg = message % ', '.join(dependencies)
custom_message=None, *args, **kwargs):
msg = custom_message or message % ', '.join(dependencies)
super(MissingDependencyError, self).__init__(msg, *args, **kwargs)
43 changes: 43 additions & 0 deletions pliers/support/setup_yamnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os
from io import BytesIO
from zipfile import ZipFile
from urllib import request
from pathlib import Path
import sys
import runpy
import shutil

PLIERS_DATA_PATH = Path.home() / 'pliers_data'
YAMNET_PATH = PLIERS_DATA_PATH / 'yamnet'

def setup_yamnet():
repo_url = 'https://github.com/tensorflow/models/archive/master.zip'
model_url = 'https://storage.googleapis.com/audioset/yamnet.h5'

tmp_dir = PLIERS_DATA_PATH / 'yamnet_tmp'
tmp_yamnet_dir = tmp_dir / 'models-master' / 'research' / 'audioset' / 'yamnet'
model_filename = YAMNET_PATH / model_url.split('/')[-1]

if not model_filename.exists():
PLIERS_DATA_PATH.mkdir(exist_ok=True)
with request.urlopen(repo_url) as z:
print('Downloading model repository...\n')
with ZipFile(BytesIO(z.read())) as zfile:
zfile.extractall(str(tmp_dir))
shutil.move(str(tmp_yamnet_dir), str(PLIERS_DATA_PATH))
shutil.rmtree(str(tmp_dir))
size = YAMNET_PATH.stat().st_size
print(f'Model repository downloaded at {str(YAMNET_PATH)} '
f', size: {size} bytes\n')

request.urlretrieve(model_url, str(model_filename))
print(f'Model file downloaded.\n')

print(YAMNET_PATH)
test_path = YAMNET_PATH / 'yamnet_test.py'
sys.path.insert(0, str(YAMNET_PATH))
os.chdir(YAMNET_PATH)
runpy.run_path(str(test_path), run_name='__main__')

if __name__ == '__main__':
setup_yamnet()
64 changes: 62 additions & 2 deletions pliers/tests/extractors/test_audio_extractors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from os.path import join
from os import environ
from ..utils import get_test_data_path
from pliers.extractors import (LibrosaFeatureExtractor,
STFTAudioExtractor,
Expand All @@ -23,13 +24,17 @@
TempoExtractor,
BeatTrackExtractor,
HarmonicExtractor,
PercussiveExtractor)
PercussiveExtractor,
AudiosetLabelExtractor)
from pliers.stimuli import (ComplexTextStim, AudioStim,
TranscribedAudioCompoundStim)
from pliers.filters import AudioResamplingFilter
from pliers.utils import attempt_to_import, verify_dependencies
import numpy as np
import pytest

AUDIO_DIR = join(get_test_data_path(), 'audio')

tf = attempt_to_import('tensorflow')

def test_stft_extractor():
stim = AudioStim(join(AUDIO_DIR, 'barber.wav'), onset=4.2)
Expand Down Expand Up @@ -352,3 +357,58 @@ def test_percussion_extractor():
assert np.isclose(df['onset'][29], 1.346757)
assert np.isclose(df['duration'][29], 0.04644)
assert np.isclose(df['percussive'][29], 0.004497, rtol=1e-4)


@pytest.mark.parametrize('hop_size', [0.1, 1])
@pytest.mark.parametrize('top_n', [5, 10])
@pytest.mark.parametrize('target_sr', [22000, 14000])
def test_audioset_extractor(hop_size, top_n, target_sr):
verify_dependencies(['tensorflow'])

def compute_expected_length(stim, ext):
stft_par = ext.params.STFT_WINDOW_SECONDS - ext.params.STFT_HOP_SECONDS
tot_window = ext.params.PATCH_WINDOW_SECONDS + stft_par
ons = np.arange(start=0, stop=stim.duration - tot_window, step=hop_size)
return len(ons)

audio_stim = AudioStim(join(AUDIO_DIR, 'crowd.mp3'))
audio_filter = AudioResamplingFilter(target_sr=target_sr)
audio_resampled = audio_filter.transform(audio_stim)

# test with defaults and 44100 stimulus
ext = AudiosetLabelExtractor(hop_size=hop_size)
r_orig = ext.transform(audio_stim).to_df()
assert r_orig.shape[0] == compute_expected_length(audio_stim, ext)
assert r_orig.shape[1] == 525
assert np.argmax(r_orig.to_numpy()[:,4:].mean(axis=0)) == 0
assert r_orig['duration'][0] == .975
assert all([np.isclose(r_orig['onset'][i] - r_orig['onset'][i-1], hop_size)
for i in range(1,r_orig.shape[0])])

# test resampled audio length and errors
if target_sr >= 14500:
r_resampled = ext.transform(audio_resampled).to_df()
assert r_orig.shape[0] == r_resampled.shape[0]
else:
with pytest.raises(ValueError) as sr_error:
ext.transform(audio_resampled)
assert all([substr in str(sr_error.value)
for substr in ['Upsample' , str(target_sr)]])

# test top_n option
ext_top_n = AudiosetLabelExtractor(top_n=top_n)
r_top_n = ext_top_n.transform(audio_stim).to_df()
assert r_top_n.shape[1] == ext_top_n.top_n + 4
assert np.argmax(r_top_n.to_numpy()[:,4:].mean(axis=0)) == 0

# test label subset
labels = ['Speech', 'Silence', 'Harmonic', 'Bark', 'Music', 'Bell',
'Steam', 'Rain']
ext_labels_only = AudiosetLabelExtractor(labels=labels)
r_labels_only = ext_labels_only.transform(audio_stim).to_df()
assert r_labels_only.shape[1] == len(labels) + 4

# test top_n/labels error
with pytest.raises(ValueError) as err:
AudiosetLabelExtractor(top_n=10, labels=labels)
assert 'Top_n and labels are mutually exclusive' in str(err.value)

0 comments on commit af1675c

Please sign in to comment.