Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #390 from rbroc/prob_extractor
add MetricExtractor
- Loading branch information
Showing
8 changed files
with
179 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
''' | ||
Extractors that operate on Miscellaneous Stims. | ||
''' | ||
|
||
from pliers.stimuli.misc import SeriesStim | ||
from pliers.extractors.base import Extractor, ExtractorResult | ||
from pliers.utils import listify, isiterable | ||
import scipy | ||
import numpy as np | ||
import pandas as pd | ||
from importlib import import_module | ||
import logging | ||
|
||
class MetricExtractor(Extractor): | ||
''' Extracts summary metrics from 1D-array using numpy, scipy or custom | ||
functions | ||
Args: | ||
functions (str, functions or list): function or string referring to absolute | ||
import path for a function (e.g. 'numpy.mean'). Function must operate | ||
on 1-dimensional numpy arrays and return a scalar. A list of | ||
functions or import strings may also be passed. | ||
var_names (list): optional list of custom alias names for each metric | ||
subset_idx (list): subset of Series index labels to compute metric on. | ||
kwargs: named arguments for function call | ||
''' | ||
|
||
_input_type = SeriesStim | ||
_log_attributes = ('functions', 'subset_idx') | ||
|
||
def __init__(self, functions=None, var_names=None, | ||
subset_idx=None, **kwargs): | ||
functions = listify(functions) | ||
if var_names is not None: | ||
var_names = listify(var_names) | ||
if len(var_names) != len(functions): | ||
raise ValueError('Length or var_names must match number of ' | ||
'functions') | ||
for idx, f in enumerate(functions): | ||
if isinstance(f, str): | ||
try: | ||
f_mod, f_func = f.rsplit('.', 1) | ||
functions[idx] = getattr(import_module(f_mod), | ||
f_func) | ||
except: | ||
raise ValueError(f"{f} is not a valid function") | ||
if var_names is None: | ||
var_names = [f.__name__ for f in functions] | ||
self.var_names = var_names | ||
|
||
self.functions = functions | ||
self.kwargs = kwargs | ||
self.subset_idx = subset_idx | ||
super(MetricExtractor, self).__init__() | ||
|
||
def _extract(self, stim): | ||
outputs = [] | ||
if self.subset_idx is not None: | ||
idx_diff = set(self.subset_idx) - set(stim.data.index) | ||
idx_int = set(self.subset_idx) & set(stim.data.index) | ||
if idx_diff: | ||
logging.warning(f'{idx_diff} not in index.') | ||
if not idx_int: | ||
raise ValueError('No valid index') | ||
series = stim.data[idx_int] | ||
else: | ||
series = stim.data | ||
for f in self.functions: | ||
metrics = f(series, **self.kwargs) | ||
if isiterable(metrics): | ||
metrics = np.array(metrics) | ||
outputs.append(metrics) | ||
return ExtractorResult([outputs], stim, self, self.var_names) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
label value confound | ||
0 label0 0.6058707709809897 0 | ||
1 label1 0.11605248333572672 1 | ||
2 label2 0.632688166774749 2 | ||
3 label3 -0.7788613741732395 3 | ||
4 label4 -1.1927231656449686 4 | ||
5 label5 -1.5145981635335528 5 | ||
6 label6 0.8169464941251849 6 | ||
7 label7 0.4493482710553274 7 | ||
8 label8 0.8962510727747143 8 | ||
9 label9 0.07410810811610649 9 | ||
10 label10 -0.5231485242320049 10 | ||
11 label11 0.01839890045779546 11 | ||
12 label12 -0.30973596547014837 12 | ||
13 label13 -0.6225272742055447 13 | ||
14 label14 -0.19706489417523174 14 | ||
15 label15 -0.3855099705813389 15 | ||
16 label16 -1.7729569214611904 16 | ||
17 label17 1.1227098161292115 17 | ||
18 label18 -0.38865077349594085 18 | ||
19 label19 0.33001487459132467 19 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from pliers.extractors import (MetricExtractor, BertLMExtractor, | ||
merge_results) | ||
from pliers.stimuli import SeriesStim, ComplexTextStim | ||
from pliers.tests.utils import get_test_data_path | ||
import numpy as np | ||
import scipy | ||
from pathlib import Path | ||
import pytest | ||
|
||
def test_metric_extractor(): | ||
|
||
def dummy(array): | ||
return array[0] | ||
|
||
def dummy_list(array): | ||
return array[0], array[1] | ||
|
||
f = Path(get_test_data_path(), 'text', 'test_lexical_dictionary.txt') | ||
stim = SeriesStim(data=np.linspace(1., 4., 20), onset=2., duration=.5) | ||
stim_file = SeriesStim(filename=f, column='frequency', sep='\t', | ||
index_col='text') | ||
|
||
ext_single = MetricExtractor(functions='numpy.mean') | ||
ext_idx = MetricExtractor(functions='numpy.mean', | ||
subset_idx=['for', 'testing', 'text']) | ||
ext_multiple = MetricExtractor(functions=['numpy.mean', 'numpy.min', | ||
scipy.stats.entropy, dummy, | ||
dummy_list]) | ||
ext_names = MetricExtractor(functions=['numpy.mean', 'numpy.min', | ||
scipy.stats.entropy, dummy, | ||
dummy_list, 'tensorflow.reduce_mean'], | ||
var_names=['mean', 'min', 'entropy', | ||
'custom1', 'custom2', 'tf_mean']) | ||
|
||
r = ext_single.transform(stim) | ||
r_file = ext_single.transform(stim_file) | ||
r_file_idx = ext_idx.transform(stim_file) | ||
r_multiple = ext_multiple.transform(stim) | ||
r_names = ext_names.transform(stim) | ||
|
||
r_df = r.to_df() | ||
r_file_df = r_file.to_df() | ||
r_file_idx_df = r_file_idx.to_df() | ||
r_multiple_df = r_multiple.to_df() | ||
r_long = r_multiple.to_df(format='long') | ||
r_names_df = r_names.to_df() | ||
|
||
for res in [r_df, r_file_df, r_multiple_df]: | ||
assert res.shape[0] == 1 | ||
assert r_long.shape[0] == len(ext_multiple.functions) | ||
assert r_df['onset'][0] == 2 | ||
assert r_df['duration'][0] == .5 | ||
assert r_df['mean'][0] == 2.5 | ||
assert np.isclose(r_file_df['mean'][0], 11.388, rtol=0.001) | ||
assert np.isclose(r_file_idx_df['mean'][0], 12.582, rtol=0.001) | ||
assert all([m in r_multiple_df.columns for m in ['mean', 'entropy']]) | ||
assert r_multiple_df['amin'][0] == 1. | ||
assert r_multiple_df['dummy'][0] == 1. | ||
assert r_multiple_df['dummy_list'][0][0] == np.linspace(1., 4., 20)[0] | ||
assert r_multiple_df['dummy_list'][0][1] == np.linspace(1., 4., 20)[1] | ||
assert type(r_multiple_df['dummy_list'][0]) == np.ndarray | ||
assert r_names_df.columns[-3] == 'custom1' | ||
assert r_names_df.columns[-2] == 'custom2' | ||
assert r_names_df.columns[-1] == 'tf_mean' | ||
assert np.isclose(r_names_df['mean'][0], r_names_df['tf_mean'][0]) | ||
|
||
def test_metric_er_as_stim(): | ||
stim = ComplexTextStim(text = 'This is [MASK] test') | ||
ext_bert = BertLMExtractor(return_softmax=True) | ||
ext_metric = MetricExtractor(functions='numpy.sum') | ||
r = ext_metric.transform(ext_bert.transform(stim)) | ||
df = merge_results(r, extractor_names=False) | ||
assert np.isclose(df['sum'][0], 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters