Skip to content

Commit

Permalink
Merge pull request #390 from rbroc/prob_extractor
Browse files Browse the repository at this point in the history
add MetricExtractor
  • Loading branch information
tyarkoni committed May 4, 2020
2 parents 4656fa2 + 9ef93e5 commit a2a9994
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 6 deletions.
4 changes: 3 additions & 1 deletion pliers/extractors/__init__.py
Expand Up @@ -58,6 +58,7 @@
VibranceExtractor, FaceRecognitionFaceEncodingsExtractor,
FaceRecognitionFaceLandmarksExtractor,
FaceRecognitionFaceLocationsExtractor)
from .misc import MetricExtractor
from .models import TensorFlowKerasApplicationExtractor
from .text import (ComplexTextExtractor, DictionaryExtractor,
PredefinedDictionaryExtractor, LengthExtractor,
Expand Down Expand Up @@ -146,5 +147,6 @@
'BertLMExtractor',
'BertSentimentExtractor',
'AudiosetLabelExtractor',
'WordCounterExtractor'
'WordCounterExtractor',
'MetricExtractor'
]
74 changes: 74 additions & 0 deletions pliers/extractors/misc.py
@@ -0,0 +1,74 @@
'''
Extractors that operate on Miscellaneous Stims.
'''

from pliers.stimuli.misc import SeriesStim
from pliers.extractors.base import Extractor, ExtractorResult
from pliers.utils import listify, isiterable
import scipy
import numpy as np
import pandas as pd
from importlib import import_module
import logging

class MetricExtractor(Extractor):
''' Extracts summary metrics from 1D-array using numpy, scipy or custom
functions
Args:
functions (str, functions or list): function or string referring to absolute
import path for a function (e.g. 'numpy.mean'). Function must operate
on 1-dimensional numpy arrays and return a scalar. A list of
functions or import strings may also be passed.
var_names (list): optional list of custom alias names for each metric
subset_idx (list): subset of Series index labels to compute metric on.
kwargs: named arguments for function call
'''

_input_type = SeriesStim
_log_attributes = ('functions', 'subset_idx')

def __init__(self, functions=None, var_names=None,
subset_idx=None, **kwargs):
functions = listify(functions)
if var_names is not None:
var_names = listify(var_names)
if len(var_names) != len(functions):
raise ValueError('Length or var_names must match number of '
'functions')
for idx, f in enumerate(functions):
if isinstance(f, str):
try:
f_mod, f_func = f.rsplit('.', 1)
functions[idx] = getattr(import_module(f_mod),
f_func)
except:
raise ValueError(f"{f} is not a valid function")
if var_names is None:
var_names = [f.__name__ for f in functions]
self.var_names = var_names

self.functions = functions
self.kwargs = kwargs
self.subset_idx = subset_idx
super(MetricExtractor, self).__init__()

def _extract(self, stim):
outputs = []
if self.subset_idx is not None:
idx_diff = set(self.subset_idx) - set(stim.data.index)
idx_int = set(self.subset_idx) & set(stim.data.index)
if idx_diff:
logging.warning(f'{idx_diff} not in index.')
if not idx_int:
raise ValueError('No valid index')
series = stim.data[idx_int]
else:
series = stim.data
for f in self.functions:
metrics = f(series, **self.kwargs)
if isiterable(metrics):
metrics = np.array(metrics)
outputs.append(metrics)
return ExtractorResult([outputs], stim, self, self.var_names)


2 changes: 1 addition & 1 deletion pliers/extractors/text.py
Expand Up @@ -729,7 +729,7 @@ def _postprocess(self, stims, preds, tok, wds, ons, dur):
sub_idx = out_idx
out_idx = [idx for idx in out_idx if idx in sub_idx]
feat = self.tokenizer.convert_ids_to_tokens(out_idx)
feat = [f.capitalize() for f in feat]
feat = [f.capitalize() if len(f)==len(f.encode()) else f for f in feat]
data = [listify(p) for p in preds[0,self.mask_pos,out_idx]]
if self.return_masked_word:
feat, data = self._return_masked_word(preds, feat, data)
Expand Down
21 changes: 21 additions & 0 deletions pliers/tests/data/vector/vector_df.txt
@@ -0,0 +1,21 @@
label value confound
0 label0 0.6058707709809897 0
1 label1 0.11605248333572672 1
2 label2 0.632688166774749 2
3 label3 -0.7788613741732395 3
4 label4 -1.1927231656449686 4
5 label5 -1.5145981635335528 5
6 label6 0.8169464941251849 6
7 label7 0.4493482710553274 7
8 label8 0.8962510727747143 8
9 label9 0.07410810811610649 9
10 label10 -0.5231485242320049 10
11 label11 0.01839890045779546 11
12 label12 -0.30973596547014837 12
13 label13 -0.6225272742055447 13
14 label14 -0.19706489417523174 14
15 label15 -0.3855099705813389 15
16 label16 -1.7729569214611904 16
17 label17 1.1227098161292115 17
18 label18 -0.38865077349594085 18
19 label19 0.33001487459132467 19
73 changes: 73 additions & 0 deletions pliers/tests/extractors/test_misc_extractors.py
@@ -0,0 +1,73 @@
from pliers.extractors import (MetricExtractor, BertLMExtractor,
merge_results)
from pliers.stimuli import SeriesStim, ComplexTextStim
from pliers.tests.utils import get_test_data_path
import numpy as np
import scipy
from pathlib import Path
import pytest

def test_metric_extractor():

