In [73]:
# dataset.py
import numpy as np
import pandas as pd
from pathlib import Path
from functools import lru_cache
from typing import Sequence, Optional, Tuple, Union, List, Dict

import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split


class SleepEpochDataset(Dataset):
    """
    A PyTorch Dataset that delivers 30-second sleep epochs together
    with optional patient-level labels.

    Parameters
    ----------
    epoch_df : pd.DataFrame
        Epoch-level table. Must contain at least:
            - "nsrrid"       : patient ID
            - "epoch_id"     : 1-based epoch index
            - "path_head"    : prefix to the signal file on disk
    patient_df : pd.DataFrame | str | Path
        Patient-level table or the CSV file path; must contain
        "nsrrid" plus any downstream label columns.
    split : {"train", "val", "test"}
        Which split to use. Partitioning is done by patient ID
        so all epochs from the same patient stay in the same split.
    target_cols : str | Sequence[str] | None
        The patient-level label column(s) to return.  None = no labels
        (e.g. for self-supervised pre-training).
    test_size, val_size : float
        Fractions for patient-level train/val/test split.
    random_state : int
        Seed for deterministic splitting.
    sample_rate : int
        Sampling rate (Hz) of the pre-processed signal.
    cache_size : int
        LRU cache size for whole-night signals to avoid reloading.
    transform : callable | None
        Optional transform / augmentation applied to the epoch tensor.
    split_ids : dict | None
        Pre-defined {"train": [...], "val": [...], "test": [...]} lists.
        If supplied, overrides the random split.
    """

    def __init__(
        self,
        epoch_df: pd.DataFrame,
        patient_df: Union[pd.DataFrame, str, Path],
        split: str = "train",
        *,
        target_cols: Optional[Union[str, Sequence[str]]] = None,
        train_edf_cols = None,
        test_size: float = 0.15,
        val_size: float = 0.15,
        random_state: int = 1337,
        sample_rate: int = 128,
        cache_size: int = 8,
        transform=None,
        split_ids: Optional[Dict[str, Sequence[str]]] = None,
    ):
        assert split in {"train", "val", "test"}
        self.transform = transform
        self.sample_rate = sample_rate
        self.target_cols = (
            [target_cols] if isinstance(target_cols, str) else target_cols
        )

        # ───── patient-level DataFrame ─────
        if isinstance(patient_df, (str, Path)):
            patient_df = pd.read_csv(patient_df)
        self.patient_df = patient_df.set_index("nsrrid")

        # ───── create / use patient splits ─────
        if split_ids is None:
            ids = self.patient_df.index.unique().tolist()

            train_ids, temp_ids = train_test_split(
                ids,
                test_size=(val_size + test_size),
                random_state=random_state,
            )
            rel_val = val_size / (val_size + test_size)
            val_ids, test_ids = train_test_split(
                temp_ids,
                test_size=1.0 - rel_val,
                random_state=random_state,
            )
            split_ids = {
                "train": train_ids,
                "val": val_ids,
                "test": test_ids,
            }

        self.split_ids = split_ids
        self.epoch_df = (
            epoch_df[epoch_df["nsrrid"].isin(split_ids[split])]
            .reset_index(drop=True)
        )

        # ───── LRU cache for whole-night signals ─────
        self._load_patient_array = lru_cache(maxsize=cache_size)(
            self._load_patient_array
        )
        self.train_edf_cols = train_edf_cols

    # ───────── Dataset API ─────────

    def __len__(self) -> int:
        return len(self.epoch_df)

    def __getitem__(self, idx: int):
        row = self.epoch_df.iloc[idx]
        nsrrid = row["nsrrid"]
        epoch_id = int(row["epoch_id"])

        # 1) load full-night signal and slice to this epoch
        t10 = time.perf_counter() 
        
        dfs = []
        for col_name in self.train_edf_cols:
            temp_sig = self._load_patient_array(nsrrid, row["path_head"], col_name = col_name)
            dfs.append(temp_sig)
        full_sig = pd.concat(dfs, axis = 1)
        t11 = time.perf_counter() 
        print(f"time of get channel data {idx} row: {t11 - t10:.3f} s")
        
        t20 = time.perf_counter() 
        start_sec = (epoch_id - 1) * 30
        end_sec = epoch_id * 30
        epoch_sig = full_sig.loc[start_sec: end_sec]
        epoch_sig = epoch_sig.iloc[:-1] # remove the last point
        
        x = torch.tensor(epoch_sig.values, dtype=torch.float32)
        
        # 1.5) can add other transformation here
        if self.transform:
            x = self.transform(x)
        t21 = time.perf_counter() 
        print(f"time of truncate and transform data {idx} row: {t21 - t20:.3f} s")
        
        # 2) add patient-level label(s) if requested
        if self.target_cols:
            y = torch.tensor(
                self.patient_df.loc[nsrrid, self.target_cols].values.astype(float),
                dtype=torch.float32,
            )
            return x, y
        else:
            return x


    def _build_signal_path(self, path_head: str, col_name = 'ECG') -> Path:
        """
        Convert path_head to the actual signal file path.
        Default: '<path_head>_data.npz' with key 'signal'.
        Adjust if your filenames / keys differ.
        """
        return Path(path_head + f"_{col_name}.npz")

    def _load_patient_array(self, nsrrid: str, path_head: str, col_name = 'ECG') -> np.ndarray:
        """
        Load the full-night signal into a NumPy array.
        If you have multiple channels, return shape (C, T).
        """
        fp = self._build_signal_path(path_head, col_name)
        if not fp.is_file():
            raise FileNotFoundError(f"Signal file missing: {fp}")
        with np.load(fp, allow_pickle=False) as npz:
            data = npz['values']
            index = npz['index']

            df_stg = pd.DataFrame(
                data,
                columns=[col_name]
            )
            df_stg.insert(0, "sec", index)
            
            sig = df_stg.set_index("sec")           
            
        return sig.astype(np.float32)


