From TUEV Readme:

lab files use 4 letter codes:

- spsw: spike and slow wave  
- gped: generalized periodic epileptiform discharge  
- pled: periodic lateralized epileptiform discharge  
- eyem: eye movement  
- artf: artifact  
- bckg: background

In the format:
`117100000 117200000 eyem`

rec files use numeric codes:

1: spsw  
2: gped  
3: pled  
4: eyem  
5: artf  
6: bckg

in the format: `13,90.4,91.4,6`

In [3]:
from pyhealth.datasets import TUEVDataset
dataset = TUEVDataset(root="/share/Share/valera/tuev/edf/")
dataset.stat()
dataset.info()

In [6]:
from pathlib import Path
import numpy as np
import pandas as pd
from collections import defaultdict

# Get all .edf file paths from dataset.patients
edf_paths = []
for records in dataset.patients.values():
    for rec in records:
        edf_paths.append(Path(rec["load_from_path"]) / rec["signal_file"])

# Collect all events as tuples: (rec_path, label, chans_tuple, start, end)
all_events = []

for edf_path in edf_paths:
    rec_path = edf_path.with_suffix(".rec")
    if not rec_path.exists():
        continue
    try:
        events = np.genfromtxt(rec_path, delimiter=",")
        if events.ndim == 1 and events.size == 4:  # to 2D
            events = events.reshape(1, 4)
        # Group by (label, start, end): collect all channels for each event
        event_dict = defaultdict(list)
        for chan, start, end, label in events:
            event_dict[(int(label), float(start), float(end))].append(int(chan))
        # Store as (rec_path, label, chans_tuple, start, end)
        for (label, start, end), chans in event_dict.items():
            all_events.append((rec_path, label, tuple(sorted(set(chans))), start, end))
    except Exception as e:
        print(f"Error reading {rec_path}: {e}")

# Sort events for merging
all_events.sort(key=lambda x: (x[0], x[1], x[2], x[3], x[4]))

# Merge successive intervals
merged_events = []
for event in all_events:
    if not merged_events:
        merged_events.append(event)
        continue
    last = merged_events[-1]
    # If rec_path, label, chans are the same and last.end == event.start, merge
    if (last[0] == event[0] and last[1] == event[1] and last[2] == event[2] and np.isclose(last[4], event[3])):
        # Merge by updating the end time
        merged_events[-1] = (last[0], last[1], last[2], last[3], event[4])
    else:
        merged_events.append(event)

# Now, build event_channel_counts from merged_events
event_channel_counts = defaultdict(list)
for rec_path, label, chans, start, end in merged_events:
    event_channel_counts[(rec_path, label, start, end)] = list(chans)

In [7]:
from collections import Counter

# event_channel_counts: key = (rec_path, label, start, end), value = [chan, chan, ...]
# We want: for each (label, n_channels), count how many events have that label and are in n_channels

pair_counter = Counter()
for key, chans in event_channel_counts.items():
    _, label, _, _ = key
    n_channels = len(set(chans))
    pair_counter[(label, n_channels)] += 1

# Convert to DataFrame
df = pd.DataFrame(
    [(label, n_channels, count) for (label, n_channels), count in pair_counter.items()],
    columns=["label", "n_channels", "count"]
).sort_values(["label", "n_channels"]).reset_index(drop=True)

display(df)

In [8]:
import matplotlib.pyplot as plt
import numpy as np

# Mapping from label number to text
label_map = {
    1: "spsw",
    2: "gped",
    3: "pled",
    4: "eyem",
    5: "artf",
    6: "bckg"
}

labels = df['label'].unique()
label_counts = df.groupby('label')['count'].sum()

# Outer: for each (label, n_channels)
outer_sizes = df['count']
outer_n_channels = df['n_channels'].values
outer_labels = [f"{row.n_channels}" for row in df.itertuples()]

# Inner: total for each label
inner_sizes = label_counts.values
inner_labels = [f"{label_map.get(l, str(l))} ({label_counts[l]})" for l in labels]

# Colors
cmap = plt.get_cmap("tab20c")
# Assign a color for each unique n_channels value
unique_n_channels = sorted(df['n_channels'].unique())
nchan_color_map = {n: cmap(i * 3 % 20) for i, n in enumerate(unique_n_channels)}
outer_colors = [nchan_color_map[n] for n in outer_n_channels]
inner_colors = cmap(np.arange(len(labels)) * 2)

fig, ax = plt.subplots(figsize=(8, 8))

# Inner pie (labels as text, closer to center)
wedges1, _ = ax.pie(
    inner_sizes, radius=1, labels=None,
    colors=inner_colors, wedgeprops=dict(width=0.3, edgecolor='w')
)

# Place text labels closer to center, with total count in brackets
for i, w in enumerate(wedges1):
    ang = (w.theta2 + w.theta1) / 2
    x = 0.5 * np.cos(np.deg2rad(ang))
    y = 0.5 * np.sin(np.deg2rad(ang))
    ax.text(x, y, inner_labels[i], ha='center', va='center', fontsize=14, weight='bold')

# Outer pie (n_channels per label)
outer_order = []
for l in labels:
    outer_order.extend(df[df['label'] == l].index.tolist())