def dummy(array):
return array[0]

def dummy_list(array):
return array[0], array[1]

f = Path(get_test_data_path(), 'text', 'test_lexical_dictionary.txt')
stim = SeriesStim(data=np.linspace(1., 4., 20), onset=2., duration=.5)
stim_file = SeriesStim(filename=f, column='frequency', sep='\t',
index_col='text')

ext_single = MetricExtractor(functions='numpy.mean')
ext_idx = MetricExtractor(functions='numpy.mean',
subset_idx=['for', 'testing', 'text'])
ext_multiple = MetricExtractor(functions=['numpy.mean', 'numpy.min',
scipy.stats.entropy, dummy,
dummy_list])
ext_names = MetricExtractor(functions=['numpy.mean', 'numpy.min',
scipy.stats.entropy, dummy,
dummy_list, 'tensorflow.reduce_mean'],
var_names=['mean', 'min', 'entropy',
'custom1', 'custom2', 'tf_mean'])

r = ext_single.transform(stim)
r_file = ext_single.transform(stim_file)
r_file_idx = ext_idx.transform(stim_file)
r_multiple = ext_multiple.transform(stim)
r_names = ext_names.transform(stim)

r_df = r.to_df()
r_file_df = r_file.to_df()
r_file_idx_df = r_file_idx.to_df()
r_multiple_df = r_multiple.to_df()
r_long = r_multiple.to_df(format='long')
r_names_df = r_names.to_df()

for res in [r_df, r_file_df, r_multiple_df]:
assert res.shape[0] == 1
assert r_long.shape[0] == len(ext_multiple.functions)
assert r_df['onset'][0] == 2
assert r_df['duration'][0] == .5
assert r_df['mean'][0] == 2.5
assert np.isclose(r_file_df['mean'][0], 11.388, rtol=0.001)
assert np.isclose(r_file_idx_df['mean'][0], 12.582, rtol=0.001)
assert all([m in r_multiple_df.columns for m in ['mean', 'entropy']])
assert r_multiple_df['amin'][0] == 1.
assert r_multiple_df['dummy'][0] == 1.
assert r_multiple_df['dummy_list'][0][0] == np.linspace(1., 4., 20)[0]
assert r_multiple_df['dummy_list'][0][1] == np.linspace(1., 4., 20)[1]
assert type(r_multiple_df['dummy_list'][0]) == np.ndarray
assert r_names_df.columns[-3] == 'custom1'
assert r_names_df.columns[-2] == 'custom2'
assert r_names_df.columns[-1] == 'tf_mean'
assert np.isclose(r_names_df['mean'][0], r_names_df['tf_mean'][0])

def test_metric_er_as_stim():
stim = ComplexTextStim(text = 'This is [MASK] test')
ext_bert = BertLMExtractor(return_softmax=True)
ext_metric = MetricExtractor(functions='numpy.sum')
r = ext_metric.transform(ext_bert.transform(stim))
df = merge_results(r, extractor_names=False)
assert np.isclose(df['sum'][0], 1)
7 changes: 4 additions & 3 deletions pliers/tests/extractors/test_text_extractors.py
Expand Up @@ -2,12 +2,10 @@
from pathlib import Path
from os import environ
import shutil

import numpy as np
import pytest
import spacy
from transformers import BertTokenizer

from pliers import config
from pliers.extractors import (DictionaryExtractor,
PartOfSpeechExtractor,
Expand Down Expand Up @@ -446,7 +444,7 @@ def test_bert_LM_extractor():
assert res_topn.shape[1] == 104
assert all([res_topn.iloc[:,3][0] > res_topn.iloc[:,i][0] for i in range(4,103)])

# Check threshold and return_softmax
# Check threshold and range
tknz = BertTokenizer.from_pretrained('bert-base-uncased')
vocab = tknz.vocab.keys()
for v in vocab:
Expand All @@ -473,6 +471,9 @@ def test_bert_LM_extractor():
assert 'true_word_score' in res_return_mask.columns
assert res_return_mask['sequence'][0] == 'This is not a tokenized sentence .'

# Make sure no non-ascii tokens are dropped
assert res.shape[1] == len(vocab) + 4

# remove variables
del ext_target, res, res_file, res_target, res_topn, \
res_threshold, res_default, res_return_mask
Expand Down
1 change: 1 addition & 0 deletions pliers/tests/test_stims.py
Expand Up @@ -303,6 +303,7 @@ def test_save():
text_stim = TextStim(text='hello')
audio_stim = AudioStim(join(get_test_data_path(), 'audio', 'crowd.mp3'))
image_stim = ImageStim(join(get_test_data_path(), 'image', 'apple.jpg'))

# Video gives travis problems
stims = [complextext_stim, text_stim, audio_stim, image_stim]
for s in stims:
Expand Down
3 changes: 2 additions & 1 deletion pliers/utils/base.py
Expand Up @@ -7,6 +7,7 @@
from itertools import islice

from tqdm import tqdm
import pandas as pd

from pliers import config
from pliers.support.exceptions import MissingDependencyError
Expand Down Expand Up @@ -79,7 +80,7 @@ def __get__(self, owner_self, owner_cls):

def isiterable(obj):
''' Returns True if the object is one of allowable iterable types. '''
return isinstance(obj, (list, tuple, GeneratorType, tqdm))
return isinstance(obj, (list, tuple, pd.Series, GeneratorType, tqdm))


def isgenerator(obj):
Expand Down

0 comments on commit a2a9994

Please sign in to comment.