From TUEV Readme:

lab files use 4 letter codes:

- spsw: spike and slow wave  
- gped: generalized periodic epileptiform discharge  
- pled: periodic lateralized epileptiform discharge  
- eyem: eye movement  
- artf: artifact  
- bckg: background

In the format:
`117100000 117200000 eyem`

rec files use numeric codes:

1: spsw  
2: gped  
3: pled  
4: eyem  
5: artf  
6: bckg

in the format: `13,90.4,91.4,6`

In [None]:
from pyhealth.datasets import TUEVDataset
dataset = TUEVDataset(root="/share/Share/valera/tuev/edf/")
dataset.stat()
dataset.info()

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
from collections import defaultdict

# Get all .edf file paths from dataset.patients
edf_paths = []
for records in dataset.patients.values():
    for rec in records:
        edf_paths.append(Path(rec["load_from_path"]) / rec["signal_file"])

# Collect all events as tuples: (rec_path, label, chans_tuple, start, end)
all_events = []

for edf_path in edf_paths:
    rec_path = edf_path.with_suffix(".rec")
    if not rec_path.exists():
        continue
    try:
        events = np.genfromtxt(rec_path, delimiter=",")
        if events.ndim == 1 and events.size == 4:  # to 2D
            events = events.reshape(1, 4)
        # Group by (label, start, end): collect all channels for each event
        event_dict = defaultdict(list)
        for chan, start, end, label in events:
            event_dict[(int(label), float(start), float(end))].append(int(chan))
        # Store as (rec_path, label, chans_tuple, start, end)
        for (label, start, end), chans in event_dict.items():
            all_events.append((rec_path, label, tuple(sorted(set(chans))), start, end))
    except Exception as e:
        print(f"Error reading {rec_path}: {e}")

# Sort events for merging
all_events.sort(key=lambda x: (x[0], x[1], x[2], x[3], x[4]))

# Merge successive intervals
merged_events = []
for event in all_events:
    if not merged_events:
        merged_events.append(event)
        continue
    last = merged_events[-1]
    # If rec_path, label, chans are the same and last.end == event.start, merge
    if (last[0] == event[0] and last[1] == event[1] and last[2] == event[2] and np.isclose(last[4], event[3])):
        # Merge by updating the end time
        merged_events[-1] = (last[0], last[1], last[2], last[3], event[4])
    else:
        merged_events.append(event)

# Now, build event_channel_counts from merged_events
event_channel_counts = defaultdict(list)
for rec_path, label, chans, start, end in merged_events:
    event_channel_counts[(rec_path, label, start, end)] = list(chans)

In [None]:
from collections import Counter

# event_channel_counts: key = (rec_path, label, start, end), value = [chan, chan, ...]
# We want: for each (label, n_channels), count how many events have that label and are in n_channels

pair_counter = Counter()
for key, chans in event_channel_counts.items():
    _, label, _, _ = key
    n_channels = len(set(chans))
    pair_counter[(label, n_channels)] += 1

# Convert to DataFrame
df = pd.DataFrame(
    [(label, n_channels, count) for (label, n_channels), count in pair_counter.items()],
    columns=["label", "n_channels", "count"]
).sort_values(["label", "n_channels"]).reset_index(drop=True)

display(df)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Mapping from label number to text
label_map = {
    1: "spsw",
    2: "gped",
    3: "pled",
    4: "eyem",
    5: "artf",
    6: "bckg"
}

labels = df['label'].unique()
label_counts = df.groupby('label')['count'].sum()

# Outer: for each (label, n_channels)
outer_sizes = df['count']
outer_n_channels = df['n_channels'].values
outer_labels = [f"{row.n_channels}" for row in df.itertuples()]

# Inner: total for each label
inner_sizes = label_counts.values
inner_labels = [f"{label_map.get(l, str(l))} ({label_counts[l]})" for l in labels]

# Colors
cmap = plt.get_cmap("tab20c")
# Assign a color for each unique n_channels value
unique_n_channels = sorted(df['n_channels'].unique())
nchan_color_map = {n: cmap(i * 3 % 20) for i, n in enumerate(unique_n_channels)}
outer_colors = [nchan_color_map[n] for n in outer_n_channels]
inner_colors = cmap(np.arange(len(labels)) * 2)

