Skip to content

Commit

Permalink
Merge pull request #43 from rundherum/refactor-selectivedataextractor
Browse files Browse the repository at this point in the history
refactor(SelectiveDataExtractor): remove dependency on NamesExtractor…
  • Loading branch information
fabianbalsiger committed Oct 11, 2021
2 parents c87cc56 + 77a1622 commit 081dbe0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
1 change: 1 addition & 0 deletions pymia/data/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def subject_index_to_str(subject_index, nb_subjects):
KEY_SUBJECT_FILES = 'subject_files' #:
KEY_CATEGORIES = 'categories' #:
KEY_PLACEHOLDER_NAMES = '{}_names' #:
KEY_PLACEHOLDER_NAMES_SELECTED = '{}_names_selected' #:
KEY_PLACEHOLDER_PROPERTIES = '{}_properties' #:
KEY_PLACEHOLDER_FILES = '{}_files' #:
KEY_FILE_ROOT = 'file_root' #:
Expand Down
18 changes: 11 additions & 7 deletions pymia/data/extraction/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,14 @@ class SelectiveDataExtractor(Extractor):
def __init__(self, selection=None, category: str = defs.KEY_LABELS) -> None:
"""Extracts data of a given category selectively.
Adds :obj:`category` as key to :obj:`extracted`.
Adds :obj:`category` as key to :obj:`extracted`, as well as
- :const:`pymia.data.definition.KEY_PLACEHOLDER_NAMES_SELECTED` with :obj:`selection` content
Args:
selection (str, tuple): Entries (e.g., "T1", "T2") within the category to select.
If selection is None, the class has the same behaviour as the DataExtractor and selects all entries.
category (str): The category (e.g. "images") to extract data from.
Note:
Requires results of :class:`NamesExtractor` in :obj:`extracted`.
"""
super().__init__()
self.subject_entries = None
Expand All @@ -226,10 +225,14 @@ def __init__(self, selection=None, category: str = defs.KEY_LABELS) -> None:
self.selection = selection
self.category = category

self.names_extractor = None # used in case that the names of the entries of the category are not extracted

def extract(self, reader: rd.Reader, params: dict, extracted: dict) -> None:
"""see :meth:`.Extractor.extract`"""
if defs.KEY_PLACEHOLDER_NAMES.format(self.category) not in extracted:
raise ValueError('selection of labels requires label_names to be extracted (use NamesExtractor)')
if self.names_extractor is None:
self.names_extractor = NamesExtractor(cache=True, categories=(self.category, ))
self.names_extractor.extract(reader, {}, extracted)

if self.subject_entries is None:
self.subject_entries = reader.get_subject_entries()
Expand All @@ -242,13 +245,14 @@ def extract(self, reader: rd.Reader, params: dict, extracted: dict) -> None:

index_str = self.subject_entries[subject_index]
data = reader.read('{}/{}'.format(defs.LOC_DATA_PLACEHOLDER.format(self.category), index_str), index_expr)
label_names = extracted[defs.KEY_PLACEHOLDER_NAMES.format(self.category)] # type: list
entry_names = extracted[defs.KEY_PLACEHOLDER_NAMES.format(self.category)] # type: list

if self.selection is None:
extracted[self.category] = data
else:
selection_indices = np.array([label_names.index(s) for s in self.selection])
selection_indices = np.array([entry_names.index(s) for s in self.selection])
extracted[self.category] = np.take(data, selection_indices, axis=-1)
extracted[defs.KEY_PLACEHOLDER_NAMES_SELECTED.format(self.category)] = list(self.selection)


class RandomDataExtractor(Extractor):
Expand Down

0 comments on commit 081dbe0

Please sign in to comment.