Skip to content

Commit

Permalink
fix and add tests to reflect changes
Browse files Browse the repository at this point in the history
  • Loading branch information
tyarkoni committed Jan 1, 2017
1 parent 7de408b commit a12cd74
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 41 deletions.
67 changes: 45 additions & 22 deletions pliers/tests/test_converters.py
Expand Up @@ -9,11 +9,12 @@
GoogleSpeechAPIConverter,
IBMSpeechAPIConverter)
from pliers.converters.google import GoogleVisionAPITextConverter
from pliers.converters.iterators import ComplexTextIterator
from pliers.stimuli.video import VideoStim, VideoFrameStim, DerivedVideoStim
from pliers.stimuli.text import TextStim, ComplexTextStim
from pliers.stimuli.audio import AudioStim
from pliers.stimuli.image import ImageStim

from pliers import config
import numpy as np
import math
import pytest
Expand All @@ -25,10 +26,10 @@ def test_video_to_audio_converter():
filename = join(get_test_data_path(), 'video', 'small.mp4')
video = VideoStim(filename)
conv = VideoToAudioConverter()
audio = conv.transform(video)
audio = conv.convert(video)
assert audio.name == 'small.mp4->small.wav'
assert isinstance(audio.source_stim, VideoStim)
assert audio.source_stim.name == 'small.mp4'
assert audio.history.source_class == 'VideoStim'
assert audio.history.source_file == filename
assert splitext(video.filename)[0] == splitext(audio.filename)[0]
assert np.isclose(video.duration, audio.duration, 1e-2)

Expand All @@ -46,7 +47,8 @@ def test_derived_video_converter():
assert len(derived.elements) == math.ceil(video.n_frames / 3.0)
first = next(f for f in derived)
assert type(first) == VideoFrameStim
assert first.name == 'small.mp4_0'
print(first.name)
assert first.name == 'small.mp4->frame[0]'
assert first.duration == 3 * (1 / 30.0)

# Should refilter from original frames
Expand Down Expand Up @@ -101,9 +103,9 @@ def test_ibmAPI_converter():
stim = AudioStim(join(audio_dir, 'homer.wav'))
conv = IBMSpeechAPIConverter()
out_stim = conv.transform(stim)
assert type(out_stim) == ComplexTextStim
assert isinstance(out_stim, ComplexTextStim)
first_word = next(w for w in out_stim)
assert type(first_word) == TextStim
assert isinstance(first_word, TextStim)
assert first_word.duration > 0
assert first_word.onset != None

Expand All @@ -117,9 +119,9 @@ def test_tesseract_converter():
stim = ImageStim(join(image_dir, 'button.jpg'))
conv = TesseractConverter()
out_stim = conv.transform(stim)
assert out_stim.name == 'button.jpg->Exit'
assert isinstance(out_stim.source_stim, ImageStim)
assert out_stim.source_stim.name == 'button.jpg'
assert out_stim.name == 'button.jpg->text[Exit]'
assert out_stim.history.source_class == 'ImageStim'
assert out_stim.history.source_name == 'button.jpg'


@pytest.mark.skipif("'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ")
Expand All @@ -143,33 +145,42 @@ def test_get_converter():


def test_converter_memoization():

cache_value = config.cache_converters
config.cache_converters = True

filename = join(get_test_data_path(), 'video', 'small.mp4')
video = VideoStim(filename)
conv = VideoToAudioConverter()

memory.clear()
def convert(stim):
start_time = time.time()
stim = conv.convert(stim)
return time.time() - start_time

# Time taken first time through
start_time = time.time()
audio1 = conv.convert(video)
convert_time = time.time() - start_time
memory.clear()

start_time = time.time()
audio2 = conv.convert(video)
cache_time = time.time() - start_time
convert_time = convert(video)
cache_time = convert(video)

# TODO: implement saner checking than this
# Converting should be at least twice as slow as retrieving from cache
assert convert_time >= cache_time * 2

# After clearing the cache, checks should fail
memory.clear()
start_time = time.time()
audio2 = conv.convert(video)
cache_time = time.time() - start_time
cache_time = convert(video)
assert convert_time <= cache_time * 2

# After clearing the cache, checks should fail
# When cach is disabled, check should also fail
config.cache_converters = False
conv = VideoToAudioConverter()
cache_time = convert(video)
assert convert_time <= cache_time * 2

config.cache_converters = cache_value


@pytest.mark.skipif("'WIT_AI_API_KEY' not in os.environ")
def test_multistep_converter():
Expand All @@ -181,12 +192,24 @@ def test_multistep_converter():
first_word = next(w for w in text)
assert type(first_word) == TextStim


@pytest.mark.skipif("'WIT_AI_API_KEY' not in os.environ")
def test_stim_history_tracking():
video = VideoStim(join(get_test_data_path(), 'video', 'obama_speech.mp4'))
assert str(video.history) == 'VideoStim'
assert video.history is None
conv = VideoToAudioConverter()
stim = conv.convert(video)
assert str(stim.history) == 'VideoStim->VideoToAudioConverter/AudioStim'
conv = WitTranscriptionConverter()
stim = conv.convert(stim)
assert str(stim.history) == 'VideoStim->VideoToAudioConverter/AudioStim->WitTranscriptionConverter/ComplexTextStim'


def test_stim_iteration_converter():
textfile = join(get_test_data_path(), 'text', 'scandal.txt')
stim = ComplexTextStim(text=open(textfile).read().strip())
words = ComplexTextIterator().transform(stim)
assert len(words) == 231
assert isinstance(words[1], TextStim)
assert words[1].text == 'Sherlock'
assert str(words[1].history) == 'ComplexTextStim->ComplexTextIterator/TextStim'
19 changes: 10 additions & 9 deletions pliers/tests/test_extractors.py
Expand Up @@ -73,8 +73,8 @@ def test_implicit_stim_conversion2():

@pytest.mark.skipif("'WIT_AI_API_KEY' not in os.environ")
def test_implicit_stim_conversion3():
audio_dir = join(get_test_data_path(), 'video')
stim = VideoStim(join(audio_dir, 'obama_speech.mp4'))
video_dir = join(get_test_data_path(), 'video')
stim = VideoStim(join(video_dir, 'obama_speech.mp4'))
ext = LengthExtractor()
result = ext.extract(stim)
first_word = result[0].to_df()
Expand All @@ -89,8 +89,7 @@ def test_text_extractor():
td = DictionaryExtractor(join(TEXT_DIR, 'test_lexical_dictionary.txt'),
variables=['length', 'frequency'])
assert td.data.shape == (7, 2)
timeline = td.extract(stim)
result = timeline[2].to_df()
result = td.extract(stim)[2].to_df()
assert np.isnan(result.iloc[0, 1])
assert result.shape == (1, 4)
assert np.isclose(result['frequency'][0], 11.729, 1e-5)
Expand Down Expand Up @@ -143,7 +142,7 @@ def test_stft_extractor():
audio_dir = join(get_test_data_path(), 'audio')
stim = AudioStim(join(audio_dir, 'barber.wav'))
ext = STFTAudioExtractor(frame_size=1., spectrogram=False,
bins=[(100, 300), (300, 3000), (3000, 20000)])
freq_bins=[(100, 300), (300, 3000), (3000, 20000)])
result = ext.extract(stim)
df = result.to_df()
assert df.shape == (557, 5)
Expand Down Expand Up @@ -259,9 +258,10 @@ def test_merge_extractor_results_by_features():
de_names = ['Extractor1', 'Extractor2', 'Extractor3']
results = [de.extract(stim, name) for name in de_names]
df = ExtractorResult.merge_features(results)
assert df.shape == (177, 10)
assert df.shape == (177, 13)
assert df.columns.levels[1].unique().tolist() == ['duration', 0, 1, 2, '']
assert df.columns.levels[0].unique().tolist() == de_names + ['onset', 'stim']
cols = cols = ['onset', 'class', 'filename', 'history', 'stim']
assert df.columns.levels[0].unique().tolist() == de_names + cols


def test_merge_extractor_results_by_stims():
Expand All @@ -286,7 +286,8 @@ def test_merge_extractor_results():
results = [de.extract(stim1, name) for name in de_names]
results += [de.extract(stim2, name) for name in de_names]
df = merge_results(results)
assert df.shape == (355, 10)
assert df.columns.levels[0].unique().tolist() == de_names + ['onset', 'stim']
assert df.shape == (355, 13)
cols = ['onset', 'class', 'filename', 'history', 'stim']
assert df.columns.levels[0].unique().tolist() == de_names + cols
assert df.columns.levels[1].unique().tolist() == ['duration', 0, 1, 2, '']
assert set(df.index.levels[1].unique()) == set(['obama.jpg', 'apple.jpg'])
12 changes: 6 additions & 6 deletions pliers/tests/test_graph.py
Expand Up @@ -47,7 +47,7 @@ def test_graph_smoke_test():
stim = ImageStim(filename)
nodes = [(BrightnessExtractor(), 'brightness')]
graph = Graph(nodes)
result = graph.extract([stim])
result = graph.extract(stim)
brightness = result[('BrightnessExtractor', 'brightness')].values[0]
assert_almost_equal(brightness, 0.556134, 5)