fig, ax = plt.subplots(figsize=(8, 8))

# Inner pie (labels as text, closer to center)
wedges1, _ = ax.pie(
    inner_sizes, radius=1, labels=None,
    colors=inner_colors, wedgeprops=dict(width=0.3, edgecolor='w')
)

# Place text labels closer to center, with total count in brackets
for i, w in enumerate(wedges1):
    ang = (w.theta2 + w.theta1) / 2
    x = 0.5 * np.cos(np.deg2rad(ang))
    y = 0.5 * np.sin(np.deg2rad(ang))
    ax.text(x, y, inner_labels[i], ha='center', va='center', fontsize=14, weight='bold')

# Outer pie (n_channels per label)
outer_order = []
for l in labels:
    outer_order.extend(df[df['label'] == l].index.tolist())
outer_sizes_ordered = df.loc[outer_order, 'count']
outer_labels_ordered = [outer_labels[i] for i in outer_order]
outer_colors_ordered = [outer_colors[i] for i in outer_order]

# Remove labels for small sectors (e.g., <3% of total)
total_outer = outer_sizes_ordered.sum()
outer_labels_final = [
    lbl if size / total_outer > 0.015 else ""  # 3% threshold
    for lbl, size in zip(outer_labels_ordered, outer_sizes_ordered)
]

wedges2, _ = ax.pie(
    outer_sizes_ordered, radius=1.3, labels=outer_labels_final, labeldistance=1.08,
    colors=outer_colors_ordered, wedgeprops=dict(width=0.3, edgecolor='w')
)

ax.set(aspect="equal")
plt.subplots_adjust(top=0.85)  # Move the title higher
plt.suptitle("Nested Pie Chart: Event Type (inner) and n_channels (outer)", y=0.98, fontsize=16)

# per file stats

In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import mne
from collections import defaultdict, Counter

# --- Gather statistics ---
stats = []
event_type_names = label_map

for edf_path in edf_paths:
    rec_path = edf_path.with_suffix(".rec")
    if not rec_path.exists():
        continue
    try:
        raw = mne.io.read_raw_edf(edf_path, preload=False, verbose=False)
        n_channels = len(raw.ch_names)
        length_sec = raw.n_times / raw.info["sfreq"]
        raw.close()
    except Exception as e:
        print(f"Could not read {edf_path}: {e}")
        continue

    # Count events by type
    try:
        events = np.genfromtxt(rec_path, delimiter=",")
        if events.ndim == 1 and events.size == 4:
            events = events.reshape(1, 4)
        event_counts = Counter()
        for row in events:
            label = int(row[3])
            event_counts[label] += 1
    except Exception as e:
        print(f"Could not read {rec_path}: {e}")
        event_counts = Counter()

    stats.append({
        "edf_path": edf_path,
        "length_sec": length_sec,
        "n_channels": n_channels,
        "n_events_total": sum(event_counts.values()),
        "event_counts": dict(event_counts)
    })

# --- Plot distributions: length, n_channels, total events ---
lengths = [s["length_sec"] for s in stats]
n_channels = [s["n_channels"] for s in stats]
n_events_total = [s["n_events_total"] for s in stats]

fig, axes = plt.subplots(3, 1, figsize=(10, 10), sharex=False)
axes[0].hist(lengths, bins=50, color='C0')
axes[0].set_title("Recording Length (seconds)")
axes[1].hist(n_channels, bins=np.arange(min(n_channels), max(n_channels)+2)-0.5, color='C1')
axes[1].set_title("Number of Channels")
axes[2].hist(n_events_total, bins=50, color='C2')
axes[2].set_title("Total Number of Events")
axes[2].set_xlabel("Value")
plt.tight_layout()
plt.show()

# --- Plot distributions: number of events of each type ---
event_types = sorted(event_type_names.keys())
event_type_counts = {etype: [] for etype in event_types}
for s in stats:
    for etype in event_types:
        event_type_counts[etype].append(s["event_counts"].get(etype, 0))

fig, axes = plt.subplots(6, 1, figsize=(10, 16), sharex=True)
for i, etype in enumerate(event_types):
    axes[i].hist(event_type_counts[etype], bins=50, color=f"C{i}")
    axes[i].set_title(f"Number of Events: {event_type_names[etype]} (label={etype}). Y axis: log scale")
    axes[i].set_yscale("log")
