In [1]:
# datamodule.py
from pathlib import Path
from typing import List, Sequence, Optional, Dict, Union

import pandas as pd          # 只有当您走「多验证集」分支时才需要
import pytorch_lightning as pl
from torch.utils.data import DataLoader

from melp.datasets.pretrain_dataset import SleepEpochDataset    # ↖ 路径视情况调整

class SleepDataModule(pl.LightningDataModule):

    # workflow for >2.5 pytorchlightning: prepare_dataset() -> setup() -> fit
    def __init__(
        self,
        csv_dir: str | Path,
        *,
        is_pretrain,
        val_dataset_list: Optional[List[str]] = None,
        batch_size: int = 128,
        num_workers: int = 4,
        target_cols: Optional[Union[str, Sequence[str]]] = None,
        train_edf_cols: Sequence[str] | None,  # 传给 Dataset
        transforms=None,
        n_views: int = 1,
        cache_size: int = 8,                   # 透传给 Dataset
        sample_rate: int = 128,
        window_size: int = 30,
        pin_memory: bool = True,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=["transforms"])  

        self.csv_dir   = csv_dir
        self.transforms = transforms
        self.n_views    = n_views
        self.pin_memory = pin_memory
        self.is_pretrain = is_pretrain
        self.target_cols = target_cols
        self._train_set = self._val_sets = self._test_set = None


    def setup(self, stage: str | None = None):

        if stage == "fit" or stage is None:
            if self.is_pretrain == 1:
                self._train_set = SleepEpochDataset(
                    csv_dir       = self.csv_dir,
                    split         = "pretrain",
                    train_edf_cols= self.hparams.train_edf_cols,
                    transform     = self.transforms,
                    sample_rate   = self.hparams.sample_rate,
                    window_size   = self.hparams.window_size,
                    cache_size    = self.hparams.cache_size,
                )
            else:
                self._train_set = SleepEpochDataset(
                    csv_dir       = self.csv_dir,
                    split         = "train",
                    target_cols   = self.target_cols,
                    train_edf_cols= self.hparams.train_edf_cols,
                    transform     = self.transforms,
                    sample_rate   = self.hparams.sample_rate,
                    window_size   = self.hparams.window_size,
                    cache_size    = self.hparams.cache_size,
                )
                


            if self.hparams.val_dataset_list:        # e.g. ["ptbxl_super_class", ...]
                self._val_sets = [
                    SleepEpochDataset(
                        csv_dir       = self.csv_dir,
                        split         = "val",
                        target_cols   = self.target_cols,
                        train_edf_cols= self.hparams.train_edf_cols,
                        transform     = None,         # 通常验证不做增广
                        sample_rate   = self.hparams.sample_rate,
                        window_size   = self.hparams.window_size,
                        cache_size    = self.hparams.cache_size,
                    )
                    for _ in self.hparams.val_dataset_list
                ]
            else:
                self._val_sets = [
                    SleepEpochDataset(
                        csv_dir       = self.csv_dir,
                        split         = "val",
                        target_cols   = self.target_cols,
                        train_edf_cols= self.hparams.train_edf_cols,
                        transform     = None,
                        sample_rate   = self.hparams.sample_rate,
                        window_size   = self.hparams.window_size,
                        cache_size    = self.hparams.cache_size,
                    )
                ]


        if stage == "test" or stage is None:
            self._test_set = SleepEpochDataset(
                csv_dir       = self.csv_dir,
                split         = "test",
                target_cols   = self.target_cols,
                train_edf_cols= self.hparams.train_edf_cols,
                transform     = None,
                sample_rate   = self.hparams.sample_rate,
                window_size   = self.hparams.window_size,
                cache_size    = self.hparams.cache_size,
            )

    # ---------- 3. DataLoader ----------
    def train_dataloader(self):
        return DataLoader(
            self._train_set,
            batch_size     = self.hparams.batch_size,
            shuffle        = True,
            num_workers    = self.hparams.num_workers,
            pin_memory     = self.pin_memory,
            persistent_workers = self.hparams.num_workers > 0,
        )

    def val_dataloader(self):

        return [
            DataLoader(
                ds,
                batch_size     = self.hparams.batch_size,
                shuffle        = False,
                num_workers    = self.hparams.num_workers,
                pin_memory     = self.pin_memory,
                persistent_workers = self.hparams.num_workers > 0,
            )
            for ds in self._val_sets
        ]

    def test_dataloader(self):
        return DataLoader(
            self._test_set,
            batch_size     = self.hparams.batch_size,
            shuffle        = False,
            num_workers    = self.hparams.num_workers,
            pin_memory     = self.pin_memory,
            persistent_workers = self.hparams.num_workers > 0,
        )


In [2]:
dm = SleepDataModule(
        is_pretrain    = 1,
        csv_dir        = "/scratch/besp/shared_data/sleep_data_split_test",
        train_edf_cols = ["ECG", "EEG_C3_A2"],   # 你的通道列表
        batch_size     = 4,
        num_workers    = 2,
    )

dm.setup(stage="fit")       # 构建 train/val dataset

train_loader = dm.train_dataloader()


In [4]:
batch = next(iter(train_loader))
# x, y = batch
# print("x", x.shape, x.dtype)   # 例：torch.Size([4, 3840, 3]) torch.float32
# print("y", y.shape, y.dtype)   # 若 Dataset 返回 label，则打印 label

print(batch['psg'].shape)

torch.Size([4, 3840, 2])