Expand Down Expand Up @@ -81,7 +81,7 @@ def test_small_pipeline():
assert history.shape == (2, 8)
assert history.iloc[0]['result_class'] == 'TextStim'
result = merge_results(result)
assert (0, 'button.jpg->Exit') in result.index
assert (0, 'button.jpg->text[Exit]') in result.index
assert ('LengthExtractor', 'text_length') in result.columns
assert result[('LengthExtractor', 'text_length')].values[0] == 4

Expand All @@ -94,16 +94,16 @@ def test_big_pipeline():
[(TesseractConverter(), 'visual_text',
[(LengthExtractor(), 'visual_text_length')]),
(VibranceExtractor(), 'visual_vibrance')])]
audio_nodes = [(VideoToAudioConverter(), 'audio',
audio_nodes = [(VideoToAudioConverter(), 'audio',
[(WitTranscriptionConverter(), 'audio_text',
[(LengthExtractor(), 'audio_text_length')])])]
graph = Graph()
graph.add_children(visual_nodes)
graph.add_children(audio_nodes)
result = graph.extract(video)
print(result)
assert ('LengthExtractor', 'text_length') in result.columns
assert ('VibranceExtractor', 'vibrance') in result.columns
assert not result[('onset', '')].isnull().any()
assert 'obama_speech.mp4_obama_speech.wav_today' in result.index.get_level_values(1)
assert 'obama_speech.mp4_90' in result.index.get_level_values(1)
print(result)
assert 'text[together]' in result.index.get_level_values(1)
assert 'obama_speech.mp4->frame[90]' in result.index.get_level_values(1)
2 changes: 1 addition & 1 deletion pliers/tests/test_io.py
Expand Up @@ -24,7 +24,7 @@ def test_convert_to_long():
audio_dir = join(get_test_data_path(), 'audio')
stim = AudioStim(join(audio_dir, 'barber.wav'))
ext = STFTAudioExtractor(frame_size=1., spectrogram=False,
bins=[(100, 300), (300, 3000), (3000, 20000)])
freq_bins=[(100, 300), (300, 3000), (3000, 20000)])
timeline = ext.extract(stim)
long_timeline = to_long_format(timeline)
assert long_timeline.shape == (timeline.to_df().shape[0] * 3, 4)
Expand Down
2 changes: 1 addition & 1 deletion pliers/tests/test_stims.py
Expand Up @@ -142,7 +142,7 @@ def test_compound_stim():
video = VideoStim(filename)
text = ComplexTextStim(text="The quick brown fox jumped...")
stim = CompoundStim([audio, image1, image2, video, text])
assert len(stim.stims) == 5
assert len(stim.elements) == 5
assert isinstance(stim.video, VideoStim)
assert isinstance(stim.complex_text, ComplexTextStim)
assert isinstance(stim.image, ImageStim)
Expand Down
5 changes: 3 additions & 2 deletions pliers/tests/test_transformers.py
@@ -1,7 +1,8 @@
from pliers.transformers import get_transformer, TransformationHistory
from pliers.transformers import get_transformer
from pliers.extractors import Extractor
from pliers.extractors.audio import STFTAudioExtractor
from pliers.tests.utils import get_test_data_path, DummyExtractor
from pliers.stimuli import TransformationLog
from pliers.stimuli.image import ImageStim
from os.path import join

Expand All @@ -16,7 +17,7 @@ def test_transformation_history():
img = ImageStim(join(get_test_data_path(), 'image', 'apple.jpg'))
ext = DummyExtractor('giraffe')
res = ext.extract(img).history
assert isinstance(res, TransformationHistory)
assert isinstance(res, TransformationLog)
df = res.to_df()
assert df.shape == (1, 8)
assert list(df.columns) == ['source_name', 'source_file', 'source_class',
Expand Down

0 comments on commit a12cd74

Please sign in to comment.