axes[-1].set_xlabel("Number of Events")

# SPSWs

In [None]:

max_channels = max(n_channels)

print("Maximum number of channels:", max_channels)

In [None]:
from tqdm import tqdm
from matplotlib.backends.backend_pdf import PdfPages

color_channels = False  # Set to False for grey axvspans

# Find all SPSW events (label == 1)
spsw_events = [k + (v,) for k, v in event_channel_counts.items() if k[1] == 1]
n_spikes = len(spsw_events)
spikes_per_plot = 8

pdf_filename = f"spsw_spikes_color_channels_{color_channels}.pdf"
with PdfPages(pdf_filename) as pdf:
    for plot_idx in tqdm(range(0, n_spikes, spikes_per_plot)):
        fig, axes = plt.subplots(max_channels, spikes_per_plot, figsize=(spikes_per_plot * 4, max_channels * 1.5), sharey='row', sharex=False)
        if max_channels == 1:
            axes = axes[np.newaxis, :]
        for col, event_key in enumerate(spsw_events[plot_idx:plot_idx + spikes_per_plot]):
            rec_path, label, start, end, chans = event_key
            edf_path = rec_path.with_suffix('.edf')
            try:
                raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
                raw.pick("eeg")
                sfreq = raw.info["sfreq"]
                data, times = raw.get_data(return_times=True)
                channel_names = raw.ch_names
                raw.close()
            except Exception as e:
                print(f"Could not load {edf_path}: {e}")
                continue

            # Center of the event
            center = (start + end) / 2
            tmin = max(center - 1, 0)
            tmax = center + 1
            idx_min = np.searchsorted(times, tmin)
            idx_max = np.searchsorted(times, tmax)

            for ch_idx in range(len(channel_names)):
                ax = axes[ch_idx, col] if max_channels > 1 else axes[col]
                ax.plot(times[idx_min:idx_max], data[ch_idx, idx_min:idx_max], color='k', linewidth=0.7)
                # Mark the SPSW interval
                if color_channels:
                    color = "red" if ch_idx in chans else "green"
                else:
                    color = "grey"
                ax.axvspan(start, end, color=color, alpha=0.3)
                if ch_idx == 0:
                    ax.set_title(f"{edf_path.name}\nSPSW {plot_idx + col + 1}", fontsize=10)
                if col == 0:
                    ax.set_ylabel(channel_names[ch_idx])
                ax.set_xlim(tmin, tmax)
                ax.grid(True)
        plt.tight_layout()
        pdf.savefig(fig)
        plt.close(fig)

print(f"Saved all SPSW plots to {pdf_filename}")

# PLEDs

In [None]:
from tqdm import tqdm
from matplotlib.backends.backend_pdf import PdfPages

color_channels = True

# Find all PLED events (label == 3) and select first 100
pled_events = [k + (v,) for k, v in event_channel_counts.items() if k[1] == 3][:100]
n_pled = len(pled_events)
pled_per_plot = 3

pdf_filename = f"pled_spikes_color_channels_{color_channels}.pdf"
with PdfPages(pdf_filename) as pdf:
    for plot_idx in tqdm(range(0, n_pled, pled_per_plot)):
        # Use max_channels rows, pled_per_plot columns
        fig, axes = plt.subplots(max_channels, pled_per_plot, figsize=(pled_per_plot * 4, max_channels * 1.5), sharey='row', sharex=False)
        if max_channels == 1:
            axes = axes[np.newaxis, :]
        for col, event_key in enumerate(pled_events[plot_idx:plot_idx + pled_per_plot]):
            rec_path, label, start, end, chans = event_key
            edf_path = rec_path.with_suffix('.edf')
            try:
                raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
                raw.pick("eeg")
                sfreq = raw.info["sfreq"]
                data, times = raw.get_data(return_times=True)
                channel_names = raw.ch_names
                raw.close()
            except Exception as e:
                print(f"Could not load {edf_path}: {e}")
                continue

            # Center of the event
            center = (start + end) / 2
            tmin = max(center - 4, 0)
            tmax = center + 4
            idx_min = np.searchsorted(times, tmin)
            idx_max = np.searchsorted(times, tmax)

            for ch_idx in range(len(channel_names)):
                ax = axes[ch_idx, col] if max_channels > 1 else axes[col]
                ax.plot(times[idx_min:idx_max], data[ch_idx, idx_min:idx_max], color='k', linewidth=0.7)
                # Mark the PLED interval
                if color_channels:
                    color = "red" if ch_idx in chans else "green"
                else:
                    color = "grey"
                ax.axvspan(start, end, color=color, alpha=0.3)
                if ch_idx == 0:
                    ax.set_title(f"{edf_path.name}\nPLED {plot_idx + col + 1}", fontsize=10)
                if col == 0:
                    ax.set_ylabel(channel_names[ch_idx])
                ax.set_xlim(tmin, tmax)
                ax.grid(True)
        plt.tight_layout()
        pdf.savefig(fig)
        plt.close(fig)

