\ No newline at end of file
diff --git a/docs/source/_static/logos/ucsd_white.png b/docs/source/_static/logos/ucsd_white.png
deleted file mode 100644
index 038ad03c..00000000
Binary files a/docs/source/_static/logos/ucsd_white.png and /dev/null differ
diff --git a/docs/source/_static/logos/ucsd_white.svg b/docs/source/_static/logos/ucsd_white.svg
new file mode 100644
index 00000000..b6199fdc
--- /dev/null
+++ b/docs/source/_static/logos/ucsd_white.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/source/api/api_core.rst b/docs/source/api/api_core.rst
index 2896efb0..68987a2b 100644
--- a/docs/source/api/api_core.rst
+++ b/docs/source/api/api_core.rst
@@ -76,4 +76,4 @@ API Reference
hbn
mongodb
paths
- utils
\ No newline at end of file
+ utils
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 34aaf796..0a357598 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -12,13 +12,15 @@
from sphinx_gallery.sorting import ExplicitOrder, FileNameSortKey
from tabulate import tabulate
+sys.path.insert(0, os.path.abspath(".."))
+
import eegdash
# -- Project information -----------------------------------------------------
project = "EEG Dash"
copyright = f"2025–{datetime.now(tz=timezone.utc).year}, {project} Developers"
-author = "Arnaud Delorme"
+author = "Bruno Aristimunha and Arnaud Delorme"
release = eegdash.__version__
version = ".".join(release.split(".")[:2])
@@ -44,6 +46,7 @@
"sphinx_sitemap",
"sphinx_copybutton",
"sphinx.ext.graphviz",
+ "sphinx_time_estimation",
]
templates_path = ["_templates"]
@@ -103,8 +106,8 @@
"navbar_end": ["theme-switcher", "navbar-icon-links"],
"footer_start": ["copyright"],
"logo": {
- "image_light": "_static/eegdash_long.png",
- "image_dark": "_static/eegdash_long.png",
+ "image_light": "_static/eegdash_long_white.svg",
+ "image_dark": "_static/eegdash_long_dark.svg",
"alt_text": "EEG Dash Logo",
},
"external_links": [
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 632007c6..3a32e2b6 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -5,53 +5,76 @@ EEGDASH Homepage
.. title:: EEG Dash
-.. raw:: html
-
-
EEG Dash Homepage
==================
-.. image:: _static/logos/eegdash.png
- :alt: EEG Dash Logo
- :class: logo mainlogo
- :align: center
- :scale: 40%
+.. raw:: html
+
+
+
+.. raw:: html
+ EEG Dash
+
+
+.. image:: _static/logos/eegdash.svg
+ :alt: EEG Dash Logo
+ :class: logo mainlogo only-dark
+ :align: center
+ :scale: 50%
+
+.. image:: _static/logos/eegdash.svg
+ :alt: EEG Dash Logo
+ :class: logo mainlogo only-light
+ :align: center
+ :scale: 50%
.. rst-class:: h4 text-center font-weight-light my-4
-The EEG-DaSh data archive will establish a data-sharing resource for MEEG (EEG, MEG) data, enabling
-large-scale computational advancements to preserve and share scientific data from publicly funded
-research for machine learning and deep learning applications.
+
+ The EEG-DaSh data archive is a data-sharing resource for MEEG (EEG, MEG) data, enabling
+ large-scale computational advancements to preserve and share scientific data from publicly funded
+ research for machine learning and deep learning applications.
.. rst-class:: text-center
-**Note:** The "DaSh" in EEG-DaSh stands for **Data Share**.
+ The "DaSh" in EEG-DaSh stands for **Data Share**.
-The EEG-DaSh data archive is a collaborative effort led by the University of California, San Diego (UCSD) and Ben-Gurion University of the Negev (BGU) and partially funded by the National Science Foundation (NSF). All are welcome to contribute to the https://github.com/sccn/EEGDash project.
+ The EEG-DaSh data archive is a collaborative effort led by the University of California, San Diego (UCSD) and Ben-Gurion University of the Negev (BGU) and partially funded by the National Science Foundation (NSF). All are welcome to contribute to the https://github.com/sccn/EEGDash project.
-The archive is currently still in :bdg-danger:`beta testing` mode, so be kind.
+ The archive is currently still in :bdg-danger:`beta testing` mode, so be kind.
.. raw:: html
Institutions
-.. list-table::
- :width: 100%
- :class: borderless logos-row
-
- * - .. image:: _static/logos/ucsd_white.png
- :alt: University of California, San Diego (UCSD)
- :class: logo mainlogo
- :align: center
- :width: 100%
-
- - .. image:: _static/logos/bgu_white.png
- :alt: Ben-Gurion University of the Negev (BGU)
- :class: logo mainlogo
- :align: center
- :width: 100%
+.. image:: _static/logos/ucsd_white.svg
+ :alt: UCSD
+ :class: logo mainlogo only-dark flex-logo
+ :width: 45%
+ :align: left
+
+
+.. image:: _static/logos/ucsd_dark.svg
+ :alt: UCSD
+ :class: logo mainlogo only-light flex-logo
+ :align: left
+ :width: 45%
+
+
+.. image:: _static/logos/bgu_dark.svg
+ :alt: Ben-Gurion University of the Negev (BGU)
+ :class: logo mainlogo only-dark flex-logo
+ :align: right
+ :width: 40%
+
+.. image:: _static/logos/bgu_white.svg
+ :alt: Ben-Gurion University of the Negev (BGU)
+ :class: logo mainlogo only-light flex-logo
+ :align: right
+ :width: 40%
+
.. toctree::
:hidden:
diff --git a/docs/sphinx_time_estimation.py b/docs/sphinx_time_estimation.py
new file mode 100644
index 00000000..a99472cb
--- /dev/null
+++ b/docs/sphinx_time_estimation.py
@@ -0,0 +1,106 @@
+from __future__ import annotations
+
+import math
+import re
+
+from docutils import nodes
+
+SKIP_CONTAINER_CLASSES = {
+ "sphx-glr-script-out",
+ "sphx-glr-single-img",
+ "sphx-glr-thumbnail",
+ "sphx-glr-horizontal",
+}
+
+
+class TextExtractor(nodes.NodeVisitor):
+ def __init__(self, document):
+ super().__init__(document)
+ self.text = []
+
+ def visit_Text(self, node):
+ self.text.append(node.astext())
+
+ def visit_literal_block(self, node):
+ # Don't visit the children of literal blocks (i.e., code blocks)
+ raise nodes.SkipNode
+
+ def visit_figure(self, node):
+ raise nodes.SkipNode
+
+ def visit_image(self, node):
+ raise nodes.SkipNode
+
+ def visit_container(self, node):
+ classes = set(node.get("classes", ()))
+ if classes & SKIP_CONTAINER_CLASSES:
+ raise nodes.SkipNode
+
+ def unknown_visit(self, node):
+ """Pass for all other nodes."""
+ pass
+
+
+EXAMPLE_PREFIX = "generated/auto_examples/"
+
+
+def _should_calculate(pagename: str) -> bool:
+ if not pagename:
+ return False
+ if not pagename.startswith(EXAMPLE_PREFIX):
+ return False
+ if pagename.endswith("/sg_execution_times"):
+ return False
+ if pagename == "generated/auto_examples/index":
+ return False
+ return True
+
+
+def html_page_context(app, pagename, templatename, context, doctree):
+ """Add estimated reading time directly under tutorial titles."""
+ if not doctree or not _should_calculate(pagename):
+ context.pop("reading_time", None)
+ return
+
+ visitor = TextExtractor(doctree)
+ doctree.walk(visitor)
+
+ full_text = " ".join(visitor.text)
+ word_count = len(re.findall(r"\w+", full_text))
+
+ wpm = 200 # Median reading speed
+ reading_time = math.ceil(word_count / wpm) if wpm > 0 else 0
+
+ if reading_time <= 0:
+ context.pop("reading_time", None)
+ return
+
+ context["reading_time"] = reading_time
+
+ body = context.get("body")
+ if not isinstance(body, str) or "" not in body:
+ return
+
+ minutes_label = "minute" if reading_time == 1 else "minutes"
+ badge_html = (
+ ''
+ 'Estimated reading time:'
+ f'{reading_time} {minutes_label}'
+ "
"
+ )
+
+ insert_at = body.find("")
+ if insert_at == -1:
+ return
+
+ context["body"] = body[: insert_at + 5] + badge_html + body[insert_at + 5 :]
+
+
+def setup(app):
+ """Setup the Sphinx extension."""
+ app.connect("html-page-context", html_page_context)
+ return {
+ "version": "0.1",
+ "parallel_read_safe": True,
+ "parallel_write_safe": True,
+ }
diff --git a/eegdash/api.py b/eegdash/api.py
index ca98d3e6..934eb67c 100644
--- a/eegdash/api.py
+++ b/eegdash/api.py
@@ -212,17 +212,22 @@ def exist(self, query: dict[str, Any]) -> bool:
return doc is not None
def _validate_input(self, record: dict[str, Any]) -> dict[str, Any]:
- """Internal method to validate the input record against the expected schema.
+ """Validate the input record against the expected schema.
Parameters
----------
- record: dict
+ record : dict
A dictionary representing the EEG data record to be validated.
Returns
-------
- dict:
- Returns the record itself on success, or raises a ValueError if the record is invalid.
+ dict
+ The record itself on success.
+
+ Raises
+ ------
+ ValueError
+ If the record is missing required keys or has values of the wrong type.
"""
input_types = {
@@ -252,20 +257,44 @@ def _validate_input(self, record: dict[str, Any]) -> dict[str, Any]:
return record
def _build_query_from_kwargs(self, **kwargs) -> dict[str, Any]:
- """Internal helper to build a validated MongoDB query from keyword args.
+ """Build a validated MongoDB query from keyword arguments.
+
+ This delegates to the module-level builder used across the package.
+
+ Parameters
+ ----------
+ **kwargs
+ Keyword arguments to convert into a MongoDB query.
+
+ Returns
+ -------
+ dict
+ A MongoDB query dictionary.
- This delegates to the module-level builder used across the package and
- is exposed here for testing and convenience.
"""
return build_query_from_kwargs(**kwargs)
- # --- Query merging and conflict detection helpers ---
- def _extract_simple_constraint(self, query: dict[str, Any], key: str):
+ def _extract_simple_constraint(
+ self, query: dict[str, Any], key: str
+ ) -> tuple[str, Any] | None:
"""Extract a simple constraint for a given key from a query dict.
- Supports only top-level equality (key: value) and $in (key: {"$in": [...]})
- constraints. Returns a tuple (kind, value) where kind is "eq" or "in". If the
- key is not present or uses other operators, returns None.
+ Supports top-level equality (e.g., ``{'subject': '01'}``) and ``$in``
+ (e.g., ``{'subject': {'$in': ['01', '02']}}``) constraints.
+
+ Parameters
+ ----------
+ query : dict
+ The MongoDB query dictionary.
+ key : str
+ The key for which to extract the constraint.
+
+ Returns
+ -------
+ tuple or None
+ A tuple of (kind, value) where kind is "eq" or "in", or None if the
+ constraint is not present or unsupported.
+
"""
if not isinstance(query, dict) or key not in query:
return None
@@ -275,16 +304,28 @@ def _extract_simple_constraint(self, query: dict[str, Any], key: str):
return ("in", list(val["$in"]))
return None # unsupported operator shape for conflict checking
else:
- return ("eq", val)
+ return "eq", val
def _raise_if_conflicting_constraints(
self, raw_query: dict[str, Any], kwargs_query: dict[str, Any]
) -> None:
- """Raise ValueError if both query sources define incompatible constraints.
+ """Raise ValueError if query sources have incompatible constraints.
+
+ Checks for mutually exclusive constraints on the same field to avoid
+ silent empty results.
+
+ Parameters
+ ----------
+ raw_query : dict
+ The raw MongoDB query dictionary.
+ kwargs_query : dict
+ The query dictionary built from keyword arguments.
+
+ Raises
+ ------
+ ValueError
+ If conflicting constraints are found.
- We conservatively check only top-level fields with simple equality or $in
- constraints. If a field appears in both queries and constraints are mutually
- exclusive, raise an explicit error to avoid silent empty result sets.
"""
if not raw_query or not kwargs_query:
return
@@ -388,12 +429,31 @@ def add_bids_dataset(
logger.info("Upserted: %s", result.upserted_count)
logger.info("Errors: %s ", result.bulk_api_result.get("writeErrors", []))
- def _add_request(self, record: dict):
- """Internal helper method to create a MongoDB insertion request for a record."""
+ def _add_request(self, record: dict) -> InsertOne:
+ """Create a MongoDB insertion request for a record.
+
+ Parameters
+ ----------
+ record : dict
+ The record to insert.
+
+ Returns
+ -------
+ InsertOne
+ A PyMongo ``InsertOne`` object.
+
+ """
return InsertOne(record)
- def add(self, record: dict):
- """Add a single record to the MongoDB collection."""
+ def add(self, record: dict) -> None:
+ """Add a single record to the MongoDB collection.
+
+ Parameters
+ ----------
+ record : dict
+ The record to add.
+
+ """
try:
self.__collection.insert_one(record)
except ValueError as e:
@@ -405,11 +465,23 @@ def add(self, record: dict):
)
logger.debug("Add operation failed", exc_info=exc)
- def _update_request(self, record: dict):
- """Internal helper method to create a MongoDB update request for a record."""
+ def _update_request(self, record: dict) -> UpdateOne:
+ """Create a MongoDB update request for a record.
+
+ Parameters
+ ----------
+ record : dict
+ The record to update.
+
+ Returns
+ -------
+ UpdateOne
+ A PyMongo ``UpdateOne`` object.
+
+ """
return UpdateOne({"data_name": record["data_name"]}, {"$set": record})
- def update(self, record: dict):
+ def update(self, record: dict) -> None:
"""Update a single record in the MongoDB collection.
Parameters
@@ -429,58 +501,81 @@ def update(self, record: dict):
logger.debug("Update operation failed", exc_info=exc)
def exists(self, query: dict[str, Any]) -> bool:
- """Alias for :meth:`exist` provided for API clarity."""
+ """Check if at least one record matches the query.
+
+ This is an alias for :meth:`exist`.
+
+ Parameters
+ ----------
+ query : dict
+ MongoDB query to check for existence.
+
+ Returns
+ -------
+ bool
+ True if a matching record exists, False otherwise.
+
+ """
return self.exist(query)
- def remove_field(self, record, field):
- """Remove a specific field from a record in the MongoDB collection.
+ def remove_field(self, record: dict, field: str) -> None:
+ """Remove a field from a specific record in the MongoDB collection.
Parameters
----------
record : dict
- Record identifying object with ``data_name``.
+ Record-identifying object with a ``data_name`` key.
field : str
- Field name to remove.
+ The name of the field to remove.
"""
self.__collection.update_one(
{"data_name": record["data_name"]}, {"$unset": {field: 1}}
)
- def remove_field_from_db(self, field):
- """Remove a field from all records (destructive).
+ def remove_field_from_db(self, field: str) -> None:
+ """Remove a field from all records in the database.
+
+ .. warning::
+ This is a destructive operation and cannot be undone.
Parameters
----------
field : str
- Field name to remove from every document.
+ The name of the field to remove from all documents.
"""
self.__collection.update_many({}, {"$unset": {field: 1}})
@property
def collection(self):
- """Return the MongoDB collection object."""
- return self.__collection
+ """The underlying PyMongo ``Collection`` object.
- def close(self):
- """Backward-compatibility no-op; connections are managed globally.
+ Returns
+ -------
+ pymongo.collection.Collection
+ The collection object used for database interactions.
- Notes
- -----
- Connections are managed by :class:`MongoConnectionManager`. Use
- :meth:`close_all_connections` to explicitly close all clients.
+ """
+ return self.__collection
+ def close(self) -> None:
+ """Close the MongoDB connection.
+
+ .. deprecated:: 0.1
+ Connections are now managed globally by :class:`MongoConnectionManager`.
+ This method is a no-op and will be removed in a future version.
+ Use :meth:`EEGDash.close_all_connections` to close all clients.
"""
# Individual instances no longer close the shared client
pass
@classmethod
- def close_all_connections(cls):
- """Close all MongoDB client connections managed by the singleton."""
+ def close_all_connections(cls) -> None:
+ """Close all MongoDB client connections managed by the singleton manager."""
MongoConnectionManager.close_all()
- def __del__(self):
+ def __del__(self) -> None:
"""Destructor; no explicit action needed due to global connection manager."""
# No longer needed since we're using singleton pattern
pass
@@ -775,45 +870,30 @@ def _find_local_bids_records(
) -> list[dict]:
"""Discover local BIDS EEG files and build minimal records.
- This helper enumerates EEG recordings under ``dataset_root`` via
- ``mne_bids.find_matching_paths`` and applies entity filters to produce a
- list of records suitable for ``EEGDashBaseDataset``. No network access
- is performed and files are not read.
+ Enumerates EEG recordings under ``dataset_root`` using
+ ``mne_bids.find_matching_paths`` and applies entity filters to produce
+ records suitable for :class:`EEGDashBaseDataset`. No network access is
+ performed, and files are not read.
Parameters
----------
dataset_root : Path
- Local dataset directory. May be the plain dataset folder (e.g.,
- ``ds005509``) or a suffixed cache variant (e.g.,
- ``ds005509-bdf-mini``).
- filters : dict of {str, Any}
- Query filters. Must include ``'dataset'`` with the dataset id (without
- local suffixes). May include BIDS entities ``'subject'``,
- ``'session'``, ``'task'``, and ``'run'``. Each value can be a scalar
- or a sequence of scalars.
+ Local dataset directory (e.g., ``/path/to/cache/ds005509``).
+ filters : dict
+ Query filters. Must include ``'dataset'`` and may include BIDS
+ entities like ``'subject'``, ``'session'``, etc.
Returns
-------
- records : list of dict
- One record per matched EEG file with at least:
-
- - ``'data_name'``
- - ``'dataset'`` (dataset id, without suffixes)
- - ``'bidspath'`` (normalized to start with the dataset id)
- - ``'subject'``, ``'session'``, ``'task'``, ``'run'`` (may be None)
- - ``'bidsdependencies'`` (empty list)
- - ``'modality'`` (``"eeg"``)
- - ``'sampling_frequency'``, ``'nchans'``, ``'ntimes'`` (minimal
- defaults for offline usage)
+ list of dict
+ A list of records, one for each matched EEG file. Each record
+ contains BIDS entities, paths, and minimal metadata for offline use.
Notes
-----
- - Matching uses ``datatypes=['eeg']`` and ``suffixes=['eeg']``.
- - ``bidspath`` is constructed as
- `` / `` to ensure the
- first path component is the dataset id (without local cache suffixes).
- - Minimal defaults are set for ``sampling_frequency``, ``nchans``, and
- ``ntimes`` to satisfy dataset length requirements offline.
+ Matching is performed for ``datatypes=['eeg']`` and ``suffixes=['eeg']``.
+ The ``bidspath`` is normalized to ensure it starts with the dataset ID,
+ even for suffixed cache directories.
"""
dataset_id = filters["dataset"]
@@ -875,10 +955,22 @@ def _find_local_bids_records(
return records_out
def _find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
- """Recursively search for target_key in nested dicts/lists with normalized matching.
+ """Recursively search for a key in nested dicts/lists.
+
+ Performs a case-insensitive and underscore/hyphen-agnostic search.
+
+ Parameters
+ ----------
+ data : Any
+ The nested data structure (dicts, lists) to search.
+ target_key : str
+ The key to search for.
+
+ Returns
+ -------
+ Any
+ The value of the first matching key, or None if not found.
- This makes lookups tolerant to naming differences like "p-factor" vs "p_factor".
- Returns the first match or None.
"""
norm_target = normalize_key(target_key)
if isinstance(data, dict):
@@ -901,23 +993,25 @@ def _find_datasets(
description_fields: list[str],
base_dataset_kwargs: dict,
) -> list[EEGDashBaseDataset]:
- """Helper method to find datasets in the MongoDB collection that satisfy the
- given query and return them as a list of EEGDashBaseDataset objects.
+ """Find and construct datasets from a MongoDB query.
+
+ Queries the database, then creates a list of
+ :class:`EEGDashBaseDataset` objects from the results.
Parameters
----------
- query : dict
- The query object, as in EEGDash.find().
- description_fields : list[str]
- A list of fields to be extracted from the dataset records and included in
- the returned dataset description(s).
- kwargs: additional keyword arguments to be passed to the EEGDashBaseDataset
- constructor.
+ query : dict, optional
+ The MongoDB query to execute.
+ description_fields : list of str
+ Fields to extract from each record for the dataset description.
+ base_dataset_kwargs : dict
+ Additional keyword arguments to pass to the
+ :class:`EEGDashBaseDataset` constructor.
Returns
-------
- list :
- A list of EEGDashBaseDataset objects that match the query.
+ list of EEGDashBaseDataset
+ A list of dataset objects matching the query.
"""
datasets: list[EEGDashBaseDataset] = []
diff --git a/eegdash/bids_eeg_metadata.py b/eegdash/bids_eeg_metadata.py
index 150aed34..bc8648fa 100644
--- a/eegdash/bids_eeg_metadata.py
+++ b/eegdash/bids_eeg_metadata.py
@@ -33,12 +33,30 @@
def build_query_from_kwargs(**kwargs) -> dict[str, Any]:
- """Build and validate a MongoDB query from user-friendly keyword arguments.
+ """Build and validate a MongoDB query from keyword arguments.
+
+ This function converts user-friendly keyword arguments into a valid
+ MongoDB query dictionary. It handles scalar values as exact matches and
+ list-like values as ``$in`` queries. It also performs validation to
+ reject unsupported fields and empty values.
+
+ Parameters
+ ----------
+ **kwargs
+ Keyword arguments representing query filters. Allowed keys are defined
+ in ``eegdash.const.ALLOWED_QUERY_FIELDS``.
+
+ Returns
+ -------
+ dict
+ A MongoDB query dictionary.
+
+ Raises
+ ------
+ ValueError
+ If an unsupported query field is provided, or if a value is None or
+ an empty string/list.
- Improvements:
- - Reject None values and empty/whitespace-only strings
- - For list/tuple/set values: strip strings, drop None/empties, deduplicate, and use `$in`
- - Preserve scalars as exact matches
"""
# 1. Validate that all provided keys are allowed for querying
unknown_fields = set(kwargs.keys()) - ALLOWED_QUERY_FIELDS
@@ -89,24 +107,29 @@ def build_query_from_kwargs(**kwargs) -> dict[str, Any]:
def load_eeg_attrs_from_bids_file(bids_dataset, bids_file: str) -> dict[str, Any]:
- """Build the metadata record for a given BIDS file (single recording) in a BIDS dataset.
+ """Build a metadata record for a BIDS file.
- Attributes are at least the ones defined in data_config attributes (set to None if missing),
- but are typically a superset, and include, among others, the paths to relevant
- meta-data files needed to load and interpret the file in question.
+ Extracts metadata attributes from a single BIDS EEG file within a given
+ BIDS dataset. The extracted attributes include BIDS entities, file paths,
+ and technical metadata required for database indexing.
Parameters
----------
bids_dataset : EEGBIDSDataset
The BIDS dataset object containing the file.
bids_file : str
- The path to the BIDS file within the dataset.
+ The path to the BIDS file to process.
Returns
-------
- dict:
- A dictionary representing the metadata record for the given file. This is the
- same format as the records stored in the database.
+ dict
+ A dictionary of metadata attributes for the file, suitable for
+ insertion into the database.
+
+ Raises
+ ------
+ ValueError
+ If ``bids_file`` is not found in the ``bids_dataset``.
"""
if bids_file not in bids_dataset.files:
@@ -198,11 +221,23 @@ def load_eeg_attrs_from_bids_file(bids_dataset, bids_file: str) -> dict[str, Any
def normalize_key(key: str) -> str:
- """Normalize a metadata key for robust matching.
+ """Normalize a string key for robust matching.
+
+ Converts the key to lowercase, replaces non-alphanumeric characters with
+ underscores, and removes leading/trailing underscores. This allows for
+ tolerant matching of keys that may have different capitalization or
+ separators (e.g., "p-factor" becomes "p_factor").
+
+ Parameters
+ ----------
+ key : str
+ The key to normalize.
+
+ Returns
+ -------
+ str
+ The normalized key.
- Lowercase and replace non-alphanumeric characters with underscores, then strip
- leading/trailing underscores. This allows tolerant matching such as
- "p-factor" ≈ "p_factor" ≈ "P Factor".
"""
return re.sub(r"[^a-z0-9]+", "_", str(key).lower()).strip("_")
@@ -212,27 +247,27 @@ def merge_participants_fields(
participants_row: dict[str, Any] | None,
description_fields: list[str] | None = None,
) -> dict[str, Any]:
- """Merge participants.tsv fields into a dataset description dictionary.
+ """Merge fields from a participants.tsv row into a description dict.
- - Preserves existing entries in ``description`` (no overwrites).
- - Fills requested ``description_fields`` first, preserving their original names.
- - Adds all remaining participants columns generically using normalized keys
- unless a matching requested field already captured them.
+ Enriches a description dictionary with data from a subject's row in
+ ``participants.tsv``. It avoids overwriting existing keys in the
+ description.
Parameters
----------
description : dict
- Current description to be enriched in-place and returned.
- participants_row : dict | None
- A mapping of participants.tsv columns for the current subject.
- description_fields : list[str] | None
- Optional list of requested description fields. When provided, matching is
- performed by normalized names; the original requested field names are kept.
+ The description dictionary to enrich.
+ participants_row : dict or None
+ A dictionary representing a row from ``participants.tsv``. If None,
+ the original description is returned unchanged.
+ description_fields : list of str, optional
+ A list of specific fields to include in the description. Matching is
+ done using normalized keys.
Returns
-------
dict
- The enriched description (same object as input for convenience).
+ The enriched description dictionary.
"""
if not isinstance(description, dict) or not isinstance(participants_row, dict):
@@ -272,10 +307,26 @@ def participants_row_for_subject(
subject: str,
id_columns: tuple[str, ...] = ("participant_id", "participant", "subject"),
) -> pd.Series | None:
- """Load participants.tsv and return the row for a subject.
+ """Load participants.tsv and return the row for a specific subject.
+
+ Searches for a subject's data in the ``participants.tsv`` file within a
+ BIDS dataset. It can identify the subject with or without the "sub-"
+ prefix.
+
+ Parameters
+ ----------
+ bids_root : str or Path
+ The root directory of the BIDS dataset.
+ subject : str
+ The subject identifier (e.g., "01" or "sub-01").
+ id_columns : tuple of str, default ("participant_id", "participant", "subject")
+ A tuple of column names to search for the subject identifier.
+
+ Returns
+ -------
+ pandas.Series or None
+ A pandas Series containing the subject's data if found, otherwise None.
- - Accepts either "01" or "sub-01" as the subject identifier.
- - Returns a pandas Series for the first matching row, or None if not found.
"""
try:
participants_tsv = Path(bids_root) / "participants.tsv"
@@ -311,9 +362,28 @@ def participants_extras_from_tsv(
id_columns: tuple[str, ...] = ("participant_id", "participant", "subject"),
na_like: tuple[str, ...] = ("", "n/a", "na", "nan", "unknown", "none"),
) -> dict[str, Any]:
- """Return non-identifier, non-empty participants.tsv fields for a subject.
+ """Extract additional participant information from participants.tsv.
+
+ Retrieves all non-identifier and non-empty fields for a subject from
+ the ``participants.tsv`` file.
+
+ Parameters
+ ----------
+ bids_root : str or Path
+ The root directory of the BIDS dataset.
+ subject : str
+ The subject identifier.
+ id_columns : tuple of str, default ("participant_id", "participant", "subject")
+ Column names to be treated as identifiers and excluded from the
+ output.
+ na_like : tuple of str, default ("", "n/a", "na", "nan", "unknown", "none")
+ Values to be considered as "Not Available" and excluded.
+
+ Returns
+ -------
+ dict
+ A dictionary of extra participant information.
- Uses vectorized pandas operations to drop id columns and NA-like values.
"""
row = participants_row_for_subject(bids_root, subject, id_columns=id_columns)
if row is None:
@@ -331,10 +401,21 @@ def attach_participants_extras(
description: Any,
extras: dict[str, Any],
) -> None:
- """Attach extras to Raw.info and dataset description without overwriting.
+ """Attach extra participant data to a raw object and its description.
+
+ Updates the ``raw.info['subject_info']`` and the description object
+ (dict or pandas Series) with extra data from ``participants.tsv``.
+ It does not overwrite existing keys.
+
+ Parameters
+ ----------
+ raw : mne.io.Raw
+ The MNE Raw object to be updated.
+ description : dict or pandas.Series
+ The description object to be updated.
+ extras : dict
+ A dictionary of extra participant information to attach.
- - Adds to ``raw.info['subject_info']['participants_extras']``.
- - Adds to ``description`` if dict or pandas Series (only missing keys).
"""
if not extras:
return
@@ -375,9 +456,28 @@ def enrich_from_participants(
raw: Any,
description: Any,
) -> dict[str, Any]:
- """Convenience wrapper: read participants.tsv and attach extras for this subject.
+ """Read participants.tsv and attach extra info for the subject.
+
+ This is a convenience function that finds the subject from the
+ ``bidspath``, retrieves extra information from ``participants.tsv``,
+ and attaches it to the raw object and its description.
+
+ Parameters
+ ----------
+ bids_root : str or Path
+ The root directory of the BIDS dataset.
+ bidspath : mne_bids.BIDSPath
+ The BIDSPath object for the current data file.
+ raw : mne.io.Raw
+ The MNE Raw object to be updated.
+ description : dict or pandas.Series
+ The description object to be updated.
+
+ Returns
+ -------
+ dict
+ The dictionary of extras that were attached.
- Returns the extras dictionary for further use if needed.
"""
subject = getattr(bidspath, "subject", None)
if not subject:
diff --git a/eegdash/const.py b/eegdash/const.py
index 897575fe..447c89d9 100644
--- a/eegdash/const.py
+++ b/eegdash/const.py
@@ -28,6 +28,8 @@
"nchans",
"ntimes",
}
+"""set: A set of field names that are permitted in database queries constructed
+via :func:`~eegdash.api.EEGDash.find` with keyword arguments."""
RELEASE_TO_OPENNEURO_DATASET_MAP = {
"R11": "ds005516",
@@ -42,6 +44,8 @@
"R2": "ds005506",
"R1": "ds005505",
}
+"""dict: A mapping from Healthy Brain Network (HBN) release identifiers (e.g., "R11")
+to their corresponding OpenNeuro dataset identifiers (e.g., "ds005516")."""
SUBJECT_MINI_RELEASE_MAP = {
"R11": [
@@ -287,6 +291,9 @@
"NDARFW972KFQ",
],
}
+"""dict: A mapping from HBN release identifiers to a list of subject IDs.
+This is used to select a small, representative subset of subjects for creating
+"mini" datasets for testing and demonstration purposes."""
config = {
"required_fields": ["data_name"],
@@ -322,3 +329,21 @@
],
"accepted_query_fields": ["data_name", "dataset"],
}
+"""dict: A global configuration dictionary for the EEGDash package.
+
+Keys
+----
+required_fields : list
+ Fields that must be present in every database record.
+attributes : dict
+ A schema defining the expected primary attributes and their types for a
+ database record.
+description_fields : list
+ A list of fields considered to be descriptive metadata for a recording,
+ which can be used for filtering and display.
+bids_dependencies_files : list
+ A list of BIDS metadata filenames that are relevant for interpreting an
+ EEG recording.
+accepted_query_fields : list
+ Fields that are accepted for lightweight existence checks in the database.
+"""
diff --git a/eegdash/data_utils.py b/eegdash/data_utils.py
index 2cc18ca8..199d94d2 100644
--- a/eegdash/data_utils.py
+++ b/eegdash/data_utils.py
@@ -37,10 +37,26 @@
class EEGDashBaseDataset(BaseDataset):
- """A single EEG recording hosted on AWS S3 and cached locally upon first access.
+ """A single EEG recording dataset.
+
+ Represents a single EEG recording, typically hosted on a remote server (like AWS S3)
+ and cached locally upon first access. This class is a subclass of
+ :class:`braindecode.datasets.BaseDataset` and can be used with braindecode's
+ preprocessing and training pipelines.
+
+ Parameters
+ ----------
+ record : dict
+ A fully resolved metadata record for the data to load.
+ cache_dir : str
+ The local directory where the data will be cached.
+ s3_bucket : str, optional
+ The S3 bucket to download data from. If not provided, defaults to the
+ OpenNeuro bucket.
+ **kwargs
+ Additional keyword arguments passed to the
+ :class:`braindecode.datasets.BaseDataset` constructor.
- This is a subclass of braindecode's BaseDataset, which can consequently be used in
- conjunction with the preprocessing and training pipelines of braindecode.
"""
_AWS_BUCKET = "s3://openneuro.org"
@@ -52,20 +68,6 @@ def __init__(
s3_bucket: str | None = None,
**kwargs,
):
- """Create a new EEGDashBaseDataset instance. Users do not usually need to call this
- directly -- instead use the EEGDashDataset class to load a collection of these
- recordings from a local BIDS folder or using a database query.
-
- Parameters
- ----------
- record : dict
- A fully resolved metadata record for the data to load.
- cache_dir : str
- A local directory where the data will be cached.
- kwargs : dict
- Additional keyword arguments to pass to the BaseDataset constructor.
-
- """
super().__init__(None, **kwargs)
self.record = record
self.cache_dir = Path(cache_dir)
@@ -121,14 +123,12 @@ def __init__(
self._raw = None
def _get_raw_bids_args(self) -> dict[str, Any]:
- """Helper to restrict the metadata record to the fields needed to locate a BIDS
- recording.
- """
+ """Extract BIDS-related arguments from the metadata record."""
desired_fields = ["subject", "session", "task", "run"]
return {k: self.record[k] for k in desired_fields if self.record[k]}
def _ensure_raw(self) -> None:
- """Download the S3 file and BIDS dependencies if not already cached."""
+ """Ensure the raw data file and its dependencies are cached locally."""
# TO-DO: remove this once is fixed on the our side
# for the competition
if not self.s3_open_neuro:
@@ -190,42 +190,53 @@ def __len__(self) -> int:
return len(self._raw)
@property
- def raw(self):
- """Return the MNE Raw object for this recording. This will perform the actual
- retrieval if not yet done so.
+ def raw(self) -> BaseRaw:
+ """The MNE Raw object for this recording.
+
+ Accessing this property triggers the download and caching of the data
+ if it has not been accessed before.
+
+ Returns
+ -------
+ mne.io.BaseRaw
+ The loaded MNE Raw object.
+
"""
if self._raw is None:
self._ensure_raw()
return self._raw
@raw.setter
- def raw(self, raw):
+ def raw(self, raw: BaseRaw):
self._raw = raw
class EEGDashBaseRaw(BaseRaw):
- """Wrapper around the MNE BaseRaw class that automatically fetches the data from S3
- (when _read_segment is called) and caches it locally. Currently for internal use.
+ """MNE BaseRaw wrapper for automatic S3 data fetching.
+
+ This class extends :class:`mne.io.BaseRaw` to automatically fetch data
+ from an S3 bucket and cache it locally when data is first accessed.
+ It is intended for internal use within the EEGDash ecosystem.
Parameters
----------
- input_fname : path-like
- Path to the S3 file
+ input_fname : str
+ The path to the file on the S3 bucket (relative to the bucket root).
metadata : dict
- The metadata record for the recording (e.g., from the database).
- preload : bool
- Whether to pre-loaded the data before the first access.
- cache_dir : str
- Local path under which the data will be cached.
- bids_dependencies : list
- List of additional BIDS metadata files that should be downloaded and cached
- alongside the main recording file.
- verbose : str | int | None
- Optionally the verbosity level for MNE logging (see MNE documentation for possible values).
+ The metadata record for the recording, containing information like
+ sampling frequency, channel names, etc.
+ preload : bool, default False
+ If True, preload the data into memory.
+ cache_dir : str, optional
+ Local directory for caching data. If None, a default directory is used.
+ bids_dependencies : list of str, default []
+ A list of BIDS metadata files to download alongside the main recording.
+ verbose : str, int, or None, default None
+ The MNE verbosity level.
See Also
--------
- mne.io.Raw : Documentation of attributes and methods.
+ mne.io.Raw : The base class for Raw objects in MNE.
"""
@@ -241,7 +252,6 @@ def __init__(
bids_dependencies: list[str] = [],
verbose: Any = None,
):
- """Get to work with S3 endpoint first, no caching"""
# Create a simple RawArray
sfreq = metadata["sfreq"] # Sampling frequency
n_times = metadata["n_times"]
@@ -277,6 +287,7 @@ def __init__(
def _read_segment(
self, start=0, stop=None, sel=None, data_buffer=None, *, verbose=None
):
+ """Read a segment of data, downloading if necessary."""
if not os.path.exists(self.filecache): # not preload
if self.bids_dependencies: # this is use only to sidecars for now
downloader.download_dependencies(
@@ -297,22 +308,23 @@ def _read_segment(
return super()._read_segment(start, stop, sel, data_buffer, verbose=verbose)
def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
- """Read a chunk of data from the file."""
+ """Read a chunk of data from a local file."""
_read_segments_file(self, data, idx, fi, start, stop, cals, mult, dtype=" bool:
- """Check if the dataset is EEG."""
+ """Check if the BIDS dataset contains EEG data.
+
+ Returns
+ -------
+ bool
+ True if the dataset's modality is EEG, False otherwise.
+
+ """
return self.get_bids_file_attribute("modality", self.files[0]).lower() == "eeg"
def _get_recordings(self, layout: BIDSLayout) -> list[str]:
@@ -370,14 +389,12 @@ def _get_recordings(self, layout: BIDSLayout) -> list[str]:
return files
def _get_relative_bidspath(self, filename: str) -> str:
- """Make the given file path relative to the BIDS directory."""
+ """Make a file path relative to the BIDS parent directory."""
bids_parent_dir = self.bidsdir.parent.absolute()
return str(Path(filename).relative_to(bids_parent_dir))
def _get_property_from_filename(self, property: str, filename: str) -> str:
- """Parse a property out of a BIDS-compliant filename. Returns an empty string
- if not found.
- """
+ """Parse a BIDS entity from a filename."""
import platform
if platform.system() == "Windows":
@@ -387,159 +404,106 @@ def _get_property_from_filename(self, property: str, filename: str) -> str:
return lookup.group(1) if lookup else ""
def _merge_json_inheritance(self, json_files: list[str | Path]) -> dict:
- """Internal helper to merge list of json files found by get_bids_file_inheritance,
- expecting the order (from left to right) is from lowest
- level to highest level, and return a merged dictionary
- """
+ """Merge a list of JSON files according to BIDS inheritance."""
json_files.reverse()
json_dict = {}
for f in json_files:
- json_dict.update(json.load(open(f))) # FIXME: should close file
+ with open(f) as fp:
+ json_dict.update(json.load(fp))
return json_dict
def _get_bids_file_inheritance(
self, path: str | Path, basename: str, extension: str
) -> list[Path]:
- """Get all file paths that apply to the basename file in the specified directory
- and that end with the specified suffix, recursively searching parent directories
- (following the BIDS inheritance principle in the order of lowest level first).
-
- Parameters
- ----------
- path : str | Path
- The directory path to search for files.
- basename : str
- BIDS file basename without _eeg.set extension for example
- extension : str
- Only consider files that end with the specified suffix; e.g. channels.tsv
-
- Returns
- -------
- list[Path]
- A list of file paths that match the given basename and extension.
-
- """
+ """Find all applicable metadata files using BIDS inheritance."""
top_level_files = ["README", "dataset_description.json", "participants.tsv"]
bids_files = []
- # check if path is str object
if isinstance(path, str):
path = Path(path)
- if not path.exists:
- raise ValueError("path {path} does not exist")
+ if not path.exists():
+ raise ValueError(f"path {path} does not exist")
- # check if file is in current path
for file in os.listdir(path):
- # target_file = path / f"{cur_file_basename}_{extension}"
- if os.path.isfile(path / file):
- # check if file has extension extension
- # check if file basename has extension
- if file.endswith(extension):
- filepath = path / file
- bids_files.append(filepath)
-
- # check if file is in top level directory
+ if os.path.isfile(path / file) and file.endswith(extension):
+ bids_files.append(path / file)
+
if any(file in os.listdir(path) for file in top_level_files):
return bids_files
else:
- # call get_bids_file_inheritance recursively with parent directory
bids_files.extend(
self._get_bids_file_inheritance(path.parent, basename, extension)
)
return bids_files
def get_bids_metadata_files(
- self, filepath: str | Path, metadata_file_extension: list[str]
+ self, filepath: str | Path, metadata_file_extension: str
) -> list[Path]:
- """Retrieve all metadata file paths that apply to a given data file path and that
- end with a specific suffix (following the BIDS inheritance principle).
+ """Retrieve all metadata files that apply to a given data file.
+
+ Follows the BIDS inheritance principle to find all relevant metadata
+ files (e.g., ``channels.tsv``, ``eeg.json``) for a specific recording.
Parameters
----------
- filepath: str | Path
- The filepath to get the associated metadata files for.
+ filepath : str or Path
+ The path to the data file.
metadata_file_extension : str
- Consider only metadata files that end with the specified suffix,
- e.g., channels.tsv or eeg.json
+ The extension of the metadata file to search for (e.g., "channels.tsv").
Returns
-------
- list[Path]:
- A list of filepaths for all matching metadata files
+ list of Path
+ A list of paths to the matching metadata files.
"""
if isinstance(filepath, str):
filepath = Path(filepath)
- if not filepath.exists:
- raise ValueError("filepath {filepath} does not exist")
+ if not filepath.exists():
+ raise ValueError(f"filepath {filepath} does not exist")
path, filename = os.path.split(filepath)
basename = filename[: filename.rfind("_")]
- # metadata files
meta_files = self._get_bids_file_inheritance(
path, basename, metadata_file_extension
)
return meta_files
def _scan_directory(self, directory: str, extension: str) -> list[Path]:
- """Return a list of file paths that end with the given extension in the specified
- directory. Ignores certain special directories like .git, .datalad, derivatives,
- and code.
- """
+ """Scan a directory for files with a given extension."""
result_files = []
directory_to_ignore = [".git", ".datalad", "derivatives", "code"]
with os.scandir(directory) as entries:
for entry in entries:
if entry.is_file() and entry.name.endswith(extension):
- result_files.append(entry.path)
- elif entry.is_dir():
- # check that entry path doesn't contain any name in ignore list
- if not any(name in entry.name for name in directory_to_ignore):
- result_files.append(entry.path) # Add directory to scan later
+ result_files.append(Path(entry.path))
+ elif entry.is_dir() and not any(
+ name in entry.name for name in directory_to_ignore
+ ):
+ result_files.append(Path(entry.path))
return result_files
def _get_files_with_extension_parallel(
self, directory: str, extension: str = ".set", max_workers: int = -1
) -> list[Path]:
- """Efficiently scan a directory and its subdirectories for files that end with
- the given extension.
-
- Parameters
- ----------
- directory : str
- The root directory to scan for files.
- extension : str
- Only consider files that end with this suffix, e.g. '.set'.
- max_workers : int
- Optionally specify the maximum number of worker threads to use for parallel scanning.
- Defaults to all available CPU cores if set to -1.
-
- Returns
- -------
- list[Path]:
- A list of filepaths for all matching metadata files
-
- """
+ """Scan a directory tree in parallel for files with a given extension."""
result_files = []
dirs_to_scan = [directory]
- # Use joblib.Parallel and delayed to parallelize directory scanning
while dirs_to_scan:
logger.info(
f"Directories to scan: {len(dirs_to_scan)}, files: {dirs_to_scan}"
)
- # Run the scan_directory function in parallel across directories
results = Parallel(n_jobs=max_workers, prefer="threads", verbose=1)(
delayed(self._scan_directory)(d, extension) for d in dirs_to_scan
)
- # Reset the directories to scan and process the results
dirs_to_scan = []
for res in results:
for path in res:
if os.path.isdir(path):
- dirs_to_scan.append(path) # Queue up subdirectories to scan
+ dirs_to_scan.append(path)
else:
- result_files.append(path) # Add files to the final result
+ result_files.append(path)
logger.info(f"Found {len(result_files)} files.")
return result_files
@@ -547,19 +511,29 @@ def _get_files_with_extension_parallel(
def load_and_preprocess_raw(
self, raw_file: str, preprocess: bool = False
) -> np.ndarray:
- """Utility function to load a raw data file with MNE and apply some simple
- (hardcoded) preprocessing and return as a numpy array. Not meant for purposes
- other than testing or debugging.
+ """Load and optionally preprocess a raw data file.
+
+ This is a utility function for testing or debugging, not for general use.
+
+ Parameters
+ ----------
+ raw_file : str
+ Path to the raw EEGLAB file (.set).
+ preprocess : bool, default False
+ If True, apply a high-pass filter, notch filter, and resample the data.
+
+ Returns
+ -------
+ numpy.ndarray
+ The loaded and processed data as a NumPy array.
+
"""
logger.info(f"Loading raw data from {raw_file}")
EEG = mne.io.read_raw_eeglab(raw_file, preload=True, verbose="error")
if preprocess:
- # highpass filter
EEG = EEG.filter(l_freq=0.25, h_freq=25, verbose=False)
- # remove 60Hz line noise
EEG = EEG.notch_filter(freqs=(60), verbose=False)
- # bring to common sampling rate
sfreq = 128
if EEG.info["sfreq"] != sfreq:
EEG = EEG.resample(sfreq)
@@ -570,26 +544,35 @@ def load_and_preprocess_raw(
raise ValueError("Expect raw data to be CxT dimension")
return mat_data
- def get_files(self) -> list[Path]:
- """Get all EEG recording file paths (with valid extensions) in the BIDS folder."""
+ def get_files(self) -> list[str]:
+ """Get all EEG recording file paths in the BIDS dataset.
+
+ Returns
+ -------
+ list of str
+ A list of file paths for all valid EEG recordings.
+
+ """
return self.files
def resolve_bids_json(self, json_files: list[str]) -> dict:
- """Resolve the BIDS JSON files and return a dictionary of the resolved values.
+ """Resolve BIDS JSON inheritance and merge files.
Parameters
----------
- json_files : list
- A list of JSON file paths to resolve in order of leaf level first.
+ json_files : list of str
+ A list of JSON file paths, ordered from the lowest (most specific)
+ to highest level of the BIDS hierarchy.
Returns
-------
- dict: A dictionary of the resolved values.
+ dict
+ A dictionary containing the merged JSON data.
"""
- if len(json_files) == 0:
+ if not json_files:
raise ValueError("No JSON files provided")
- json_files.reverse() # TODO undeterministic
+ json_files.reverse()
json_dict = {}
for json_file in json_files:
@@ -598,8 +581,20 @@ def resolve_bids_json(self, json_files: list[str]) -> dict:
return json_dict
def get_bids_file_attribute(self, attribute: str, data_filepath: str) -> Any:
- """Retrieve a specific attribute from the BIDS file metadata applicable
- to the provided recording file path.
+ """Retrieve a specific attribute from BIDS metadata.
+
+ Parameters
+ ----------
+ attribute : str
+ The name of the attribute to retrieve (e.g., "sfreq", "subject").
+ data_filepath : str
+ The path to the data file.
+
+ Returns
+ -------
+ Any
+ The value of the requested attribute, or None if not found.
+
"""
entities = self.layout.parse_file_entities(data_filepath)
bidsfile = self.layout.get(**entities)[0]
@@ -618,21 +613,59 @@ def get_bids_file_attribute(self, attribute: str, data_filepath: str) -> Any:
return attribute_value
def channel_labels(self, data_filepath: str) -> list[str]:
- """Get a list of channel labels for the given data file path."""
+ """Get a list of channel labels from channels.tsv.
+
+ Parameters
+ ----------
+ data_filepath : str
+ The path to the data file.
+
+ Returns
+ -------
+ list of str
+ A list of channel names.
+
+ """
channels_tsv = pd.read_csv(
self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
)
return channels_tsv["name"].tolist()
def channel_types(self, data_filepath: str) -> list[str]:
- """Get a list of channel types for the given data file path."""
+ """Get a list of channel types from channels.tsv.
+
+ Parameters
+ ----------
+ data_filepath : str
+ The path to the data file.
+
+ Returns
+ -------
+ list of str
+ A list of channel types.
+
+ """
channels_tsv = pd.read_csv(
self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
)
return channels_tsv["type"].tolist()
def num_times(self, data_filepath: str) -> int:
- """Get the approximate number of time points in the EEG recording based on the BIDS metadata."""
+ """Get the number of time points in the recording.
+
+ Calculated from ``SamplingFrequency`` and ``RecordingDuration`` in eeg.json.
+
+ Parameters
+ ----------
+ data_filepath : str
+ The path to the data file.
+
+ Returns
+ -------
+ int
+ The approximate number of time points.
+
+ """
eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
eeg_json_dict = self._merge_json_inheritance(eeg_jsons)
return int(
@@ -640,38 +673,71 @@ def num_times(self, data_filepath: str) -> int:
)
def subject_participant_tsv(self, data_filepath: str) -> dict[str, Any]:
- """Get BIDS participants.tsv record for the subject to which the given file
- path corresponds, as a dictionary.
+ """Get the participants.tsv record for a subject.
+
+ Parameters
+ ----------
+ data_filepath : str
+ The path to a data file belonging to the subject.
+
+ Returns
+ -------
+ dict
+ A dictionary of the subject's information from participants.tsv.
+
"""
- participants_tsv = pd.read_csv(
- self.get_bids_metadata_files(data_filepath, "participants.tsv")[0], sep="\t"
- )
- # if participants_tsv is not empty
+ participants_tsv_path = self.get_bids_metadata_files(
+ data_filepath, "participants.tsv"
+ )[0]
+ participants_tsv = pd.read_csv(participants_tsv_path, sep="\t")
if participants_tsv.empty:
return {}
- # set 'participant_id' as index
participants_tsv.set_index("participant_id", inplace=True)
subject = f"sub-{self.get_bids_file_attribute('subject', data_filepath)}"
return participants_tsv.loc[subject].to_dict()
def eeg_json(self, data_filepath: str) -> dict[str, Any]:
- """Get BIDS eeg.json metadata for the given data file path."""
+ """Get the merged eeg.json metadata for a data file.
+
+ Parameters
+ ----------
+ data_filepath : str
+ The path to the data file.
+
+ Returns
+ -------
+ dict
+ The merged eeg.json metadata.
+
+ """
eeg_jsons = self.get_bids_metadata_files(data_filepath, "eeg.json")
- eeg_json_dict = self._merge_json_inheritance(eeg_jsons)
- return eeg_json_dict
+ return self._merge_json_inheritance(eeg_jsons)
def channel_tsv(self, data_filepath: str) -> dict[str, Any]:
- """Get BIDS channels.tsv metadata for the given data file path, as a dictionary
- of lists and/or single values.
+ """Get the channels.tsv metadata as a dictionary.
+
+ Parameters
+ ----------
+ data_filepath : str
+ The path to the data file.
+
+ Returns
+ -------
+ dict
+ The channels.tsv data, with columns as keys.
+
"""
- channels_tsv = pd.read_csv(
- self.get_bids_metadata_files(data_filepath, "channels.tsv")[0], sep="\t"
- )
- channel_tsv = channels_tsv.to_dict()
- # 'name' and 'type' now have a dictionary of index-value. Convert them to list
+ channels_tsv_path = self.get_bids_metadata_files(data_filepath, "channels.tsv")[
+ 0
+ ]
+ channels_tsv = pd.read_csv(channels_tsv_path, sep="\t")
+ channel_tsv_dict = channels_tsv.to_dict()
for list_field in ["name", "type", "units"]:
- channel_tsv[list_field] = list(channel_tsv[list_field].values())
- return channel_tsv
+ if list_field in channel_tsv_dict:
+ channel_tsv_dict[list_field] = list(
+ channel_tsv_dict[list_field].values()
+ )
+ return channel_tsv_dict
__all__ = ["EEGDashBaseDataset", "EEGBIDSDataset", "EEGDashBaseRaw"]
diff --git a/eegdash/dataset/dataset.py b/eegdash/dataset/dataset.py
index 480f2762..e26e32fe 100644
--- a/eegdash/dataset/dataset.py
+++ b/eegdash/dataset/dataset.py
@@ -12,26 +12,48 @@
class EEGChallengeDataset(EEGDashDataset):
- """EEG 2025 Challenge dataset helper.
+ """A dataset helper for the EEG 2025 Challenge.
- This class provides a convenient wrapper around :class:`EEGDashDataset`
- configured for the EEG 2025 Challenge releases. It maps a given
- ``release`` to its corresponding OpenNeuro dataset and optionally restricts
- to the official "mini" subject subset.
+ This class simplifies access to the EEG 2025 Challenge datasets. It is a
+ specialized version of :class:`~eegdash.api.EEGDashDataset` that is
+ pre-configured for the challenge's data releases. It automatically maps a
+ release name (e.g., "R1") to the corresponding OpenNeuro dataset and handles
+ the selection of subject subsets (e.g., "mini" release).
Parameters
----------
release : str
- Release name. One of ["R1", ..., "R11"].
+ The name of the challenge release to load. Must be one of the keys in
+ :const:`~eegdash.const.RELEASE_TO_OPENNEURO_DATASET_MAP`
+ (e.g., "R1", "R2", ..., "R11").
+ cache_dir : str
+ The local directory where the dataset will be downloaded and cached.
mini : bool, default True
- If True, restrict subjects to the challenge mini subset.
- query : dict | None
- Additional MongoDB-style filters to AND with the release selection.
- Must not contain the key ``dataset``.
- s3_bucket : str | None, default "s3://nmdatasets/NeurIPS25"
- Base S3 bucket used to locate the challenge data.
+ If True, the dataset is restricted to the official "mini" subset of
+ subjects for the specified release. If False, all subjects for the
+ release are included.
+ query : dict, optional
+ An additional MongoDB-style query to apply as a filter. This query is
+ combined with the release and subject filters using a logical AND.
+ The query must not contain the ``dataset`` key, as this is determined
+ by the ``release`` parameter.
+ s3_bucket : str, optional
+ The base S3 bucket URI where the challenge data is stored. Defaults to
+ the official challenge bucket.
**kwargs
- Passed through to :class:`EEGDashDataset`.
+ Additional keyword arguments that are passed directly to the
+ :class:`~eegdash.api.EEGDashDataset` constructor.
+
+ Raises
+ ------
+ ValueError
+ If the specified ``release`` is unknown, or if the ``query`` argument
+ contains a ``dataset`` key. Also raised if ``mini`` is True and a
+ requested subject is not part of the official mini-release subset.
+
+ See Also
+ --------
+ EEGDashDataset : The base class for creating datasets from queries.
"""
diff --git a/eegdash/dataset/registry.py b/eegdash/dataset/registry.py
index 5eeacb1e..0d40db9c 100644
--- a/eegdash/dataset/registry.py
+++ b/eegdash/dataset/registry.py
@@ -14,7 +14,35 @@ def register_openneuro_datasets(
namespace: Dict[str, Any] | None = None,
add_to_all: bool = True,
) -> Dict[str, type]:
- """Dynamically create dataset classes from a summary file."""
+ """Dynamically create and register dataset classes from a summary file.
+
+ This function reads a CSV file containing summaries of OpenNeuro datasets
+ and dynamically creates a Python class for each dataset. These classes
+ inherit from a specified base class and are pre-configured with the
+ dataset's ID.
+
+ Parameters
+ ----------
+ summary_file : str or pathlib.Path
+ The path to the CSV file containing the dataset summaries.
+ base_class : type, optional
+ The base class from which the new dataset classes will inherit. If not
+ provided, :class:`eegdash.api.EEGDashDataset` is used.
+ namespace : dict, optional
+ The namespace (e.g., `globals()`) into which the newly created classes
+ will be injected. Defaults to the local `globals()` of this module.
+ add_to_all : bool, default True
+ If True, the names of the newly created classes are added to the
+ `__all__` list of the target namespace, making them importable with
+ `from ... import *`.
+
+ Returns
+ -------
+ dict[str, type]
+ A dictionary mapping the names of the registered classes to the class
+ types themselves.
+
+ """
if base_class is None:
from ..api import EEGDashDataset as base_class # lazy import
@@ -84,8 +112,28 @@ def __init__(
return registered
-def _generate_rich_docstring(dataset_id: str, row_series: pd.Series, base_class) -> str:
- """Generate a comprehensive docstring for a dataset class."""
+def _generate_rich_docstring(
+ dataset_id: str, row_series: pd.Series, base_class: type
+) -> str:
+ """Generate a comprehensive, well-formatted docstring for a dataset class.
+
+ Parameters
+ ----------
+ dataset_id : str
+ The identifier of the dataset (e.g., "ds002718").
+ row_series : pandas.Series
+ A pandas Series containing the metadata for the dataset, extracted
+ from the summary CSV file.
+ base_class : type
+ The base class from which the new dataset class inherits. Used to
+ generate the "See Also" section of the docstring.
+
+ Returns
+ -------
+ str
+ A formatted docstring.
+
+ """
# Extract metadata with safe defaults
n_subjects = row_series.get("n_subjects", "Unknown")
n_records = row_series.get("n_records", "Unknown")
@@ -173,7 +221,24 @@ def _generate_rich_docstring(dataset_id: str, row_series: pd.Series, base_class)
def _markdown_table(row_series: pd.Series) -> str:
- """Create a reStructuredText grid table from a pandas Series."""
+ """Create a reStructuredText grid table from a pandas Series.
+
+ This helper function takes a pandas Series containing dataset metadata
+ and formats it into a reStructuredText grid table for inclusion in
+ docstrings.
+
+ Parameters
+ ----------
+ row_series : pandas.Series
+ A Series where each index is a metadata field and each value is the
+ corresponding metadata value.
+
+ Returns
+ -------
+ str
+ A string containing the formatted reStructuredText table.
+
+ """
if row_series.empty:
return ""
dataset_id = row_series["dataset"]
diff --git a/eegdash/downloader.py b/eegdash/downloader.py
index efbbc12d..70eb31d3 100644
--- a/eegdash/downloader.py
+++ b/eegdash/downloader.py
@@ -17,18 +17,62 @@
from fsspec.callbacks import TqdmCallback
-def get_s3_filesystem():
- """Returns an S3FileSystem object."""
+def get_s3_filesystem() -> s3fs.S3FileSystem:
+ """Get an anonymous S3 filesystem object.
+
+ Initializes and returns an ``s3fs.S3FileSystem`` for anonymous access
+ to public S3 buckets, configured for the 'us-east-2' region.
+
+ Returns
+ -------
+ s3fs.S3FileSystem
+ An S3 filesystem object.
+
+ """
return s3fs.S3FileSystem(anon=True, client_kwargs={"region_name": "us-east-2"})
def get_s3path(s3_bucket: str, filepath: str) -> str:
- """Helper to form an AWS S3 URI for the given relative filepath."""
+ """Construct an S3 URI from a bucket and file path.
+
+ Parameters
+ ----------
+ s3_bucket : str
+ The S3 bucket name (e.g., "s3://my-bucket").
+ filepath : str
+ The path to the file within the bucket.
+
+ Returns
+ -------
+ str
+ The full S3 URI (e.g., "s3://my-bucket/path/to/file").
+
+ """
return f"{s3_bucket}/{filepath}"
-def download_s3_file(s3_path: str, local_path: Path, s3_open_neuro: bool):
- """Download function that gets the raw EEG data from S3."""
+def download_s3_file(s3_path: str, local_path: Path, s3_open_neuro: bool) -> Path:
+ """Download a single file from S3 to a local path.
+
+ Handles the download of a raw EEG data file from an S3 bucket, caching it
+ at the specified local path. Creates parent directories if they do not exist.
+
+ Parameters
+ ----------
+ s3_path : str
+ The full S3 URI of the file to download.
+ local_path : pathlib.Path
+ The local file path where the downloaded file will be saved.
+ s3_open_neuro : bool
+ A flag indicating if the S3 bucket is the OpenNeuro main bucket, which
+ may affect path handling.
+
+ Returns
+ -------
+ pathlib.Path
+ The local path to the downloaded file.
+
+ """
filesystem = get_s3_filesystem()
if not s3_open_neuro:
s3_path = re.sub(r"(^|/)ds\d{6}/", r"\1", s3_path, count=1)
@@ -51,8 +95,31 @@ def download_dependencies(
dataset_folder: Path,
record: dict[str, Any],
s3_open_neuro: bool,
-):
- """Download all BIDS dependency files from S3 and cache them locally."""
+) -> None:
+ """Download all BIDS dependency files from S3.
+
+ Iterates through a list of BIDS dependency files, downloads each from the
+ specified S3 bucket, and caches them in the appropriate local directory
+ structure.
+
+ Parameters
+ ----------
+ s3_bucket : str
+ The S3 bucket to download from.
+ bids_dependencies : list of str
+ A list of dependency file paths relative to the S3 bucket root.
+ bids_dependencies_original : list of str
+ The original dependency paths, used for resolving local cache paths.
+ cache_dir : pathlib.Path
+ The root directory for caching.
+ dataset_folder : pathlib.Path
+ The specific folder for the dataset within the cache directory.
+ record : dict
+ The metadata record for the main data file, used to resolve paths.
+ s3_open_neuro : bool
+ Flag for OpenNeuro-specific path handling.
+
+ """
filesystem = get_s3_filesystem()
for i, dep in enumerate(bids_dependencies):
if not s3_open_neuro:
@@ -78,8 +145,27 @@ def download_dependencies(
_filesystem_get(filesystem=filesystem, s3path=s3path, filepath=filepath)
-def _filesystem_get(filesystem: s3fs.S3FileSystem, s3path: str, filepath: Path):
- """Helper to download a file from S3 with a progress bar."""
+def _filesystem_get(filesystem: s3fs.S3FileSystem, s3path: str, filepath: Path) -> Path:
+ """Perform the file download using fsspec with a progress bar.
+
+ Internal helper function that wraps the ``filesystem.get`` call to include
+ a TQDM progress bar.
+
+ Parameters
+ ----------
+ filesystem : s3fs.S3FileSystem
+ The filesystem object to use for the download.
+ s3path : str
+ The full S3 URI of the source file.
+ filepath : pathlib.Path
+ The local destination path.
+
+ Returns
+ -------
+ pathlib.Path
+ The local path to the downloaded file.
+
+ """
info = filesystem.info(s3path)
size = info.get("size") or info.get("Size")
diff --git a/eegdash/features/datasets.py b/eegdash/features/datasets.py
index f4b021cc..9e933c59 100644
--- a/eegdash/features/datasets.py
+++ b/eegdash/features/datasets.py
@@ -20,20 +20,34 @@
class FeaturesDataset(EEGWindowsDataset):
- """Returns samples from a pandas DataFrame object along with a target.
+ """A dataset of features extracted from EEG windows.
- Dataset which serves samples from a pandas DataFrame object along with a
- target. The target is unique for the dataset, and is obtained through the
- `description` attribute.
+ This class holds features in a pandas DataFrame and provides an interface
+ compatible with braindecode's dataset structure. Each row in the feature
+ DataFrame corresponds to a single sample (e.g., an EEG window).
Parameters
----------
- features : a pandas DataFrame
- Tabular data.
- description : dict | pandas.Series | None
- Holds additional description about the continuous signal / subject.
- transform : callable | None
- On-the-fly transform applied to the example before it is returned.
+ features : pandas.DataFrame
+ A DataFrame where each row is a sample and each column is a feature.
+ metadata : pandas.DataFrame, optional
+ A DataFrame containing metadata for each sample, indexed consistently
+ with `features`. Must include columns 'i_window_in_trial',
+ 'i_start_in_trial', 'i_stop_in_trial', and 'target'.
+ description : dict or pandas.Series, optional
+ Additional high-level information about the dataset (e.g., subject ID).
+ transform : callable, optional
+ A function or transform to apply to the feature data on-the-fly.
+ raw_info : dict, optional
+ Information about the original raw recording, for provenance.
+ raw_preproc_kwargs : dict, optional
+ Keyword arguments used for preprocessing the raw data.
+ window_kwargs : dict, optional
+ Keyword arguments used for windowing the data.
+ window_preproc_kwargs : dict, optional
+ Keyword arguments used for preprocessing the windowed data.
+ features_kwargs : dict, optional
+ Keyword arguments used for feature extraction.
"""
@@ -65,7 +79,21 @@ def __init__(
].to_numpy()
self.y = metadata.loc[:, "target"].to_list()
- def __getitem__(self, index):
+ def __getitem__(self, index: int) -> tuple[np.ndarray, int, list]:
+ """Get a single sample from the dataset.
+
+ Parameters
+ ----------
+ index : int
+ The index of the sample to retrieve.
+
+ Returns
+ -------
+ tuple
+ A tuple containing the feature vector (X), the target (y), and the
+ cropping indices.
+
+ """
crop_inds = self.crop_inds[index].tolist()
X = self.features.iloc[index].to_numpy()
X = X.copy()
@@ -75,18 +103,27 @@ def __getitem__(self, index):
y = self.y[index]
return X, y, crop_inds
- def __len__(self):
+ def __len__(self) -> int:
+ """Return the number of samples in the dataset.
+
+ Returns
+ -------
+ int
+ The total number of feature samples.
+
+ """
return len(self.features.index)
def _compute_stats(
ds: FeaturesDataset,
- return_count=False,
- return_mean=False,
- return_var=False,
- ddof=1,
- numeric_only=False,
-):
+ return_count: bool = False,
+ return_mean: bool = False,
+ return_var: bool = False,
+ ddof: int = 1,
+ numeric_only: bool = False,
+) -> tuple:
+ """Compute statistics for a single FeaturesDataset."""
res = []
if return_count:
res.append(ds.features.count(numeric_only=numeric_only))
@@ -97,7 +134,14 @@ def _compute_stats(
return tuple(res)
-def _pooled_var(counts, means, variances, ddof, ddof_in=None):
+def _pooled_var(
+ counts: np.ndarray,
+ means: np.ndarray,
+ variances: np.ndarray,
+ ddof: int,
+ ddof_in: int | None = None,
+) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Compute pooled variance across multiple datasets."""
if ddof_in is None:
ddof_in = ddof
count = counts.sum(axis=0)
@@ -110,17 +154,20 @@ def _pooled_var(counts, means, variances, ddof, ddof_in=None):
class FeaturesConcatDataset(BaseConcatDataset):
- """A base class for concatenated datasets.
+ """A concatenated dataset of `FeaturesDataset` objects.
- Holds either mne.Raw or mne.Epoch in self.datasets and has
- a pandas DataFrame with additional description.
+ This class holds a list of :class:`FeaturesDataset` instances and allows
+ them to be treated as a single, larger dataset. It provides methods for
+
+ splitting, saving, and performing DataFrame-like operations (e.g., `mean`,
+ `var`, `fillna`) across all contained datasets.
Parameters
----------
- list_of_ds : list
- list of BaseDataset, BaseConcatDataset or WindowsDataset
- target_transform : callable | None
- Optional function to call on targets before returning them.
+ list_of_ds : list of FeaturesDataset
+ A list of :class:`FeaturesDataset` objects to concatenate.
+ target_transform : callable, optional
+ A function to apply to the target values before they are returned.
"""
@@ -140,26 +187,28 @@ def split(
self,
by: str | list[int] | list[list[int]] | dict[str, list[int]],
) -> dict[str, FeaturesConcatDataset]:
- """Split the dataset based on information listed in its description.
+ """Split the dataset into subsets.
- The format could be based on a DataFrame or based on indices.
+ The splitting can be done based on a column in the description
+ DataFrame or by providing explicit indices for each split.
Parameters
----------
- by : str | list | dict
- If ``by`` is a string, splitting is performed based on the
- description DataFrame column with this name.
- If ``by`` is a (list of) list of integers, the position in the first
- list corresponds to the split id and the integers to the
- datapoints of that split.
- If a dict then each key will be used in the returned
- splits dict and each value should be a list of int.
+ by : str or list or dict
+ - If a string, splits are created for each unique value in the
+ description column `by`.
+ - If a list of integers, a single split is created containing the
+ datasets at the specified indices.
+ - If a list of lists of integers, multiple splits are created, one
+ for each sublist of indices.
+ - If a dictionary, keys are used as split names and values are
+ lists of dataset indices.
Returns
-------
- splits : dict
- A dictionary with the name of the split (a string) as key and the
- dataset as value.
+ dict[str, FeaturesConcatDataset]
+ A dictionary where keys are split names and values are the new
+ :class:`FeaturesConcatDataset` subsets.
"""
if isinstance(by, str):
@@ -184,14 +233,21 @@ def split(
}
def get_metadata(self) -> pd.DataFrame:
- """Concatenate the metadata and description of the wrapped Epochs.
+ """Get the metadata of all datasets as a single DataFrame.
+
+ Concatenates the metadata from all contained datasets and adds columns
+ from their `description` attributes.
Returns
-------
- metadata : pd.DataFrame
- DataFrame containing as many rows as there are windows in the
- BaseConcatDataset, with the metadata and description information
- for each window.
+ pandas.DataFrame
+ A DataFrame containing the metadata for every sample in the
+ concatenated dataset.
+
+ Raises
+ ------
+ TypeError
+ If any of the contained datasets is not a :class:`FeaturesDataset`.
"""
if not all([isinstance(ds, FeaturesDataset) for ds in self.datasets]):
@@ -202,60 +258,59 @@ def get_metadata(self) -> pd.DataFrame:
all_dfs = list()
for ds in self.datasets:
- df = ds.metadata
+ df = ds.metadata.copy()
for k, v in ds.description.items():
df[k] = v
all_dfs.append(df)
return pd.concat(all_dfs)
- def save(self, path: str, overwrite: bool = False, offset: int = 0):
- """Save datasets to files by creating one subdirectory for each dataset:
- path/
- 0/
- 0-feat.parquet
- metadata_df.pkl
- description.json
- raw-info.fif (if raw info was saved)
- raw_preproc_kwargs.json (if raws were preprocessed)
- window_kwargs.json (if this is a windowed dataset)
- window_preproc_kwargs.json (if windows were preprocessed)
- features_kwargs.json
- 1/
- 1-feat.parquet
- metadata_df.pkl
- description.json
- raw-info.fif (if raw info was saved)
- raw_preproc_kwargs.json (if raws were preprocessed)
- window_kwargs.json (if this is a windowed dataset)
- window_preproc_kwargs.json (if windows were preprocessed)
- features_kwargs.json
+ def save(self, path: str, overwrite: bool = False, offset: int = 0) -> None:
+ """Save the concatenated dataset to a directory.
+
+ Creates a directory structure where each contained dataset is saved in
+ its own numbered subdirectory.
+
+ .. code-block::
+
+ path/
+ 0/
+ 0-feat.parquet
+ metadata_df.pkl
+ description.json
+ ...
+ 1/
+ 1-feat.parquet
+ ...
Parameters
----------
path : str
- Directory in which subdirectories are created to store
- -feat.parquet and .json files to.
- overwrite : bool
- Whether to delete old subdirectories that will be saved to in this
- call.
- offset : int
- If provided, the integer is added to the id of the dataset in the
- concat. This is useful in the setting of very large datasets, where
- one dataset has to be processed and saved at a time to account for
- its original position.
+ The directory where the dataset will be saved.
+ overwrite : bool, default False
+ If True, any existing subdirectories that conflict with the new
+ ones will be removed.
+ offset : int, default 0
+ An integer to add to the subdirectory names. Useful for saving
+ datasets in chunks.
+
+ Raises
+ ------
+ ValueError
+ If the dataset is empty.
+ FileExistsError
+ If a subdirectory already exists and `overwrite` is False.
"""
if len(self.datasets) == 0:
raise ValueError("Expect at least one dataset")
path_contents = os.listdir(path)
- n_sub_dirs = len([os.path.isdir(e) for e in path_contents])
+ n_sub_dirs = len([os.path.isdir(os.path.join(path, e)) for e in path_contents])
for i_ds, ds in enumerate(self.datasets):
- # remove subdirectory from list of untouched files / subdirectories
- if str(i_ds + offset) in path_contents:
- path_contents.remove(str(i_ds + offset))
- # save_dir/i_ds/
- sub_dir = os.path.join(path, str(i_ds + offset))
+ sub_dir_name = str(i_ds + offset)
+ if sub_dir_name in path_contents:
+ path_contents.remove(sub_dir_name)
+ sub_dir = os.path.join(path, sub_dir_name)
if os.path.exists(sub_dir):
if overwrite:
shutil.rmtree(sub_dir)
@@ -265,35 +320,21 @@ def save(self, path: str, overwrite: bool = False, offset: int = 0):
f" a different directory, set overwrite=True, or "
f"resolve manually."
)
- # save_dir/{i_ds+offset}/
os.makedirs(sub_dir)
- # save_dir/{i_ds+offset}/{i_ds+offset}-feat.parquet
self._save_features(sub_dir, ds, i_ds, offset)
- # save_dir/{i_ds+offset}/metadata_df.pkl
self._save_metadata(sub_dir, ds)
- # save_dir/{i_ds+offset}/description.json
self._save_description(sub_dir, ds.description)
- # save_dir/{i_ds+offset}/raw-info.fif
self._save_raw_info(sub_dir, ds)
- # save_dir/{i_ds+offset}/raw_preproc_kwargs.json
- # save_dir/{i_ds+offset}/window_kwargs.json
- # save_dir/{i_ds+offset}/window_preproc_kwargs.json
- # save_dir/{i_ds+offset}/features_kwargs.json
self._save_kwargs(sub_dir, ds)
- if overwrite:
- # the following will be True for all datasets preprocessed and
- # stored in parallel with braindecode.preprocessing.preprocess
- if i_ds + 1 + offset < n_sub_dirs:
- logger.warning(
- f"The number of saved datasets ({i_ds + 1 + offset}) "
- f"does not match the number of existing "
- f"subdirectories ({n_sub_dirs}). You may now "
- f"encounter a mix of differently preprocessed "
- f"datasets!",
- UserWarning,
- )
- # if path contains files or directories that were not touched, raise
- # warning
+ if overwrite and i_ds + 1 + offset < n_sub_dirs:
+ logger.warning(
+ f"The number of saved datasets ({i_ds + 1 + offset}) "
+ f"does not match the number of existing "
+ f"subdirectories ({n_sub_dirs}). You may now "
+ f"encounter a mix of differently preprocessed "
+ f"datasets!",
+ UserWarning,
+ )
if path_contents:
logger.warning(
f"Chosen directory {path} contains other "
@@ -301,20 +342,37 @@ def save(self, path: str, overwrite: bool = False, offset: int = 0):
)
@staticmethod
- def _save_features(sub_dir, ds, i_ds, offset):
+ def _save_features(sub_dir: str, ds: FeaturesDataset, i_ds: int, offset: int):
+ """Save the feature DataFrame to a Parquet file."""
parquet_file_name = f"{i_ds + offset}-feat.parquet"
parquet_file_path = os.path.join(sub_dir, parquet_file_name)
ds.features.to_parquet(parquet_file_path)
@staticmethod
- def _save_raw_info(sub_dir, ds):
- if hasattr(ds, "raw_info"):
+ def _save_metadata(sub_dir: str, ds: FeaturesDataset):
+ """Save the metadata DataFrame to a pickle file."""
+ metadata_file_name = "metadata_df.pkl"
+ metadata_file_path = os.path.join(sub_dir, metadata_file_name)
+ ds.metadata.to_pickle(metadata_file_path)
+
+ @staticmethod
+ def _save_description(sub_dir: str, description: pd.Series):
+ """Save the description Series to a JSON file."""
+ desc_file_name = "description.json"
+ desc_file_path = os.path.join(sub_dir, desc_file_name)
+ description.to_json(desc_file_path)
+
+ @staticmethod
+ def _save_raw_info(sub_dir: str, ds: FeaturesDataset):
+ """Save the raw info dictionary to a FIF file if it exists."""
+ if hasattr(ds, "raw_info") and ds.raw_info is not None:
fif_file_name = "raw-info.fif"
fif_file_path = os.path.join(sub_dir, fif_file_name)
- ds.raw_info.save(fif_file_path)
+ ds.raw_info.save(fif_file_path, overwrite=True)
@staticmethod
- def _save_kwargs(sub_dir, ds):
+ def _save_kwargs(sub_dir: str, ds: FeaturesDataset):
+ """Save various keyword argument dictionaries to JSON files."""
for kwargs_name in [
"raw_preproc_kwargs",
"window_kwargs",
@@ -322,10 +380,10 @@ def _save_kwargs(sub_dir, ds):
"features_kwargs",
]:
if hasattr(ds, kwargs_name):
- kwargs_file_name = ".".join([kwargs_name, "json"])
- kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
kwargs = getattr(ds, kwargs_name)
if kwargs is not None:
+ kwargs_file_name = ".".join([kwargs_name, "json"])
+ kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
with open(kwargs_file_path, "w") as f:
json.dump(kwargs, f)
@@ -334,7 +392,25 @@ def to_dataframe(
include_metadata: bool | str | List[str] = False,
include_target: bool = False,
include_crop_inds: bool = False,
- ):
+ ) -> pd.DataFrame:
+ """Convert the dataset to a single pandas DataFrame.
+
+ Parameters
+ ----------
+ include_metadata : bool or str or list of str, default False
+ If True, include all metadata columns. If a string or list of
+ strings, include only the specified metadata columns.
+ include_target : bool, default False
+ If True, include the 'target' column.
+ include_crop_inds : bool, default False
+ If True, include window cropping index columns.
+
+ Returns
+ -------
+ pandas.DataFrame
+ A DataFrame containing the features and requested metadata.
+
+ """
if (
not isinstance(include_metadata, bool)
or include_metadata
@@ -343,7 +419,7 @@ def to_dataframe(
include_dataset = False
if isinstance(include_metadata, bool) and include_metadata:
include_dataset = True
- cols = self.datasets[0].metadata.columns
+ cols = self.datasets[0].metadata.columns.tolist()
else:
cols = include_metadata
if isinstance(cols, bool) and not cols:
@@ -352,13 +428,14 @@ def to_dataframe(
cols = [cols]
cols = set(cols)
if include_crop_inds:
- cols = {
- "i_dataset",
- "i_window_in_trial",
- "i_start_in_trial",
- "i_stop_in_trial",
- *cols,
- }
+ cols.update(
+ {
+ "i_dataset",
+ "i_window_in_trial",
+ "i_start_in_trial",
+ "i_stop_in_trial",
+ }
+ )
if include_target:
cols.add("target")
cols = list(cols)
@@ -381,10 +458,26 @@ def to_dataframe(
dataframes = [ds.features for ds in self.datasets]
return pd.concat(dataframes, axis=0, ignore_index=True)
- def _numeric_columns(self):
+ def _numeric_columns(self) -> pd.Index:
+ """Get the names of numeric columns from the feature DataFrames."""
return self.datasets[0].features.select_dtypes(include=np.number).columns
- def count(self, numeric_only=False, n_jobs=1):
+ def count(self, numeric_only: bool = False, n_jobs: int = 1) -> pd.Series:
+ """Count non-NA cells for each feature column.
+
+ Parameters
+ ----------
+ numeric_only : bool, default False
+ Include only float, int, boolean columns.
+ n_jobs : int, default 1
+ Number of jobs to run in parallel.
+
+ Returns
+ -------
+ pandas.Series
+ The count of non-NA cells for each column.
+
+ """
stats = Parallel(n_jobs)(
delayed(_compute_stats)(ds, return_count=True, numeric_only=numeric_only)
for ds in self.datasets
@@ -393,7 +486,22 @@ def count(self, numeric_only=False, n_jobs=1):
count = counts.sum(axis=0)
return pd.Series(count, index=self._numeric_columns())
- def mean(self, numeric_only=False, n_jobs=1):
+ def mean(self, numeric_only: bool = False, n_jobs: int = 1) -> pd.Series:
+ """Compute the mean for each feature column.
+
+ Parameters
+ ----------
+ numeric_only : bool, default False
+ Include only float, int, boolean columns.
+ n_jobs : int, default 1
+ Number of jobs to run in parallel.
+
+ Returns
+ -------
+ pandas.Series
+ The mean of each column.
+
+ """
stats = Parallel(n_jobs)(
delayed(_compute_stats)(
ds, return_count=True, return_mean=True, numeric_only=numeric_only
@@ -405,7 +513,26 @@ def mean(self, numeric_only=False, n_jobs=1):
mean = np.sum((counts / count) * means, axis=0)
return pd.Series(mean, index=self._numeric_columns())
- def var(self, ddof=1, numeric_only=False, n_jobs=1):
+ def var(
+ self, ddof: int = 1, numeric_only: bool = False, n_jobs: int = 1
+ ) -> pd.Series:
+ """Compute the variance for each feature column.
+
+ Parameters
+ ----------
+ ddof : int, default 1
+ Delta Degrees of Freedom. The divisor used in calculations is N - ddof.
+ numeric_only : bool, default False
+ Include only float, int, boolean columns.
+ n_jobs : int, default 1
+ Number of jobs to run in parallel.
+
+ Returns
+ -------
+ pandas.Series
+ The variance of each column.
+
+ """
stats = Parallel(n_jobs)(
delayed(_compute_stats)(
ds,
@@ -425,12 +552,50 @@ def var(self, ddof=1, numeric_only=False, n_jobs=1):
_, _, var = _pooled_var(counts, means, variances, ddof, ddof_in=0)
return pd.Series(var, index=self._numeric_columns())
- def std(self, ddof=1, numeric_only=False, eps=0, n_jobs=1):
+ def std(
+ self, ddof: int = 1, numeric_only: bool = False, eps: float = 0, n_jobs: int = 1
+ ) -> pd.Series:
+ """Compute the standard deviation for each feature column.
+
+ Parameters
+ ----------
+ ddof : int, default 1
+ Delta Degrees of Freedom.
+ numeric_only : bool, default False
+ Include only float, int, boolean columns.
+ eps : float, default 0
+ A small epsilon value to add to the variance before taking the
+ square root to avoid numerical instability.
+ n_jobs : int, default 1
+ Number of jobs to run in parallel.
+
+ Returns
+ -------
+ pandas.Series
+ The standard deviation of each column.
+
+ """
return np.sqrt(
self.var(ddof=ddof, numeric_only=numeric_only, n_jobs=n_jobs) + eps
)
- def zscore(self, ddof=1, numeric_only=False, eps=0, n_jobs=1):
+ def zscore(
+ self, ddof: int = 1, numeric_only: bool = False, eps: float = 0, n_jobs: int = 1
+ ) -> None:
+ """Apply z-score normalization to numeric columns in-place.
+
+ Parameters
+ ----------
+ ddof : int, default 1
+ Delta Degrees of Freedom for variance calculation.
+ numeric_only : bool, default False
+ Include only float, int, boolean columns.
+ eps : float, default 0
+ Epsilon for numerical stability.
+ n_jobs : int, default 1
+ Number of jobs to run in parallel for statistics computation.
+
+ """
stats = Parallel(n_jobs)(
delayed(_compute_stats)(
ds,
@@ -450,10 +615,13 @@ def zscore(self, ddof=1, numeric_only=False, eps=0, n_jobs=1):
_, mean, var = _pooled_var(counts, means, variances, ddof, ddof_in=0)
std = np.sqrt(var + eps)
for ds in self.datasets:
- ds.features = (ds.features - mean) / std
+ ds.features.loc[:, self._numeric_columns()] = (
+ ds.features.loc[:, self._numeric_columns()] - mean
+ ) / std
@staticmethod
- def _enforce_inplace_operations(func_name, kwargs):
+ def _enforce_inplace_operations(func_name: str, kwargs: dict):
+ """Raise an error if 'inplace=False' is passed to a method."""
if "inplace" in kwargs and kwargs["inplace"] is False:
raise ValueError(
f"{func_name} only works inplace, please change "
@@ -461,33 +629,49 @@ def _enforce_inplace_operations(func_name, kwargs):
)
kwargs["inplace"] = True
- def fillna(self, *args, **kwargs):
+ def fillna(self, *args, **kwargs) -> None:
+ """Fill NA/NaN values in-place. See :meth:`pandas.DataFrame.fillna`."""
FeaturesConcatDataset._enforce_inplace_operations("fillna", kwargs)
for ds in self.datasets:
ds.features.fillna(*args, **kwargs)
- def replace(self, *args, **kwargs):
+ def replace(self, *args, **kwargs) -> None:
+ """Replace values in-place. See :meth:`pandas.DataFrame.replace`."""
FeaturesConcatDataset._enforce_inplace_operations("replace", kwargs)
for ds in self.datasets:
ds.features.replace(*args, **kwargs)
- def interpolate(self, *args, **kwargs):
+ def interpolate(self, *args, **kwargs) -> None:
+ """Interpolate values in-place. See :meth:`pandas.DataFrame.interpolate`."""
FeaturesConcatDataset._enforce_inplace_operations("interpolate", kwargs)
for ds in self.datasets:
ds.features.interpolate(*args, **kwargs)
- def dropna(self, *args, **kwargs):
+ def dropna(self, *args, **kwargs) -> None:
+ """Remove missing values in-place. See :meth:`pandas.DataFrame.dropna`."""
FeaturesConcatDataset._enforce_inplace_operations("dropna", kwargs)
for ds in self.datasets:
ds.features.dropna(*args, **kwargs)
- def drop(self, *args, **kwargs):
+ def drop(self, *args, **kwargs) -> None:
+ """Drop specified labels from rows or columns in-place. See :meth:`pandas.DataFrame.drop`."""
FeaturesConcatDataset._enforce_inplace_operations("drop", kwargs)
for ds in self.datasets:
ds.features.drop(*args, **kwargs)
- def join(self, concat_dataset: FeaturesConcatDataset, **kwargs):
+ def join(self, concat_dataset: FeaturesConcatDataset, **kwargs) -> None:
+ """Join columns with other FeaturesConcatDataset in-place.
+
+ Parameters
+ ----------
+ concat_dataset : FeaturesConcatDataset
+ The dataset to join with. Must have the same number of datasets,
+ and each corresponding dataset must have the same length.
+ **kwargs
+ Keyword arguments to pass to :meth:`pandas.DataFrame.join`.
+
+ """
assert len(self.datasets) == len(concat_dataset.datasets)
for ds1, ds2 in zip(self.datasets, concat_dataset.datasets):
assert len(ds1) == len(ds2)
- ds1.features.join(ds2, **kwargs)
+ ds1.features = ds1.features.join(ds2.features, **kwargs)
diff --git a/eegdash/features/decorators.py b/eegdash/features/decorators.py
index 841a96a0..687d57fe 100644
--- a/eegdash/features/decorators.py
+++ b/eegdash/features/decorators.py
@@ -12,6 +12,21 @@
class FeaturePredecessor:
+ """A decorator to specify parent extractors for a feature function.
+
+ This decorator attaches a list of parent extractor types to a feature
+ extraction function. This information can be used to build a dependency
+ graph of features.
+
+ Parameters
+ ----------
+ *parent_extractor_type : list of Type
+ A list of feature extractor classes (subclasses of
+ :class:`~eegdash.features.extractors.FeatureExtractor`) that this
+ feature depends on.
+
+ """
+
def __init__(self, *parent_extractor_type: List[Type]):
parent_cls = parent_extractor_type
if not parent_cls:
@@ -20,17 +35,58 @@ def __init__(self, *parent_extractor_type: List[Type]):
assert issubclass(p_cls, FeatureExtractor)
self.parent_extractor_type = parent_cls
- def __call__(self, func: Callable):
+ def __call__(self, func: Callable) -> Callable:
+ """Apply the decorator to a function.
+
+ Parameters
+ ----------
+ func : callable
+ The feature extraction function to decorate.
+
+ Returns
+ -------
+ callable
+ The decorated function with the `parent_extractor_type` attribute
+ set.
+
+ """
f = _get_underlying_func(func)
f.parent_extractor_type = self.parent_extractor_type
return func
class FeatureKind:
+ """A decorator to specify the kind of a feature.
+
+ This decorator attaches a "feature kind" (e.g., univariate, bivariate)
+ to a feature extraction function.
+
+ Parameters
+ ----------
+ feature_kind : MultivariateFeature
+ An instance of a feature kind class, such as
+ :class:`~eegdash.features.extractors.UnivariateFeature` or
+ :class:`~eegdash.features.extractors.BivariateFeature`.
+
+ """
+
def __init__(self, feature_kind: MultivariateFeature):
self.feature_kind = feature_kind
- def __call__(self, func):
+ def __call__(self, func: Callable) -> Callable:
+ """Apply the decorator to a function.
+
+ Parameters
+ ----------
+ func : callable
+ The feature extraction function to decorate.
+
+ Returns
+ -------
+ callable
+ The decorated function with the `feature_kind` attribute set.
+
+ """
f = _get_underlying_func(func)
f.feature_kind = self.feature_kind
return func
@@ -38,9 +94,33 @@ def __call__(self, func):
# Syntax sugar
univariate_feature = FeatureKind(UnivariateFeature())
+"""Decorator to mark a feature as univariate.
+
+This is a convenience instance of :class:`FeatureKind` pre-configured for
+univariate features.
+"""
-def bivariate_feature(func, directed=False):
+def bivariate_feature(func: Callable, directed: bool = False) -> Callable:
+ """Decorator to mark a feature as bivariate.
+
+ This decorator specifies that the feature operates on pairs of channels.
+
+ Parameters
+ ----------
+ func : callable
+ The feature extraction function to decorate.
+ directed : bool, default False
+ If True, the feature is directed (e.g., connectivity from channel A
+ to B is different from B to A). If False, the feature is undirected.
+
+ Returns
+ -------
+ callable
+ The decorated function with the appropriate bivariate feature kind
+ attached.
+
+ """
if directed:
kind = DirectedBivariateFeature()
else:
@@ -49,3 +129,8 @@ def bivariate_feature(func, directed=False):
multivariate_feature = FeatureKind(MultivariateFeature())
+"""Decorator to mark a feature as multivariate.
+
+This is a convenience instance of :class:`FeatureKind` pre-configured for
+multivariate features, which operate on all channels simultaneously.
+"""
diff --git a/eegdash/features/extractors.py b/eegdash/features/extractors.py
index e3a785ef..5df28663 100644
--- a/eegdash/features/extractors.py
+++ b/eegdash/features/extractors.py
@@ -7,7 +7,23 @@
from numba.core.dispatcher import Dispatcher
-def _get_underlying_func(func):
+def _get_underlying_func(func: Callable) -> Callable:
+ """Get the underlying function from a potential wrapper.
+
+ This helper unwraps functions that might be wrapped by `functools.partial`
+ or `numba.dispatcher.Dispatcher`.
+
+ Parameters
+ ----------
+ func : callable
+ The function to unwrap.
+
+ Returns
+ -------
+ callable
+ The underlying Python function.
+
+ """
f = func
if isinstance(f, partial):
f = f.func
@@ -17,22 +33,46 @@ def _get_underlying_func(func):
class TrainableFeature(ABC):
+ """Abstract base class for features that require training.
+
+ This ABC defines the interface for feature extractors that need to be
+ fitted on data before they can be used. It includes methods for fitting
+ the feature extractor and for resetting its state.
+ """
+
def __init__(self):
self._is_trained = False
self.clear()
@abstractmethod
def clear(self):
+ """Reset the internal state of the feature extractor."""
pass
@abstractmethod
def partial_fit(self, *x, y=None):
+ """Update the feature extractor's state with a batch of data.
+
+ Parameters
+ ----------
+ *x : tuple
+ The input data for fitting.
+ y : any, optional
+ The target data, if required for supervised training.
+
+ """
pass
def fit(self):
+ """Finalize the training of the feature extractor.
+
+ This method should be called after all data has been seen via
+ `partial_fit`. It marks the feature as fitted.
+ """
self._is_fitted = True
def __call__(self, *args, **kwargs):
+ """Check if the feature is fitted before execution."""
if not self._is_fitted:
raise RuntimeError(
f"{self.__class__} cannot be called, it has to be fitted first."
@@ -40,6 +80,22 @@ def __call__(self, *args, **kwargs):
class FeatureExtractor(TrainableFeature):
+ """A composite feature extractor that applies multiple feature functions.
+
+ This class orchestrates the application of a dictionary of feature
+ extraction functions to input data. It can handle nested extractors,
+ pre-processing, and trainable features.
+
+ Parameters
+ ----------
+ feature_extractors : dict[str, callable]
+ A dictionary where keys are feature names and values are the feature
+ extraction functions or other `FeatureExtractor` instances.
+ **preprocess_kwargs
+ Keyword arguments to be passed to the `preprocess` method.
+
+ """
+
def __init__(
self, feature_extractors: Dict[str, Callable], **preprocess_kwargs: Dict
):
@@ -63,30 +119,64 @@ def __init__(
if isinstance(fe, partial):
self.features_kwargs[fn] = fe.keywords
- def _validate_execution_tree(self, feature_extractors):
+ def _validate_execution_tree(self, feature_extractors: dict) -> dict:
+ """Validate the feature dependency graph."""
for fname, f in feature_extractors.items():
f = _get_underlying_func(f)
pe_type = getattr(f, "parent_extractor_type", [FeatureExtractor])
- assert type(self) in pe_type
+ if type(self) not in pe_type:
+ raise TypeError(
+ f"Feature '{fname}' cannot be a child of {type(self).__name__}"
+ )
return feature_extractors
- def _check_is_trainable(self, feature_extractors):
- is_trainable = False
+ def _check_is_trainable(self, feature_extractors: dict) -> bool:
+ """Check if any of the contained features are trainable."""
for fname, f in feature_extractors.items():
if isinstance(f, FeatureExtractor):
- is_trainable = f._is_trainable
- else:
- f = _get_underlying_func(f)
- if isinstance(f, TrainableFeature):
- is_trainable = True
- if is_trainable:
- break
- return is_trainable
+ if f._is_trainable:
+ return True
+ elif isinstance(_get_underlying_func(f), TrainableFeature):
+ return True
+ return False
def preprocess(self, *x, **kwargs):
+ """Apply pre-processing to the input data.
+
+ Parameters
+ ----------
+ *x : tuple
+ Input data.
+ **kwargs
+ Additional keyword arguments.
+
+ Returns
+ -------
+ tuple
+ The pre-processed data.
+
+ """
return (*x,)
- def __call__(self, *x, _batch_size=None, _ch_names=None):
+ def __call__(self, *x, _batch_size=None, _ch_names=None) -> dict:
+ """Apply all feature extractors to the input data.
+
+ Parameters
+ ----------
+ *x : tuple
+ Input data.
+ _batch_size : int, optional
+ The number of samples in the batch.
+ _ch_names : list of str, optional
+ The names of the channels in the input data.
+
+ Returns
+ -------
+ dict
+ A dictionary where keys are feature names and values are the
+ computed feature values.
+
+ """
assert _batch_size is not None
assert _ch_names is not None
if self._is_trainable:
@@ -100,59 +190,83 @@ def __call__(self, *x, _batch_size=None, _ch_names=None):
r = f(*z, _batch_size=_batch_size, _ch_names=_ch_names)
else:
r = f(*z)
- f = _get_underlying_func(f)
- if hasattr(f, "feature_kind"):
- r = f.feature_kind(r, _ch_names=_ch_names)
+ f_und = _get_underlying_func(f)
+ if hasattr(f_und, "feature_kind"):
+ r = f_und.feature_kind(r, _ch_names=_ch_names)
if not isinstance(fname, str) or not fname:
- if isinstance(f, FeatureExtractor) or not hasattr(f, "__name__"):
- fname = ""
- else:
- fname = f.__name__
+ fname = getattr(f_und, "__name__", "")
if isinstance(r, dict):
- if fname:
- fname += "_"
+ prefix = f"{fname}_" if fname else ""
for k, v in r.items():
- self._add_feature_to_dict(results_dict, fname + k, v, _batch_size)
+ self._add_feature_to_dict(results_dict, prefix + k, v, _batch_size)
else:
self._add_feature_to_dict(results_dict, fname, r, _batch_size)
return results_dict
- def _add_feature_to_dict(self, results_dict, name, value, batch_size):
- if not isinstance(value, np.ndarray):
- results_dict[name] = value
- else:
+ def _add_feature_to_dict(
+ self, results_dict: dict, name: str, value: any, batch_size: int
+ ):
+ """Add a computed feature to the results dictionary."""
+ if isinstance(value, np.ndarray):
assert value.shape[0] == batch_size
- results_dict[name] = value
+ results_dict[name] = value
def clear(self):
+ """Clear the state of all trainable sub-features."""
if not self._is_trainable:
return
- for fname, f in self.feature_extractors_dict.items():
- f = _get_underlying_func(f)
- if isinstance(f, TrainableFeature):
- f.clear()
+ for f in self.feature_extractors_dict.values():
+ if isinstance(_get_underlying_func(f), TrainableFeature):
+ _get_underlying_func(f).clear()
def partial_fit(self, *x, y=None):
+ """Partially fit all trainable sub-features."""
if not self._is_trainable:
return
z = self.preprocess(*x, **self.preprocess_kwargs)
- for fname, f in self.feature_extractors_dict.items():
- f = _get_underlying_func(f)
- if isinstance(f, TrainableFeature):
- f.partial_fit(*z, y=y)
+ if not isinstance(z, tuple):
+ z = (z,)
+ for f in self.feature_extractors_dict.values():
+ if isinstance(_get_underlying_func(f), TrainableFeature):
+ _get_underlying_func(f).partial_fit(*z, y=y)
def fit(self):
+ """Fit all trainable sub-features."""
if not self._is_trainable:
return
- for fname, f in self.feature_extractors_dict.items():
- f = _get_underlying_func(f)
- if isinstance(f, TrainableFeature):
+ for f in self.feature_extractors_dict.values():
+ if isinstance(_get_underlying_func(f), TrainableFeature):
f.fit()
super().fit()
class MultivariateFeature:
- def __call__(self, x, _ch_names=None):
+ """A mixin for features that operate on multiple channels.
+
+ This class provides a `__call__` method that converts a feature array into
+ a dictionary with named features, where names are derived from channel
+ names.
+ """
+
+ def __call__(
+ self, x: np.ndarray, _ch_names: list[str] | None = None
+ ) -> dict | np.ndarray:
+ """Convert a feature array to a named dictionary.
+
+ Parameters
+ ----------
+ x : numpy.ndarray
+ The computed feature array.
+ _ch_names : list of str, optional
+ The list of channel names.
+
+ Returns
+ -------
+ dict or numpy.ndarray
+ A dictionary of named features, or the original array if feature
+ channel names cannot be generated.
+
+ """
assert _ch_names is not None
f_channels = self.feature_channel_names(_ch_names)
if isinstance(x, dict):
@@ -163,37 +277,66 @@ def __call__(self, x, _ch_names=None):
return self._array_to_dict(x, f_channels)
@staticmethod
- def _array_to_dict(x, f_channels, name=""):
+ def _array_to_dict(
+ x: np.ndarray, f_channels: list[str], name: str = ""
+ ) -> dict | np.ndarray:
+ """Convert a numpy array to a dictionary with named keys."""
assert isinstance(x, np.ndarray)
- if len(f_channels) == 0:
- assert x.ndim == 1
- if name:
- return {name: x}
- return x
- assert x.shape[1] == len(f_channels)
+ if not f_channels:
+ return {name: x} if name else x
+ assert x.shape[1] == len(f_channels), f"{x.shape[1]} != {len(f_channels)}"
x = x.swapaxes(0, 1)
- names = [f"{name}_{ch}" for ch in f_channels] if name else f_channels
+ prefix = f"{name}_" if name else ""
+ names = [f"{prefix}{ch}" for ch in f_channels]
return dict(zip(names, x))
- def feature_channel_names(self, ch_names):
+ def feature_channel_names(self, ch_names: list[str]) -> list[str]:
+ """Generate feature names based on channel names.
+
+ Parameters
+ ----------
+ ch_names : list of str
+ The names of the input channels.
+
+ Returns
+ -------
+ list of str
+ The names for the output features.
+
+ """
return []
class UnivariateFeature(MultivariateFeature):
- def feature_channel_names(self, ch_names):
+ """A feature kind for operations applied to each channel independently."""
+
+ def feature_channel_names(self, ch_names: list[str]) -> list[str]:
+ """Return the channel names themselves as feature names."""
return ch_names
class BivariateFeature(MultivariateFeature):
- def __init__(self, *args, channel_pair_format="{}<>{}"):
+ """A feature kind for operations on pairs of channels.
+
+ Parameters
+ ----------
+ channel_pair_format : str, default="{}<>{}"
+ A format string used to create feature names from pairs of
+ channel names.
+
+ """
+
+ def __init__(self, *args, channel_pair_format: str = "{}<>{}"):
super().__init__(*args)
self.channel_pair_format = channel_pair_format
@staticmethod
- def get_pair_iterators(n):
+ def get_pair_iterators(n: int) -> tuple[np.ndarray, np.ndarray]:
+ """Get indices for unique, unordered pairs of channels."""
return np.triu_indices(n, 1)
- def feature_channel_names(self, ch_names):
+ def feature_channel_names(self, ch_names: list[str]) -> list[str]:
+ """Generate feature names for each pair of channels."""
return [
self.channel_pair_format.format(ch_names[i], ch_names[j])
for i, j in zip(*self.get_pair_iterators(len(ch_names)))
@@ -201,8 +344,11 @@ def feature_channel_names(self, ch_names):
class DirectedBivariateFeature(BivariateFeature):
+ """A feature kind for directed operations on pairs of channels."""
+
@staticmethod
- def get_pair_iterators(n):
+ def get_pair_iterators(n: int) -> list[np.ndarray]:
+ """Get indices for all ordered pairs of channels (excluding self-pairs)."""
return [
np.append(a, b)
for a, b in zip(np.tril_indices(n, -1), np.triu_indices(n, 1))
diff --git a/eegdash/features/inspect.py b/eegdash/features/inspect.py
index 29a6d23f..8e379b58 100644
--- a/eegdash/features/inspect.py
+++ b/eegdash/features/inspect.py
@@ -5,7 +5,27 @@
from .extractors import FeatureExtractor, MultivariateFeature, _get_underlying_func
-def get_feature_predecessors(feature_or_extractor: Callable):
+def get_feature_predecessors(feature_or_extractor: Callable) -> list:
+ """Get the dependency hierarchy for a feature or feature extractor.
+
+ This function recursively traverses the `parent_extractor_type` attribute
+ of a feature or extractor to build a list representing its dependency
+ lineage.
+
+ Parameters
+ ----------
+ feature_or_extractor : callable
+ The feature function or :class:`FeatureExtractor` class to inspect.
+
+ Returns
+ -------
+ list
+ A nested list representing the dependency tree. For a simple linear
+ chain, this will be a flat list from the specific feature up to the
+ base `FeatureExtractor`. For multiple dependencies, it will contain
+ tuples of sub-dependencies.
+
+ """
current = _get_underlying_func(feature_or_extractor)
if current is FeatureExtractor:
return [current]
@@ -20,18 +40,59 @@ def get_feature_predecessors(feature_or_extractor: Callable):
return [current, tuple(predecessors)]
-def get_feature_kind(feature: Callable):
+def get_feature_kind(feature: Callable) -> MultivariateFeature:
+ """Get the 'kind' of a feature function.
+
+ The feature kind (e.g., univariate, bivariate) is typically attached by a
+ decorator.
+
+ Parameters
+ ----------
+ feature : callable
+ The feature function to inspect.
+
+ Returns
+ -------
+ MultivariateFeature
+ An instance of the feature kind (e.g., `UnivariateFeature()`).
+
+ """
return _get_underlying_func(feature).feature_kind
-def get_all_features():
+def get_all_features() -> list[tuple[str, Callable]]:
+ """Get a list of all available feature functions.
+
+ Scans the `eegdash.features.feature_bank` module for functions that have
+ been decorated to have a `feature_kind` attribute.
+
+ Returns
+ -------
+ list[tuple[str, callable]]
+ A list of (name, function) tuples for all discovered features.
+
+ """
+
def isfeature(x):
return hasattr(_get_underlying_func(x), "feature_kind")
return inspect.getmembers(feature_bank, isfeature)
-def get_all_feature_extractors():
+def get_all_feature_extractors() -> list[tuple[str, type[FeatureExtractor]]]:
+ """Get a list of all available `FeatureExtractor` classes.
+
+ Scans the `eegdash.features.feature_bank` module for all classes that
+ subclass :class:`~eegdash.features.extractors.FeatureExtractor`.
+
+ Returns
+ -------
+ list[tuple[str, type[FeatureExtractor]]]
+ A list of (name, class) tuples for all discovered feature extractors,
+ including the base `FeatureExtractor` itself.
+
+ """
+
def isfeatureextractor(x):
return inspect.isclass(x) and issubclass(x, FeatureExtractor)
@@ -41,7 +102,19 @@ def isfeatureextractor(x):
]
-def get_all_feature_kinds():
+def get_all_feature_kinds() -> list[tuple[str, type[MultivariateFeature]]]:
+ """Get a list of all available feature 'kind' classes.
+
+ Scans the `eegdash.features.extractors` module for all classes that
+ subclass :class:`~eegdash.features.extractors.MultivariateFeature`.
+
+ Returns
+ -------
+ list[tuple[str, type[MultivariateFeature]]]
+ A list of (name, class) tuples for all discovered feature kinds.
+
+ """
+
def isfeaturekind(x):
return inspect.isclass(x) and issubclass(x, MultivariateFeature)
diff --git a/eegdash/features/serialization.py b/eegdash/features/serialization.py
index 39f09d4c..5aeb1787 100644
--- a/eegdash/features/serialization.py
+++ b/eegdash/features/serialization.py
@@ -1,7 +1,8 @@
"""Convenience functions for storing and loading features datasets.
-See Also:
- https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229
+See Also
+--------
+https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229
"""
@@ -16,34 +17,40 @@
from .datasets import FeaturesConcatDataset, FeaturesDataset
-def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1):
- """Load a stored features dataset from files.
+def load_features_concat_dataset(
+ path: str | Path, ids_to_load: list[int] | None = None, n_jobs: int = 1
+) -> FeaturesConcatDataset:
+ """Load a stored `FeaturesConcatDataset` from a directory.
+
+ This function reconstructs a :class:`FeaturesConcatDataset` by loading
+ individual :class:`FeaturesDataset` instances from subdirectories within
+ the given path. It uses joblib for parallel loading.
Parameters
----------
- path: str | pathlib.Path
- Path to the directory of the .fif / -epo.fif and .json files.
- ids_to_load: list of int | None
- Ids of specific files to load.
- n_jobs: int
- Number of jobs to be used to read files in parallel.
+ path : str or pathlib.Path
+ The path to the directory where the dataset was saved. This directory
+ should contain subdirectories (e.g., "0", "1", "2", ...) for each
+ individual dataset.
+ ids_to_load : list of int, optional
+ A list of specific dataset IDs (subdirectory names) to load. If None,
+ all subdirectories in the path will be loaded.
+ n_jobs : int, default 1
+ The number of jobs to use for parallel loading. -1 means using all
+ processors.
Returns
-------
- concat_dataset: eegdash.features.datasets.FeaturesConcatDataset
- A concatenation of multiple eegdash.features.datasets.FeaturesDataset
- instances loaded from the given directory.
+ eegdash.features.datasets.FeaturesConcatDataset
+ A concatenated dataset containing the loaded `FeaturesDataset` instances.
"""
# Make sure we always work with a pathlib.Path
path = Path(path)
- # else we have a dataset saved in the new way with subdirectories in path
- # for every dataset with description.json and -feat.parquet,
- # target_name.json, raw_preproc_kwargs.json, window_kwargs.json,
- # window_preproc_kwargs.json, features_kwargs.json
if ids_to_load is None:
- ids_to_load = [p.name for p in path.iterdir()]
+ # Get all subdirectories and sort them numerically
+ ids_to_load = [p.name for p in path.iterdir() if p.is_dir()]
ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
ids_to_load = [str(i) for i in ids_to_load]
@@ -51,7 +58,26 @@ def load_features_concat_dataset(path, ids_to_load=None, n_jobs=1):
return FeaturesConcatDataset(datasets)
-def _load_parallel(path, i):
+def _load_parallel(path: Path, i: str) -> FeaturesDataset:
+ """Load a single `FeaturesDataset` from its subdirectory.
+
+ This is a helper function for `load_features_concat_dataset` that handles
+ the loading of one dataset's files (features, metadata, descriptions, etc.).
+
+ Parameters
+ ----------
+ path : pathlib.Path
+ The root directory of the saved `FeaturesConcatDataset`.
+ i : str
+ The identifier of the dataset to load, corresponding to its
+ subdirectory name.
+
+ Returns
+ -------
+ eegdash.features.datasets.FeaturesDataset
+ The loaded dataset instance.
+
+ """
sub_dir = path / i
parquet_name_pattern = "{}-feat.parquet"
diff --git a/eegdash/features/utils.py b/eegdash/features/utils.py
index 170a514a..5c311496 100644
--- a/eegdash/features/utils.py
+++ b/eegdash/features/utils.py
@@ -22,7 +22,28 @@ def _extract_features_from_windowsdataset(
win_ds: EEGWindowsDataset | WindowsDataset,
feature_extractor: FeatureExtractor,
batch_size: int = 512,
-):
+) -> FeaturesDataset:
+ """Extract features from a single `WindowsDataset`.
+
+ This is a helper function that iterates through a `WindowsDataset` in
+ batches, applies a `FeatureExtractor`, and returns the results as a
+ `FeaturesDataset`.
+
+ Parameters
+ ----------
+ win_ds : EEGWindowsDataset or WindowsDataset
+ The windowed dataset to extract features from.
+ feature_extractor : FeatureExtractor
+ The feature extractor instance to apply.
+ batch_size : int, default 512
+ The number of windows to process in each batch.
+
+ Returns
+ -------
+ FeaturesDataset
+ A new dataset containing the extracted features and associated metadata.
+
+ """
metadata = win_ds.metadata
if not win_ds.targets_from == "metadata":
metadata = copy.deepcopy(metadata)
@@ -51,18 +72,16 @@ def _extract_features_from_windowsdataset(
features_dict[k].extend(v)
features_df = pd.DataFrame(features_dict)
if not win_ds.targets_from == "metadata":
- metadata.set_index("orig_index", drop=False, inplace=True)
metadata.reset_index(drop=True, inplace=True)
- metadata.drop("orig_index", axis=1, inplace=True)
+ metadata.drop("orig_index", axis=1, inplace=True, errors="ignore")
- # FUTURE: truly support WindowsDataset objects
return FeaturesDataset(
features_df,
metadata=metadata,
description=win_ds.description,
raw_info=win_ds.raw.info,
- raw_preproc_kwargs=win_ds.raw_preproc_kwargs,
- window_kwargs=win_ds.window_kwargs,
+ raw_preproc_kwargs=getattr(win_ds, "raw_preproc_kwargs", None),
+ window_kwargs=getattr(win_ds, "window_kwargs", None),
features_kwargs=feature_extractor.features_kwargs,
)
@@ -73,7 +92,34 @@ def extract_features(
*,
batch_size: int = 512,
n_jobs: int = 1,
-):
+) -> FeaturesConcatDataset:
+ """Extract features from a concatenated dataset of windows.
+
+ This function applies a feature extractor to each `WindowsDataset` within a
+ `BaseConcatDataset` in parallel and returns a `FeaturesConcatDataset`
+ with the results.
+
+ Parameters
+ ----------
+ concat_dataset : BaseConcatDataset
+ A concatenated dataset of `WindowsDataset` or `EEGWindowsDataset`
+ instances.
+ features : FeatureExtractor or dict or list
+ The feature extractor(s) to apply. Can be a `FeatureExtractor`
+ instance, a dictionary of named feature functions, or a list of
+ feature functions.
+ batch_size : int, default 512
+ The size of batches to use for feature extraction.
+ n_jobs : int, default 1
+ The number of parallel jobs to use for extracting features from the
+ datasets.
+
+ Returns
+ -------
+ FeaturesConcatDataset
+ A new concatenated dataset containing the extracted features.
+
+ """
if isinstance(features, list):
features = dict(enumerate(features))
if not isinstance(features, FeatureExtractor):
@@ -97,7 +143,28 @@ def fit_feature_extractors(
concat_dataset: BaseConcatDataset,
features: FeatureExtractor | Dict[str, Callable] | List[Callable],
batch_size: int = 8192,
-):
+) -> FeatureExtractor:
+ """Fit trainable feature extractors on a dataset.
+
+ If the provided feature extractor (or any of its sub-extractors) is
+ trainable (i.e., subclasses `TrainableFeature`), this function iterates
+ through the dataset to fit it.
+
+ Parameters
+ ----------
+ concat_dataset : BaseConcatDataset
+ The dataset to use for fitting the feature extractors.
+ features : FeatureExtractor or dict or list
+ The feature extractor(s) to fit.
+ batch_size : int, default 8192
+ The batch size to use when iterating through the dataset for fitting.
+
+ Returns
+ -------
+ FeatureExtractor
+ The fitted feature extractor.
+
+ """
if isinstance(features, list):
features = dict(enumerate(features))
if not isinstance(features, FeatureExtractor):
diff --git a/eegdash/hbn/preprocessing.py b/eegdash/hbn/preprocessing.py
index 357d2cc7..102fefce 100644
--- a/eegdash/hbn/preprocessing.py
+++ b/eegdash/hbn/preprocessing.py
@@ -18,27 +18,47 @@
class hbn_ec_ec_reannotation(Preprocessor):
- """Preprocessor to reannotate the raw data for eyes open and eyes closed events.
+ """Preprocessor to reannotate HBN data for eyes-open/eyes-closed events.
- This processor is designed for HBN datasets.
+ This preprocessor is specifically designed for Healthy Brain Network (HBN)
+ datasets. It identifies existing annotations for "instructed_toCloseEyes"
+ and "instructed_toOpenEyes" and creates new, regularly spaced annotations
+ for "eyes_closed" and "eyes_open" segments, respectively.
+
+ This is useful for creating windowed datasets based on these new, more
+ precise event markers.
+
+ Notes
+ -----
+ This class inherits from :class:`braindecode.preprocessing.Preprocessor`
+ and is intended to be used within a braindecode preprocessing pipeline.
"""
def __init__(self):
super().__init__(fn=self.transform, apply_on_array=False)
- def transform(self, raw):
- """Reannotate the raw data to create new events for eyes open and eyes closed
+ def transform(self, raw: mne.io.Raw) -> mne.io.Raw:
+ """Create new annotations for eyes-open and eyes-closed periods.
- This function modifies the raw MNE object by creating new events based on
- the existing annotations for "instructed_toCloseEyes" and "instructed_toOpenEyes".
- It generates new events every 2 seconds within specified time ranges after
- the original events, and replaces the existing annotations with these new events.
+ This function finds the original "instructed_to..." annotations and
+ generates new annotations every 2 seconds within specific time ranges
+ relative to the original markers:
+ - "eyes_closed": 15s to 29s after "instructed_toCloseEyes"
+ - "eyes_open": 5s to 19s after "instructed_toOpenEyes"
+
+ The original annotations in the `mne.io.Raw` object are replaced by
+ this new set of annotations.
Parameters
----------
raw : mne.io.Raw
- The raw MNE object containing EEG data and annotations.
+ The raw MNE object containing the HBN data and original annotations.
+
+ Returns
+ -------
+ mne.io.Raw
+ The raw MNE object with the modified annotations.
"""
events, event_id = mne.events_from_annotations(raw)
@@ -48,15 +68,27 @@ def transform(self, raw):
# Create new events array for 2-second segments
new_events = []
sfreq = raw.info["sfreq"]
- for event in events[events[:, 2] == event_id["instructed_toCloseEyes"]]:
- # For each original event, create events every 2 seconds from 15s to 29s after
- start_times = event[0] + np.arange(15, 29, 2) * sfreq
- new_events.extend([[int(t), 0, 1] for t in start_times])
- for event in events[events[:, 2] == event_id["instructed_toOpenEyes"]]:
- # For each original event, create events every 2 seconds from 5s to 19s after
- start_times = event[0] + np.arange(5, 19, 2) * sfreq
- new_events.extend([[int(t), 0, 2] for t in start_times])
+ close_event_id = event_id.get("instructed_toCloseEyes")
+ if close_event_id:
+ for event in events[events[:, 2] == close_event_id]:
+ # For each original event, create events every 2s from 15s to 29s after
+ start_times = event[0] + np.arange(15, 29, 2) * sfreq
+ new_events.extend([[int(t), 0, 1] for t in start_times])
+
+ open_event_id = event_id.get("instructed_toOpenEyes")
+ if open_event_id:
+ for event in events[events[:, 2] == open_event_id]:
+ # For each original event, create events every 2s from 5s to 19s after
+ start_times = event[0] + np.arange(5, 19, 2) * sfreq
+ new_events.extend([[int(t), 0, 2] for t in start_times])
+
+ if not new_events:
+ logger.warning(
+ "Could not find 'instructed_toCloseEyes' or 'instructed_toOpenEyes' "
+ "annotations. No new events created."
+ )
+ return raw
# replace events in raw
new_events = np.array(new_events)
@@ -65,6 +97,7 @@ def transform(self, raw):
events=new_events,
event_desc={1: "eyes_closed", 2: "eyes_open"},
sfreq=raw.info["sfreq"],
+ orig_time=raw.info.get("meas_date"),
)
raw.set_annotations(annot_from_events)
diff --git a/eegdash/hbn/windows.py b/eegdash/hbn/windows.py
index bba77731..8a2ee3d5 100644
--- a/eegdash/hbn/windows.py
+++ b/eegdash/hbn/windows.py
@@ -21,7 +21,25 @@
def build_trial_table(events_df: pd.DataFrame) -> pd.DataFrame:
- """One row per contrast trial with stimulus/response metrics."""
+ """Build a table of contrast trials from an events DataFrame.
+
+ This function processes a DataFrame of events (typically from a BIDS
+ `events.tsv` file) to identify contrast trials and extract relevant
+ metrics like stimulus onset, response onset, and reaction times.
+
+ Parameters
+ ----------
+ events_df : pandas.DataFrame
+ A DataFrame containing event information, with at least "onset" and
+ "value" columns.
+
+ Returns
+ -------
+ pandas.DataFrame
+ A DataFrame where each row represents a single contrast trial, with
+ columns for onsets, reaction times, and response correctness.
+
+ """
events_df = events_df.copy()
events_df["onset"] = pd.to_numeric(events_df["onset"], errors="raise")
events_df = events_df.sort_values("onset", kind="mergesort").reset_index(drop=True)
@@ -92,12 +110,13 @@ def build_trial_table(events_df: pd.DataFrame) -> pd.DataFrame:
return pd.DataFrame(rows)
-# Aux functions to inject the annot
def _to_float_or_none(x):
+ """Safely convert a value to float or None."""
return None if pd.isna(x) else float(x)
def _to_int_or_none(x):
+ """Safely convert a value to int or None."""
if pd.isna(x):
return None
if isinstance(x, (bool, np.bool_)):
@@ -106,22 +125,55 @@ def _to_int_or_none(x):
return int(x)
try:
return int(x)
- except Exception:
+ except (ValueError, TypeError):
return None
def _to_str_or_none(x):
+ """Safely convert a value to string or None."""
return None if (x is None or (isinstance(x, float) and np.isnan(x))) else str(x)
def annotate_trials_with_target(
- raw,
- target_field="rt_from_stimulus",
- epoch_length=2.0,
- require_stimulus=True,
- require_response=True,
-):
- """Create 'contrast_trial_start' annotations with float target in extras."""
+ raw: mne.io.Raw,
+ target_field: str = "rt_from_stimulus",
+ epoch_length: float = 2.0,
+ require_stimulus: bool = True,
+ require_response: bool = True,
+) -> mne.io.Raw:
+ """Create trial annotations with a specified target value.
+
+ This function reads the BIDS events file associated with the `raw` object,
+ builds a trial table, and creates new MNE annotations for each trial.
+ The annotations are labeled "contrast_trial_start" and their `extras`
+ dictionary is populated with trial metrics, including a "target" key.
+
+ Parameters
+ ----------
+ raw : mne.io.Raw
+ The raw data object. Must have a single associated file name from
+ which the BIDS path can be derived.
+ target_field : str, default "rt_from_stimulus"
+ The column from the trial table to use as the "target" value in the
+ annotation extras.
+ epoch_length : float, default 2.0
+ The duration to set for each new annotation.
+ require_stimulus : bool, default True
+ If True, only include trials that have a recorded stimulus event.
+ require_response : bool, default True
+ If True, only include trials that have a recorded response event.
+
+ Returns
+ -------
+ mne.io.Raw
+ The `raw` object with the new annotations set.
+
+ Raises
+ ------
+ KeyError
+ If `target_field` is not a valid column in the built trial table.
+
+ """
fnames = raw.filenames
assert len(fnames) == 1, "Expected a single filename"
bids_path = get_bids_path_from_fname(fnames[0])
@@ -152,7 +204,6 @@ def annotate_trials_with_target(
extras = []
for i, v in enumerate(targets):
row = trials.iloc[i]
-
extras.append(
{
"target": _to_float_or_none(v),
@@ -169,14 +220,39 @@ def annotate_trials_with_target(
onset=onsets,
duration=durations,
description=descs,
- orig_time=raw.info["meas_date"],
+ orig_time=raw.info.get("meas_date"),
extras=extras,
)
raw.set_annotations(new_ann, verbose=False)
return raw
-def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor"):
+def add_aux_anchors(
+ raw: mne.io.Raw,
+ stim_desc: str = "stimulus_anchor",
+ resp_desc: str = "response_anchor",
+) -> mne.io.Raw:
+ """Add auxiliary annotations for stimulus and response onsets.
+
+ This function inspects existing "contrast_trial_start" annotations and
+ adds new, zero-duration "anchor" annotations at the precise onsets of
+ stimuli and responses for each trial.
+
+ Parameters
+ ----------
+ raw : mne.io.Raw
+ The raw data object with "contrast_trial_start" annotations.
+ stim_desc : str, default "stimulus_anchor"
+ The description for the new stimulus annotations.
+ resp_desc : str, default "response_anchor"
+ The description for the new response annotations.
+
+ Returns
+ -------
+ mne.io.Raw
+ The `raw` object with the auxiliary annotations added.
+
+ """
ann = raw.annotations
mask = ann.description == "contrast_trial_start"
if not np.any(mask):
@@ -189,28 +265,24 @@ def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor
ex = ann.extras[idx] if ann.extras is not None else {}
t0 = float(ann.onset[idx])
- stim_t = ex["stimulus_onset"]
- resp_t = ex["response_onset"]
+ stim_t = ex.get("stimulus_onset")
+ resp_t = ex.get("response_onset")
if stim_t is None or (isinstance(stim_t, float) and np.isnan(stim_t)):
- rtt = ex["rt_from_trialstart"]
- rts = ex["rt_from_stimulus"]
+ rtt = ex.get("rt_from_trialstart")
+ rts = ex.get("rt_from_stimulus")
if rtt is not None and rts is not None:
stim_t = t0 + float(rtt) - float(rts)
if resp_t is None or (isinstance(resp_t, float) and np.isnan(resp_t)):
- rtt = ex["rt_from_trialstart"]
+ rtt = ex.get("rt_from_trialstart")
if rtt is not None:
resp_t = t0 + float(rtt)
- if (stim_t is not None) and not (
- isinstance(stim_t, float) and np.isnan(stim_t)
- ):
+ if stim_t is not None and not (isinstance(stim_t, float) and np.isnan(stim_t)):
stim_onsets.append(float(stim_t))
stim_extras.append(dict(ex, anchor="stimulus"))
- if (resp_t is not None) and not (
- isinstance(resp_t, float) and np.isnan(resp_t)
- ):
+ if resp_t is not None and not (isinstance(resp_t, float) and np.isnan(resp_t)):
resp_onsets.append(float(resp_t))
resp_extras.append(dict(ex, anchor="response"))
@@ -220,7 +292,7 @@ def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor
onset=new_onsets,
duration=np.zeros_like(new_onsets, dtype=float),
description=[stim_desc] * len(stim_onsets) + [resp_desc] * len(resp_onsets),
- orig_time=raw.info["meas_date"],
+ orig_time=raw.info.get("meas_date"),
extras=stim_extras + resp_extras,
)
raw.set_annotations(ann + aux, verbose=False)
@@ -228,10 +300,10 @@ def add_aux_anchors(raw, stim_desc="stimulus_anchor", resp_desc="response_anchor
def add_extras_columns(
- windows_concat_ds,
- original_concat_ds,
- desc="contrast_trial_start",
- keys=(
+ windows_concat_ds: BaseConcatDataset,
+ original_concat_ds: BaseConcatDataset,
+ desc: str = "contrast_trial_start",
+ keys: tuple = (
"target",
"rt_from_stimulus",
"rt_from_trialstart",
@@ -240,7 +312,31 @@ def add_extras_columns(
"correct",
"response_type",
),
-):
+) -> BaseConcatDataset:
+ """Add columns from annotation extras to a windowed dataset's metadata.
+
+ This function propagates trial-level information stored in the `extras`
+ of annotations to the `metadata` DataFrame of a `WindowsDataset`.
+
+ Parameters
+ ----------
+ windows_concat_ds : BaseConcatDataset
+ The windowed dataset whose metadata will be updated.
+ original_concat_ds : BaseConcatDataset
+ The original (non-windowed) dataset containing the raw data and
+ annotations with the `extras` to be added.
+ desc : str, default "contrast_trial_start"
+ The description of the annotations to source the extras from.
+ keys : tuple, default (...)
+ The keys to extract from each annotation's `extras` dictionary and
+ add as columns to the metadata.
+
+ Returns
+ -------
+ BaseConcatDataset
+ The `windows_concat_ds` with updated metadata.
+
+ """
float_cols = {
"target",
"rt_from_stimulus",
@@ -292,7 +388,6 @@ def add_extras_columns(
else: # response_type
ser = pd.Series(vals, index=md.index, dtype="string")
- # Replace the whole column to avoid dtype conflicts
md[k] = ser
win_ds.metadata = md.reset_index(drop=True)
@@ -303,7 +398,25 @@ def add_extras_columns(
return windows_concat_ds
-def keep_only_recordings_with(desc, concat_ds):
+def keep_only_recordings_with(
+ desc: str, concat_ds: BaseConcatDataset
+) -> BaseConcatDataset:
+ """Filter a concatenated dataset to keep only recordings with a specific annotation.
+
+ Parameters
+ ----------
+ desc : str
+ The description of the annotation that must be present in a recording
+ for it to be kept.
+ concat_ds : BaseConcatDataset
+ The concatenated dataset to filter.
+
+ Returns
+ -------
+ BaseConcatDataset
+ A new concatenated dataset containing only the filtered recordings.
+
+ """
kept = []
for ds in concat_ds.datasets:
if np.any(ds.raw.annotations.description == desc):
diff --git a/eegdash/logging.py b/eegdash/logging.py
index fa62708c..0e92bc42 100644
--- a/eegdash/logging.py
+++ b/eegdash/logging.py
@@ -29,6 +29,25 @@
# Now, get your package-specific logger. It will inherit the
# configuration from the root logger we just set up.
logger = logging.getLogger("eegdash")
+"""The primary logger for the EEGDash package.
+
+This logger is configured to use :class:`rich.logging.RichHandler` for
+formatted, colorful output in the console. It inherits its base configuration
+from the root logger, which is set to the ``INFO`` level.
+
+Examples
+--------
+Usage in other modules:
+
+.. code-block:: python
+
+ from .logging import logger
+
+ logger.info("This is an informational message.")
+ logger.warning("This is a warning.")
+ logger.error("This is an error.")
+"""
+
logger.setLevel(logging.INFO)
diff --git a/eegdash/mongodb.py b/eegdash/mongodb.py
index d734530c..897e5643 100644
--- a/eegdash/mongodb.py
+++ b/eegdash/mongodb.py
@@ -4,50 +4,63 @@
"""MongoDB connection and operations management.
-This module provides thread-safe MongoDB connection management and high-level database
-operations for the EEGDash metadata database. It includes methods for finding, adding,
-and updating EEG data records with proper connection pooling and error handling.
+This module provides a thread-safe singleton manager for MongoDB connections,
+ensuring that connections to the database are handled efficiently and consistently
+across the application.
"""
import threading
from pymongo import MongoClient
-
-# MongoDB Operations
-# These methods provide a high-level interface to interact with the MongoDB
-# collection, allowing users to find, add, and update EEG data records.
-# - find:
-# - exist:
-# - add_request:
-# - add:
-# - update_request:
-# - remove_field:
-# - remove_field_from_db:
-# - close: Close the MongoDB connection.
-# - __del__: Destructor to close the MongoDB connection.
+from pymongo.collection import Collection
+from pymongo.database import Database
class MongoConnectionManager:
- """Singleton class to manage MongoDB client connections."""
+ """A thread-safe singleton to manage MongoDB client connections.
+
+ This class ensures that only one connection instance is created for each
+ unique combination of a connection string and staging flag. It provides
+ class methods to get a client and to close all active connections.
+
+ Attributes
+ ----------
+ _instances : dict
+ A dictionary to store singleton instances, mapping a
+ (connection_string, is_staging) tuple to a (client, db, collection)
+ tuple.
+ _lock : threading.Lock
+ A lock to ensure thread-safe instantiation of clients.
+
+ """
- _instances = {}
+ _instances: dict[tuple[str, bool], tuple[MongoClient, Database, Collection]] = {}
_lock = threading.Lock()
@classmethod
- def get_client(cls, connection_string: str, is_staging: bool = False):
- """Get or create a MongoDB client for the given connection string and staging flag.
+ def get_client(
+ cls, connection_string: str, is_staging: bool = False
+ ) -> tuple[MongoClient, Database, Collection]:
+ """Get or create a MongoDB client for the given connection parameters.
+
+ This method returns a cached client if one already exists for the given
+ connection string and staging flag. Otherwise, it creates a new client,
+ connects to the appropriate database ("eegdash" or "eegdashstaging"),
+ and returns the client, database, and "records" collection.
Parameters
----------
connection_string : str
- The MongoDB connection string
- is_staging : bool
- Whether to use staging database
+ The MongoDB connection string.
+ is_staging : bool, default False
+ If True, connect to the staging database ("eegdashstaging").
+ Otherwise, connect to the production database ("eegdash").
Returns
-------
- tuple
- A tuple of (client, database, collection)
+ tuple[MongoClient, Database, Collection]
+ A tuple containing the connected MongoClient instance, the Database
+ object, and the Collection object for the "records" collection.
"""
# Create a unique key based on connection string and staging flag
@@ -66,8 +79,12 @@ def get_client(cls, connection_string: str, is_staging: bool = False):
return cls._instances[key]
@classmethod
- def close_all(cls):
- """Close all MongoDB client connections."""
+ def close_all(cls) -> None:
+ """Close all managed MongoDB client connections.
+
+ This method iterates through all cached client instances and closes
+ their connections. It also clears the instance cache.
+ """
with cls._lock:
for client, _, _ in cls._instances.values():
try:
diff --git a/eegdash/paths.py b/eegdash/paths.py
index 495169a8..25655827 100644
--- a/eegdash/paths.py
+++ b/eegdash/paths.py
@@ -18,12 +18,21 @@
def get_default_cache_dir() -> Path:
- """Resolve a consistent default cache directory for EEGDash.
+ """Resolve the default cache directory for EEGDash data.
+
+ The function determines the cache directory based on the following
+ priority order:
+ 1. The path specified by the ``EEGDASH_CACHE_DIR`` environment variable.
+ 2. The path specified by the ``MNE_DATA`` configuration in the MNE-Python
+ config file.
+ 3. A hidden directory named ``.eegdash_cache`` in the current working
+ directory.
+
+ Returns
+ -------
+ pathlib.Path
+ The resolved, absolute path to the default cache directory.
- Priority order:
- 1) Environment variable ``EEGDASH_CACHE_DIR`` if set.
- 2) MNE config ``MNE_DATA`` if set (aligns with tests and ecosystem caches).
- 3) ``.eegdash_cache`` under the current working directory.
"""
# 1) Explicit env var wins
env_dir = os.environ.get("EEGDASH_CACHE_DIR")
diff --git a/eegdash/utils.py b/eegdash/utils.py
index 60f3466a..0a9c520e 100644
--- a/eegdash/utils.py
+++ b/eegdash/utils.py
@@ -11,7 +11,22 @@
from mne.utils import get_config, set_config, use_log_level
-def _init_mongo_client():
+def _init_mongo_client() -> None:
+ """Initialize the default MongoDB connection URI in the MNE config.
+
+ This function checks if the ``EEGDASH_DB_URI`` is already set in the
+ MNE-Python configuration. If it is not set, this function sets it to the
+ default public EEGDash MongoDB Atlas cluster URI.
+
+ The operation is performed with MNE's logging level temporarily set to
+ "ERROR" to suppress verbose output.
+
+ Notes
+ -----
+ This is an internal helper function and is not intended for direct use
+ by end-users.
+
+ """
with use_log_level("ERROR"):
if get_config("EEGDASH_DB_URI") is None:
set_config(
diff --git a/examples/eeg2025/tutorial_challenge_1.py b/examples/eeg2025/tutorial_challenge_1.py
index 669f9529..e74094ae 100644
--- a/examples/eeg2025/tutorial_challenge_1.py
+++ b/examples/eeg2025/tutorial_challenge_1.py
@@ -67,8 +67,8 @@
# Note: For simplicity purposes, we will only show how to do the decoding
# directly in our target task, and it is up to the teams to think about
# how to use the passive task to perform the pre-training.
-#
-######################################################################
+
+#######################################################################
# Install dependencies
# --------------------
# For the challenge, we will need two significant dependencies:
@@ -132,7 +132,7 @@
#
######################################################################
# The brain decodes the problem
-# =============================
+# -----------------------------
#
# Broadly speaking, here *brain decoding* is the following problem:
# given brain time-series signals :math:`X \in \mathbb{R}^{C \times T}` with
@@ -155,7 +155,6 @@
# is the temporal window length/epoch size over the interval of interest.
# Here, :math:`\theta` denotes the parameters learned by the neural network.
#
-# ------------------------------------------------------------------------------
# Input/Output definition
# ---------------------------
# For the competition, the HBN-EEG (Healthy Brain Network EEG Datasets)
@@ -194,8 +193,10 @@
# * The **ramp onset**, the **button press**, and the **feedback** are **time-locked events** that yield ERP-like components.
#
# Your task (**label**) is to predict the response time for the subject during this windows.
-######################################################################
+#
+#######################################################################
# In the figure below, we have the timeline representation of the cognitive task:
+#
# .. image:: https://eeg2025.github.io/assets/img/image-2.jpg
######################################################################
diff --git a/examples/eeg2025/tutorial_challenge_2.py b/examples/eeg2025/tutorial_challenge_2.py
index c7b7ee03..f154f585 100644
--- a/examples/eeg2025/tutorial_challenge_2.py
+++ b/examples/eeg2025/tutorial_challenge_2.py
@@ -319,9 +319,11 @@ def __getitem__(self, index):
# All the braindecode models expect the input to be of shape (batch_size, n_channels, n_times)
# and have a test coverage about the behavior of the model.
# However, you can use any pytorch model you want.
-########################################################################
+#
+######################################################################
# Initialize model
-# ----------------
+# -----------------
+
model = EEGNeX(n_chans=129, n_outputs=1, n_times=2 * SFREQ).to(device)
# Specify optimizer