diff --git a/eegdash/api.py b/eegdash/api.py index ca98d3e6..934eb67c 100644 --- a/eegdash/api.py +++ b/eegdash/api.py @@ -212,17 +212,22 @@ def exist(self, query: dict[str, Any]) -> bool: return doc is not None def _validate_input(self, record: dict[str, Any]) -> dict[str, Any]: - """Internal method to validate the input record against the expected schema. + """Validate the input record against the expected schema. Parameters ---------- - record: dict + record : dict A dictionary representing the EEG data record to be validated. Returns ------- - dict: - Returns the record itself on success, or raises a ValueError if the record is invalid. + dict + The record itself on success. + + Raises + ------ + ValueError + If the record is missing required keys or has values of the wrong type. """ input_types = { @@ -252,20 +257,44 @@ def _validate_input(self, record: dict[str, Any]) -> dict[str, Any]: return record def _build_query_from_kwargs(self, **kwargs) -> dict[str, Any]: - """Internal helper to build a validated MongoDB query from keyword args. + """Build a validated MongoDB query from keyword arguments. + + This delegates to the module-level builder used across the package. + + Parameters + ---------- + **kwargs + Keyword arguments to convert into a MongoDB query. + + Returns + ------- + dict + A MongoDB query dictionary. - This delegates to the module-level builder used across the package and - is exposed here for testing and convenience. """ return build_query_from_kwargs(**kwargs) - # --- Query merging and conflict detection helpers --- - def _extract_simple_constraint(self, query: dict[str, Any], key: str): + def _extract_simple_constraint( + self, query: dict[str, Any], key: str + ) -> tuple[str, Any] | None: """Extract a simple constraint for a given key from a query dict. - Supports only top-level equality (key: value) and $in (key: {"$in": [...]}) - constraints. Returns a tuple (kind, value) where kind is "eq" or "in". If the - key is not present or uses other operators, returns None. + Supports top-level equality (e.g., ``{'subject': '01'}``) and ``$in`` + (e.g., ``{'subject': {'$in': ['01', '02']}}``) constraints. + + Parameters + ---------- + query : dict + The MongoDB query dictionary. + key : str + The key for which to extract the constraint. + + Returns + ------- + tuple or None + A tuple of (kind, value) where kind is "eq" or "in", or None if the + constraint is not present or unsupported. + """ if not isinstance(query, dict) or key not in query: return None @@ -275,16 +304,28 @@ def _extract_simple_constraint(self, query: dict[str, Any], key: str): return ("in", list(val["$in"])) return None # unsupported operator shape for conflict checking else: - return ("eq", val) + return "eq", val def _raise_if_conflicting_constraints( self, raw_query: dict[str, Any], kwargs_query: dict[str, Any] ) -> None: - """Raise ValueError if both query sources define incompatible constraints. + """Raise ValueError if query sources have incompatible constraints. + + Checks for mutually exclusive constraints on the same field to avoid + silent empty results. + + Parameters + ---------- + raw_query : dict + The raw MongoDB query dictionary. + kwargs_query : dict + The query dictionary built from keyword arguments. + + Raises + ------ + ValueError + If conflicting constraints are found. - We conservatively check only top-level fields with simple equality or $in - constraints. If a field appears in both queries and constraints are mutually - exclusive, raise an explicit error to avoid silent empty result sets. """ if not raw_query or not kwargs_query: return @@ -388,12 +429,31 @@ def add_bids_dataset( logger.info("Upserted: %s", result.upserted_count) logger.info("Errors: %s ", result.bulk_api_result.get("writeErrors", [])) - def _add_request(self, record: dict): - """Internal helper method to create a MongoDB insertion request for a record.""" + def _add_request(self, record: dict) -> InsertOne: + """Create a MongoDB insertion request for a record. + + Parameters + ---------- + record : dict + The record to insert. + + Returns + ------- + InsertOne + A PyMongo ``InsertOne`` object. + + """ return InsertOne(record) - def add(self, record: dict): - """Add a single record to the MongoDB collection.""" + def add(self, record: dict) -> None: + """Add a single record to the MongoDB collection. + + Parameters + ---------- + record : dict + The record to add. + + """ try: self.__collection.insert_one(record) except ValueError as e: @@ -405,11 +465,23 @@ def add(self, record: dict): ) logger.debug("Add operation failed", exc_info=exc) - def _update_request(self, record: dict): - """Internal helper method to create a MongoDB update request for a record.""" + def _update_request(self, record: dict) -> UpdateOne: + """Create a MongoDB update request for a record. + + Parameters + ---------- + record : dict + The record to update. + + Returns + ------- + UpdateOne + A PyMongo ``UpdateOne`` object. + + """ return UpdateOne({"data_name": record["data_name"]}, {"$set": record}) - def update(self, record: dict): + def update(self, record: dict) -> None: """Update a single record in the MongoDB collection. Parameters @@ -429,58 +501,81 @@ def update(self, record: dict): logger.debug("Update operation failed", exc_info=exc) def exists(self, query: dict[str, Any]) -> bool: - """Alias for :meth:`exist` provided for API clarity.""" + """Check if at least one record matches the query. + + This is an alias for :meth:`exist`. + + Parameters + ---------- + query : dict + MongoDB query to check for existence. + + Returns + ------- + bool + True if a matching record exists, False otherwise. + + """ return self.exist(query) - def remove_field(self, record, field): - """Remove a specific field from a record in the MongoDB collection. + def remove_field(self, record: dict, field: str) -> None: + """Remove a field from a specific record in the MongoDB collection. Parameters ---------- record : dict - Record identifying object with ``data_name``. + Record-identifying object with a ``data_name`` key. field : str - Field name to remove. + The name of the field to remove. """ self.__collection.update_one( {"data_name": record["data_name"]}, {"$unset": {field: 1}} ) - def remove_field_from_db(self, field): - """Remove a field from all records (destructive). + def remove_field_from_db(self, field: str) -> None: + """Remove a field from all records in the database. + + .. warning:: + This is a destructive operation and cannot be undone. Parameters ---------- field : str - Field name to remove from every document. + The name of the field to remove from all documents. """ self.__collection.update_many({}, {"$unset": {field: 1}}) @property def collection(self): - """Return the MongoDB collection object.""" - return self.__collection + """The underlying PyMongo ``Collection`` object. - def close(self): - """Backward-compatibility no-op; connections are managed globally. + Returns + ------- + pymongo.collection.Collection + The collection object used for database interactions. - Notes - ----- - Connections are managed by :class:`MongoConnectionManager`. Use - :meth:`close_all_connections` to explicitly close all clients. + """ + return self.__collection + def close(self) -> None: + """Close the MongoDB connection. + + .. deprecated:: 0.1 + Connections are now managed globally by :class:`MongoConnectionManager`. + This method is a no-op and will be removed in a future version. + Use :meth:`EEGDash.close_all_connections` to close all clients. """ # Individual instances no longer close the shared client pass @classmethod - def close_all_connections(cls): - """Close all MongoDB client connections managed by the singleton.""" + def close_all_connections(cls) -> None: + """Close all MongoDB client connections managed by the singleton manager.""" MongoConnectionManager.close_all() - def __del__(self): + def __del__(self) -> None: """Destructor; no explicit action needed due to global connection manager.""" # No longer needed since we're using singleton pattern pass @@ -775,45 +870,30 @@ def _find_local_bids_records( ) -> list[dict]: """Discover local BIDS EEG files and build minimal records. - This helper enumerates EEG recordings under ``dataset_root`` via - ``mne_bids.find_matching_paths`` and applies entity filters to produce a - list of records suitable for ``EEGDashBaseDataset``. No network access - is performed and files are not read. + Enumerates EEG recordings under ``dataset_root`` using + ``mne_bids.find_matching_paths`` and applies entity filters to produce + records suitable for :class:`EEGDashBaseDataset`. No network access is + performed, and files are not read. Parameters ---------- dataset_root : Path - Local dataset directory. May be the plain dataset folder (e.g., - ``ds005509``) or a suffixed cache variant (e.g., - ``ds005509-bdf-mini``). - filters : dict of {str, Any} - Query filters. Must include ``'dataset'`` with the dataset id (without - local suffixes). May include BIDS entities ``'subject'``, - ``'session'``, ``'task'``, and ``'run'``. Each value can be a scalar - or a sequence of scalars. + Local dataset directory (e.g., ``/path/to/cache/ds005509``). + filters : dict + Query filters. Must include ``'dataset'`` and may include BIDS + entities like ``'subject'``, ``'session'``, etc. Returns ------- - records : list of dict - One record per matched EEG file with at least: - - - ``'data_name'`` - - ``'dataset'`` (dataset id, without suffixes) - - ``'bidspath'`` (normalized to start with the dataset id) - - ``'subject'``, ``'session'``, ``'task'``, ``'run'`` (may be None) - - ``'bidsdependencies'`` (empty list) - - ``'modality'`` (``"eeg"``) - - ``'sampling_frequency'``, ``'nchans'``, ``'ntimes'`` (minimal - defaults for offline usage) + list of dict + A list of records, one for each matched EEG file. Each record + contains BIDS entities, paths, and minimal metadata for offline use. Notes ----- - - Matching uses ``datatypes=['eeg']`` and ``suffixes=['eeg']``. - - ``bidspath`` is constructed as - `` / `` to ensure the - first path component is the dataset id (without local cache suffixes). - - Minimal defaults are set for ``sampling_frequency``, ``nchans``, and - ``ntimes`` to satisfy dataset length requirements offline. + Matching is performed for ``datatypes=['eeg']`` and ``suffixes=['eeg']``. + The ``bidspath`` is normalized to ensure it starts with the dataset ID, + even for suffixed cache directories. """ dataset_id = filters["dataset"] @@ -875,10 +955,22 @@ def _find_local_bids_records( return records_out def _find_key_in_nested_dict(self, data: Any, target_key: str) -> Any: - """Recursively search for target_key in nested dicts/lists with normalized matching. + """Recursively search for a key in nested dicts/lists. + + Performs a case-insensitive and underscore/hyphen-agnostic search. + + Parameters + ---------- + data : Any + The nested data structure (dicts, lists) to search. + target_key : str + The key to search for. + + Returns + ------- + Any + The value of the first matching key, or None if not found. - This makes lookups tolerant to naming differences like "p-factor" vs "p_factor". - Returns the first match or None. """ norm_target = normalize_key(target_key) if isinstance(data, dict): @@ -901,23 +993,25 @@ def _find_datasets( description_fields: list[str], base_dataset_kwargs: dict, ) -> list[EEGDashBaseDataset]: - """Helper method to find datasets in the MongoDB collection that satisfy the - given query and return them as a list of EEGDashBaseDataset objects. + """Find and construct datasets from a MongoDB query. + + Queries the database, then creates a list of + :class:`EEGDashBaseDataset` objects from the results. Parameters ---------- - query : dict - The query object, as in EEGDash.find(). - description_fields : list[str] - A list of fields to be extracted from the dataset records and included in - the returned dataset description(s). - kwargs: additional keyword arguments to be passed to the EEGDashBaseDataset - constructor. + query : dict, optional + The MongoDB query to execute. + description_fields : list of str + Fields to extract from each record for the dataset description. + base_dataset_kwargs : dict + Additional keyword arguments to pass to the + :class:`EEGDashBaseDataset` constructor. Returns ------- - list : - A list of EEGDashBaseDataset objects that match the query. + list of EEGDashBaseDataset + A list of dataset objects matching the query. """ datasets: list[EEGDashBaseDataset] = [] diff --git a/eegdash/bids_eeg_metadata.py b/eegdash/bids_eeg_metadata.py index 150aed34..bc8648fa 100644 --- a/eegdash/bids_eeg_metadata.py +++ b/eegdash/bids_eeg_metadata.py @@ -33,12 +33,30 @@ def build_query_from_kwargs(**kwargs) -> dict[str, Any]: - """Build and validate a MongoDB query from user-friendly keyword arguments. + """Build and validate a MongoDB query from keyword arguments. + + This function converts user-friendly keyword arguments into a valid + MongoDB query dictionary. It handles scalar values as exact matches and + list-like values as ``$in`` queries. It also performs validation to + reject unsupported fields and empty values. + + Parameters + ---------- + **kwargs + Keyword arguments representing query filters. Allowed keys are defined + in ``eegdash.const.ALLOWED_QUERY_FIELDS``. + + Returns + ------- + dict + A MongoDB query dictionary. + + Raises + ------ + ValueError + If an unsupported query field is provided, or if a value is None or + an empty string/list. - Improvements: - - Reject None values and empty/whitespace-only strings - - For list/tuple/set values: strip strings, drop None/empties, deduplicate, and use `$in` - - Preserve scalars as exact matches """ # 1. Validate that all provided keys are allowed for querying unknown_fields = set(kwargs.keys()) - ALLOWED_QUERY_FIELDS @@ -89,24 +107,29 @@ def build_query_from_kwargs(**kwargs) -> dict[str, Any]: def load_eeg_attrs_from_bids_file(bids_dataset, bids_file: str) -> dict[str, Any]: - """Build the metadata record for a given BIDS file (single recording) in a BIDS dataset. + """Build a metadata record for a BIDS file. - Attributes are at least the ones defined in data_config attributes (set to None if missing), - but are typically a superset, and include, among others, the paths to relevant - meta-data files needed to load and interpret the file in question. + Extracts metadata attributes from a single BIDS EEG file within a given + BIDS dataset. The extracted attributes include BIDS entities, file paths, + and technical metadata required for database indexing. Parameters ---------- bids_dataset : EEGBIDSDataset The BIDS dataset object containing the file. bids_file : str - The path to the BIDS file within the dataset. + The path to the BIDS file to process. Returns ------- - dict: - A dictionary representing the metadata record for the given file. This is the - same format as the records stored in the database. + dict + A dictionary of metadata attributes for the file, suitable for + insertion into the database. + + Raises + ------ + ValueError + If ``bids_file`` is not found in the ``bids_dataset``. """ if bids_file not in bids_dataset.files: @@ -198,11 +221,23 @@ def load_eeg_attrs_from_bids_file(bids_dataset, bids_file: str) -> dict[str, Any def normalize_key(key: str) -> str: - """Normalize a metadata key for robust matching. + """Normalize a string key for robust matching. + + Converts the key to lowercase, replaces non-alphanumeric characters with + underscores, and removes leading/trailing underscores. This allows for + tolerant matching of keys that may have different capitalization or + separators (e.g., "p-factor" becomes "p_factor"). + + Parameters + ---------- + key : str + The key to normalize. + + Returns + ------- + str + The normalized key. - Lowercase and replace non-alphanumeric characters with underscores, then strip - leading/trailing underscores. This allows tolerant matching such as - "p-factor" ≈ "p_factor" ≈ "P Factor". """ return re.sub(r"[^a-z0-9]+", "_", str(key).lower()).strip("_") @@ -212,27 +247,27 @@ def merge_participants_fields( participants_row: dict[str, Any] | None, description_fields: list[str] | None = None, ) -> dict[str, Any]: - """Merge participants.tsv fields into a dataset description dictionary. + """Merge fields from a participants.tsv row into a description dict. - - Preserves existing entries in ``description`` (no overwrites). - - Fills requested ``description_fields`` first, preserving their original names. - - Adds all remaining participants columns generically using normalized keys - unless a matching requested field already captured them. + Enriches a description dictionary with data from a subject's row in + ``participants.tsv``. It avoids overwriting existing keys in the + description. Parameters ---------- description : dict - Current description to be enriched in-place and returned. - participants_row : dict | None - A mapping of participants.tsv columns for the current subject. - description_fields : list[str] | None - Optional list of requested description fields. When provided, matching is - performed by normalized names; the original requested field names are kept. + The description dictionary to enrich. + participants_row : dict or None + A dictionary representing a row from ``participants.tsv``. If None, + the original description is returned unchanged. + description_fields : list of str, optional + A list of specific fields to include in the description. Matching is + done using normalized keys. Returns ------- dict - The enriched description (same object as input for convenience). + The enriched description dictionary. """ if not isinstance(description, dict) or not isinstance(participants_row, dict): @@ -272,10 +307,26 @@ def participants_row_for_subject( subject: str, id_columns: tuple[str, ...] = ("participant_id", "participant", "subject"), ) -> pd.Series | None: - """Load participants.tsv and return the row for a subject. + """Load participants.tsv and return the row for a specific subject. + + Searches for a subject's data in the ``participants.tsv`` file within a + BIDS dataset. It can identify the subject with or without the "sub-" + prefix. + + Parameters + ---------- + bids_root : str or Path + The root directory of the BIDS dataset. + subject : str + The subject identifier (e.g., "01" or "sub-01"). + id_columns : tuple of str, default ("participant_id", "participant", "subject") + A tuple of column names to search for the subject identifier. + + Returns + ------- + pandas.Series or None + A pandas Series containing the subject's data if found, otherwise None. - - Accepts either "01" or "sub-01" as the subject identifier. - - Returns a pandas Series for the first matching row, or None if not found. """ try: participants_tsv = Path(bids_root) / "participants.tsv" @@ -311,9 +362,28 @@ def participants_extras_from_tsv( id_columns: tuple[str, ...] = ("participant_id", "participant", "subject"), na_like: tuple[str, ...] = ("", "n/a", "na", "nan", "unknown", "none"), ) -> dict[str, Any]: - """Return non-identifier, non-empty participants.tsv fields for a subject. + """Extract additional participant information from participants.tsv. + + Retrieves all non-identifier and non-empty fields for a subject from + the ``participants.tsv`` file. + + Parameters + ---------- + bids_root : str or Path + The root directory of the BIDS dataset. + subject : str + The subject identifier. + id_columns : tuple of str, default ("participant_id", "participant", "subject") + Column names to be treated as identifiers and excluded from the + output. + na_like : tuple of str, default ("", "n/a", "na", "nan", "unknown", "none") + Values to be considered as "Not Available" and excluded. + + Returns + ------- + dict + A dictionary of extra participant information. - Uses vectorized pandas operations to drop id columns and NA-like values. """ row = participants_row_for_subject(bids_root, subject, id_columns=id_columns) if row is None: @@ -331,10 +401,21 @@ def attach_participants_extras( description: Any, extras: dict[str, Any], ) -> None: - """Attach extras to Raw.info and dataset description without overwriting. + """Attach extra participant data to a raw object and its description. + + Updates the ``raw.info['subject_info']`` and the description object + (dict or pandas Series) with extra data from ``participants.tsv``. + It does not overwrite existing keys. + + Parameters + ---------- + raw : mne.io.Raw + The MNE Raw object to be updated. + description : dict or pandas.Series + The description object to be updated. + extras : dict + A dictionary of extra participant information to attach. - - Adds to ``raw.info['subject_info']['participants_extras']``. - - Adds to ``description`` if dict or pandas Series (only missing keys). """ if not extras: return @@ -375,9 +456,28 @@ def enrich_from_participants( raw: Any, description: Any, ) -> dict[str, Any]: - """Convenience wrapper: read participants.tsv and attach extras for this subject. + """Read participants.tsv and attach extra info for the subject. + + This is a convenience function that finds the subject from the + ``bidspath``, retrieves extra information from ``participants.tsv``, + and attaches it to the raw object and its description. + + Parameters + ---------- + bids_root : str or Path + The root directory of the BIDS dataset. + bidspath : mne_bids.BIDSPath + The BIDSPath object for the current data file. + raw : mne.io.Raw + The MNE Raw object to be updated. + description : dict or pandas.Series + The description object to be updated. + + Returns + ------- + dict + The dictionary of extras that were attached. - Returns the extras dictionary for further use if needed. """ subject = getattr(bidspath, "subject", None) if not subject: diff --git a/eegdash/const.py b/eegdash/const.py index 897575fe..447c89d9 100644 --- a/eegdash/const.py +++ b/eegdash/const.py @@ -28,6 +28,8 @@ "nchans", "ntimes", } +"""set: A set of field names that are permitted in database queries constructed +via :func:`~eegdash.api.EEGDash.find` with keyword arguments.""" RELEASE_TO_OPENNEURO_DATASET_MAP = { "R11": "ds005516", @@ -42,6 +44,8 @@ "R2": "ds005506", "R1": "ds005505", } +"""dict: A mapping from Healthy Brain Network (HBN) release identifiers (e.g., "R11") +to their corresponding OpenNeuro dataset identifiers (e.g., "ds005516").""" SUBJECT_MINI_RELEASE_MAP = { "R11": [ @@ -287,6 +291,9 @@ "NDARFW972KFQ", ], } +"""dict: A mapping from HBN release identifiers to a list of subject IDs. +This is used to select a small, representative subset of subjects for creating +"mini" datasets for testing and demonstration purposes.""" config = { "required_fields": ["data_name"], @@ -322,3 +329,21 @@ ], "accepted_query_fields": ["data_name", "dataset"], } +"""dict: A global configuration dictionary for the EEGDash package. + +Keys +---- +required_fields : list + Fields that must be present in every database record. +attributes : dict + A schema defining the expected primary attributes and their types for a + database record. +description_fields : list + A list of fields considered to be descriptive metadata for a recording, + which can be used for filtering and display. +bids_dependencies_files : list + A list of BIDS metadata filenames that are relevant for interpreting an + EEG recording. +accepted_query_fields : list + Fields that are accepted for lightweight existence checks in the database. +""" diff --git a/eegdash/data_utils.py b/eegdash/data_utils.py index 2cc18ca8..199d94d2 100644 --- a/eegdash/data_utils.py +++ b/eegdash/data_utils.py @@ -37,10 +37,26 @@ class EEGDashBaseDataset(BaseDataset): - """A single EEG recording hosted on AWS S3 and cached locally upon first access. + """A single EEG recording dataset. + + Represents a single EEG recording, typically hosted on a remote server (like AWS S3) + and cached locally upon first access. This class is a subclass of + :class:`braindecode.datasets.BaseDataset` and can be used with braindecode's + preprocessing and training pipelines. + + Parameters + ---------- + record : dict + A fully resolved metadata record for the data to load. + cache_dir : str + The local directory where the data will be cached. + s3_bucket : str, optional + The S3 bucket to download data from. If not provided, defaults to the + OpenNeuro bucket. + **kwargs + Additional keyword arguments passed to the + :class:`braindecode.datasets.BaseDataset` constructor. - This is a subclass of braindecode's BaseDataset, which can consequently be used in - conjunction with the preprocessing and training pipelines of braindecode. """ _AWS_BUCKET = "s3://openneuro.org" @@ -52,20 +68,6 @@ def __init__( s3_bucket: str | None = None, **kwargs, ): - """Create a new EEGDashBaseDataset instance. Users do not usually need to call this - directly -- instead use the EEGDashDataset class to load a collection of these - recordings from a local BIDS folder or using a database query. - - Parameters - ---------- - record : dict - A fully resolved metadata record for the data to load. - cache_dir : str - A local directory where the data will be cached. - kwargs : dict - Additional keyword arguments to pass to the BaseDataset constructor. - - """ super().__init__(None, **kwargs) self.record = record self.cache_dir = Path(cache_dir) @@ -121,14 +123,12 @@ def __init__( self._raw = None def _get_raw_bids_args(self) -> dict[str, Any]: - """Helper to restrict the metadata record to the fields needed to locate a BIDS - recording. - """ + """Extract BIDS-related arguments from the metadata record.""" desired_fields = ["subject", "session", "task", "run"] return {k: self.record[k] for k in desired_fields if self.record[k]} def _ensure_raw(self) -> None: - """Download the S3 file and BIDS dependencies if not already cached.""" + """Ensure the raw data file and its dependencies are cached locally.""" # TO-DO: remove this once is fixed on the our side # for the competition if not self.s3_open_neuro: @@ -190,42 +190,53 @@ def __len__(self) -> int: return len(self._raw) @property - def raw(self): - """Return the MNE Raw object for this recording. This will perform the actual - retrieval if not yet done so. + def raw(self) -> BaseRaw: + """The MNE Raw object for this recording. + + Accessing this property triggers the download and caching of the data + if it has not been accessed before. + + Returns + ------- + mne.io.BaseRaw + The loaded MNE Raw object. + """ if self._raw is None: self._ensure_raw() return self._raw @raw.setter - def raw(self, raw): + def raw(self, raw: BaseRaw): self._raw = raw class EEGDashBaseRaw(BaseRaw): - """Wrapper around the MNE BaseRaw class that automatically fetches the data from S3 - (when _read_segment is called) and caches it locally. Currently for internal use. + """MNE BaseRaw wrapper for automatic S3 data fetching. + + This class extends :class:`mne.io.BaseRaw` to automatically fetch data + from an S3 bucket and cache it locally when data is first accessed. + It is intended for internal use within the EEGDash ecosystem. Parameters ---------- - input_fname : path-like - Path to the S3 file + input_fname : str + The path to the file on the S3 bucket (relative to the bucket root). metadata : dict - The metadata record for the recording (e.g., from the database). - preload : bool - Whether to pre-loaded the data before the first access. - cache_dir : str - Local path under which the data will be cached. - bids_dependencies : list - List of additional BIDS metadata files that should be downloaded and cached - alongside the main recording file. - verbose : str | int | None - Optionally the verbosity level for MNE logging (see MNE documentation for possible values). + The metadata record for the recording, containing information like + sampling frequency, channel names, etc. + preload : bool, default False + If True, preload the data into memory. + cache_dir : str, optional + Local directory for caching data. If None, a default directory is used. + bids_dependencies : list of str, default [] + A list of BIDS metadata files to download alongside the main recording. + verbose : str, int, or None, default None + The MNE verbosity level. See Also -------- - mne.io.Raw : Documentation of attributes and methods. + mne.io.Raw : The base class for Raw objects in MNE. """ @@ -241,7 +252,6 @@ def __init__( bids_dependencies: list[str] = [], verbose: Any = None, ): - """Get to work with S3 endpoint first, no caching""" # Create a simple RawArray sfreq = metadata["sfreq"] # Sampling frequency n_times = metadata["n_times"] @@ -277,6 +287,7 @@ def __init__( def _read_segment( self, start=0, stop=None, sel=None, data_buffer=None, *, verbose=None ): + """Read a segment of data, downloading if necessary.""" if not os.path.exists(self.filecache): # not preload if self.bids_dependencies: # this is use only to sidecars for now downloader.download_dependencies( @@ -297,22 +308,23 @@ def _read_segment( return super()._read_segment(start, stop, sel, data_buffer, verbose=verbose) def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): - """Read a chunk of data from the file.""" + """Read a chunk of data from a local file.""" _read_segments_file(self, data, idx, fi, start, stop, cals, mult, dtype=" bool: - """Check if the dataset is EEG.""" + """Check if the BIDS dataset contains EEG data. + + Returns + ------- + bool + True if the dataset's modality is EEG, False otherwise. + + """ return self.get_bids_file_attribute("modality", self.files[0]).lower() == "eeg" def _get_recordings(self, layout: BIDSLayout) -> list[str]: @@ -370,14 +389,12 @@ def _get_recordings(self, layout: BIDSLayout) -> list[str]: return files def _get_relative_bidspath(self, filename: str) -> str: - """Make the given file path relative to the BIDS directory.""" + """Make a file path relative to the BIDS parent directory.""" bids_parent_dir = self.bidsdir.parent.absolute() return str(Path(filename).relative_to(bids_parent_dir)) def _get_property_from_filename(self, property: str, filename: str) -> str: - """Parse a property out of a BIDS-compliant filename. Returns an empty string - if not found. - """ + """Parse a BIDS entity from a filename.""" import platform if platform.system() == "Windows": @@ -387,159 +404,106 @@ def _get_property_from_filename(self, property: str, filename: str) -> str: return lookup.group(1) if lookup else "" def _merge_json_inheritance(self, json_files: list[str | Path]) -> dict: - """Internal helper to merge list of json files found by get_bids_file_inheritance, - expecting the order (from left to right) is from lowest - level to highest level, and return a merged dictionary - """ + """Merge a list of JSON files according to BIDS inheritance.""" json_files.reverse() json_dict = {} for f in json_files: - json_dict.update(json.load(open(f))) # FIXME: should close file + with open(f) as fp: + json_dict.update(json.load(fp)) return json_dict def _get_bids_file_inheritance( self, path: str | Path, basename: str, extension: str ) -> list[Path]: - """Get all file paths that apply to the basename file in the specified directory - and that end with the specified suffix, recursively searching parent directories - (following the BIDS inheritance principle in the order of lowest level first). - - Parameters - ---------- - path : str | Path - The directory path to search for files. - basename : str - BIDS file basename without _eeg.set extension for example - extension : str - Only consider files that end with the specified suffix; e.g. channels.tsv - - Returns - ------- - list[Path] - A list of file paths that match the given basename and extension. - - """ + """Find all applicable metadata files using BIDS inheritance.""" top_level_files = ["README", "dataset_description.json", "participants.tsv"] bids_files = [] - # check if path is str object if isinstance(path, str): path = Path(path) - if not path.exists: - raise ValueError("path {path} does not exist") + if not path.exists(): + raise ValueError(f"path {path} does not exist") - # check if file is in current path for file in os.listdir(path): - # target_file = path / f"{cur_file_basename}_{extension}" - if os.path.isfile(path / file): - # check if file has extension extension - # check if file basename has extension - if file.endswith(extension): - filepath = path / file - bids_files.append(filepath) - - # check if file is in top level directory + if os.path.isfile(path / file) and file.endswith(extension): + bids_files.append(path / file) + if any(file in os.listdir(path) for file in top_level_files): return bids_files else: - # call get_bids_file_inheritance recursively with parent directory bids_files.extend( self._get_bids_file_inheritance(path.parent, basename, extension) ) return bids_files def get_bids_metadata_files( - self, filepath: str | Path, metadata_file_extension: list[str] + self, filepath: str | Path, metadata_file_extension: str ) -> list[Path]: - """Retrieve all metadata file paths that apply to a given data file path and that - end with a specific suffix (following the BIDS inheritance principle). + """Retrieve all metadata files that apply to a given data file. + + Follows the BIDS inheritance principle to find all relevant metadata + files (e.g., ``channels.tsv``, ``eeg.json``) for a specific recording. Parameters ---------- - filepath: str | Path - The filepath to get the associated metadata files for. + filepath : str or Path + The path to the data file. metadata_file_extension : str - Consider only metadata files that end with the specified suffix, - e.g., channels.tsv or eeg.json + The extension of the metadata file to search for (e.g., "channels.tsv"). Returns ------- - list[Path]: - A list of filepaths for all matching metadata files + list of Path + A list of paths to the matching metadata files. """ if isinstance(filepath, str): filepath = Path(filepath) - if not filepath.exists: - raise ValueError("filepath {filepath} does not exist") + if not filepath.exists(): + raise ValueError(f"filepath {filepath} does not exist") path, filename = os.path.split(filepath) basename = filename[: filename.rfind("_")] - # metadata files meta_files = self._get_bids_file_inheritance( path, basename, metadata_file_extension ) return meta_files def _scan_directory(self, directory: str, extension: str) -> list[Path]: - """Return a list of file paths that end with the given extension in the specified - directory. Ignores certain special directories like .git, .datalad, derivatives, - and code. - """ + """Scan a directory for files with a given extension.""" result_files = [] directory_to_ignore = [".git", ".datalad", "derivatives", "code"] with os.scandir(directory) as entries: for entry in entries: if entry.is_file() and entry.name.endswith(extension): - result_files.append(entry.path) - elif entry.is_dir(): - # check that entry path doesn't contain any name in ignore list - if not any(name in entry.name for name in directory_to_ignore): - result_files.append(entry.path) # Add directory to scan later + result_files.append(Path(entry.path)) + elif entry.is_dir() and not any( + name in entry.name for name in directory_to_ignore + ): + result_files.append(Path(entry.path)) return result_files def _get_files_with_extension_parallel( self, directory: str, extension: str = ".set", max_workers: int = -1 ) -> list[Path]: - """Efficiently scan a directory and its subdirectories for files that end with - the given extension. - - Parameters - ---------- - directory : str - The root directory to scan for files. - extension : str - Only consider files that end with this suffix, e.g. '.set'. - max_workers : int - Optionally specify the maximum number of worker threads to use for parallel scanning. - Defaults to all available CPU cores if set to -1. - - Returns - ------- - list[Path]: - A list of filepaths for all matching metadata files - - """ + """Scan a directory tree in parallel for files with a given extension.""" result_files = [] dirs_to_scan = [directory] - # Use joblib.Parallel and delayed to parallelize directory scanning while dirs_to_scan: logger.info( f"Directories to scan: {len(dirs_to_scan)}, files: {dirs_to_scan}" ) - # Run the scan_directory function in parallel across directories results = Parallel(n_jobs=max_workers, prefer="threads", verbose=1)( delayed(self._scan_directory)(d, extension) for d in dirs_to_scan ) - # Reset the directories to scan and process the results dirs_to_scan = [] for res in results: for path in res: if os.path.isdir(path): - dirs_to_scan.append(path) # Queue up subdirectories to scan + dirs_to_scan.append(path) else: - result_files.append(path) # Add files to the final result + result_files.append(path) logger.info(f"Found {len(result_files)} files.") return result_files @@ -547,19 +511,29 @@ def _get_files_with_extension_parallel( def load_and_preprocess_raw( self, raw_file: str, preprocess: bool = False ) -> np.ndarray: - """Utility function to load a raw data file with MNE and apply some simple - (hardcoded) preprocessing and return as a numpy array. Not meant for purposes - other than testing or debugging. + """Load and optionally preprocess a raw data file. + + This is a utility function for testing or debugging, not for general use. + + Parameters + ---------- + raw_file : str + Path to the raw EEGLAB file (.set). + preprocess : bool, default False + If True, apply a high-pass filter, notch filter, and resample the data. + + Returns + ------- + numpy.ndarray + The loaded and processed data as a NumPy array. + """ logger.info(f"Loading raw data from {raw_file}") EEG = mne.io.read_raw_eeglab(raw_file, preload=True, verbose="error") if preprocess: - # highpass filter EEG = EEG.filter(l_freq=0.25, h_freq=25, verbose=False) - # remove 60Hz line noise EEG = EEG.notch_filter(freqs=(60), verbose=False) - # bring to common sampling rate sfreq = 128 if EEG.info["sfreq"] != sfreq: EEG = EEG.resample(sfreq) @@ -570,26 +544,35 @@ def load_and_preprocess_raw( raise ValueError("Expect raw data to be CxT dimension") return mat_data - def get_files(self) -> list[Path]: - """Get all EEG recording file paths (with valid extensions) in the BIDS folder.""" + def get_files(self) -> list[str]: + """Get all EEG recording file paths in the BIDS dataset. + + Returns + ------- + list of str + A list of file paths for all valid EEG recordings. + + """ return self.files def resolve_bids_json(self, json_files: list[str]) -> dict: - """Resolve the BIDS JSON files and return a dictionary of the resolved values. + """Resolve BIDS JSON inheritance and merge files. Parameters ---------- - json_files : list - A list of JSON file paths to resolve in order of leaf level first. + json_files : list of str + A list of JSON file paths, ordered from the lowest (most specific) + to highest level of the BIDS hierarchy. Returns ------- - dict: A dictionary of the resolved values. + dict + A dictionary containing the merged JSON data. """ - if len(json_files) == 0: + if not json_files: raise ValueError("No JSON files provided") - json_files.reverse() # TODO undeterministic + json_files.reverse() json_dict = {} for json_file in json_files: @@ -598,8 +581,20 @@ def resolve_bids_json(self, json_files: list[str]) -> dict: return json_dict def get_bids_file_attribute(self, attribute: str, data_filepath: str) -> Any: - """Retrieve a specific attribute from the BIDS file metadata applicable - to the provided recording file path. + """Retrieve a specific attribute from BIDS metadata. + + Parameters + ---------- + attribute : str + The name of the attribute to retrieve (e.g., "sfreq", "subject"). + data_filepath : str + The path to the data file. + + Returns + ------- + Any + The value of the requested attribute, or None if not found. + """ entities = self.layout.parse_file_entities(data_filepath) bidsfile = self.layout.get(**entities)[0] @@ -618,21 +613,59 @@ def get_bids_file_attribute(self, attribute: str, data_filepath: str) -> Any: return attribute_value def channel_labels(self, data_filepath: str) -> list[str]: - """Get a list of channel labels for the given data file path.""" + """Get a list of channel labels from channels.tsv. + + Parameters + ---------- + data_filepath : str + The path to the data file. + + Returns + ------- + list of str + A list of channel names. + + """ channels_tsv = pd.read_csv( self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t" ) return channels_tsv["name"].tolist() def channel_types(self, data_filepath: str) -> list[str]: - """Get a list of channel types for the given data file path.""" + """Get a list of channel types from channels.tsv. + + Parameters + ---------- + data_filepath : str + The path to the data file. + + Returns + ------- + list of str + A list of channel types. + + """ channels_tsv = pd.read_csv( self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t" ) return channels_tsv["type"].tolist() def num_times(self, data_filepath: str) -> int: - """Get the approximate number of time points in the EEG recording based on the BIDS metadata.""" + """Get the number of time points in the recording. + + Calculated from ``SamplingFrequency`` and ``RecordingDuration`` in eeg.json. + + Parameters + ---------- + data_filepath : str + The path to the data file. + + Returns + ------- + int + The approximate number of time points. + + """ eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json") eeg_json_dict = self._merge_json_inheritance(eeg_jsons) return int( @@ -640,38 +673,71 @@ def num_times(self, data_filepath: str) -> int: ) def subject_participant_tsv(self, data_filepath: str) -> dict[str, Any]: - """Get BIDS participants.tsv record for the subject to which the given file - path corresponds, as a dictionary. + """Get the participants.tsv record for a subject. + + Parameters + ---------- + data_filepath : str + The path to a data file belonging to the subject. + + Returns + ------- + dict + A dictionary of the subject's information from participants.tsv. + """ - participants_tsv = pd.read_csv( - self.get_bids_metadata_files(data_filepath, "participants.tsv")[0], sep="\t" - ) - # if participants_tsv is not empty + participants_tsv_path = self.get_bids_metadata_files( + data_filepath, "participants.tsv" + )[0] + participants_tsv = pd.read_csv(participants_tsv_path, sep="\t") if participants_tsv.empty: return {} - # set 'participant_id' as index participants_tsv.set_index("participant_id", inplace=True) subject = f"sub-{self.get_bids_file_attribute('subject', data_filepath)}" return participants_tsv.loc[subject].to_dict() def eeg_json(self, data_filepath: str) -> dict[str, Any]: - """Get BIDS eeg.json metadata for the given data file path.""" + """Get the merged eeg.json metadata for a data file. + + Parameters + ---------- + data_filepath : str + The path to the data file. + + Returns + ------- + dict + The merged eeg.json metadata. + + """ eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json") - eeg_json_dict = self._merge_json_inheritance(eeg_jsons) - return eeg_json_dict + return self._merge_json_inheritance(eeg_jsons) def channel_tsv(self, data_filepath: str) -> dict[str, Any]: - """Get BIDS channels.tsv metadata for the given data file path, as a dictionary - of lists and/or single values. + """Get the channels.tsv metadata as a dictionary. + + Parameters + ---------- + data_filepath : str + The path to the data file. + + Returns + ------- + dict + The channels.tsv data, with columns as keys. + """ - channels_tsv = pd.read_csv( - self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t" - ) - channel_tsv = channels_tsv.to_dict() - # 'name' and 'type' now have a dictionary of index-value. Convert them to list + channels_tsv_path = self.get_bids_metadata_files(data_filepath, "channels.tsv")[ + 0 + ] + channels_tsv = pd.read_csv(channels_tsv_path, sep="\t") + channel_tsv_dict = channels_tsv.to_dict() for list_field in ["name", "type", "units"]: - channel_tsv[list_field] = list(channel_tsv[list_field].values()) - return channel_tsv + if list_field in channel_tsv_dict: + channel_tsv_dict[list_field] = list( + channel_tsv_dict[list_field].values() + ) + return channel_tsv_dict __all__ = ["EEGDashBaseDataset", "EEGBIDSDataset", "EEGDashBaseRaw"] diff --git a/eegdash/dataset/dataset.py b/eegdash/dataset/dataset.py index 480f2762..e26e32fe 100644 --- a/eegdash/dataset/dataset.py +++ b/eegdash/dataset/dataset.py @@ -12,26 +12,48 @@ class EEGChallengeDataset(EEGDashDataset): - """EEG 2025 Challenge dataset helper. + """A dataset helper for the EEG 2025 Challenge. - This class provides a convenient wrapper around :class:`EEGDashDataset` - configured for the EEG 2025 Challenge releases. It maps a given - ``release`` to its corresponding OpenNeuro dataset and optionally restricts - to the official "mini" subject subset. + This class simplifies access to the EEG 2025 Challenge datasets. It is a + specialized version of :class:`~eegdash.api.EEGDashDataset` that is + pre-configured for the challenge's data releases. It automatically maps a + release name (e.g., "R1") to the corresponding OpenNeuro dataset and handles + the selection of subject subsets (e.g., "mini" release). Parameters ---------- release : str - Release name. One of ["R1", ..., "R11"]. + The name of the challenge release to load. Must be one of the keys in + :const:`~eegdash.const.RELEASE_TO_OPENNEURO_DATASET_MAP` + (e.g., "R1", "R2", ..., "R11"). + cache_dir : str + The local directory where the dataset will be downloaded and cached. mini : bool, default True - If True, restrict subjects to the challenge mini subset. - query : dict | None - Additional MongoDB-style filters to AND with the release selection. - Must not contain the key ``dataset``. - s3_bucket : str | None, default "s3://nmdatasets/NeurIPS25" - Base S3 bucket used to locate the challenge data. + If True, the dataset is restricted to the official "mini" subset of + subjects for the specified release. If False, all subjects for the + release are included. + query : dict, optional + An additional MongoDB-style query to apply as a filter. This query is + combined with the release and subject filters using a logical AND. + The query must not contain the ``dataset`` key, as this is determined + by the ``release`` parameter. + s3_bucket : str, optional + The base S3 bucket URI where the challenge data is stored. Defaults to + the official challenge bucket. **kwargs - Passed through to :class:`EEGDashDataset`. + Additional keyword arguments that are passed directly to the + :class:`~eegdash.api.EEGDashDataset` constructor. + + Raises + ------ + ValueError + If the specified ``release`` is unknown, or if the ``query`` argument + contains a ``dataset`` key. Also raised if ``mini`` is True and a + requested subject is not part of the official mini-release subset. + + See Also + -------- + EEGDashDataset : The base class for creating datasets from queries. """ diff --git a/eegdash/dataset/registry.py b/eegdash/dataset/registry.py index 5eeacb1e..0d40db9c 100644 --- a/eegdash/dataset/registry.py +++ b/eegdash/dataset/registry.py @@ -14,7 +14,35 @@ def register_openneuro_datasets( namespace: Dict[str, Any] | None = None, add_to_all: bool = True, ) -> Dict[str, type]: - """Dynamically create dataset classes from a summary file.""" + """Dynamically create and register dataset classes from a summary file. + + This function reads a CSV file containing summaries of OpenNeuro datasets + and dynamically creates a Python class for each dataset. These classes + inherit from a specified base class and are pre-configured with the + dataset's ID. + + Parameters + ---------- + summary_file : str or pathlib.Path + The path to the CSV file containing the dataset summaries. + base_class : type, optional + The base class from which the new dataset classes will inherit. If not + provided, :class:`eegdash.api.EEGDashDataset` is used. + namespace : dict, optional + The namespace (e.g., `globals()`) into which the newly created classes + will be injected. Defaults to the local `globals()` of this module. + add_to_all : bool, default True + If True, the names of the newly created classes are added to the + `__all__` list of the target namespace, making them importable with + `from ... import *`. + + Returns + ------- + dict[str, type] + A dictionary mapping the names of the registered classes to the class + types themselves. + + """ if base_class is None: from ..api import EEGDashDataset as base_class # lazy import @@ -84,8 +112,28 @@ def __init__( return registered -def _generate_rich_docstring(dataset_id: str, row_series: pd.Series, base_class) -> str: - """Generate a comprehensive docstring for a dataset class.""" +def _generate_rich_docstring( + dataset_id: str, row_series: pd.Series, base_class: type +) -> str: + """Generate a comprehensive, well-formatted docstring for a dataset class. + + Parameters + ---------- + dataset_id : str + The identifier of the dataset (e.g., "ds002718"). + row_series : pandas.Series + A pandas Series containing the metadata for the dataset, extracted + from the summary CSV file. + base_class : type + The base class from which the new dataset class inherits. Used to + generate the "See Also" section of the docstring. + + Returns + ------- + str + A formatted docstring. + + """ # Extract metadata with safe defaults n_subjects = row_series.get("n_subjects", "Unknown") n_records = row_series.get("n_records", "Unknown") @@ -173,7 +221,24 @@ def _generate_rich_docstring(dataset_id: str, row_series: pd.Series, base_class) def _markdown_table(row_series: pd.Series) -> str: - """Create a reStructuredText grid table from a pandas Series.""" + """Create a reStructuredText grid table from a pandas Series. + + This helper function takes a pandas Series containing dataset metadata + and formats it into a reStructuredText grid table for inclusion in + docstrings. + + Parameters + ---------- + row_series : pandas.Series + A Series where each index is a metadata field and each value is the + corresponding metadata value. + + Returns + ------- + str + A string containing the formatted reStructuredText table. + + """ if row_series.empty: return "" dataset_id = row_series["dataset"] diff --git a/eegdash/downloader.py b/eegdash/downloader.py index efbbc12d..70eb31d3 100644 --- a/eegdash/downloader.py +++ b/eegdash/downloader.py @@ -17,18 +17,62 @@ from fsspec.callbacks import TqdmCallback -def get_s3_filesystem(): - """Returns an S3FileSystem object.""" +def get_s3_filesystem() -> s3fs.S3FileSystem: + """Get an anonymous S3 filesystem object. + + Initializes and returns an ``s3fs.S3FileSystem`` for anonymous access + to public S3 buckets, configured for the 'us-east-2' region. + + Returns + ------- + s3fs.S3FileSystem + An S3 filesystem object. + + """ return s3fs.S3FileSystem(anon=True, client_kwargs={"region_name": "us-east-2"}) def get_s3path(s3_bucket: str, filepath: str) -> str: - """Helper to form an AWS S3 URI for the given relative filepath.""" + """Construct an S3 URI from a bucket and file path. + + Parameters + ---------- + s3_bucket : str + The S3 bucket name (e.g., "s3://my-bucket"). + filepath : str + The path to the file within the bucket. + + Returns + ------- + str + The full S3 URI (e.g., "s3://my-bucket/path/to/file"). + + """ return f"{s3_bucket}/{filepath}" -def download_s3_file(s3_path: str, local_path: Path, s3_open_neuro: bool): - """Download function that gets the raw EEG data from S3.""" +def download_s3_file(s3_path: str, local_path: Path, s3_open_neuro: bool) -> Path: + """Download a single file from S3 to a local path. + + Handles the download of a raw EEG data file from an S3 bucket, caching it + at the specified local path. Creates parent directories if they do not exist. + + Parameters + ---------- + s3_path : str + The full S3 URI of the file to download. + local_path : pathlib.Path + The local file path where the downloaded file will be saved. + s3_open_neuro : bool + A flag indicating if the S3 bucket is the OpenNeuro main bucket, which + may affect path handling. + + Returns + ------- + pathlib.Path + The local path to the downloaded file. + + """ filesystem = get_s3_filesystem() if not s3_open_neuro: s3_path = re.sub(r"(^|/)ds\d{6}/", r"\1", s3_path, count=1) @@ -51,8 +95,31 @@ def download_dependencies( dataset_folder: Path, record: dict[str, Any], s3_open_neuro: bool, -): - """Download all BIDS dependency files from S3 and cache them locally.""" +) -> None: + """Download all BIDS dependency files from S3. + + Iterates through a list of BIDS dependency files, downloads each from the + specified S3 bucket, and caches them in the appropriate local directory + structure. + + Parameters + ---------- + s3_bucket : str + The S3 bucket to download from. + bids_dependencies : list of str + A list of dependency file paths relative to the S3 bucket root. + bids_dependencies_original : list of str + The original dependency paths, used for resolving local cache paths. + cache_dir : pathlib.Path + The root directory for caching. + dataset_folder : pathlib.Path + The specific folder for the dataset within the cache directory. + record : dict + The metadata record for the main data file, used to resolve paths. + s3_open_neuro : bool + Flag for OpenNeuro-specific path handling. + + """ filesystem = get_s3_filesystem() for i, dep in enumerate(bids_dependencies): if not s3_open_neuro: @@ -78,8 +145,27 @@ def download_dependencies( _filesystem_get(filesystem=filesystem, s3path=s3path, filepath=filepath) -def _filesystem_get(filesystem: s3fs.S3FileSystem, s3path: str, filepath: Path): - """Helper to download a file from S3 with a progress bar.""" +def _filesystem_get(filesystem: s3fs.S3FileSystem, s3path: str, filepath: Path) -> Path: + """Perform the file download using fsspec with a progress bar. + + Internal helper function that wraps the ``filesystem.get`` call to include + a TQDM progress bar. + + Parameters + ---------- + filesystem : s3fs.S3FileSystem + The filesystem object to use for the download. + s3path : str + The full S3 URI of the source file. + filepath : pathlib.Path + The local destination path. + + Returns + ------- + pathlib.Path + The local path to the downloaded file. + + """ info = filesystem.info(s3path) size = info.get("size") or info.get("Size") diff --git a/eegdash/features/datasets.py b/eegdash/features/datasets.py index f4b021cc..9e933c59 100644 --- a/eegdash/features/datasets.py +++ b/eegdash/features/datasets.py @@ -20,20 +20,34 @@ class FeaturesDataset(EEGWindowsDataset): - """Returns samples from a pandas DataFrame object along with a target. + """A dataset of features extracted from EEG windows. - Dataset which serves samples from a pandas DataFrame object along with a - target. The target is unique for the dataset, and is obtained through the - `description` attribute. + This class holds features in a pandas DataFrame and provides an interface + compatible with braindecode's dataset structure. Each row in the feature + DataFrame corresponds to a single sample (e.g., an EEG window). Parameters ---------- - features : a pandas DataFrame - Tabular data. - description : dict | pandas.Series | None - Holds additional description about the continuous signal / subject. - transform : callable | None - On-the-fly transform applied to the example before it is returned. + features : pandas.DataFrame + A DataFrame where each row is a sample and each column is a feature. + metadata : pandas.DataFrame, optional + A DataFrame containing metadata for each sample, indexed consistently + with `features`. Must include columns 'i_window_in_trial', + 'i_start_in_trial', 'i_stop_in_trial', and 'target'. + description : dict or pandas.Series, optional + Additional high-level information about the dataset (e.g., subject ID). + transform : callable, optional + A function or transform to apply to the feature data on-the-fly. + raw_info : dict, optional + Information about the original raw recording, for provenance. + raw_preproc_kwargs : dict, optional + Keyword arguments used for preprocessing the raw data. + window_kwargs : dict, optional + Keyword arguments used for windowing the data. + window_preproc_kwargs : dict, optional + Keyword arguments used for preprocessing the windowed data. + features_kwargs : dict, optional + Keyword arguments used for feature extraction. """ @@ -65,7 +79,21 @@ def __init__( ].to_numpy() self.y = metadata.loc[:, "target"].to_list() - def __getitem__(self, index): + def __getitem__(self, index: int) -> tuple[np.ndarray, int, list]: + """Get a single sample from the dataset. + + Parameters + ---------- + index : int + The index of the sample to retrieve. + + Returns + ------- + tuple + A tuple containing the feature vector (X), the target (y), and the + cropping indices. + + """ crop_inds = self.crop_inds[index].tolist() X = self.features.iloc[index].to_numpy() X = X.copy() @@ -75,18 +103,27 @@ def __getitem__(self, index): y = self.y[index] return X, y, crop_inds - def __len__(self): + def __len__(self) -> int: + """Return the number of samples in the dataset. + + Returns + ------- + int + The total number of feature samples. + + """ return len(self.features.index) def _compute_stats( ds: FeaturesDataset, - return_count=False, - return_mean=False, - return_var=False, - ddof=1, - numeric_only=False, -): + return_count: bool = False, + return_mean: bool = False, + return_var: bool = False, + ddof: int = 1, + numeric_only: bool = False, +) -> tuple: + """Compute statistics for a single FeaturesDataset.""" res = [] if return_count: res.append(ds.features.count(numeric_only=numeric_only)) @@ -97,7 +134,14 @@ def _compute_stats( return tuple(res) -def _pooled_var(counts, means, variances, ddof, ddof_in=None): +def _pooled_var( + counts: np.ndarray, + means: np.ndarray, + variances: np.ndarray, + ddof: int, + ddof_in: int | None = None, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute pooled variance across multiple datasets.""" if ddof_in is None: ddof_in = ddof count = counts.sum(axis=0) @@ -110,17 +154,20 @@ def _pooled_var(counts, means, variances, ddof, ddof_in=None): class FeaturesConcatDataset(BaseConcatDataset): - """A base class for concatenated datasets. + """A concatenated dataset of `FeaturesDataset` objects. - Holds either mne.Raw or mne.Epoch in self.datasets and has - a pandas DataFrame with additional description. + This class holds a list of :class:`FeaturesDataset` instances and allows + them to be treated as a single, larger dataset. It provides methods for + + splitting, saving, and performing DataFrame-like operations (e.g., `mean`, + `var`, `fillna`) across all contained datasets. Parameters ---------- - list_of_ds : list - list of BaseDataset, BaseConcatDataset or WindowsDataset - target_transform : callable | None - Optional function to call on targets before returning them. + list_of_ds : list of FeaturesDataset + A list of :class:`FeaturesDataset` objects to concatenate. + target_transform : callable, optional + A function to apply to the target values before they are returned. """ @@ -140,26 +187,28 @@ def split( self, by: str | list[int] | list[list[int]] | dict[str, list[int]], ) -> dict[str, FeaturesConcatDataset]: - """Split the dataset based on information listed in its description. + """Split the dataset into subsets. - The format could be based on a DataFrame or based on indices. + The splitting can be done based on a column in the description + DataFrame or by providing explicit indices for each split. Parameters ---------- - by : str | list | dict - If ``by`` is a string, splitting is performed based on the - description DataFrame column with this name. - If ``by`` is a (list of) list of integers, the position in the first - list corresponds to the split id and the integers to the - datapoints of that split. - If a dict then each key will be used in the returned - splits dict and each value should be a list of int. + by : str or list or dict + - If a string, splits are created for each unique value in the + description column `by`. + - If a list of integers, a single split is created containing the + datasets at the specified indices. + - If a list of lists of integers, multiple splits are created, one + for each sublist of indices. + - If a dictionary, keys are used as split names and values are + lists of dataset indices. Returns ------- - splits : dict - A dictionary with the name of the split (a string) as key and the - dataset as value. + dict[str, FeaturesConcatDataset] + A dictionary where keys are split names and values are the new + :class:`FeaturesConcatDataset` subsets. """ if isinstance(by, str): @@ -184,14 +233,21 @@ def split( } def get_metadata(self) -> pd.DataFrame: - """Concatenate the metadata and description of the wrapped Epochs. + """Get the metadata of all datasets as a single DataFrame. + + Concatenates the metadata from all contained datasets and adds columns + from their `description` attributes. Returns ------- - metadata : pd.DataFrame - DataFrame containing as many rows as there are windows in the - BaseConcatDataset, with the metadata and description information - for each window. + pandas.DataFrame + A DataFrame containing the metadata for every sample in the + concatenated dataset. + + Raises + ------ + TypeError + If any of the contained datasets is not a :class:`FeaturesDataset`. """ if not all([isinstance(ds, FeaturesDataset) for ds in self.datasets]): @@ -202,60 +258,59 @@ def get_metadata(self) -> pd.DataFrame: all_dfs = list() for ds in self.datasets: - df = ds.metadata + df = ds.metadata.copy() for k, v in ds.description.items(): df[k] = v all_dfs.append(df) return pd.concat(all_dfs) - def save(self, path: str, overwrite: bool = False, offset: int = 0): - """Save datasets to files by creating one subdirectory for each dataset: - path/ - 0/ - 0-feat.parquet - metadata_df.pkl - description.json - raw-info.fif (if raw info was saved) - raw_preproc_kwargs.json (if raws were preprocessed) - window_kwargs.json (if this is a windowed dataset) - window_preproc_kwargs.json (if windows were preprocessed) - features_kwargs.json - 1/ - 1-feat.parquet - metadata_df.pkl - description.json - raw-info.fif (if raw info was saved) - raw_preproc_kwargs.json (if raws were preprocessed) - window_kwargs.json (if this is a windowed dataset) - window_preproc_kwargs.json (if windows were preprocessed) - features_kwargs.json + def save(self, path: str, overwrite: bool = False, offset: int = 0) -> None: + """Save the concatenated dataset to a directory. + + Creates a directory structure where each contained dataset is saved in + its own numbered subdirectory. + + .. code-block:: + + path/ + 0/ + 0-feat.parquet + metadata_df.pkl + description.json + ... + 1/ + 1-feat.parquet + ... Parameters ---------- path : str - Directory in which subdirectories are created to store - -feat.parquet and .json files to. - overwrite : bool - Whether to delete old subdirectories that will be saved to in this - call. - offset : int - If provided, the integer is added to the id of the dataset in the - concat. This is useful in the setting of very large datasets, where - one dataset has to be processed and saved at a time to account for - its original position. + The directory where the dataset will be saved. + overwrite : bool, default False + If True, any existing subdirectories that conflict with the new + ones will be removed. + offset : int, default 0 + An integer to add to the subdirectory names. Useful for saving + datasets in chunks. + + Raises + ------ + ValueError + If the dataset is empty. + FileExistsError + If a subdirectory already exists and `overwrite` is False. """ if len(self.datasets) == 0: raise ValueError("Expect at least one dataset") path_contents = os.listdir(path) - n_sub_dirs = len([os.path.isdir(e) for e in path_contents]) + n_sub_dirs = len([os.path.isdir(os.path.join(path, e)) for e in path_contents]) for i_ds, ds in enumerate(self.datasets): - # remove subdirectory from list of untouched files / subdirectories - if str(i_ds + offset) in path_contents: - path_contents.remove(str(i_ds + offset)) - # save_dir/i_ds/ - sub_dir = os.path.join(path, str(i_ds + offset)) + sub_dir_name = str(i_ds + offset) + if sub_dir_name in path_contents: + path_contents.remove(sub_dir_name) + sub_dir = os.path.join(path, sub_dir_name) if os.path.exists(sub_dir): if overwrite: shutil.rmtree(sub_dir) @@ -265,35 +320,21 @@ def save(self, path: str, overwrite: bool = False, offset: int = 0): f" a different directory, set overwrite=True, or " f"resolve manually." ) - # save_dir/{i_ds+offset}/ os.makedirs(sub_dir) - # save_dir/{i_ds+offset}/{i_ds+offset}-feat.parquet self._save_features(sub_dir, ds, i_ds, offset) - # save_dir/{i_ds+offset}/metadata_df.pkl self._save_metadata(sub_dir, ds) - # save_dir/{i_ds+offset}/description.json self._save_description(sub_dir, ds.description) - # save_dir/{i_ds+offset}/raw-info.fif self._save_raw_info(sub_dir, ds) - # save_dir/{i_ds+offset}/raw_preproc_kwargs.json - # save_dir/{i_ds+offset}/window_kwargs.json - # save_dir/{i_ds+offset}/window_preproc_kwargs.json - # save_dir/{i_ds+offset}/features_kwargs.json self._save_kwargs(sub_dir, ds) - if overwrite: - # the following will be True for all datasets preprocessed and - # stored in parallel with braindecode.preprocessing.preprocess - if i_ds + 1 + offset < n_sub_dirs: - logger.warning( - f"The number of saved datasets ({i_ds + 1 + offset}) " - f"does not match the number of existing " - f"subdirectories ({n_sub_dirs}). You may now " - f"encounter a mix of differently preprocessed " - f"datasets!", - UserWarning, - ) - # if path contains files or directories that were not touched, raise - # warning + if overwrite and i_ds + 1 + offset < n_sub_dirs: + logger.warning( + f"The number of saved datasets ({i_ds + 1 + offset}) " + f"does not match the number of existing " + f"subdirectories ({n_sub_dirs}). You may now " + f"encounter a mix of differently preprocessed " + f"datasets!", + UserWarning, + ) if path_contents: logger.warning( f"Chosen directory {path} contains other " @@ -301,20 +342,37 @@ def save(self, path: str, overwrite: bool = False, offset: int = 0): ) @staticmethod - def _save_features(sub_dir, ds, i_ds, offset): + def _save_features(sub_dir: str, ds: FeaturesDataset, i_ds: int, offset: int): + """Save the feature DataFrame to a Parquet file.""" parquet_file_name = f"{i_ds + offset}-feat.parquet" parquet_file_path = os.path.join(sub_dir, parquet_file_name) ds.features.to_parquet(parquet_file_path) @staticmethod - def _save_raw_info(sub_dir, ds): - if hasattr(ds, "raw_info"): + def _save_metadata(sub_dir: str, ds: FeaturesDataset): + """Save the metadata DataFrame to a pickle file.""" + metadata_file_name = "metadata_df.pkl" + metadata_file_path = os.path.join(sub_dir, metadata_file_name) + ds.metadata.to_pickle(metadata_file_path) + + @staticmethod + def _save_description(sub_dir: str, description: pd.Series): + """Save the description Series to a JSON file.""" + desc_file_name = "description.json" + desc_file_path = os.path.join(sub_dir, desc_file_name) + description.to_json(desc_file_path) + + @staticmethod + def _save_raw_info(sub_dir: str, ds: FeaturesDataset): + """Save the raw info dictionary to a FIF file if it exists.""" + if hasattr(ds, "raw_info") and ds.raw_info is not None: fif_file_name = "raw-info.fif" fif_file_path = os.path.join(sub_dir, fif_file_name) - ds.raw_info.save(fif_file_path) + ds.raw_info.save(fif_file_path, overwrite=True) @staticmethod - def _save_kwargs(sub_dir, ds): + def _save_kwargs(sub_dir: str, ds: FeaturesDataset): + """Save various keyword argument dictionaries to JSON files.""" for kwargs_name in [ "raw_preproc_kwargs", "window_kwargs", @@ -322,10 +380,10 @@ def _save_kwargs(sub_dir, ds): "features_kwargs", ]: if hasattr(ds, kwargs_name): - kwargs_file_name = ".".join([kwargs_name, "json"]) - kwargs_file_path = os.path.join(sub_dir, kwargs_file_name) kwargs = getattr(ds, kwargs_name) if kwargs is not None: + kwargs_file_name = ".".join([kwargs_name, "json"]) + kwargs_file_path = os.path.join(sub_dir, kwargs_file_name) with open(kwargs_file_path, "w") as f: json.dump(kwargs, f) @@ -334,7 +392,25 @@ def to_dataframe( include_metadata: bool | str | List[str] = False, include_target: bool = False, include_crop_inds: bool = False, - ): + ) -> pd.DataFrame: + """Convert the dataset to a single pandas DataFrame. + + Parameters + ---------- + include_metadata : bool or str or list of str, default False + If True, include all metadata columns. If a string or list of + strings, include only the specified metadata columns. + include_target : bool, default False + If True, include the 'target' column. + include_crop_inds : bool, default False + If True, include window cropping index columns. + + Returns + ------- + pandas.DataFrame + A DataFrame containing the features and requested metadata. + + """ if ( not isinstance(include_metadata, bool) or include_metadata @@ -343,7 +419,7 @@ def to_dataframe( include_dataset = False if isinstance(include_metadata, bool) and include_metadata: include_dataset = True - cols = self.datasets[0].metadata.columns + cols = self.datasets[0].metadata.columns.tolist() else: cols = include_metadata if isinstance(cols, bool) and not cols: @@ -352,13 +428,14 @@ def to_dataframe( cols = [cols] cols = set(cols) if include_crop_inds: - cols = { - "i_dataset", - "i_window_in_trial", - "i_start_in_trial", - "i_stop_in_trial", - *cols, - } + cols.update( + { + "i_dataset", + "i_window_in_trial", + "i_start_in_trial", + "i_stop_in_trial", + } + ) if include_target: cols.add("target") cols = list(cols) @@ -381,10 +458,26 @@ def to_dataframe( dataframes = [ds.features for ds in self.datasets] return pd.concat(dataframes, axis=0, ignore_index=True) - def _numeric_columns(self): + def _numeric_columns(self) -> pd.Index: + """Get the names of numeric columns from the feature DataFrames.""" return self.datasets[0].features.select_dtypes(include=np.number).columns - def count(self, numeric_only=False, n_jobs=1): + def count(self, numeric_only: bool = False, n_jobs: int = 1) -> pd.Series: + """Count non-NA cells for each feature column. + + Parameters + ---------- + numeric_only : bool, default False + Include only float, int, boolean columns. + n_jobs : int, default 1 + Number of jobs to run in parallel. + + Returns + ------- + pandas.Series + The count of non-NA cells for each column. + + """ stats = Parallel(n_jobs)( delayed(_compute_stats)(ds, return_count=True, numeric_only=numeric_only) for ds in self.datasets @@ -393,7 +486,22 @@ def count(self, numeric_only=False, n_jobs=1): count = counts.sum(axis=0) return pd.Series(count, index=self._numeric_columns()) - def mean(self, numeric_only=False, n_jobs=1): + def mean(self, numeric_only: bool = False, n_jobs: int = 1) -> pd.Series: + """Compute the mean for each feature column. + + Parameters + ---------- + numeric_only : bool, default False + Include only float, int, boolean columns. + n_jobs : int, default 1 + Number of jobs to run in parallel. + + Returns + ------- + pandas.Series + The mean of each column. + + """ stats = Parallel(n_jobs)( delayed(_compute_stats)( ds, return_count=True, return_mean=True, numeric_only=numeric_only @@ -405,7 +513,26 @@ def mean(self, numeric_only=False, n_jobs=1): mean = np.sum((counts / count) * means, axis=0) return pd.Series(mean, index=self._numeric_columns()) - def var(self, ddof=1, numeric_only=False, n_jobs=1): + def var( + self, ddof: int = 1, numeric_only: bool = False, n_jobs: int = 1 + ) -> pd.Series: + """Compute the variance for each feature column. + + Parameters + ---------- + ddof : int, default 1 + Delta Degrees of Freedom. The divisor used in calculations is N - ddof. + numeric_only : bool, default False + Include only float, int, boolean columns. + n_jobs : int, default 1 + Number of jobs to run in parallel. + + Returns + ------- + pandas.Series + The variance of each column. + + """ stats = Parallel(n_jobs)( delayed(_compute_stats)( ds, @@ -425,12 +552,50 @@ def var(self, ddof=1, numeric_only=False, n_jobs=1): _, _, var = _pooled_var(counts, means, variances, ddof, ddof_in=0) return pd.Series(var, index=self._numeric_columns()) - def std(self, ddof=1, numeric_only=False, eps=0, n_jobs=1): + def std( + self, ddof: int = 1, numeric_only: bool = False, eps: float = 0, n_jobs: int = 1 + ) -> pd.Series: + """Compute the standard deviation for each feature column. + + Parameters + ---------- + ddof : int, default 1 + Delta Degrees of Freedom. + numeric_only : bool, default False + Include only float, int, boolean columns. + eps : float, default 0 + A small epsilon value to add to the variance before taking the + square root to avoid numerical instability. + n_jobs : int, default 1 + Number of jobs to run in parallel. + + Returns + ------- + pandas.Series + The standard deviation of each column. + + """ return np.sqrt( self.var(ddof=ddof, numeric_only=numeric_only, n_jobs=n_jobs) + eps ) - def zscore(self, ddof=1, numeric_only=False, eps=0, n_jobs=1): + def zscore( + self, ddof: int = 1, numeric_only: bool = False, eps: float = 0, n_jobs: int = 1 + ) -> None: + """Apply z-score normalization to numeric columns in-place. + + Parameters + ---------- + ddof : int, default 1 + Delta Degrees of Freedom for variance calculation. + numeric_only : bool, default False + Include only float, int, boolean columns. + eps : float, default 0 + Epsilon for numerical stability. + n_jobs : int, default 1 + Number of jobs to run in parallel for statistics computation. + + """ stats = Parallel(n_jobs)( delayed(_compute_stats)( ds, @@ -450,10 +615,13 @@ def zscore(self, ddof=1, numeric_only=False, eps=0, n_jobs=1): _, mean, var = _pooled_var(counts, means, variances, ddof, ddof_in=0) std = np.sqrt(var + eps) for ds in self.datasets: - ds.features = (ds.features - mean) / std + ds.features.loc[:, self._numeric_columns()] = ( + ds.features.loc[:, self._numeric_columns()] - mean + ) / std @staticmethod - def _enforce_inplace_operations(func_name, kwargs): + def _enforce_inplace_operations(func_name: str, kwargs: dict): + """Raise an error if 'inplace=False' is passed to a method.""" if "inplace" in kwargs and kwargs["inplace"] is False: raise ValueError( f"{func_name} only works inplace, please change " @@ -461,33 +629,49 @@ def _enforce_inplace_operations(func_name, kwargs): ) kwargs["inplace"] = True - def fillna(self, *args, **kwargs): + def fillna(self, *args, **kwargs) -> None: + """Fill NA/NaN values in-place. See :meth:`pandas.DataFrame.fillna`.""" FeaturesConcatDataset._enforce_inplace_operations("fillna", kwargs) for ds in self.datasets: ds.features.fillna(*args, **kwargs) - def replace(self, *args, **kwargs): + def replace(self, *args, **kwargs) -> None: + """Replace values in-place. See :meth:`pandas.DataFrame.replace`.""" FeaturesConcatDataset._enforce_inplace_operations("replace", kwargs) for ds in self.datasets: ds.features.replace(*args, **kwargs) - def interpolate(self, *args, **kwargs): + def interpolate(self, *args, **kwargs) -> None: + """Interpolate values in-place. See :meth:`pandas.DataFrame.interpolate`.""" FeaturesConcatDataset._enforce_inplace_operations("interpolate", kwargs) for ds in self.datasets: ds.features.interpolate(*args, **kwargs) - def dropna(self, *args, **kwargs): + def dropna(self, *args, **kwargs) -> None: + """Remove missing values in-place. See :meth:`pandas.DataFrame.dropna`.""" FeaturesConcatDataset._enforce_inplace_operations("dropna", kwargs) for ds in self.datasets: ds.features.dropna(*args, **kwargs) - def drop(self, *args, **kwargs): + def drop(self, *args, **kwargs) -> None: + """Drop specified labels from rows or columns in-place. See :meth:`pandas.DataFrame.drop`.""" FeaturesConcatDataset._enforce_inplace_operations("drop", kwargs) for ds in self.datasets: ds.features.drop(*args, **kwargs) - def join(self, concat_dataset: FeaturesConcatDataset, **kwargs): + def join(self, concat_dataset: FeaturesConcatDataset, **kwargs) -> None: + """Join columns with other FeaturesConcatDataset in-place. + + Parameters + ---------- + concat_dataset : FeaturesConcatDataset + The dataset to join with. Must have the same number of datasets, + and each corresponding dataset must have the same length. + **kwargs + Keyword arguments to pass to :meth:`pandas.DataFrame.join`. + + """ assert len(self.datasets) == len(concat_dataset.datasets) for ds1, ds2 in zip(self.datasets, concat_dataset.datasets): assert len(ds1) == len(ds2) - ds1.features.join(ds2, **kwargs) + ds1.features = ds1.features.join(ds2.features, **kwargs) diff --git a/eegdash/features/decorators.py b/eegdash/features/decorators.py index 841a96a0..687d57fe 100644 --- a/eegdash/features/decorators.py +++ b/eegdash/features/decorators.py @@ -12,6 +12,21 @@ class FeaturePredecessor: + """A decorator to specify parent extractors for a feature function. + + This decorator attaches a list of parent extractor types to a feature + extraction function. This information can be used to build a dependency + graph of features. + + Parameters + ---------- + *parent_extractor_type : list of Type + A list of feature extractor classes (subclasses of + :class:`~eegdash.features.extractors.FeatureExtractor`) that this + feature depends on. + + """ + def __init__(self, *parent_extractor_type: List[Type]): parent_cls = parent_extractor_type if not parent_cls: @@ -20,17 +35,58 @@ def __init__(self, *parent_extractor_type: List[Type]): assert issubclass(p_cls, FeatureExtractor) self.parent_extractor_type = parent_cls - def __call__(self, func: Callable): + def __call__(self, func: Callable) -> Callable: + """Apply the decorator to a function. + + Parameters + ---------- + func : callable + The feature extraction function to decorate. + + Returns + ------- + callable + The decorated function with the `parent_extractor_type` attribute + set. + + """ f = _get_underlying_func(func) f.parent_extractor_type = self.parent_extractor_type return func class FeatureKind: + """A decorator to specify the kind of a feature. + + This decorator attaches a "feature kind" (e.g., univariate, bivariate) + to a feature extraction function. + + Parameters + ---------- + feature_kind : MultivariateFeature + An instance of a feature kind class, such as + :class:`~eegdash.features.extractors.UnivariateFeature` or + :class:`~eegdash.features.extractors.BivariateFeature`. + + """ + def __init__(self, feature_kind: MultivariateFeature): self.feature_kind = feature_kind - def __call__(self, func): + def __call__(self, func: Callable) -> Callable: + """Apply the decorator to a function. + + Parameters + ---------- + func : callable + The feature extraction function to decorate. + + Returns + ------- + callable + The decorated function with the `feature_kind` attribute set. + + """ f = _get_underlying_func(func) f.feature_kind = self.feature_kind return func @@ -38,9 +94,33 @@ def __call__(self, func): # Syntax sugar univariate_feature = FeatureKind(UnivariateFeature()) +"""Decorator to mark a feature as univariate. + +This is a convenience instance of :class:`FeatureKind` pre-configured for +univariate features. +""" -def bivariate_feature(func, directed=False): +def bivariate_feature(func: Callable, directed: bool = False) -> Callable: + """Decorator to mark a feature as bivariate. + + This decorator specifies that the feature operates on pairs of channels. + + Parameters + ---------- + func : callable + The feature extraction function to decorate. + directed : bool, default False + If True, the feature is directed (e.g., connectivity from channel A + to B is different from B to A). If False, the feature is undirected. + + Returns + ------- + callable + The decorated function with the appropriate bivariate feature kind + attached. + + """ if directed: kind = DirectedBivariateFeature() else: @@ -49,3 +129,8 @@ def bivariate_feature(func, directed=False): multivariate_feature = FeatureKind(MultivariateFeature()) +"""Decorator to mark a feature as multivariate. + +This is a convenience instance of :class:`FeatureKind` pre-configured for +multivariate features, which operate on all channels simultaneously. +""" diff --git a/eegdash/features/extractors.py b/eegdash/features/extractors.py index e3a785ef..5df28663 100644 --- a/eegdash/features/extractors.py +++ b/eegdash/features/extractors.py @@ -7,7 +7,23 @@ from numba.core.dispatcher import Dispatcher -def _get_underlying_func(func): +def _get_underlying_func(func: Callable) -> Callable: + """Get the underlying function from a potential wrapper. + + This helper unwraps functions that might be wrapped by `functools.partial` + or `numba.dispatcher.Dispatcher`. + + Parameters + ---------- + func : callable + The function to unwrap. + + Returns + ------- + callable + The underlying Python function. + + """ f = func if isinstance(f, partial): f = f.func @@ -17,22 +33,46 @@ def _get_underlying_func(func): class TrainableFeature(ABC): + """Abstract base class for features that require training. + + This ABC defines the interface for feature extractors that need to be + fitted on data before they can be used. It includes methods for fitting + the feature extractor and for resetting its state. + """ + def __init__(self): self._is_trained = False self.clear() @abstractmethod def clear(self): + """Reset the internal state of the feature extractor.""" pass @abstractmethod def partial_fit(self, *x, y=None): + """Update the feature extractor's state with a batch of data. + + Parameters + ---------- + *x : tuple + The input data for fitting. + y : any, optional + The target data, if required for supervised training. + + """ pass def fit(self): + """Finalize the training of the feature extractor. + + This method should be called after all data has been seen via + `partial_fit`. It marks the feature as fitted. + """ self._is_fitted = True def __call__(self, *args, **kwargs): + """Check if the feature is fitted before execution.""" if not self._is_fitted: raise RuntimeError( f"{self.__class__} cannot be called, it has to be fitted first." @@ -40,6 +80,22 @@ def __call__(self, *args, **kwargs): class FeatureExtractor(TrainableFeature): + """A composite feature extractor that applies multiple feature functions. + + This class orchestrates the application of a dictionary of feature + extraction functions to input data. It can handle nested extractors, + pre-processing, and trainable features. + + Parameters + ---------- + feature_extractors : dict[str, callable] + A dictionary where keys are feature names and values are the feature + extraction functions or other `FeatureExtractor` instances. + **preprocess_kwargs + Keyword arguments to be passed to the `preprocess` method. + + """ + def __init__( self, feature_extractors: Dict[str, Callable], **preprocess_kwargs: Dict ): @@ -63,30 +119,64 @@ def __init__( if isinstance(fe, partial): self.features_kwargs[fn] = fe.keywords - def _validate_execution_tree(self, feature_extractors): + def _validate_execution_tree(self, feature_extractors: dict) -> dict: + """Validate the feature dependency graph.""" for fname, f in feature_extractors.items(): f = _get_underlying_func(f) pe_type = getattr(f, "parent_extractor_type", [FeatureExtractor]) - assert type(self) in pe_type + if type(self) not in pe_type: + raise TypeError( + f"Feature '{fname}' cannot be a child of {type(self).__name__}" + ) return feature_extractors - def _check_is_trainable(self, feature_extractors): - is_trainable = False + def _check_is_trainable(self, feature_extractors: dict) -> bool: + """Check if any of the contained features are trainable.""" for fname, f in feature_extractors.items(): if isinstance(f, FeatureExtractor): - is_trainable = f._is_trainable - else: - f = _get_underlying_func(f) - if isinstance(f, TrainableFeature): - is_trainable = True - if is_trainable: - break - return is_trainable + if f._is_trainable: + return True + elif isinstance(_get_underlying_func(f), TrainableFeature): + return True + return False def preprocess(self, *x, **kwargs): + """Apply pre-processing to the input data. + + Parameters + ---------- + *x : tuple + Input data. + **kwargs + Additional keyword arguments. + + Returns + ------- + tuple + The pre-processed data. + + """ return (*x,) - def __call__(self, *x, _batch_size=None, _ch_names=None): + def __call__(self, *x, _batch_size=None, _ch_names=None) -> dict: + """Apply all feature extractors to the input data. + + Parameters + ---------- + *x : tuple + Input data. + _batch_size : int, optional + The number of samples in the batch. + _ch_names : list of str, optional + The names of the channels in the input data. + + Returns + ------- + dict + A dictionary where keys are feature names and values are the + computed feature values. + + """ assert _batch_size is not None assert _ch_names is not None if self._is_trainable: @@ -100,59 +190,83 @@ def __call__(self, *x, _batch_size=None, _ch_names=None): r = f(*z, _batch_size=_batch_size, _ch_names=_ch_names) else: r = f(*z) - f = _get_underlying_func(f) - if hasattr(f, "feature_kind"): - r = f.feature_kind(r, _ch_names=_ch_names) + f_und = _get_underlying_func(f) + if hasattr(f_und, "feature_kind"): + r = f_und.feature_kind(r, _ch_names=_ch_names) if not isinstance(fname, str) or not fname: - if isinstance(f, FeatureExtractor) or not hasattr(f, "__name__"): - fname = "" - else: - fname = f.__name__ + fname = getattr(f_und, "__name__", "") if isinstance(r, dict): - if fname: - fname += "_" + prefix = f"{fname}_" if fname else "" for k, v in r.items(): - self._add_feature_to_dict(results_dict, fname + k, v, _batch_size) + self._add_feature_to_dict(results_dict, prefix + k, v, _batch_size) else: self._add_feature_to_dict(results_dict, fname, r, _batch_size) return results_dict - def _add_feature_to_dict(self, results_dict, name, value, batch_size): - if not isinstance(value, np.ndarray): - results_dict[name] = value - else: + def _add_feature_to_dict( + self, results_dict: dict, name: str, value: any, batch_size: int + ): + """Add a computed feature to the results dictionary.""" + if isinstance(value, np.ndarray): assert value.shape[0] == batch_size - results_dict[name] = value + results_dict[name] = value def clear(self): + """Clear the state of all trainable sub-features.""" if not self._is_trainable: return - for fname, f in self.feature_extractors_dict.items(): - f = _get_underlying_func(f) - if isinstance(f, TrainableFeature): - f.clear() + for f in self.feature_extractors_dict.values(): + if isinstance(_get_underlying_func(f), TrainableFeature): + _get_underlying_func(f).clear() def partial_fit(self, *x, y=None): + """Partially fit all trainable sub-features.""" if not self._is_trainable: return z = self.preprocess(*x, **self.preprocess_kwargs) - for fname, f in self.feature_extractors_dict.items(): - f = _get_underlying_func(f) - if isinstance(f, TrainableFeature): - f.partial_fit(*z, y=y) + if not isinstance(z, tuple): + z = (z,) + for f in self.feature_extractors_dict.values(): + if isinstance(_get_underlying_func(f), TrainableFeature): + _get_underlying_func(f).partial_fit(*z, y=y) def fit(self): + """Fit all trainable sub-features.""" if not self._is_trainable: return - for fname, f in self.feature_extractors_dict.items(): - f = _get_underlying_func(f) - if isinstance(f, TrainableFeature): + for f in self.feature_extractors_dict.values(): + if isinstance(_get_underlying_func(f), TrainableFeature): f.fit() super().fit() class MultivariateFeature: - def __call__(self, x, _ch_names=None): + """A mixin for features that operate on multiple channels. + + This class provides a `__call__` method that converts a feature array into + a dictionary with named features, where names are derived from channel + names. + """ + + def __call__( + self, x: np.ndarray, _ch_names: list[str] | None = None + ) -> dict | np.ndarray: + """Convert a feature array to a named dictionary. + + Parameters + ---------- + x : numpy.ndarray + The computed feature array. + _ch_names : list of str, optional + The list of channel names. + + Returns + ------- + dict or numpy.ndarray + A dictionary of named features, or the original array if feature + channel names cannot be generated. + + """ assert _ch_names is not None f_channels = self.feature_channel_names(_ch_names) if isinstance(x, dict): @@ -163,37 +277,66 @@ def __call__(self, x, _ch_names=None): return self._array_to_dict(x, f_channels) @staticmethod - def _array_to_dict(x, f_channels, name=""): + def _array_to_dict( + x: np.ndarray, f_channels: list[str], name: str = "" + ) -> dict | np.ndarray: + """Convert a numpy array to a dictionary with named keys.""" assert isinstance(x, np.ndarray) - if len(f_channels) == 0: - assert x.ndim == 1 - if name: - return {name: x} - return x - assert x.shape[1] == len(f_channels) + if not f_channels: + return {name: x} if name else x + assert x.shape[1] == len(f_channels), f"{x.shape[1]} != {len(f_channels)}" x = x.swapaxes(0, 1) - names = [f"{name}_{ch}" for ch in f_channels] if name else f_channels + prefix = f"{name}_" if name else "" + names = [f"{prefix}{ch}" for ch in f_channels] return dict(zip(names, x)) - def feature_channel_names(self, ch_names): + def feature_channel_names(self, ch_names: list[str]) -> list[str]: + """Generate feature names based on channel names. + + Parameters + ---------- + ch_names : list of str + The names of the input channels. + + Returns + ------- + list of str + The names for the output features. + + """ return [] class UnivariateFeature(MultivariateFeature): - def feature_channel_names(self, ch_names): + """A feature kind for operations applied to each channel independently.""" + + def feature_channel_names(self, ch_names: list[str]) -> list[str]: + """Return the channel names themselves as feature names.""" return ch_names class BivariateFeature(MultivariateFeature): - def __init__(self, *args, channel_pair_format="{}<>{}"): + """A feature kind for operations on pairs of channels. + + Parameters + ---------- + channel_pair_format : str, default="{}<>{}" + A format string used to create feature names from pairs of + channel names. + + """ + + def __init__(self, *args, channel_pair_format: str = "{}<>{}"): super().__init__(*args) self.channel_pair_format = channel_pair_format @staticmethod - def get_pair_iterators(n): + def get_pair_iterators(n: int) -> tuple[np.ndarray, np.ndarray]: + """Get indices for unique, unordered pairs of channels.""" return np.triu_indices(n, 1) - def feature_channel_names(self, ch_names): + def feature_channel_names(self, ch_names: list[str]) -> list[str]: + """Generate feature names for each pair of channels.""" return [ self.channel_pair_format.format(ch_names[i], ch_names[j]) for i, j in zip(*self.get_pair_iterators(len(ch_names))) @@ -201,8 +344,11 @@ def feature_channel_names(self, ch_names): class DirectedBivariateFeature(BivariateFeature): + """A feature kind for directed operations on pairs of channels.""" + @staticmethod - def get_pair_iterators(n): + def get_pair_iterators(n: int) -> list[np.ndarray]: + """Get indices for all ordered pairs of channels (excluding self-pairs).""" return [ np.append(a, b) for a, b in zip(np.tril_indices(n, -1), np.triu_indices(n, 1)) diff --git a/eegdash/features/inspect.py b/eegdash/features/inspect.py index 29a6d23f..8e379b58 100644 --- a/eegdash/features/inspect.py +++ b/eegdash/features/inspect.py @@ -5,7 +5,27 @@ from .extractors import FeatureExtractor, MultivariateFeature, _get_underlying_func -def get_feature_predecessors(feature_or_extractor: Callable): +def get_feature_predecessors(feature_or_extractor: Callable) -> list: + """Get the dependency hierarchy for a feature or feature extractor. + + This function recursively traverses the `parent_extractor_type` attribute + of a feature or extractor to build a list representing its dependency + lineage. + + Parameters + ---------- + feature_or_extractor : callable + The feature function or :class:`FeatureExtractor` class to inspect. + + Returns + ------- + list + A nested list representing the dependency tree. For a simple linear + chain, this will be a flat list from the specific feature up to the + base `FeatureExtractor`. For multiple dependencies, it will contain + tuples of sub-dependencies. + + """ current = _get_underlying_func(feature_or_extractor) if current is FeatureExtractor: return [current] @@ -20,18 +40,59 @@ def get_feature_predecessors(feature_or_extractor: Callable): return [current, tuple(predecessors)] -def get_feature_kind(feature: Callable): +def get_feature_kind(feature: Callable) -> MultivariateFeature: + """Get the 'kind' of a feature function. + + The feature kind (e.g., univariate, bivariate) is typically attached by a + decorator. + + Parameters + ---------- + feature : callable + The feature function to inspect. + + Returns + ------- + MultivariateFeature + An instance of the feature kind (e.g., `UnivariateFeature()`). + + """ return _get_underlying_func(feature).feature_kind -def get_all_features(): +def get_all_features() -> list[tuple[str, Callable]]: + """Get a list of all available feature functions. + + Scans the `eegdash.features.feature_bank` module for functions that have + been decorated to have a `feature_kind` attribute. + + Returns + ------- + list[tuple[str, callable]] + A list of (name, function) tuples for all discovered features. + + """ + def isfeature(x): return hasattr(_get_underlying_func(x), "feature_kind") return inspect.getmembers(feature_bank, isfeature) -def get_all_feature_extractors(): +def get_all_feature_extractors() -> list[tuple[str, type[FeatureExtractor]]]: + """Get a list of all available `FeatureExtractor` classes. + + Scans the `eegdash.features.feature_bank` module for all classes that + subclass :class:`~eegdash.features.extractors.FeatureExtractor`. + + Returns + ------- + list[tuple[str, type[FeatureExtractor]]] + A list of (name, class) tuples for all discovered feature extractors, + including the base `FeatureExtractor` itself. + + """ + def isfeatureextractor(x): return inspect.isclass(x) and issubclass(x, FeatureExtractor) @@ -41,7 +102,19 @@ def isfeatureextractor(x): ] -def get_all_feature_kinds(): +def get_all_feature_kinds() -> list[tuple[str, type[MultivariateFeature]]]: + """Get a list of all available feature 'kind' classes. + + Scans the `eegdash.features.extractors` module for all classes that + subclass :class:`~eegdash.features.extractors.MultivariateFeature`. + + Returns + ------- + list[tuple[str, type[MultivariateFeature]]] + A list of (name, class) tuples for all discovered feature kinds. + + """ + def isfeaturekind(x): return inspect.isclass(x) and issubclass(x, MultivariateFeature) diff --git a/eegdash/features/serialization.py b/eegdash/features/serialization.py index 39f09d4c..5aeb1787 100644 --- a/eegdash/features/serialization.py +++ b/eegdash/features/serialization.py @@ -1,7 +1,8 @@ """Convenience functions for storing and loading features datasets. -See Also: - https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229 +See Also +-------- +https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229 """ @@ -16,34 +17,40 @@ from .datasets import FeaturesConcatDataset, FeaturesDataset -def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1): - """Load a stored features dataset from files. +def load_features_concat_dataset( + path: str | Path, ids_to_load: list[int] | None = None, n_jobs: int = 1 +) -> FeaturesConcatDataset: + """Load a stored `FeaturesConcatDataset` from a directory. + + This function reconstructs a :class:`FeaturesConcatDataset` by loading + individual :class:`FeaturesDataset` instances from subdirectories within + the given path. It uses joblib for parallel loading. Parameters ---------- - path: str | pathlib.Path - Path to the directory of the .fif / -epo.fif and .json files. - ids_to_load: list of int | None - Ids of specific files to load. - n_jobs: int - Number of jobs to be used to read files in parallel. + path : str or pathlib.Path + The path to the directory where the dataset was saved. This directory + should contain subdirectories (e.g., "0", "1", "2", ...) for each + individual dataset. + ids_to_load : list of int, optional + A list of specific dataset IDs (subdirectory names) to load. If None, + all subdirectories in the path will be loaded. + n_jobs : int, default 1 + The number of jobs to use for parallel loading. -1 means using all + processors. Returns ------- - concat_dataset: eegdash.features.datasets.FeaturesConcatDataset - A concatenation of multiple eegdash.features.datasets.FeaturesDataset - instances loaded from the given directory. + eegdash.features.datasets.FeaturesConcatDataset + A concatenated dataset containing the loaded `FeaturesDataset` instances. """ # Make sure we always work with a pathlib.Path path = Path(path) - # else we have a dataset saved in the new way with subdirectories in path - # for every dataset with description.json and -feat.parquet, - # target_name.json, raw_preproc_kwargs.json, window_kwargs.json, - # window_preproc_kwargs.json, features_kwargs.json if ids_to_load is None: - ids_to_load = [p.name for p in path.iterdir()] + # Get all subdirectories and sort them numerically + ids_to_load = [p.name for p in path.iterdir() if p.is_dir()] ids_to_load = sorted(ids_to_load, key=lambda i: int(i)) ids_to_load = [str(i) for i in ids_to_load] @@ -51,7 +58,26 @@ def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1): return FeaturesConcatDataset(datasets) -def _load_parallel(path, i): +def _load_parallel(path: Path, i: str) -> FeaturesDataset: + """Load a single `FeaturesDataset` from its subdirectory. + + This is a helper function for `load_features_concat_dataset` that handles + the loading of one dataset's files (features, metadata, descriptions, etc.). + + Parameters + ---------- + path : pathlib.Path + The root directory of the saved `FeaturesConcatDataset`. + i : str + The identifier of the dataset to load, corresponding to its + subdirectory name. + + Returns + ------- + eegdash.features.datasets.FeaturesDataset + The loaded dataset instance. + + """ sub_dir = path / i parquet_name_pattern = "{}-feat.parquet" diff --git a/eegdash/features/utils.py b/eegdash/features/utils.py index 170a514a..5c311496 100644 --- a/eegdash/features/utils.py +++ b/eegdash/features/utils.py @@ -22,7 +22,28 @@ def _extract_features_from_windowsdataset( win_ds: EEGWindowsDataset | WindowsDataset, feature_extractor: FeatureExtractor, batch_size: int = 512, -): +) -> FeaturesDataset: + """Extract features from a single `WindowsDataset`. + + This is a helper function that iterates through a `WindowsDataset` in + batches, applies a `FeatureExtractor`, and returns the results as a + `FeaturesDataset`. + + Parameters + ---------- + win_ds : EEGWindowsDataset or WindowsDataset + The windowed dataset to extract features from. + feature_extractor : FeatureExtractor + The feature extractor instance to apply. + batch_size : int, default 512 + The number of windows to process in each batch. + + Returns + ------- + FeaturesDataset + A new dataset containing the extracted features and associated metadata. + + """ metadata = win_ds.metadata if not win_ds.targets_from == "metadata": metadata = copy.deepcopy(metadata) @@ -51,18 +72,16 @@ def _extract_features_from_windowsdataset( features_dict[k].extend(v) features_df = pd.DataFrame(features_dict) if not win_ds.targets_from == "metadata": - metadata.set_index("orig_index", drop=False, inplace=True) metadata.reset_index(drop=True, inplace=True) - metadata.drop("orig_index", axis=1, inplace=True) + metadata.drop("orig_index", axis=1, inplace=True, errors="ignore") - # FUTURE: truly support WindowsDataset objects return FeaturesDataset( features_df, metadata=metadata, description=win_ds.description, raw_info=win_ds.raw.info, - raw_preproc_kwargs=win_ds.raw_preproc_kwargs, - window_kwargs=win_ds.window_kwargs, + raw_preproc_kwargs=getattr(win_ds, "raw_preproc_kwargs", None), + window_kwargs=getattr(win_ds, "window_kwargs", None), features_kwargs=feature_extractor.features_kwargs, ) @@ -73,7 +92,34 @@ def extract_features( *, batch_size: int = 512, n_jobs: int = 1, -): +) -> FeaturesConcatDataset: + """Extract features from a concatenated dataset of windows. + + This function applies a feature extractor to each `WindowsDataset` within a + `BaseConcatDataset` in parallel and returns a `FeaturesConcatDataset` + with the results. + + Parameters + ---------- + concat_dataset : BaseConcatDataset + A concatenated dataset of `WindowsDataset` or `EEGWindowsDataset` + instances. + features : FeatureExtractor or dict or list + The feature extractor(s) to apply. Can be a `FeatureExtractor` + instance, a dictionary of named feature functions, or a list of + feature functions. + batch_size : int, default 512 + The size of batches to use for feature extraction. + n_jobs : int, default 1 + The number of parallel jobs to use for extracting features from the + datasets. + + Returns + ------- + FeaturesConcatDataset + A new concatenated dataset containing the extracted features. + + """ if isinstance(features, list): features = dict(enumerate(features)) if not isinstance(features, FeatureExtractor): @@ -97,7 +143,28 @@ def fit_feature_extractors( concat_dataset: BaseConcatDataset, features: FeatureExtractor | Dict[str, Callable] | List[Callable], batch_size: int = 8192, -): +) -> FeatureExtractor: + """Fit trainable feature extractors on a dataset. + + If the provided feature extractor (or any of its sub-extractors) is + trainable (i.e., subclasses `TrainableFeature`), this function iterates + through the dataset to fit it. + + Parameters + ---------- + concat_dataset : BaseConcatDataset + The dataset to use for fitting the feature extractors. + features : FeatureExtractor or dict or list + The feature extractor(s) to fit. + batch_size : int, default 8192 + The batch size to use when iterating through the dataset for fitting. + + Returns + ------- + FeatureExtractor + The fitted feature extractor. + + """ if isinstance(features, list): features = dict(enumerate(features)) if not isinstance(features, FeatureExtractor): diff --git a/eegdash/hbn/preprocessing.py b/eegdash/hbn/preprocessing.py index 357d2cc7..102fefce 100644 --- a/eegdash/hbn/preprocessing.py +++ b/eegdash/hbn/preprocessing.py @@ -18,27 +18,47 @@ class hbn_ec_ec_reannotation(Preprocessor): - """Preprocessor to reannotate the raw data for eyes open and eyes closed events. + """Preprocessor to reannotate HBN data for eyes-open/eyes-closed events. - This processor is designed for HBN datasets. + This preprocessor is specifically designed for Healthy Brain Network (HBN) + datasets. It identifies existing annotations for "instructed_toCloseEyes" + and "instructed_toOpenEyes" and creates new, regularly spaced annotations + for "eyes_closed" and "eyes_open" segments, respectively. + + This is useful for creating windowed datasets based on these new, more + precise event markers. + + Notes + ----- + This class inherits from :class:`braindecode.preprocessing.Preprocessor` + and is intended to be used within a braindecode preprocessing pipeline. """ def __init__(self): super().__init__(fn=self.transform, apply_on_array=False) - def transform(self, raw): - """Reannotate the raw data to create new events for eyes open and eyes closed + def transform(self, raw: mne.io.Raw) -> mne.io.Raw: + """Create new annotations for eyes-open and eyes-closed periods. - This function modifies the raw MNE object by creating new events based on - the existing annotations for "instructed_toCloseEyes" and "instructed_toOpenEyes". - It generates new events every 2 seconds within specified time ranges after - the original events, and replaces the existing annotations with these new events. + This function finds the original "instructed_to..." annotations and + generates new annotations every 2 seconds within specific time ranges + relative to the original markers: + - "eyes_closed": 15s to 29s after "instructed_toCloseEyes" + - "eyes_open": 5s to 19s after "instructed_toOpenEyes" + + The original annotations in the `mne.io.Raw` object are replaced by + this new set of annotations. Parameters ---------- raw : mne.io.Raw - The raw MNE object containing EEG data and annotations. + The raw MNE object containing the HBN data and original annotations. + + Returns + ------- + mne.io.Raw + The raw MNE object with the modified annotations. """ events, event_id = mne.events_from_annotations(raw) @@ -48,15 +68,27 @@ def transform(self, raw): # Create new events array for 2-second segments new_events = [] sfreq = raw.info["sfreq"] - for event in events[events[:, 2] == event_id["instructed_toCloseEyes"]]: - # For each original event, create events every 2 seconds from 15s to 29s after - start_times = event[0] + np.arange(15, 29, 2) * sfreq - new_events.extend([[int(t), 0, 1] for t in start_times]) - for event in events[events[:, 2] == event_id["instructed_toOpenEyes"]]: - # For each original event, create events every 2 seconds from 5s to 19s after - start_times = event[0] + np.arange(5, 19, 2) * sfreq - new_events.extend([[int(t), 0, 2] for t in start_times]) + close_event_id = event_id.get("instructed_toCloseEyes") + if close_event_id: + for event in events[events[:, 2] == close_event_id]: + # For each original event, create events every 2s from 15s to 29s after + start_times = event[0] + np.arange(15, 29, 2) * sfreq + new_events.extend([[int(t), 0, 1] for t in start_times]) + + open_event_id = event_id.get("instructed_toOpenEyes") + if open_event_id: + for event in events[events[:, 2] == open_event_id]: + # For each original event, create events every 2s from 5s to 19s after + start_times = event[0] + np.arange(5, 19, 2) * sfreq + new_events.extend([[int(t), 0, 2] for t in start_times]) + + if not new_events: + logger.warning( + "Could not find 'instructed_toCloseEyes' or 'instructed_toOpenEyes' " + "annotations. No new events created." + ) + return raw # replace events in raw new_events = np.array(new_events) @@ -65,6 +97,7 @@ def transform(self, raw): events=new_events, event_desc={1: "eyes_closed", 2: "eyes_open"}, sfreq=raw.info["sfreq"], + orig_time=raw.info.get("meas_date"), ) raw.set_annotations(annot_from_events) diff --git a/eegdash/hbn/windows.py b/eegdash/hbn/windows.py index bba77731..8a2ee3d5 100644 --- a/eegdash/hbn/windows.py +++ b/eegdash/hbn/windows.py @@ -21,7 +21,25 @@ def build_trial_table(events_df: pd.DataFrame) -> pd.DataFrame: - """One row per contrast trial with stimulus/response metrics.""" + """Build a table of contrast trials from an events DataFrame. + + This function processes a DataFrame of events (typically from a BIDS + `events.tsv` file) to identify contrast trials and extract relevant + metrics like stimulus onset, response onset, and reaction times. + + Parameters + ---------- + events_df : pandas.DataFrame + A DataFrame containing event information, with at least "onset" and + "value" columns. + + Returns + ------- + pandas.DataFrame + A DataFrame where each row represents a single contrast trial, with + columns for onsets, reaction times, and response correctness. + + """ events_df = events_df.copy() events_df["onset"] = pd.to_numeric(events_df["onset"], errors="raise") events_df = events_df.sort_values("onset", kind="mergesort").reset_index(drop=True) @@ -92,12 +110,13 @@ def build_trial_table(events_df: pd.DataFrame) -> pd.DataFrame: return pd.DataFrame(rows) -# Aux functions to inject the annot def _to_float_or_none(x): + """Safely convert a value to float or None.""" return None if pd.isna(x) else float(x) def _to_int_or_none(x): + """Safely convert a value to int or None.""" if pd.isna(x): return None if isinstance(x, (bool, np.bool_)): @@ -106,22 +125,55 @@ def _to_int_or_none(x): return int(x) try: return int(x) - except Exception: + except (ValueError, TypeError): return None def _to_str_or_none(x): + """Safely convert a value to string or None.""" return None if (x is None or (isinstance(x, float) and np.isnan(x))) else str(x) def annotate_trials_with_target( - raw, - target_field="rt_from_stimulus", - epoch_length=2.0, - require_stimulus=True, - require_response=True, -): - """Create 'contrast_trial_start' annotations with float target in extras.""" + raw: mne.io.Raw, + target_field: str = "rt_from_stimulus", + epoch_length: float = 2.0, + require_stimulus: bool = True, + require_response: bool = True, +) -> mne.io.Raw: + """Create trial annotations with a specified target value. + + This function reads the BIDS events file associated with the `raw` object, + builds a trial table, and creates new MNE annotations for each trial. + The annotations are labeled "contrast_trial_start" and their `extras` + dictionary is populated with trial metrics, including a "target" key. + + Parameters + ---------- + raw : mne.io.Raw + The raw data object. Must have a single associated file name from + which the BIDS path can be derived. + target_field : str, default "rt_from_stimulus" + The column from the trial table to use as the "target" value in the + annotation extras. + epoch_length : float, default 2.0 + The duration to set for each new annotation. + require_stimulus : bool, default True + If True, only include trials that have a recorded stimulus event. + require_response : bool, default True + If True, only include trials that have a recorded response event. + + Returns + ------- + mne.io.Raw + The `raw` object with the new annotations set. + + Raises + ------ + KeyError + If `target_field` is not a valid column in the built trial table. + + """ fnames = raw.filenames assert len(fnames) == 1, "Expected a single filename" bids_path = get_bids_path_from_fname(fnames[0]) @@ -152,7 +204,6 @@ def annotate_trials_with_target( extras = [] for i, v in enumerate(targets): row = trials.iloc[i] - extras.append( { "target": _to_float_or_none(v), @@ -169,14 +220,39 @@ def annotate_trials_with_target( onset=onsets, duration=durations, description=descs, - orig_time=raw.info["meas_date"], + orig_time=raw.info.get("meas_date"), extras=extras, ) raw.set_annotations(new_ann, verbose=False) return raw -def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor"): +def add_aux_anchors( + raw: mne.io.Raw, + stim_desc: str = "stimulus_anchor", + resp_desc: str = "response_anchor", +) -> mne.io.Raw: + """Add auxiliary annotations for stimulus and response onsets. + + This function inspects existing "contrast_trial_start" annotations and + adds new, zero-duration "anchor" annotations at the precise onsets of + stimuli and responses for each trial. + + Parameters + ---------- + raw : mne.io.Raw + The raw data object with "contrast_trial_start" annotations. + stim_desc : str, default "stimulus_anchor" + The description for the new stimulus annotations. + resp_desc : str, default "response_anchor" + The description for the new response annotations. + + Returns + ------- + mne.io.Raw + The `raw` object with the auxiliary annotations added. + + """ ann = raw.annotations mask = ann.description == "contrast_trial_start" if not np.any(mask): @@ -189,28 +265,24 @@ def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor ex = ann.extras[idx] if ann.extras is not None else {} t0 = float(ann.onset[idx]) - stim_t = ex["stimulus_onset"] - resp_t = ex["response_onset"] + stim_t = ex.get("stimulus_onset") + resp_t = ex.get("response_onset") if stim_t is None or (isinstance(stim_t, float) and np.isnan(stim_t)): - rtt = ex["rt_from_trialstart"] - rts = ex["rt_from_stimulus"] + rtt = ex.get("rt_from_trialstart") + rts = ex.get("rt_from_stimulus") if rtt is not None and rts is not None: stim_t = t0 + float(rtt) - float(rts) if resp_t is None or (isinstance(resp_t, float) and np.isnan(resp_t)): - rtt = ex["rt_from_trialstart"] + rtt = ex.get("rt_from_trialstart") if rtt is not None: resp_t = t0 + float(rtt) - if (stim_t is not None) and not ( - isinstance(stim_t, float) and np.isnan(stim_t) - ): + if stim_t is not None and not (isinstance(stim_t, float) and np.isnan(stim_t)): stim_onsets.append(float(stim_t)) stim_extras.append(dict(ex, anchor="stimulus")) - if (resp_t is not None) and not ( - isinstance(resp_t, float) and np.isnan(resp_t) - ): + if resp_t is not None and not (isinstance(resp_t, float) and np.isnan(resp_t)): resp_onsets.append(float(resp_t)) resp_extras.append(dict(ex, anchor="response")) @@ -220,7 +292,7 @@ def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor onset=new_onsets, duration=np.zeros_like(new_onsets, dtype=float), description=[stim_desc] * len(stim_onsets) + [resp_desc] * len(resp_onsets), - orig_time=raw.info["meas_date"], + orig_time=raw.info.get("meas_date"), extras=stim_extras + resp_extras, ) raw.set_annotations(ann + aux, verbose=False) @@ -228,10 +300,10 @@ def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor def add_extras_columns( - windows_concat_ds, - original_concat_ds, - desc="contrast_trial_start", - keys=( + windows_concat_ds: BaseConcatDataset, + original_concat_ds: BaseConcatDataset, + desc: str = "contrast_trial_start", + keys: tuple = ( "target", "rt_from_stimulus", "rt_from_trialstart", @@ -240,7 +312,31 @@ def add_extras_columns( "correct", "response_type", ), -): +) -> BaseConcatDataset: + """Add columns from annotation extras to a windowed dataset's metadata. + + This function propagates trial-level information stored in the `extras` + of annotations to the `metadata` DataFrame of a `WindowsDataset`. + + Parameters + ---------- + windows_concat_ds : BaseConcatDataset + The windowed dataset whose metadata will be updated. + original_concat_ds : BaseConcatDataset + The original (non-windowed) dataset containing the raw data and + annotations with the `extras` to be added. + desc : str, default "contrast_trial_start" + The description of the annotations to source the extras from. + keys : tuple, default (...) + The keys to extract from each annotation's `extras` dictionary and + add as columns to the metadata. + + Returns + ------- + BaseConcatDataset + The `windows_concat_ds` with updated metadata. + + """ float_cols = { "target", "rt_from_stimulus", @@ -292,7 +388,6 @@ def add_extras_columns( else: # response_type ser = pd.Series(vals, index=md.index, dtype="string") - # Replace the whole column to avoid dtype conflicts md[k] = ser win_ds.metadata = md.reset_index(drop=True) @@ -303,7 +398,25 @@ def add_extras_columns( return windows_concat_ds -def keep_only_recordings_with(desc, concat_ds): +def keep_only_recordings_with( + desc: str, concat_ds: BaseConcatDataset +) -> BaseConcatDataset: + """Filter a concatenated dataset to keep only recordings with a specific annotation. + + Parameters + ---------- + desc : str + The description of the annotation that must be present in a recording + for it to be kept. + concat_ds : BaseConcatDataset + The concatenated dataset to filter. + + Returns + ------- + BaseConcatDataset + A new concatenated dataset containing only the filtered recordings. + + """ kept = [] for ds in concat_ds.datasets: if np.any(ds.raw.annotations.description == desc): diff --git a/eegdash/logging.py b/eegdash/logging.py index fa62708c..0e92bc42 100644 --- a/eegdash/logging.py +++ b/eegdash/logging.py @@ -29,6 +29,25 @@ # Now, get your package-specific logger. It will inherit the # configuration from the root logger we just set up. logger = logging.getLogger("eegdash") +"""The primary logger for the EEGDash package. + +This logger is configured to use :class:`rich.logging.RichHandler` for +formatted, colorful output in the console. It inherits its base configuration +from the root logger, which is set to the ``INFO`` level. + +Examples +-------- +Usage in other modules: + +.. code-block:: python + + from .logging import logger + + logger.info("This is an informational message.") + logger.warning("This is a warning.") + logger.error("This is an error.") +""" + logger.setLevel(logging.INFO) diff --git a/eegdash/mongodb.py b/eegdash/mongodb.py index d734530c..897e5643 100644 --- a/eegdash/mongodb.py +++ b/eegdash/mongodb.py @@ -4,50 +4,63 @@ """MongoDB connection and operations management. -This module provides thread-safe MongoDB connection management and high-level database -operations for the EEGDash metadata database. It includes methods for finding, adding, -and updating EEG data records with proper connection pooling and error handling. +This module provides a thread-safe singleton manager for MongoDB connections, +ensuring that connections to the database are handled efficiently and consistently +across the application. """ import threading from pymongo import MongoClient - -# MongoDB Operations -# These methods provide a high-level interface to interact with the MongoDB -# collection, allowing users to find, add, and update EEG data records. -# - find: -# - exist: -# - add_request: -# - add: -# - update_request: -# - remove_field: -# - remove_field_from_db: -# - close: Close the MongoDB connection. -# - __del__: Destructor to close the MongoDB connection. +from pymongo.collection import Collection +from pymongo.database import Database class MongoConnectionManager: - """Singleton class to manage MongoDB client connections.""" + """A thread-safe singleton to manage MongoDB client connections. + + This class ensures that only one connection instance is created for each + unique combination of a connection string and staging flag. It provides + class methods to get a client and to close all active connections. + + Attributes + ---------- + _instances : dict + A dictionary to store singleton instances, mapping a + (connection_string, is_staging) tuple to a (client, db, collection) + tuple. + _lock : threading.Lock + A lock to ensure thread-safe instantiation of clients. + + """ - _instances = {} + _instances: dict[tuple[str, bool], tuple[MongoClient, Database, Collection]] = {} _lock = threading.Lock() @classmethod - def get_client(cls, connection_string: str, is_staging: bool = False): - """Get or create a MongoDB client for the given connection string and staging flag. + def get_client( + cls, connection_string: str, is_staging: bool = False + ) -> tuple[MongoClient, Database, Collection]: + """Get or create a MongoDB client for the given connection parameters. + + This method returns a cached client if one already exists for the given + connection string and staging flag. Otherwise, it creates a new client, + connects to the appropriate database ("eegdash" or "eegdashstaging"), + and returns the client, database, and "records" collection. Parameters ---------- connection_string : str - The MongoDB connection string - is_staging : bool - Whether to use staging database + The MongoDB connection string. + is_staging : bool, default False + If True, connect to the staging database ("eegdashstaging"). + Otherwise, connect to the production database ("eegdash"). Returns ------- - tuple - A tuple of (client, database, collection) + tuple[MongoClient, Database, Collection] + A tuple containing the connected MongoClient instance, the Database + object, and the Collection object for the "records" collection. """ # Create a unique key based on connection string and staging flag @@ -66,8 +79,12 @@ def get_client(cls, connection_string: str, is_staging: bool = False): return cls._instances[key] @classmethod - def close_all(cls): - """Close all MongoDB client connections.""" + def close_all(cls) -> None: + """Close all managed MongoDB client connections. + + This method iterates through all cached client instances and closes + their connections. It also clears the instance cache. + """ with cls._lock: for client, _, _ in cls._instances.values(): try: diff --git a/eegdash/paths.py b/eegdash/paths.py index 495169a8..25655827 100644 --- a/eegdash/paths.py +++ b/eegdash/paths.py @@ -18,12 +18,21 @@ def get_default_cache_dir() -> Path: - """Resolve a consistent default cache directory for EEGDash. + """Resolve the default cache directory for EEGDash data. + + The function determines the cache directory based on the following + priority order: + 1. The path specified by the ``EEGDASH_CACHE_DIR`` environment variable. + 2. The path specified by the ``MNE_DATA`` configuration in the MNE-Python + config file. + 3. A hidden directory named ``.eegdash_cache`` in the current working + directory. + + Returns + ------- + pathlib.Path + The resolved, absolute path to the default cache directory. - Priority order: - 1) Environment variable ``EEGDASH_CACHE_DIR`` if set. - 2) MNE config ``MNE_DATA`` if set (aligns with tests and ecosystem caches). - 3) ``.eegdash_cache`` under the current working directory. """ # 1) Explicit env var wins env_dir = os.environ.get("EEGDASH_CACHE_DIR") diff --git a/eegdash/utils.py b/eegdash/utils.py index 60f3466a..0a9c520e 100644 --- a/eegdash/utils.py +++ b/eegdash/utils.py @@ -11,7 +11,22 @@ from mne.utils import get_config, set_config, use_log_level -def _init_mongo_client(): +def _init_mongo_client() -> None: + """Initialize the default MongoDB connection URI in the MNE config. + + This function checks if the ``EEGDASH_DB_URI`` is already set in the + MNE-Python configuration. If it is not set, this function sets it to the + default public EEGDash MongoDB Atlas cluster URI. + + The operation is performed with MNE's logging level temporarily set to + "ERROR" to suppress verbose output. + + Notes + ----- + This is an internal helper function and is not intended for direct use + by end-users. + + """ with use_log_level("ERROR"): if get_config("EEGDASH_DB_URI") is None: set_config(