print(f"Saved all PLED plots to {pdf_filename}")

# run locally

In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import mne

# Set the path to the .edf and .rec file
record_path = Path("../data/aaaaaaar/aaaaaaar_00000001.edf")
rec_path = record_path.with_suffix(".rec")  # assumes aaaaaaar.rec is in the same folder

# Load the signal
raw = mne.io.read_raw_edf(record_path, preload=True)
raw.pick("eeg")  # only keep EEG channels
sfreq = raw.info["sfreq"]
data, times = raw.get_data(return_times=True)

# Save channel names before closing
channel_names = raw.ch_names.copy()
raw.close()

# Load the .rec annotations
events = np.genfromtxt(rec_path, delimiter=",")  # shape (N_events, 4)

# Plot multiple EEG channels with annotations
n_channels_to_plot = 21
end_idx = 10000
time_end = times[end_idx]

fig, axes = plt.subplots(n_channels_to_plot, 1, figsize=(15, 2.5 * n_channels_to_plot), sharex=True)

for i in range(n_channels_to_plot):
    ax = axes[i]
    ax.plot(times[:end_idx], data[i, :end_idx], label=channel_names[i], linewidth=0.5)
    
    # Add annotations for current channel
    for chan, start, end, label in events:
        if end >= time_end:
            continue
        if int(chan) == i:
            ax.axvspan(start, end, color='red', alpha=0.3)
            ax.text((start + end) / 2, np.max(data[i, :end_idx]) * 0.8,
                    str(int(label)), color='black', ha='center', fontsize=8)
    
    ax.set_ylabel("Amp (V)")
    ax.legend(loc="upper right")
    ax.grid(True)

axes[-1].set_xlabel("Time (s)")
# plt.suptitle(f"EEG signals from {record_path.name} — First {n_channels_to_plot} channels")
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

# Write EDF with annotations

In [None]:
import mne
import numpy as np
from mne.annotations import Annotations
from pathlib import Path

# edf_paths = [Path("/Users/vzuev/Documents/git/gh_zuevval/FPCM/data/aaaaaaar/aaaaaaar_00000001.edf")]

# Mapping from label number to text
label_map = {
    1: "spsw",
    2: "gped",
    3: "pled",
    4: "eyem",
    5: "artf",
    6: "bckg"
}

for edf_path in edf_paths:
    rec_path = edf_path.with_suffix(".rec")
    if not rec_path.exists():
        continue
    try:
        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
    except Exception as e:
        print(f"Could not read {edf_path}: {e}")
        continue

    try:
        events = np.genfromtxt(rec_path, delimiter=",")
        if events.ndim == 1 and events.size == 4:
            events = events.reshape(1, 4)
    except Exception as e:
        print(f"Could not read {rec_path}: {e}")
        raw.close()
        continue

    # Prepare annotation lists
    onsets, durations, descriptions = [], [], []
    for chan, start, end, label in events:
        onsets.append(start)
        durations.append(end - start)
        descriptions.append(label_map.get(int(label), str(int(label))))

    # Add annotations to raw
    annots = Annotations(onset=onsets, duration=durations, description=descriptions)
    raw.set_annotations(annots)

    # Write new EDF with annotations
    out_path = edf_path.with_name(edf_path.stem + "_annot.edf")
    try:
        raw.export(out_path, fmt='edf', overwrite=True)
        print(f"Wrote: {out_path}")
    except Exception as e:
        print(f"Could not write {out_path}: {e}")
    raw.close()