In [74]:
import pandas as pd
from torch.utils.data import DataLoader

epoch_df = pd.read_csv('/scratch/besp/shared_data/df_epoch_level_all.csv')
patient_df = pd.read_csv('/u/ztshuai/ondemand/postprocess/example_patient_level_master.csv')
TRAIN_EDF_COLS = [ECG]
train_set = SleepEpochDataset(
    epoch_df, patient_df, split="train",
    train_edf_cols = TRAIN_EDF_COLS,
    target_cols=["nsrr_bmi"],     # or None
    sample_rate=128,
)

In [75]:
train_loader = DataLoader(
    train_set,
    batch_size=128,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)


import time 
it = iter(train_loader) 
t0 = time.perf_counter() 
batch = next(it)
t1 = time.perf_counter() 
print(f"1st-batch load: {t1 - t0:.3f} s")

time of get channel data 3055559 row: 0.799 s
time of truncate and ransform data 3055559 row: 0.005 s
time of get channel data 1468497 row: 0.879 s
time of get channel data 727750 row: 0.926 s
time of truncate and ransform data 1468497 row: 0.009 stime of truncate and ransform data 727750 row: 0.005 s

time of get channel data 522362 row: 0.999 s
time of truncate and ransform data 522362 row: 0.005 s
time of get channel data 1451926 row: 0.139 s
time of truncate and ransform data 1451926 row: 0.003 s
time of get channel data 908522 row: 0.333 s
time of truncate and ransform data 908522 row: 0.005 s
time of get channel data 1441013 row: 0.490 s
time of truncate and ransform data 1441013 row: 0.004 s
time of get channel data 3963005 row: 0.512 s
time of truncate and ransform data 3963005 row: 0.003 s
time of get channel data 2949548 row: 0.404 s
time of get channel data 1146600 row: 0.306 s
time of truncate and ransform data 1146600 row: 0.003 s
time of truncate and ransform data 2949548

In [68]:
x, y = batch
print(x.shape)
print(y.shape)

torch.Size([4, 3840, 1])
torch.Size([4, 1])


In [71]:
print(epoch_df.shape)

(7568873, 18)


In [76]:
import time, numpy as np, torch

x0 = x[0].squeeze(-1).cpu().numpy()          # (3840,)  or (3840,1) → (3840,)
np.savez("test.npz", signal=x0)
print("saved:", "test.npz", x0.shape)


t0 = time.perf_counter()
with np.load("test.npz") as npz:
    sig = npz["signal"]
t1 = time.perf_counter()

print(f"load time: {(t1 - t0)*1e3:.2f} ms  |  shape: {sig.shape}")


saved: test.npz (3840,)
load time: 1.66 ms  |  shape: (3840,)