outer_sizes_ordered = df.loc[outer_order, 'count']
outer_labels_ordered = [outer_labels[i] for i in outer_order]
outer_colors_ordered = [outer_colors[i] for i in outer_order]

# Remove labels for small sectors (e.g., <3% of total)
total_outer = outer_sizes_ordered.sum()
outer_labels_final = [
    lbl if size / total_outer > 0.015 else ""  # 3% threshold
    for lbl, size in zip(outer_labels_ordered, outer_sizes_ordered)
]

wedges2, _ = ax.pie(
    outer_sizes_ordered, radius=1.3, labels=outer_labels_final, labeldistance=1.08,
    colors=outer_colors_ordered, wedgeprops=dict(width=0.3, edgecolor='w')
)

ax.set(aspect="equal")
plt.subplots_adjust(top=0.85)  # Move the title higher
plt.suptitle("Nested Pie Chart: Event Type (inner) and n_channels (outer)", y=0.98, fontsize=16)

In [12]:
import mne
import matplotlib.pyplot as plt
import numpy as np

max_channels = 0
for edf_path in edf_paths:
    try:
        raw = mne.io.read_raw_edf(edf_path, preload=False, verbose=False)
        raw.pick("eeg")
        n_ch = len(raw.ch_names)
        if n_ch > max_channels:
            max_channels = n_ch
        raw.close()
    except Exception as e:
        print(f"Could not read {edf_path}: {e}")

print("Maximum number of channels:", max_channels)

In [18]:
from tqdm import tqdm

# Find all SPSW events (label == 1)
spsw_events = [k + (v,) for k, v in event_channel_counts.items() if k[1] == 1]
n_spikes = len(spsw_events)
spikes_per_plot = 8

for plot_idx in tqdm(range(0, n_spikes, spikes_per_plot)):
    fig, axes = plt.subplots(max_channels, spikes_per_plot, figsize=(spikes_per_plot * 4, len(raw.ch_names) * 1.5), sharey='row', sharex=False)
    if len(raw.ch_names) == 1:
        axes = axes[np.newaxis, :]
    for col, event_key in enumerate(spsw_events[plot_idx:plot_idx + spikes_per_plot]):
        rec_path, label, start, end, chans = event_key
        edf_path = rec_path.with_suffix('.edf')
        try:
            raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
            raw.pick("eeg")
            sfreq = raw.info["sfreq"]
            data, times = raw.get_data(return_times=True)
            channel_names = raw.ch_names
            raw.close()
        except Exception as e:
            print(f"Could not load {edf_path}: {e}")
            continue

        # Center of the event
        center = (start + end) / 2
        tmin = max(center - 1, 0)
        tmax = center + 1
        idx_min = np.searchsorted(times, tmin)
        idx_max = np.searchsorted(times, tmax)

        for ch_idx in range(len(channel_names)):
            ax = axes[ch_idx, col] if len(raw.ch_names) > 1 else axes[col]
            ax.plot(times[idx_min:idx_max], data[ch_idx, idx_min:idx_max], color='k', linewidth=0.7)
            # Mark the SPSW interval
            ax.axvspan(start, end, color="red" if ch_idx in chans else "green", alpha=0.3)
            if ch_idx == 0:
                ax.set_title(f"{edf_path.name}\nSPSW {plot_idx + col + 1}", fontsize=10)
            if col == 0:
                ax.set_ylabel(channel_names[ch_idx])
            ax.set_xlim(tmin, tmax)
            ax.grid(True)
    plt.tight_layout()
    plt.show()

# run locally

In [31]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import mne

# Set the path to the .edf and .rec file
record_path = Path("../data/aaaaaaar/aaaaaaar_00000001.edf")
rec_path = record_path.with_suffix(".rec")  # assumes aaaaaaar.rec is in the same folder

# Load the signal
raw = mne.io.read_raw_edf(record_path, preload=True)
raw.pick("eeg")  # only keep EEG channels
sfreq = raw.info["sfreq"]
data, times = raw.get_data(return_times=True)

# Save channel names before closing
channel_names = raw.ch_names.copy()
raw.close()

# Load the .rec annotations
events = np.genfromtxt(rec_path, delimiter=",")  # shape (N_events, 4)

# Plot multiple EEG channels with annotations
n_channels_to_plot = 21
end_idx = 10000
time_end = times[end_idx]

fig, axes = plt.subplots(n_channels_to_plot, 1, figsize=(15, 2.5 * n_channels_to_plot), sharex=True)

for i in range(n_channels_to_plot):
    ax = axes[i]
    ax.plot(times[:end_idx], data[i, :end_idx], label=channel_names[i], linewidth=0.5)
    
    # Add annotations for current channel
    for chan, start, end, label in events:
        if end >= time_end:
            continue
        if int(chan) == i:
            ax.axvspan(start, end, color='red', alpha=0.3)
            ax.text((start + end) / 2, np.max(data[i, :end_idx]) * 0.8,
                    str(int(label)), color='black', ha='center', fontsize=8)
    
    ax.set_ylabel("Amp (V)")
    ax.legend(loc="upper right")
    ax.grid(True)

axes[-1].set_xlabel("Time (s)")
# plt.suptitle(f"EEG signals from {record_path.name} — First {n_channels_to_plot} channels")
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()