From e52475f252469cc5f240bfec933ae7b9c54e59b0 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Sun, 12 Oct 2025 23:10:18 +0200 Subject: [PATCH 01/12] including tree plot --- docs/plot_dataset/__init__.py | 1 + docs/plot_dataset/treemap.py | 351 ++++++++++++++++++++++++ docs/prepare_summary_tables.py | 8 + docs/source/dataset_summary.rst | 4 + docs/source/dataset_summary/treemap.rst | 20 ++ treemap_plot.py | 51 ++++ 6 files changed, 435 insertions(+) create mode 100644 docs/plot_dataset/treemap.py create mode 100644 docs/source/dataset_summary/treemap.rst create mode 100644 treemap_plot.py diff --git a/docs/plot_dataset/__init__.py b/docs/plot_dataset/__init__.py index 85942823..93700636 100644 --- a/docs/plot_dataset/__init__.py +++ b/docs/plot_dataset/__init__.py @@ -11,3 +11,4 @@ ) from .plot_sankey import generate_dataset_sankey # noqa: F401 from .ridgeline import generate_modality_ridgeline # noqa: F401 +from .treemap import generate_dataset_treemap # noqa: F401 diff --git a/docs/plot_dataset/treemap.py b/docs/plot_dataset/treemap.py new file mode 100644 index 00000000..814089f8 --- /dev/null +++ b/docs/plot_dataset/treemap.py @@ -0,0 +1,351 @@ +from __future__ import annotations + +"""Utilities to generate the EEG Dash dataset treemap.""" + +from pathlib import Path +from typing import Iterable + +import math +import pandas as pd +import plotly.graph_objects as go + +try: # Allow import both as a package and as a script + from .colours import CANONICAL_MAP, MODALITY_COLOR_MAP, PATHOLOGY_COLOR_MAP +except ImportError: # pragma: no cover - fallback for direct script execution + from colours import CANONICAL_MAP, MODALITY_COLOR_MAP, PATHOLOGY_COLOR_MAP # type: ignore + +__all__ = ["generate_dataset_treemap"] + +_CATEGORY_COLUMNS = ( + ("Type Subject", "population_type"), + ("modality of exp", "experimental_modality"), +) + +_DATASET_COLUMN = "dataset" +_DATASET_ALIAS = "dataset_name" +_SEPARATORS = ("/", "|", ";", ",") +_DEFAULT_COLOR = "#94a3b8" + + +def _tokenise_cell(value: object, column_key: str) -> list[str]: + """Split multi-valued cells, normalise, and keep Unknown buckets.""" + if value is None or (isinstance(value, float) and math.isnan(value)): + tokens = [] + else: + text = str(value).strip() + if not text or text.lower() in {"nan", "none"}: + tokens = [] + else: + normalised = text + for sep in _SEPARATORS: + normalised = normalised.replace(sep, ",") + tokens = [tok.strip() for tok in normalised.split(",") if tok.strip()] + + if not tokens: + return ["Unknown"] + + canonical = CANONICAL_MAP.get(column_key, {}) + resolved: list[str] = [] + for token in tokens: + lowered = token.lower() + resolved.append(canonical.get(lowered, token)) + final = [tok if tok else "Unknown" for tok in resolved] + return final or ["Unknown"] + + +def _preprocess_dataframe(df: pd.DataFrame) -> pd.DataFrame: + required_columns = {orig for orig, _ in _CATEGORY_COLUMNS} | { + _DATASET_COLUMN, + "n_records", + "n_subjects", + "duration_hours_total", + } + missing = sorted(required_columns - set(df.columns)) + if missing: + raise KeyError(f"Missing required columns: {missing}") + + rename_map = { + "n_records": "records", + "n_subjects": "subjects", + "duration_hours_total": "duration_hours", + _DATASET_COLUMN: _DATASET_ALIAS, + } + for original, alias in _CATEGORY_COLUMNS: + rename_map[original] = alias + + renamed = df.rename(columns=rename_map) + columns_to_keep = [_DATASET_ALIAS] + [alias for _, alias in _CATEGORY_COLUMNS] + cleaned = renamed.loc[:, columns_to_keep].copy() + numeric = renamed[["records", "subjects", "duration_hours"]] + + cleaned[_DATASET_ALIAS] = ( + cleaned[_DATASET_ALIAS] + .astype(str) + .replace({"nan": "Unknown", "None": "Unknown", "": "Unknown"}) + .fillna("Unknown") + ) + + cleaned = cleaned.join(numeric) + + for original, alias in _CATEGORY_COLUMNS: + cleaned[alias] = cleaned[alias].map(lambda v: _tokenise_cell(v, original)) + cleaned = cleaned.explode(alias).reset_index(drop=True) + cleaned[alias] = cleaned[alias].fillna("Unknown") + + cleaned["records"] = pd.to_numeric(cleaned["records"], errors="coerce").fillna(0) + cleaned["subjects"] = pd.to_numeric(cleaned["subjects"], errors="coerce").fillna(0) + cleaned["duration_hours"] = pd.to_numeric( + cleaned["duration_hours"], errors="coerce" + ) + cleaned.loc[cleaned["duration_hours"] < 0, "duration_hours"] = pd.NA + + hours = cleaned["duration_hours"] + fallback_mask = hours.isna() | (hours <= 0) + cleaned["hours_from_records"] = 0.0 + cleaned.loc[fallback_mask, "hours_from_records"] = cleaned.loc[ + fallback_mask, "records" + ] + + cleaned["hours"] = hours.fillna(0) + cleaned.loc[fallback_mask, "hours"] = cleaned.loc[fallback_mask, "records"] + cleaned["hours"] = cleaned["hours"].fillna(0).clip(lower=0) + + cleaned["records"] = cleaned["records"].clip(lower=0) + cleaned["subjects"] = cleaned["subjects"].clip(lower=0) + cleaned["hours_from_records"] = cleaned["hours_from_records"].clip(lower=0) + + return cleaned[ + [ + "population_type", + "experimental_modality", + "dataset_name", + "hours", + "records", + "subjects", + "hours_from_records", + ] + ] + + +def _abbreviate(value: float | int) -> str: + try: + num = float(value) + except (TypeError, ValueError): + return "0" + + if not math.isfinite(num): + return "0" + if num == 0: + return "0" + + thresholds = [ + (1_000_000_000, "B"), + (1_000_000, "M"), + (1_000, "k"), + ] + for divisor, suffix in thresholds: + if abs(num) >= divisor: + scaled = num / divisor + text = f"{scaled:.1f}".rstrip("0").rstrip(".") + return f"{text}{suffix}" + return f"{num:.0f}" + + +def _filter_zero_nodes(df: pd.DataFrame, column: str) -> pd.DataFrame: + mask = (df["hours"] > 0) | (df[column] == "Unknown") + return df.loc[mask].copy() + + +def _format_label( + name: str, + hours: float | int, + records: float | int, + hours_from_records: float | int, +) -> str: + area_value = float(hours) if pd.notna(hours) else 0.0 + records_value = float(records) if pd.notna(records) else 0.0 + fallback_value = float(hours_from_records) if pd.notna(hours_from_records) else 0.0 + + unit = " record" if math.isclose(area_value, fallback_value, rel_tol=1e-6) else "" + area_text = f"{area_value:.0f}" + records_text = _abbreviate(records_value) + return ( + f"{name}
{area_text}{unit}" + f" | {records_text} rec" + ) + + +def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]: + dataset_level = dataset_level.sort_values( + ["population_type", "experimental_modality", "dataset_name"] + ).reset_index(drop=True) + + level2 = dataset_level.groupby( + ["population_type", "experimental_modality"], dropna=False, as_index=False + ).agg( + hours=("hours", "sum"), + records=("records", "sum"), + subjects=("subjects", "sum"), + hours_from_records=("hours_from_records", "sum"), + ) + level2 = _filter_zero_nodes(level2, "experimental_modality") + + level1 = level2.groupby(["population_type"], dropna=False, as_index=False).agg( + hours=("hours", "sum"), + records=("records", "sum"), + subjects=("subjects", "sum"), + hours_from_records=("hours_from_records", "sum"), + ) + level1 = _filter_zero_nodes(level1, "population_type") + + nodes: list[dict[str, object]] = [] + + total_hours = level1["hours"].sum() + total_records = level1["records"].sum() + total_from_records = level1["hours_from_records"].sum() + + root_label = _format_label( + "EEG Dash Datasets", + total_hours, + total_records, + total_from_records, + ) + nodes.append( + { + "id": "EEG Dash datasets", + "parent": "", + "name": "EEG Dash datasets", + "text": root_label, + "value": float(total_hours), + "color": "white", + "hover": root_label, + } + ) + + for _, row in level1.iterrows(): + name = row["population_type"] or "Unknown" + node_id = name + label = _format_label( + name, + row["hours"], + row["records"], + row["hours_from_records"], + ) + color = PATHOLOGY_COLOR_MAP.get(name, _DEFAULT_COLOR) + nodes.append( + { + "id": node_id, + "parent": "EEG Dash datasets", + "name": name, + "text": label, + "value": float(row["hours"]), + "color": color, + "hover": label, + } + ) + + for _, row in level2.iterrows(): + modality = row["experimental_modality"] or "Unknown" + parent = row["population_type"] or "Unknown" + node_id = f"{parent} / {modality}" + label = _format_label( + modality, + row["hours"], + row["records"], + row["hours_from_records"], + ) + color = MODALITY_COLOR_MAP.get(modality, _DEFAULT_COLOR) + nodes.append( + { + "id": node_id, + "parent": parent, + "name": modality, + "text": label, + "value": float(row["hours"]), + "color": color, + "hover": label, + } + ) + + dataset_level = _filter_zero_nodes(dataset_level, "dataset_name") + for _, row in dataset_level.iterrows(): + dataset_name = row["dataset_name"] or "Unknown" + modality = row["experimental_modality"] or "Unknown" + parent = f"{row['population_type']} / {modality}" + node_id = f"{parent} / {dataset_name}" + label = _format_label( + dataset_name, + row["hours"], + row["records"], + row["hours_from_records"], + ) + color = MODALITY_COLOR_MAP.get(modality, _DEFAULT_COLOR) + nodes.append( + { + "id": node_id, + "parent": parent, + "name": dataset_name, + "text": label, + "value": float(row["hours"]), + "color": color, + "hover": label, + } + ) + + return nodes + + +def _build_figure(nodes: Iterable[dict[str, object]]) -> go.Figure: + node_list = list(nodes) + if not node_list: + raise ValueError("No data available to render the treemap.") + + return go.Figure( + go.Treemap( + ids=[node["id"] for node in node_list], + labels=[node["name"] for node in node_list], + parents=[node["parent"] for node in node_list], + values=[node["value"] for node in node_list], + text=[node["text"] for node in node_list], + customdata=[[node["hover"]] for node in node_list], + branchvalues="total", + marker=dict( + colors=[node["color"] for node in node_list], + line=dict(color="white", width=2), + ), + textinfo="text", + hovertemplate="%{customdata[0]}", + pathbar=dict(visible=True, edgeshape="/"), + ) + ) + + +def generate_dataset_treemap( + df: pd.DataFrame, + out_html: str | Path, +) -> Path: + """Generate the dataset treemap and return the output path.""" + cleaned = _preprocess_dataframe(df) + aggregated = cleaned.groupby( + ["population_type", "experimental_modality", "dataset_name"], + dropna=False, + as_index=False, + ).agg( + hours=("hours", "sum"), + records=("records", "sum"), + subjects=("subjects", "sum"), + hours_from_records=("hours_from_records", "sum"), + ) + + aggregated = _filter_zero_nodes(aggregated, "dataset_name") + nodes = _build_nodes(aggregated) + fig = _build_figure(nodes) + fig.update_layout( + uniformtext=dict(minsize=10, mode="hide"), + margin=dict(t=20, l=10, r=10, b=10), + ) + + out_path = Path(out_html) + out_path.parent.mkdir(parents=True, exist_ok=True) + fig.write_html(out_path, include_plotlyjs="cdn", full_html=True) + return out_path diff --git a/docs/prepare_summary_tables.py b/docs/prepare_summary_tables.py index a3ebab06..01188278 100644 --- a/docs/prepare_summary_tables.py +++ b/docs/prepare_summary_tables.py @@ -10,6 +10,7 @@ from plot_dataset import ( generate_dataset_bubble, generate_dataset_sankey, + generate_dataset_treemap, generate_modality_ridgeline, ) from plot_dataset.utils import get_dataset_url, human_readable_size @@ -349,6 +350,13 @@ def main(source_dir: str, target_dir: str): except Exception as exc: print(f"[dataset Sankey] Skipped due to error: {exc}") + try: + treemap_path = target_dir / "dataset_treemap.html" + treemap_output = generate_dataset_treemap(df_raw, treemap_path) + copyfile(treemap_output, STATIC_DATASET_DIR / treemap_output.name) + except Exception as exc: + print(f"[dataset Treemap] Skipped due to error: {exc}") + df = prepare_table(df_raw) # preserve int values df["n_subjects"] = df["n_subjects"].astype(int) diff --git a/docs/source/dataset_summary.rst b/docs/source/dataset_summary.rst index c3ef7a33..17ed1465 100644 --- a/docs/source/dataset_summary.rst +++ b/docs/source/dataset_summary.rst @@ -33,6 +33,10 @@ To leverage recent and ongoing advancements in large-scale computational methods .. include:: dataset_summary/sankey.rst + .. tab-item:: Dataset Treemap + + .. include:: dataset_summary/treemap.rst + .. tab-item:: Scatter of Sample Size vs. Recording Duration .. include:: dataset_summary/bubble.rst diff --git a/docs/source/dataset_summary/treemap.rst b/docs/source/dataset_summary/treemap.rst new file mode 100644 index 00000000..9790ecd5 --- /dev/null +++ b/docs/source/dataset_summary/treemap.rst @@ -0,0 +1,20 @@ +.. title:: Dataset treemap + +.. rubric:: Dataset treemap + +.. raw:: html + +
+ +.. raw:: html + :file: ../_static/dataset_generated/dataset_treemap.html + +.. raw:: html + +
+ Figure: Treemap of EEG Dash datasets. The top level groups population type, + the second level breaks down experimental modality, and leaves list individual datasets. + Hover to view aggregated hours (or records when unavailable) and record counts. +
+
+ diff --git a/treemap_plot.py b/treemap_plot.py new file mode 100644 index 00000000..39a9d909 --- /dev/null +++ b/treemap_plot.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 + +"""Build a Plotly treemap for the EEG/MEG dataset summary CSV.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import pandas as pd + + +def _find_dataset_csv(base_path: Path) -> Path: + """Return the dataset summary path, preferring a CSV next to the script.""" + local_csv = base_path / "dataset_summary.csv" + if local_csv.exists(): + return local_csv + + bundled_csv = base_path / "eegdash" / "dataset" / "dataset_summary.csv" + if bundled_csv.exists(): + return bundled_csv + + msg = "dataset_summary.csv not found next to the script or under eegdash/dataset." + raise FileNotFoundError(msg) + + +def main() -> None: + base_path = Path(__file__).resolve().parent + docs_dir = base_path / "docs" + if docs_dir.exists(): + sys.path.insert(0, str(docs_dir)) + + try: + from plot_dataset.treemap import generate_dataset_treemap + except ImportError as exc: # pragma: no cover - guard for CLI usage + raise SystemExit(f"Unable to import treemap generator: {exc}") from exc + + dataset_csv = _find_dataset_csv(base_path) + df = pd.read_csv(dataset_csv) + + output = base_path / "treemap.html" + generate_dataset_treemap(df, output) + print(f"Treemap saved to {output}") + + +if __name__ == "__main__": + try: + main() + except Exception as exc: # pragma: no cover - guard for CLI usage + print(f"Error: {exc}", file=sys.stderr) + sys.exit(1) From ef86eac04591da2ee00debb4389e91e28c11b276 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Sun, 12 Oct 2025 23:21:16 +0200 Subject: [PATCH 02/12] updating the treemap --- docs/source/dataset_summary/treemap.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/dataset_summary/treemap.rst b/docs/source/dataset_summary/treemap.rst index 9790ecd5..781895fa 100644 --- a/docs/source/dataset_summary/treemap.rst +++ b/docs/source/dataset_summary/treemap.rst @@ -14,7 +14,6 @@
Figure: Treemap of EEG Dash datasets. The top level groups population type, the second level breaks down experimental modality, and leaves list individual datasets. - Hover to view aggregated hours (or records when unavailable) and record counts. + Tile area encodes the total number of subjects; hover to view aggregated hours (or records when unavailable).
- From 01f19c56c10ff74d021839dfe2795dabfa0f80b4 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 13 Oct 2025 00:55:56 +0200 Subject: [PATCH 03/12] first iteration of the treemap and page stats on the top --- docs/plot_dataset/treemap.py | 140 ++++++++++++++++++++++------ docs/source/_static/css/treemap.css | 3 + docs/source/_static/custom.css | 116 +++++++++++++++++++++++ docs/source/conf.py | 104 +++++++++++++++++++++ docs/source/dataset_summary.rst | 55 +++++++++++ 5 files changed, 392 insertions(+), 26 deletions(-) create mode 100644 docs/source/_static/css/treemap.css diff --git a/docs/plot_dataset/treemap.py b/docs/plot_dataset/treemap.py index 814089f8..8f5875c3 100644 --- a/docs/plot_dataset/treemap.py +++ b/docs/plot_dataset/treemap.py @@ -10,9 +10,19 @@ import plotly.graph_objects as go try: # Allow import both as a package and as a script - from .colours import CANONICAL_MAP, MODALITY_COLOR_MAP, PATHOLOGY_COLOR_MAP + from .colours import ( + CANONICAL_MAP, + MODALITY_COLOR_MAP, + PATHOLOGY_COLOR_MAP, + hex_to_rgba, + ) except ImportError: # pragma: no cover - fallback for direct script execution - from colours import CANONICAL_MAP, MODALITY_COLOR_MAP, PATHOLOGY_COLOR_MAP # type: ignore + from colours import ( # type: ignore + CANONICAL_MAP, + MODALITY_COLOR_MAP, + PATHOLOGY_COLOR_MAP, + hex_to_rgba, + ) __all__ = ["generate_dataset_treemap"] @@ -26,6 +36,18 @@ _SEPARATORS = ("/", "|", ";", ",") _DEFAULT_COLOR = "#94a3b8" +MODALITY_EMOJI = { + "Visual": "👁️", + "Auditory": "👂", + "Sleep": "🌙", + "Multisensory": "🧩", + "Tactile": "✋", + "Motor": "🏃", + "Resting State": "🧘", + "Rest": "🧘", + "Other": "🧭", +} + def _tokenise_cell(value: object, column_key: str) -> list[str]: """Split multi-valued cells, normalise, and keep Unknown buckets.""" @@ -152,26 +174,34 @@ def _abbreviate(value: float | int) -> str: def _filter_zero_nodes(df: pd.DataFrame, column: str) -> pd.DataFrame: - mask = (df["hours"] > 0) | (df[column] == "Unknown") + mask = (df["subjects"] > 0) | (df[column] == "Unknown") return df.loc[mask].copy() def _format_label( name: str, + subjects: float | int, hours: float | int, records: float | int, hours_from_records: float | int, + *, + font_px: int = 13, ) -> str: - area_value = float(hours) if pd.notna(hours) else 0.0 + subjects_value = float(subjects) if pd.notna(subjects) else 0.0 + hours_value = float(hours) if pd.notna(hours) else 0.0 records_value = float(records) if pd.notna(records) else 0.0 fallback_value = float(hours_from_records) if pd.notna(hours_from_records) else 0.0 - unit = " record" if math.isclose(area_value, fallback_value, rel_tol=1e-6) else "" - area_text = f"{area_value:.0f}" - records_text = _abbreviate(records_value) + subjects_text = _abbreviate(subjects_value) + if hours_value > 0: + secondary_text = f"{hours_value:.0f} h" + elif fallback_value > 0: + secondary_text = f"{_abbreviate(records_value)} rec" + else: + secondary_text = "0 h" return ( - f"{name}
{area_text}{unit}" - f" | {records_text} rec" + f"{name}
{subjects_text} subj" + f" | {secondary_text}" ) @@ -199,16 +229,20 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]: level1 = _filter_zero_nodes(level1, "population_type") nodes: list[dict[str, object]] = [] + level1_meta: list[dict[str, str]] = [] + total_subjects = level1["subjects"].sum() total_hours = level1["hours"].sum() total_records = level1["records"].sum() total_from_records = level1["hours_from_records"].sum() root_label = _format_label( "EEG Dash Datasets", + total_subjects, total_hours, total_records, total_from_records, + font_px=18, ) nodes.append( { @@ -216,7 +250,7 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]: "parent": "", "name": "EEG Dash datasets", "text": root_label, - "value": float(total_hours), + "value": float(total_subjects), "color": "white", "hover": root_label, } @@ -227,18 +261,24 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]: node_id = name label = _format_label( name, + row["subjects"], row["hours"], row["records"], row["hours_from_records"], + font_px=16, ) - color = PATHOLOGY_COLOR_MAP.get(name, _DEFAULT_COLOR) + base_color = PATHOLOGY_COLOR_MAP.get(name) + if not base_color: + base_color = PATHOLOGY_COLOR_MAP.get("Clinical", _DEFAULT_COLOR) + color = hex_to_rgba(base_color, alpha=0.75) + level1_meta.append({"name": name, "color": base_color}) nodes.append( { "id": node_id, "parent": "EEG Dash datasets", "name": name, "text": label, - "value": float(row["hours"]), + "value": float(row["subjects"]), "color": color, "hover": label, } @@ -248,20 +288,26 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]: modality = row["experimental_modality"] or "Unknown" parent = row["population_type"] or "Unknown" node_id = f"{parent} / {modality}" + modality_label = modality + emoji = MODALITY_EMOJI.get(modality) + if emoji: + modality_label = f"{emoji} {modality}" label = _format_label( - modality, + modality_label, + row["subjects"], row["hours"], row["records"], row["hours_from_records"], + font_px=16, ) color = MODALITY_COLOR_MAP.get(modality, _DEFAULT_COLOR) nodes.append( { "id": node_id, "parent": parent, - "name": modality, + "name": modality_label, "text": label, - "value": float(row["hours"]), + "value": float(row["subjects"]), "color": color, "hover": label, } @@ -275,32 +321,41 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]: node_id = f"{parent} / {dataset_name}" label = _format_label( dataset_name, + row["subjects"], row["hours"], row["records"], row["hours_from_records"], + font_px=16, ) - color = MODALITY_COLOR_MAP.get(modality, _DEFAULT_COLOR) + _ = row["population_type"] or "Unknown" + if dataset_name == "Unknown": + color = _DEFAULT_COLOR + else: + color = MODALITY_COLOR_MAP.get(modality, _DEFAULT_COLOR) nodes.append( { "id": node_id, "parent": parent, "name": dataset_name, "text": label, - "value": float(row["hours"]), + "value": float(row["subjects"]), "color": color, "hover": label, } ) - return nodes + return nodes, level1_meta -def _build_figure(nodes: Iterable[dict[str, object]]) -> go.Figure: +def _build_figure( + nodes: Iterable[dict[str, object]], + legend_entries: Iterable[dict[str, str]], +) -> go.Figure: node_list = list(nodes) if not node_list: raise ValueError("No data available to render the treemap.") - return go.Figure( + fig = go.Figure( go.Treemap( ids=[node["id"] for node in node_list], labels=[node["name"] for node in node_list], @@ -311,14 +366,46 @@ def _build_figure(nodes: Iterable[dict[str, object]]) -> go.Figure: branchvalues="total", marker=dict( colors=[node["color"] for node in node_list], - line=dict(color="white", width=2), + line=dict(color="white", width=1), + pad=dict(t=6, r=6, b=6, l=6), ), textinfo="text", hovertemplate="%{customdata[0]}", - pathbar=dict(visible=True, edgeshape="/"), + pathbar=dict(visible=True, edgeshape="/", thickness=34), + textfont=dict(size=24), + insidetextfont=dict(size=24), + tiling=dict(pad=6, packing="squarify"), + root=dict(color="rgba(255,255,255,0.95)"), ) ) + for entry in legend_entries: + fig.add_trace( + go.Scatter( + x=[None], + y=[None], + mode="markers", + marker=dict(size=14, symbol="square", color=entry["color"]), + name=entry["name"], + showlegend=True, + hoverinfo="skip", + ) + ) + + fig.update_layout( + legend=dict( + orientation="h", + yanchor="bottom", + y=1.08, + xanchor="left", + x=0.0, + font=dict(size=14), + itemwidth=80, + ) + ) + + return fig + def generate_dataset_treemap( df: pd.DataFrame, @@ -338,11 +425,12 @@ def generate_dataset_treemap( ) aggregated = _filter_zero_nodes(aggregated, "dataset_name") - nodes = _build_nodes(aggregated) - fig = _build_figure(nodes) + nodes, legend_entries = _build_nodes(aggregated) + fig = _build_figure(nodes, legend_entries) fig.update_layout( - uniformtext=dict(minsize=10, mode="hide"), - margin=dict(t=20, l=10, r=10, b=10), + uniformtext=dict(minsize=18, mode="hide"), + margin=dict(t=140, l=24, r=24, b=16), + hoverlabel=dict(font_size=16), ) out_path = Path(out_html) diff --git a/docs/source/_static/css/treemap.css b/docs/source/_static/css/treemap.css new file mode 100644 index 00000000..828c26a0 --- /dev/null +++ b/docs/source/_static/css/treemap.css @@ -0,0 +1,3 @@ +.dataset-summary-article .eegdash-figure iframe { + min-height: 680px; +} diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css index 65d4d5dc..4b26cd5f 100644 --- a/docs/source/_static/custom.css +++ b/docs/source/_static/custom.css @@ -140,6 +140,122 @@ figure.eegdash-figure figcaption.eegdash-caption { text-align: center; } +/* Dataset counters on treemap page */ +.dataset-counter-grid { + display: grid; + gap: 1.25rem; + margin: 1.5rem 0 2.5rem; + grid-template-columns: repeat(auto-fit, minmax(210px, 1fr)); +} + +.dataset-counter-card { + display: flex; + align-items: center; + gap: 1rem; + padding: 1.25rem 1.5rem; + border-radius: 1rem; + background-color: var(--sd-color-surface, var(--pst-color-background, #ffffff)); + border: 1px solid rgba(15, 23, 42, 0.08); + box-shadow: 0 14px 32px rgba(15, 23, 42, 0.08); +} + +html[data-theme="dark"] .dataset-counter-card { + background-color: rgba(15, 23, 42, 0.65); + border-color: rgba(148, 163, 184, 0.28); + box-shadow: 0 16px 36px rgba(8, 8, 8, 0.55); +} + +.dataset-counter-icon { + width: 3.25rem; + height: 3.25rem; + border-radius: 50%; + display: grid; + place-items: center; +} + +.dataset-counter-svg { + width: 2rem; + height: 2rem; +} + +.dataset-counter-body { + display: flex; + flex-direction: column; + gap: 0.3rem; +} + +.dataset-counter-label { + font-size: 0.9rem; + letter-spacing: 0.08em; + text-transform: uppercase; + font-weight: 600; + color: var(--pst-color-text-muted, #6b7280); +} + +.dataset-counter-value { + font-family: "Sora", var(--pst-font-family-heading, system-ui), sans-serif; + font-weight: 700; + font-size: clamp(1.7rem, 2.4vw, 2.25rem); + color: var(--pst-color-text, #0f172a); +} + +html[data-theme="dark"] .dataset-counter-value { + color: #f9fafb; +} + +.dataset-counter-card:nth-child(1) .dataset-counter-icon { + background: rgba(10, 111, 182, 0.12); + color: #0a6fb6; +} + +.dataset-counter-card:nth-child(2) .dataset-counter-icon { + background: rgba(124, 58, 237, 0.14); + color: #7c3aed; +} + +.dataset-counter-card:nth-child(3) .dataset-counter-icon { + background: rgba(5, 150, 105, 0.14); + color: #059669; +} + +.dataset-counter-card:nth-child(4) .dataset-counter-icon { + background: rgba(245, 158, 11, 0.16); + color: #d97706; +} + +html[data-theme="dark"] .dataset-counter-card:nth-child(1) .dataset-counter-icon { + background: rgba(10, 111, 182, 0.28); +} + +html[data-theme="dark"] .dataset-counter-card:nth-child(2) .dataset-counter-icon { + background: rgba(124, 58, 237, 0.28); +} + +html[data-theme="dark"] .dataset-counter-card:nth-child(3) .dataset-counter-icon { + background: rgba(5, 150, 105, 0.28); +} + +html[data-theme="dark"] .dataset-counter-card:nth-child(4) .dataset-counter-icon { + background: rgba(245, 158, 11, 0.32); +} + +@media (max-width: 640px) { + .dataset-counter-card { + padding: 1.1rem 1.25rem; + gap: 0.85rem; + } + + .dataset-counter-icon { + width: 3rem; + height: 3rem; + } + + .dataset-counter-svg { + width: 1.75rem; + height: 1.75rem; + } +} + /* Make the DataTables filter input and buttons match size */ .dataTables_wrapper .dataTables_filter input { font-size: var(--pst-font-size-base, 1rem); diff --git a/docs/source/conf.py b/docs/source/conf.py index 0a357598..799e8ebc 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -3,6 +3,7 @@ import inspect import os import sys +import shutil from collections import Counter from datetime import datetime, timezone from pathlib import Path @@ -78,6 +79,7 @@ "https://cdn.datatables.net/select/1.7.0/css/select.dataTables.min.css", "https://cdn.datatables.net/searchpanes/2.3.1/css/searchPanes.dataTables.min.css", "custom.css", + "css/treemap.css", ] html_js_files = [ "https://code.jquery.com/jquery-3.7.1.min.js", @@ -602,6 +604,106 @@ def _generate_dataset_docs(app) -> None: continue +def _split_tokens(value: str | None) -> set[str]: + tokens: set[str] = set() + if not value: + return tokens + for part in value.split(","): + cleaned = part.strip() + if cleaned: + tokens.add(cleaned) + return tokens + + +def _compute_dataset_counter_defaults() -> dict[str, int]: + csv_path = Path(importlib.import_module("eegdash.dataset").__file__).with_name( + "dataset_summary.csv" + ) + if not csv_path.exists(): + return {} + + dataset_ids: set[str] = set() + modalities: set[str] = set() + cognitive: set[str] = set() + subject_total = 0 + + with csv_path.open(encoding="utf-8") as handle: + filtered = ( + line + for line in handle + if line.strip() and not line.lstrip().startswith("#") + ) + reader = csv.DictReader(filtered) + for row in reader: + dataset = (row.get("dataset") or row.get("Dataset") or "").strip() + if dataset: + dataset_ids.add(dataset) + + try: + subject_total += int(float(row.get("n_subjects", "0") or 0)) + except (TypeError, ValueError): + pass + + modalities.update(_split_tokens(row.get("record_modality"))) + cognitive.update(_split_tokens(row.get("type of exp"))) + + return { + "datasets": len(dataset_ids), + "subjects": subject_total, + "modalities": len(modalities), + "cognitive": len(cognitive), + } + + +_DATASET_COUNTER_DEFAULTS = _compute_dataset_counter_defaults() + + +def _format_counter(key: str) -> str: + value = _DATASET_COUNTER_DEFAULTS.get(key, 0) + if isinstance(value, (int, float)): + if isinstance(value, float) and not value.is_integer(): + return f"{value:,.2f}" + return f"{int(value):,}" + return str(value) + + +_DATASET_COUNTER_PLACEHOLDERS = { + "|datasets_total|": _format_counter("datasets"), + "|subjects_total|": _format_counter("subjects"), + "|modalities_total|": _format_counter("modalities"), + "|cognitive_total|": _format_counter("cognitive"), +} + + +def _copy_dataset_summary(app, exception) -> None: + if exception is not None or not getattr(app, "builder", None): + return + + csv_path = Path(importlib.import_module("eegdash.dataset").__file__).with_name( + "dataset_summary.csv" + ) + if not csv_path.exists(): + LOGGER.warning("dataset_summary.csv not found; skipping counter data copy.") + return + + static_dir = Path(app.outdir) / "_static" + try: + static_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(csv_path, static_dir / "dataset_summary.csv") + except OSError as exc: + LOGGER.warning("Unable to copy dataset_summary.csv to _static: %s", exc) + + +def _inject_counter_values(app, docname, source) -> None: + if docname != "dataset_summary": + return + + text = source[0] + for token, value in _DATASET_COUNTER_PLACEHOLDERS.items(): + text = text.replace(token, value) + source[0] = text + + def setup(app): """Create the back-references directory if it doesn't exist.""" backreferences_dir = os.path.join( @@ -611,6 +713,8 @@ def setup(app): os.makedirs(backreferences_dir) app.connect("builder-inited", _generate_dataset_docs) + app.connect("build-finished", _copy_dataset_summary) + app.connect("source-read", _inject_counter_values) # Configure sitemap URL format (omit .html where possible) diff --git a/docs/source/dataset_summary.rst b/docs/source/dataset_summary.rst index 17ed1465..f49a9043 100644 --- a/docs/source/dataset_summary.rst +++ b/docs/source/dataset_summary.rst @@ -15,6 +15,61 @@ Datasets Catalog To leverage recent and ongoing advancements in large-scale computational methods and to ensure the preservation of scientific data generated from publicly funded research, the EEG-DaSh data archive will create a data-sharing resource for MEEG (EEG, MEG) data contributed by collaborators for machine learning (ML) and deep learning (DL) applications. +.. raw:: html + +
+
+ +
+ Datasets + |datasets_total| +
+
+
+ +
+ Subjects + |subjects_total| +
+
+
+ +
+ Experiment Modalities + |modalities_total| +
+
+
+ +
+ Cognitive Domains + |cognitive_total| +
+
+
+ .. raw:: html From 84f6207c23abeb92ece07817c7dcc357f4626f2e Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 13 Oct 2025 01:05:08 +0200 Subject: [PATCH 04/12] better config --- docs/source/conf.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 799e8ebc..ac071c8c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -223,11 +223,21 @@ def linkcode_resolve(domain, info): sphinx_gallery_conf = { "examples_dirs": [f"{EX_DIR}"], "gallery_dirs": ["generated/auto_examples"], + "binder": { + "org": "sccn", + "repo": "EEGDash", + "branch": "main", + "binderhub_url": "https://mybinder.org", + "dependencies": "binder/requirements.txt", + "notebooks_dir": "notebooks", + "use_jupyter_lab": True, + }, + "capture_repr": ("_repr_html_", "__repr__"), "nested_sections": False, "backreferences_dir": "gen_modules/backreferences", "inspect_global_variables": True, "show_memory": True, - "show_api_usage": False, + "show_api_usage": True, "doc_module": ("eegdash", "numpy", "scipy", "matplotlib"), "reference_url": {"eegdash": None}, "filename_pattern": r"/(?:plot|tutorial)_(?!_).*\.py", From 9098b5faf954aa87066331b4d2c91ef673f29fa6 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 13 Oct 2025 01:25:11 +0200 Subject: [PATCH 05/12] updating the __all__ --- eegdash/features/datasets.py | 6 + eegdash/features/decorators.py | 9 ++ eegdash/features/extractors.py | 10 ++ eegdash/features/feature_bank/signal.py | 21 +-- eegdash/features/feature_bank/utils.py | 9 ++ eegdash/features/inspect.py | 9 ++ eegdash/features/serialization.py | 5 + eegdash/features/utils.py | 6 + scripts/1_nemar_dataset_list.py | 186 ++++++++++++++++++++++++ 9 files changed, 251 insertions(+), 10 deletions(-) create mode 100755 scripts/1_nemar_dataset_list.py diff --git a/eegdash/features/datasets.py b/eegdash/features/datasets.py index 9e933c59..99099c4d 100644 --- a/eegdash/features/datasets.py +++ b/eegdash/features/datasets.py @@ -19,6 +19,12 @@ from ..logging import logger +__all__ = [ + "FeaturesDataset", + "FeaturesConcatDataset", +] + + class FeaturesDataset(EEGWindowsDataset): """A dataset of features extracted from EEG windows. diff --git a/eegdash/features/decorators.py b/eegdash/features/decorators.py index 687d57fe..de8b9ad9 100644 --- a/eegdash/features/decorators.py +++ b/eegdash/features/decorators.py @@ -11,6 +11,15 @@ ) +__all__ = [ + "bivariate_feature", + "FeatureKind", + "FeaturePredecessor", + "multivariate_feature", + "univariate_feature", +] + + class FeaturePredecessor: """A decorator to specify parent extractors for a feature function. diff --git a/eegdash/features/extractors.py b/eegdash/features/extractors.py index c0bb0759..fbd3d663 100644 --- a/eegdash/features/extractors.py +++ b/eegdash/features/extractors.py @@ -9,6 +9,16 @@ from numba.core.dispatcher import Dispatcher +__all__ = [ + "BivariateFeature", + "DirectedBivariateFeature", + "FeatureExtractor", + "MultivariateFeature", + "TrainableFeature", + "UnivariateFeature", +] + + def _get_underlying_func(func: Callable) -> Callable: """Get the underlying function from a potential wrapper. diff --git a/eegdash/features/feature_bank/signal.py b/eegdash/features/feature_bank/signal.py index 42601e93..df8fdb4e 100644 --- a/eegdash/features/feature_bank/signal.py +++ b/eegdash/features/feature_bank/signal.py @@ -8,20 +8,21 @@ __all__ = [ "HilbertFeatureExtractor", - "signal_mean", - "signal_variance", - "signal_skewness", + "SIGNAL_PREDECESSORS", + "signal_decorrelation_time", + "signal_hjorth_activity", + "signal_hjorth_complexity", + "signal_hjorth_mobility", "signal_kurtosis", - "signal_std", - "signal_root_mean_square", + "signal_line_length", + "signal_mean", "signal_peak_to_peak", "signal_quantile", + "signal_root_mean_square", + "signal_skewness", + "signal_std", + "signal_variance", "signal_zero_crossings", - "signal_line_length", - "signal_hjorth_activity", - "signal_hjorth_mobility", - "signal_hjorth_complexity", - "signal_decorrelation_time", ] diff --git a/eegdash/features/feature_bank/utils.py b/eegdash/features/feature_bank/utils.py index 7aa29bc8..7d56a4a8 100644 --- a/eegdash/features/feature_bank/utils.py +++ b/eegdash/features/feature_bank/utils.py @@ -1,5 +1,14 @@ import numpy as np + +__all__ = [ + "DEFAULT_FREQ_BANDS", + "get_valid_freq_band", + "reduce_freq_bands", + "slice_freq_band", +] + + DEFAULT_FREQ_BANDS = { "delta": (1, 4.5), "theta": (4.5, 8), diff --git a/eegdash/features/inspect.py b/eegdash/features/inspect.py index 395a3326..c07d5c26 100644 --- a/eegdash/features/inspect.py +++ b/eegdash/features/inspect.py @@ -7,6 +7,15 @@ from .extractors import FeatureExtractor, MultivariateFeature, _get_underlying_func +__all__ = [ + "get_all_feature_extractors", + "get_all_feature_kinds", + "get_all_features", + "get_feature_kind", + "get_feature_predecessors", +] + + def get_feature_predecessors(feature_or_extractor: Callable) -> list: """Get the dependency hierarchy for a feature or feature extractor. diff --git a/eegdash/features/serialization.py b/eegdash/features/serialization.py index 75ebdc34..969e40c8 100644 --- a/eegdash/features/serialization.py +++ b/eegdash/features/serialization.py @@ -19,6 +19,11 @@ from .datasets import FeaturesConcatDataset, FeaturesDataset +__all__ = [ + "load_features_concat_dataset", +] + + def load_features_concat_dataset( path: str | Path, ids_to_load: list[int] | None = None, n_jobs: int = 1 ) -> FeaturesConcatDataset: diff --git a/eegdash/features/utils.py b/eegdash/features/utils.py index 5c311496..e9dda425 100644 --- a/eegdash/features/utils.py +++ b/eegdash/features/utils.py @@ -18,6 +18,12 @@ from .extractors import FeatureExtractor +__all__ = [ + "extract_features", + "fit_feature_extractors", +] + + def _extract_features_from_windowsdataset( win_ds: EEGWindowsDataset | WindowsDataset, feature_extractor: FeatureExtractor, diff --git a/scripts/1_nemar_dataset_list.py b/scripts/1_nemar_dataset_list.py new file mode 100755 index 00000000..f664e0d7 --- /dev/null +++ b/scripts/1_nemar_dataset_list.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +"""Script to retrieve and process NEMAR datasets.""" + +import logging +import os +import sys +from typing import Dict, List, Optional + +import requests +import urllib3 + +# Add the project root to the Python path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import eegdash.dataset + +# Disable SSL warnings since we're using verify=False +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(message)s") +logger = logging.getLogger(__name__) + + +class NemarAPI: + """Client for interacting with the NEMAR API.""" + + def __init__(self, token: Optional[str] = None): + """Initialize NEMAR API client. + + Args: + token: NEMAR access token. If not provided, will look for NEMAR_TOKEN env variable. + + Raises: + ValueError: If no token is provided or found in environment. + + """ + self.base_url = "https://nemar.org/api/dataexplorer/datapipeline" + self.token = token or os.environ.get("NEMAR_TOKEN") + if not self.token: + raise ValueError( + "NEMAR token must be provided either as argument or NEMAR_TOKEN environment variable" + ) + + def get_datasets(self, start: int = 0, limit: int = 500) -> Optional[Dict]: + """Get list of datasets from NEMAR. + + Args: + start: Starting index for pagination. + limit: Maximum number of datasets to return. + + Returns: + JSON response containing dataset information or None if request fails. + + """ + payload = { + "nemar_access_token": self.token, + "table_name": "dataexplorer_dataset", + "start": start, + "limit": limit, + } + + try: + response = requests.post( + f"{self.base_url}/list", + headers={"Content-Type": "application/json"}, + json=payload, + verify=False, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + logger.error("Error fetching datasets: %s", e) + return None + + @staticmethod + def extract_dataset_info(datasets_response: Dict) -> List[Dict]: + """Extract relevant information from datasets response. + + Args: + datasets_response: Response from get_datasets(). + + Returns: + List of dictionaries containing dataset information. + + """ + if not datasets_response or "entries" not in datasets_response: + return [] + + return [ + { + "id": data["id"], + "name": data["name"], + "modalities": data["modalities"], + "participants": data["participants"], + "file_size": data["file_size"], + "file_size_gb": float(data["file_size"]) / (1024 * 1024 * 1024), + "tasks": data.get("tasks", ""), + "authors": data.get("Authors", ""), + "doi": data.get("DatasetDOI", ""), + } + for _, data in datasets_response["entries"].items() + ] + + +def fetch_all_datasets() -> List[Dict]: + """Fetch all available datasets from NEMAR. + + Returns: + List of dataset information dictionaries. + + """ + try: + nemar = NemarAPI() + except ValueError as e: + logger.error("Error: %s", e) + logger.error( + "Please set your NEMAR token using: export NEMAR_TOKEN='your_token_here'" + ) + return [] + + all_datasets = [] + start = 0 + batch_size = 500 + + logger.info("Fetching datasets...") + while True: + datasets = nemar.get_datasets(start=start, limit=batch_size) + if not datasets or not datasets.get("entries"): + break + + batch_info = nemar.extract_dataset_info(datasets) + if not batch_info: + break + + all_datasets.extend(batch_info) + logger.info("Retrieved %d datasets so far...", len(all_datasets)) + + if len(batch_info) < batch_size: + break + + start += batch_size + + return all_datasets + + +def find_undigested_datasets() -> List[Dict]: + """Find datasets that haven't been digested into eegdash yet. + + Returns: + List of dataset information dictionaries for undigested datasets. + + """ + # Get all available datasets from NEMAR + all_datasets = fetch_all_datasets() + + # Get all classes from eegdash.dataset + eegdash_classes = dir(eegdash.dataset) + + # Filter for undigested datasets + undigested = [] + for dataset in all_datasets: + # Convert dataset ID to expected class name format (e.g., ds001785 -> DS001785) + class_name = dataset["id"].upper() + + # Check if this dataset exists as a class in eegdash.dataset + if class_name not in eegdash_classes: + undigested.append(dataset) + + return undigested + + +def main(): + """Main function to find and output undigested datasets.""" + undigested = find_undigested_datasets() + + # Print just the dataset IDs and names + print("\nUndigested Datasets:") + print("-" * 80) + for dataset in undigested: + print(f"{dataset['id']}: {dataset['name']}") + print(f"\nTotal undigested datasets: {len(undigested)}") + + +if __name__ == "__main__": + main() From 3b0ab27078fc819f274da3c1e9230d8fccef2257 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 13 Oct 2025 01:28:03 +0200 Subject: [PATCH 06/12] fixing bad link --- eegdash/features/serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eegdash/features/serialization.py b/eegdash/features/serialization.py index 969e40c8..ae9af9e6 100644 --- a/eegdash/features/serialization.py +++ b/eegdash/features/serialization.py @@ -2,7 +2,7 @@ See Also -------- -https://github.com/braindecode/braindecode//blob/master/braindecode/datautil/serialization.py#L165-L229 +https://github.com/braindecode/braindecode/blob/master/braindecode/datautil/serialization.py#L165-L229 """ From 424b0af5cd25c2c08a20442159188d4460c09346 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 13 Oct 2025 01:34:03 +0200 Subject: [PATCH 07/12] updating the first iteration --- docs/plot_dataset/colours.py | 28 +++++ docs/plot_dataset/treemap.py | 195 ++++++++++++++++++++++++++++------- 2 files changed, 183 insertions(+), 40 deletions(-) diff --git a/docs/plot_dataset/colours.py b/docs/plot_dataset/colours.py index 7d2d50ae..52b57f2e 100644 --- a/docs/plot_dataset/colours.py +++ b/docs/plot_dataset/colours.py @@ -21,6 +21,34 @@ "Unknown": "#94a3b8", } +MODALITY_EMOJI = { + "Visual": "👁️", + "Auditory": "👂", + "Sleep": "🌙", + "Multisensory": "🧩", + "Tactile": "✋", + "Motor": "🏃", + "Resting State": "🧘", + "Rest": "🧘", + "Other": "🧭", + "Unknown": "❔", +} + +PATHOLOGY_PASTEL_OVERRIDES = { + "Healthy": "#bbf7d0", + "Unknown": "#d0d7df", + "Dementia": "#fcd4d4", + "Schizophrenia": "#f9d0e7", + "Psychosis": "#f9d0e7", + "Epilepsy": "#f9d7c4", + "Parkinson's": "#f8c8c8", + "TBI": "#f9cabd", + "Surgery": "#f7d9b8", + "Other": "#f8cbdc", + "Clinical": "#f8d0d0", +} + + TYPE_COLOR_MAP = { "Perception": "#3b82f6", "Decision-making": "#eab308", diff --git a/docs/plot_dataset/treemap.py b/docs/plot_dataset/treemap.py index 8f5875c3..c56df643 100644 --- a/docs/plot_dataset/treemap.py +++ b/docs/plot_dataset/treemap.py @@ -14,6 +14,8 @@ CANONICAL_MAP, MODALITY_COLOR_MAP, PATHOLOGY_COLOR_MAP, + PATHOLOGY_PASTEL_OVERRIDES, + MODALITY_EMOJI, hex_to_rgba, ) except ImportError: # pragma: no cover - fallback for direct script execution @@ -21,6 +23,8 @@ CANONICAL_MAP, MODALITY_COLOR_MAP, PATHOLOGY_COLOR_MAP, + PATHOLOGY_PASTEL_OVERRIDES, + MODALITY_EMOJI, hex_to_rgba, ) @@ -36,18 +40,6 @@ _SEPARATORS = ("/", "|", ";", ",") _DEFAULT_COLOR = "#94a3b8" -MODALITY_EMOJI = { - "Visual": "👁️", - "Auditory": "👂", - "Sleep": "🌙", - "Multisensory": "🧩", - "Tactile": "✋", - "Motor": "🏃", - "Resting State": "🧘", - "Rest": "🧘", - "Other": "🧭", -} - def _tokenise_cell(value: object, column_key: str) -> list[str]: """Split multi-valued cells, normalise, and keep Unknown buckets.""" @@ -160,6 +152,12 @@ def _abbreviate(value: float | int) -> str: if num == 0: return "0" + if abs(num) < 1000: + rounded = round(num / 10.0) * 10.0 + if rounded == 0 and num > 0: + rounded = 10.0 + return f"{int(rounded):,}" + thresholds = [ (1_000_000_000, "B"), (1_000_000, "M"), @@ -168,11 +166,50 @@ def _abbreviate(value: float | int) -> str: for divisor, suffix in thresholds: if abs(num) >= divisor: scaled = num / divisor + scaled = round(scaled, 1) text = f"{scaled:.1f}".rstrip("0").rstrip(".") return f"{text}{suffix}" return f"{num:.0f}" +def _lighten_hex(hex_color: str, factor: float = 0.55) -> str: + if not isinstance(hex_color, str) or not hex_color.startswith("#"): + return _DEFAULT_COLOR + hex_color = hex_color.lstrip("#") + if len(hex_color) != 6: + return _DEFAULT_COLOR + try: + r = int(hex_color[0:2], 16) + g = int(hex_color[2:4], 16) + b = int(hex_color[4:6], 16) + except ValueError: + return _DEFAULT_COLOR + r = int(r + (255 - r) * factor) + g = int(g + (255 - g) * factor) + b = int(b + (255 - b) * factor) + return f"#{r:02x}{g:02x}{b:02x}" + + +def _pathology_colors(name: str) -> tuple[str, str, str]: + """Return (fill_rgba, legend_hex, group_key).""" + base_hex = PATHOLOGY_PASTEL_OVERRIDES.get(name) + if not base_hex: + fallback = PATHOLOGY_COLOR_MAP.get(name) + if fallback: + base_hex = _lighten_hex(fallback, 0.6) + else: + base_hex = PATHOLOGY_PASTEL_OVERRIDES.get("Clinical", _DEFAULT_COLOR) + + fill = hex_to_rgba(base_hex, alpha=0.65) + if name == "Healthy": + group = "healthy" + elif name == "Unknown": + group = "unknown" + else: + group = "clinical" + return fill, base_hex, group + + def _filter_zero_nodes(df: pd.DataFrame, column: str) -> pd.DataFrame: mask = (df["subjects"] > 0) | (df[column] == "Unknown") return df.loc[mask].copy() @@ -198,14 +235,16 @@ def _format_label( elif fallback_value > 0: secondary_text = f"{_abbreviate(records_value)} rec" else: - secondary_text = "0 h" + secondary_text = "records unavailable" return ( f"{name}
{subjects_text} subj" f" | {secondary_text}" ) -def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]: +def _build_nodes( + dataset_level: pd.DataFrame, +) -> tuple[list[dict[str, object]], list[dict[str, str]]]: dataset_level = dataset_level.sort_values( ["population_type", "experimental_modality", "dataset_name"] ).reset_index(drop=True) @@ -229,7 +268,12 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]: level1 = _filter_zero_nodes(level1, "population_type") nodes: list[dict[str, object]] = [] - level1_meta: list[dict[str, str]] = [] + legend_entries: list[dict[str, str]] = [] + seen_groups: set[str] = set() + modality_meta: dict[str, dict[str, str]] = {} + modality_priority = { + name: idx for idx, name in enumerate(MODALITY_COLOR_MAP.keys()) + } total_subjects = level1["subjects"].sum() total_hours = level1["hours"].sum() @@ -267,11 +311,8 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]: row["hours_from_records"], font_px=16, ) - base_color = PATHOLOGY_COLOR_MAP.get(name) - if not base_color: - base_color = PATHOLOGY_COLOR_MAP.get("Clinical", _DEFAULT_COLOR) - color = hex_to_rgba(base_color, alpha=0.75) - level1_meta.append({"name": name, "color": base_color}) + fill_color, _, group = _pathology_colors(name) + seen_groups.add(group) nodes.append( { "id": node_id, @@ -279,7 +320,7 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]: "name": name, "text": label, "value": float(row["subjects"]), - "color": color, + "color": fill_color, "hover": label, } ) @@ -288,10 +329,10 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]: modality = row["experimental_modality"] or "Unknown" parent = row["population_type"] or "Unknown" node_id = f"{parent} / {modality}" + emoji_symbol = MODALITY_EMOJI.get(modality) modality_label = modality - emoji = MODALITY_EMOJI.get(modality) - if emoji: - modality_label = f"{emoji} {modality}" + if emoji_symbol and row["subjects"] >= 120: + modality_label = f"{emoji_symbol} {modality}" label = _format_label( modality_label, row["subjects"], @@ -301,6 +342,18 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]: font_px=16, ) color = MODALITY_COLOR_MAP.get(modality, _DEFAULT_COLOR) + legend_label = ( + f"{(emoji_symbol + ' ') if emoji_symbol else ''}{modality}".strip() + ) + if modality not in modality_meta: + order = modality_priority.get(modality, len(modality_priority)) + modality_meta[modality] = { + "name": legend_label, + "color": color, + "group": "level2", + "order": 100 + order, + "legendgroup": "modalities", + } nodes.append( { "id": node_id, @@ -327,7 +380,6 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]: row["hours_from_records"], font_px=16, ) - _ = row["population_type"] or "Unknown" if dataset_name == "Unknown": color = _DEFAULT_COLOR else: @@ -344,7 +396,41 @@ def _build_nodes(dataset_level: pd.DataFrame) -> list[dict[str, object]]: } ) - return nodes, level1_meta + group_config = { + "healthy": { + "name": "Healthy", + "color": PATHOLOGY_PASTEL_OVERRIDES.get("Healthy", "#bbf7d0"), + "order": 0, + }, + "clinical": { + "name": "Clinical", + "color": PATHOLOGY_PASTEL_OVERRIDES.get("Clinical", "#f8d0d0"), + "order": 1, + }, + "unknown": { + "name": "To be Categorised", + "color": PATHOLOGY_PASTEL_OVERRIDES.get("Unknown", "#d0d7df"), + "order": 2, + }, + } + + for key, cfg in group_config.items(): + if key in seen_groups: + legend_entries.append( + { + "name": cfg["name"], + "color": cfg["color"], + "group": "level1", + "order": cfg["order"], + "legendgroup": "populations", + } + ) + + legend_entries.extend( + sorted(modality_meta.values(), key=lambda item: item["order"]) + ) + + return nodes, legend_entries def _build_figure( @@ -355,6 +441,17 @@ def _build_figure( if not node_list: raise ValueError("No data available to render the treemap.") + legend_list = list(legend_entries) + seen: set[str] = set() + deduped: list[dict[str, str]] = [] + for entry in legend_list: + if entry["name"] in seen: + continue + seen.add(entry["name"]) + deduped.append(entry) + + deduped.sort(key=lambda item: item.get("order", 999)) + fig = go.Figure( go.Treemap( ids=[node["id"] for node in node_list], @@ -367,41 +464,58 @@ def _build_figure( marker=dict( colors=[node["color"] for node in node_list], line=dict(color="white", width=1), - pad=dict(t=6, r=6, b=6, l=6), + pad=dict(t=10, r=10, b=10, l=10), ), textinfo="text", hovertemplate="%{customdata[0]}", - pathbar=dict(visible=True, edgeshape="/", thickness=34), + pathbar=dict( + visible=True, edgeshape="/", thickness=34, textfont=dict(size=14) + ), textfont=dict(size=24), insidetextfont=dict(size=24), - tiling=dict(pad=6, packing="squarify"), + tiling=dict(pad=10, packing="squarify"), root=dict(color="rgba(255,255,255,0.95)"), ) ) - for entry in legend_entries: + for entry in deduped: fig.add_trace( go.Scatter( - x=[None], - y=[None], + x=[0], + y=[0], mode="markers", - marker=dict(size=14, symbol="square", color=entry["color"]), + marker=dict(size=12, symbol="square", color=entry["color"]), name=entry["name"], showlegend=True, hoverinfo="skip", + xaxis="x2", + yaxis="y2", + legendgroup=entry.get("legendgroup"), ) ) fig.update_layout( legend=dict( orientation="h", - yanchor="bottom", + yanchor="top", y=1.08, - xanchor="left", - x=0.0, + xanchor="center", + x=0.5, font=dict(size=14), - itemwidth=80, - ) + itemsizing="constant", + traceorder="normal", + itemclick=False, + itemdoubleclick=False, + bgcolor="rgba(255,255,255,0)", + bordercolor="rgba(0,0,0,0)", + borderwidth=0, + ), + legend_traceorder="normal", + ) + + fig.update_layout( + xaxis2=dict(visible=False), + yaxis2=dict(visible=False), ) return fig @@ -429,8 +543,9 @@ def generate_dataset_treemap( fig = _build_figure(nodes, legend_entries) fig.update_layout( uniformtext=dict(minsize=18, mode="hide"), - margin=dict(t=140, l=24, r=24, b=16), - hoverlabel=dict(font_size=16), + margin=dict(t=60, l=32, r=220, b=40), + hoverlabel=dict(font=dict(size=16), align="left"), + height=860, ) out_path = Path(out_html) From 2056cbab51cc6505b8d9f9fe4ab873feae3e3a2f Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 13 Oct 2025 01:45:58 +0200 Subject: [PATCH 08/12] finishing to sleep --- docs/plot_dataset/treemap.py | 29 ++++++++++++++----- docs/source/_templates/autosummary/module.rst | 16 ++++++++++ 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/docs/plot_dataset/treemap.py b/docs/plot_dataset/treemap.py index c56df643..1ed8a752 100644 --- a/docs/plot_dataset/treemap.py +++ b/docs/plot_dataset/treemap.py @@ -464,7 +464,7 @@ def _build_figure( marker=dict( colors=[node["color"] for node in node_list], line=dict(color="white", width=1), - pad=dict(t=10, r=10, b=10, l=10), + pad=dict(t=15, r=15, b=15, l=15), ), textinfo="text", hovertemplate="%{customdata[0]}", @@ -473,18 +473,28 @@ def _build_figure( ), textfont=dict(size=24), insidetextfont=dict(size=24), - tiling=dict(pad=10, packing="squarify"), - root=dict(color="rgba(255,255,255,0.95)"), + # Increase pad to create more visual separation between tiles, + # especially the top-level (population_type) nodes. + tiling=dict(pad=8, packing="squarify"), + # Slightly more transparent root to avoid harsh borders + root=dict(color="rgba(255,255,255,0.98)"), ) ) + # Add legend swatches. increase marker size and use a thin white border so + # legend squares visually separate from adjacent tiles when exported. for entry in deduped: fig.add_trace( go.Scatter( x=[0], y=[0], mode="markers", - marker=dict(size=12, symbol="square", color=entry["color"]), + marker=dict( + size=14, + symbol="square", + color=entry["color"], + line=dict(color="white", width=1.5), + ), name=entry["name"], showlegend=True, hoverinfo="skip", @@ -541,11 +551,14 @@ def generate_dataset_treemap( aggregated = _filter_zero_nodes(aggregated, "dataset_name") nodes, legend_entries = _build_nodes(aggregated) fig = _build_figure(nodes, legend_entries) + # Tune text sizes and margins so the increased padding doesn't cause + # labels to overflow. Keeping uniformtext minsize slightly lower ensures + # smaller tiles don't get crowded. fig.update_layout( - uniformtext=dict(minsize=18, mode="hide"), - margin=dict(t=60, l=32, r=220, b=40), - hoverlabel=dict(font=dict(size=16), align="left"), - height=860, + uniformtext=dict(minsize=20, mode="hide"), + margin=dict(t=56, l=28, r=28, b=36), + hoverlabel=dict(font=dict(size=14), align="left"), + height=880, ) out_path = Path(out_html) diff --git a/docs/source/_templates/autosummary/module.rst b/docs/source/_templates/autosummary/module.rst index 3efd0ea8..a1226c0b 100644 --- a/docs/source/_templates/autosummary/module.rst +++ b/docs/source/_templates/autosummary/module.rst @@ -63,3 +63,19 @@ {% endif %} {%- endblock %} +{% if sg_api_usage %} +.. _sg_api_{{ fullname }}: + +API Usage +--------- + +.. raw:: html + +
+ +{{ sg_api_usage }} + +.. raw:: html + +
+{% endif %} From 86625ec59676564bcc60d6427b06644cf3918abb Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 13 Oct 2025 01:49:42 +0200 Subject: [PATCH 09/12] overlap legends --- docs/plot_dataset/treemap.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/plot_dataset/treemap.py b/docs/plot_dataset/treemap.py index 1ed8a752..81475a20 100644 --- a/docs/plot_dataset/treemap.py +++ b/docs/plot_dataset/treemap.py @@ -553,10 +553,11 @@ def generate_dataset_treemap( fig = _build_figure(nodes, legend_entries) # Tune text sizes and margins so the increased padding doesn't cause # labels to overflow. Keeping uniformtext minsize slightly lower ensures - # smaller tiles don't get crowded. + # smaller tiles don't get crowded, and the larger top margin prevents + # the custom legend from overlapping the pathbar. fig.update_layout( uniformtext=dict(minsize=20, mode="hide"), - margin=dict(t=56, l=28, r=28, b=36), + margin=dict(t=96, l=28, r=28, b=36), hoverlabel=dict(font=dict(size=14), align="left"), height=880, ) From ade95bdb053d32fe36749ddc5a54bc1ba71fcf32 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 13 Oct 2025 01:52:44 +0200 Subject: [PATCH 10/12] padding --- docs/plot_dataset/treemap.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/plot_dataset/treemap.py b/docs/plot_dataset/treemap.py index 81475a20..eed7ad49 100644 --- a/docs/plot_dataset/treemap.py +++ b/docs/plot_dataset/treemap.py @@ -508,7 +508,7 @@ def _build_figure( legend=dict( orientation="h", yanchor="top", - y=1.08, + y=1.15, xanchor="center", x=0.5, font=dict(size=14), @@ -553,11 +553,11 @@ def generate_dataset_treemap( fig = _build_figure(nodes, legend_entries) # Tune text sizes and margins so the increased padding doesn't cause # labels to overflow. Keeping uniformtext minsize slightly lower ensures - # smaller tiles don't get crowded, and the larger top margin prevents - # the custom legend from overlapping the pathbar. + # smaller tiles don't get crowded, while the higher top margin combined + # with the raised legend keeps it clear of the pathbar. fig.update_layout( uniformtext=dict(minsize=20, mode="hide"), - margin=dict(t=96, l=28, r=28, b=36), + margin=dict(t=132, l=28, r=28, b=36), hoverlabel=dict(font=dict(size=14), align="left"), height=880, ) From 86a77853501470b248141ba9fc9dc15a84a212d4 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 13 Oct 2025 01:59:48 +0200 Subject: [PATCH 11/12] pre-commit --- docs/plot_dataset/treemap.py | 6 +++--- docs/source/conf.py | 2 +- eegdash/features/datasets.py | 1 - eegdash/features/decorators.py | 1 - eegdash/features/extractors.py | 1 - eegdash/features/feature_bank/utils.py | 1 - eegdash/features/inspect.py | 1 - eegdash/features/serialization.py | 1 - eegdash/features/utils.py | 1 - 9 files changed, 4 insertions(+), 11 deletions(-) diff --git a/docs/plot_dataset/treemap.py b/docs/plot_dataset/treemap.py index eed7ad49..45e9136a 100644 --- a/docs/plot_dataset/treemap.py +++ b/docs/plot_dataset/treemap.py @@ -2,10 +2,10 @@ """Utilities to generate the EEG Dash dataset treemap.""" +import math from pathlib import Path from typing import Iterable -import math import pandas as pd import plotly.graph_objects as go @@ -13,18 +13,18 @@ from .colours import ( CANONICAL_MAP, MODALITY_COLOR_MAP, + MODALITY_EMOJI, PATHOLOGY_COLOR_MAP, PATHOLOGY_PASTEL_OVERRIDES, - MODALITY_EMOJI, hex_to_rgba, ) except ImportError: # pragma: no cover - fallback for direct script execution from colours import ( # type: ignore CANONICAL_MAP, MODALITY_COLOR_MAP, + MODALITY_EMOJI, PATHOLOGY_COLOR_MAP, PATHOLOGY_PASTEL_OVERRIDES, - MODALITY_EMOJI, hex_to_rgba, ) diff --git a/docs/source/conf.py b/docs/source/conf.py index ac071c8c..bd6e5b70 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -2,8 +2,8 @@ import importlib import inspect import os -import sys import shutil +import sys from collections import Counter from datetime import datetime, timezone from pathlib import Path diff --git a/eegdash/features/datasets.py b/eegdash/features/datasets.py index 99099c4d..81515645 100644 --- a/eegdash/features/datasets.py +++ b/eegdash/features/datasets.py @@ -18,7 +18,6 @@ from ..logging import logger - __all__ = [ "FeaturesDataset", "FeaturesConcatDataset", diff --git a/eegdash/features/decorators.py b/eegdash/features/decorators.py index de8b9ad9..0e933557 100644 --- a/eegdash/features/decorators.py +++ b/eegdash/features/decorators.py @@ -10,7 +10,6 @@ _get_underlying_func, ) - __all__ = [ "bivariate_feature", "FeatureKind", diff --git a/eegdash/features/extractors.py b/eegdash/features/extractors.py index fbd3d663..451f1636 100644 --- a/eegdash/features/extractors.py +++ b/eegdash/features/extractors.py @@ -8,7 +8,6 @@ import numpy as np from numba.core.dispatcher import Dispatcher - __all__ = [ "BivariateFeature", "DirectedBivariateFeature", diff --git a/eegdash/features/feature_bank/utils.py b/eegdash/features/feature_bank/utils.py index 7d56a4a8..954d5bed 100644 --- a/eegdash/features/feature_bank/utils.py +++ b/eegdash/features/feature_bank/utils.py @@ -1,6 +1,5 @@ import numpy as np - __all__ = [ "DEFAULT_FREQ_BANDS", "get_valid_freq_band", diff --git a/eegdash/features/inspect.py b/eegdash/features/inspect.py index c07d5c26..1f5e6826 100644 --- a/eegdash/features/inspect.py +++ b/eegdash/features/inspect.py @@ -6,7 +6,6 @@ from . import extractors, feature_bank from .extractors import FeatureExtractor, MultivariateFeature, _get_underlying_func - __all__ = [ "get_all_feature_extractors", "get_all_feature_kinds", diff --git a/eegdash/features/serialization.py b/eegdash/features/serialization.py index ae9af9e6..2cdd962c 100644 --- a/eegdash/features/serialization.py +++ b/eegdash/features/serialization.py @@ -18,7 +18,6 @@ from .datasets import FeaturesConcatDataset, FeaturesDataset - __all__ = [ "load_features_concat_dataset", ] diff --git a/eegdash/features/utils.py b/eegdash/features/utils.py index e9dda425..ec5c24a1 100644 --- a/eegdash/features/utils.py +++ b/eegdash/features/utils.py @@ -17,7 +17,6 @@ from .datasets import FeaturesConcatDataset, FeaturesDataset from .extractors import FeatureExtractor - __all__ = [ "extract_features", "fit_feature_extractors", From 0ac111e7f2973ea2c5eb2c95fbfc5cfa07ed2a5e Mon Sep 17 00:00:00 2001 From: Bru Date: Mon, 13 Oct 2025 02:01:45 +0200 Subject: [PATCH 12/12] Delete scripts/1_nemar_dataset_list.py --- scripts/1_nemar_dataset_list.py | 186 -------------------------------- 1 file changed, 186 deletions(-) delete mode 100755 scripts/1_nemar_dataset_list.py diff --git a/scripts/1_nemar_dataset_list.py b/scripts/1_nemar_dataset_list.py deleted file mode 100755 index f664e0d7..00000000 --- a/scripts/1_nemar_dataset_list.py +++ /dev/null @@ -1,186 +0,0 @@ -#!/usr/bin/env python3 -"""Script to retrieve and process NEMAR datasets.""" - -import logging -import os -import sys -from typing import Dict, List, Optional - -import requests -import urllib3 - -# Add the project root to the Python path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -import eegdash.dataset - -# Disable SSL warnings since we're using verify=False -urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - -# Configure logging -logging.basicConfig(level=logging.INFO, format="%(message)s") -logger = logging.getLogger(__name__) - - -class NemarAPI: - """Client for interacting with the NEMAR API.""" - - def __init__(self, token: Optional[str] = None): - """Initialize NEMAR API client. - - Args: - token: NEMAR access token. If not provided, will look for NEMAR_TOKEN env variable. - - Raises: - ValueError: If no token is provided or found in environment. - - """ - self.base_url = "https://nemar.org/api/dataexplorer/datapipeline" - self.token = token or os.environ.get("NEMAR_TOKEN") - if not self.token: - raise ValueError( - "NEMAR token must be provided either as argument or NEMAR_TOKEN environment variable" - ) - - def get_datasets(self, start: int = 0, limit: int = 500) -> Optional[Dict]: - """Get list of datasets from NEMAR. - - Args: - start: Starting index for pagination. - limit: Maximum number of datasets to return. - - Returns: - JSON response containing dataset information or None if request fails. - - """ - payload = { - "nemar_access_token": self.token, - "table_name": "dataexplorer_dataset", - "start": start, - "limit": limit, - } - - try: - response = requests.post( - f"{self.base_url}/list", - headers={"Content-Type": "application/json"}, - json=payload, - verify=False, - ) - response.raise_for_status() - return response.json() - except requests.exceptions.RequestException as e: - logger.error("Error fetching datasets: %s", e) - return None - - @staticmethod - def extract_dataset_info(datasets_response: Dict) -> List[Dict]: - """Extract relevant information from datasets response. - - Args: - datasets_response: Response from get_datasets(). - - Returns: - List of dictionaries containing dataset information. - - """ - if not datasets_response or "entries" not in datasets_response: - return [] - - return [ - { - "id": data["id"], - "name": data["name"], - "modalities": data["modalities"], - "participants": data["participants"], - "file_size": data["file_size"], - "file_size_gb": float(data["file_size"]) / (1024 * 1024 * 1024), - "tasks": data.get("tasks", ""), - "authors": data.get("Authors", ""), - "doi": data.get("DatasetDOI", ""), - } - for _, data in datasets_response["entries"].items() - ] - - -def fetch_all_datasets() -> List[Dict]: - """Fetch all available datasets from NEMAR. - - Returns: - List of dataset information dictionaries. - - """ - try: - nemar = NemarAPI() - except ValueError as e: - logger.error("Error: %s", e) - logger.error( - "Please set your NEMAR token using: export NEMAR_TOKEN='your_token_here'" - ) - return [] - - all_datasets = [] - start = 0 - batch_size = 500 - - logger.info("Fetching datasets...") - while True: - datasets = nemar.get_datasets(start=start, limit=batch_size) - if not datasets or not datasets.get("entries"): - break - - batch_info = nemar.extract_dataset_info(datasets) - if not batch_info: - break - - all_datasets.extend(batch_info) - logger.info("Retrieved %d datasets so far...", len(all_datasets)) - - if len(batch_info) < batch_size: - break - - start += batch_size - - return all_datasets - - -def find_undigested_datasets() -> List[Dict]: - """Find datasets that haven't been digested into eegdash yet. - - Returns: - List of dataset information dictionaries for undigested datasets. - - """ - # Get all available datasets from NEMAR - all_datasets = fetch_all_datasets() - - # Get all classes from eegdash.dataset - eegdash_classes = dir(eegdash.dataset) - - # Filter for undigested datasets - undigested = [] - for dataset in all_datasets: - # Convert dataset ID to expected class name format (e.g., ds001785 -> DS001785) - class_name = dataset["id"].upper() - - # Check if this dataset exists as a class in eegdash.dataset - if class_name not in eegdash_classes: - undigested.append(dataset) - - return undigested - - -def main(): - """Main function to find and output undigested datasets.""" - undigested = find_undigested_datasets() - - # Print just the dataset IDs and names - print("\nUndigested Datasets:") - print("-" * 80) - for dataset in undigested: - print(f"{dataset['id']}: {dataset['name']}") - print(f"\nTotal undigested datasets: {len(undigested)}") - - -if __name__ == "__main__": - main()