In [None]:
import os
import h5py
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import lightning as L
import pickle


class HDF5Dataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample


class HDF5DataModule(L.LightningDataModule):
    def __init__(self, data_dir: str, batch_size: int = 32, transform=None, processed_data_path="processed_data.pkl"):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transform
        self.processed_data_path = processed_data_path

    def prepare_data(self):
        if not os.path.exists(self.processed_data_path):
            print("Processing HDF5 files...")
            all_data = []
            hdf5_files = [os.path.join(self.data_dir, f) for f in os.listdir(self.data_dir) if f.endswith('.hdf5')]

            for file in hdf5_files:
                with h5py.File(file, 'r') as f:
                    for key in f.keys():
                        if isinstance(f[key], h5py.Dataset):
                            all_data.append(f[key][:])  # 将数据加载到内存

            # 将处理后的数据保存到磁盘
            with open(self.processed_data_path, 'wb') as f:
                pickle.dump(all_data, f)
            print(f"Data processed and saved to {self.processed_data_path}")

    def setup(self, stage: str):
        with open(self.processed_data_path, 'rb') as f:
            all_data = pickle.load(f)

        if stage == "fit":
            dataset = HDF5Dataset(all_data, transform=self.transform)
            self.train_data, self.val_data = random_split(
                dataset, [int(len(dataset) * 0.8), len(dataset) - int(len(dataset) * 0.8)],
                generator=torch.Generator().manual_seed(42)
            )
        elif stage == "test":
            self.test_data = HDF5Dataset(all_data, transform=self.transform)
        elif stage == "predict":
            self.predict_data = HDF5Dataset(all_data, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.predict_data, batch_size=self.batch_size)


In [None]:
dm = HDF5DataModule(data_dir="/data/home/ty/embody", batch_size=1)

dm.prepare_data()

dm.setup(stage="fit")

for batch_idx, batch in enumerate(dm.train_dataloader()):
    print(f"Train Batch {batch_idx + 1}: {batch}")

dm.setup(stage="val")
for batch_idx, batch in enumerate(dm.val_dataloader()):
    print(f"Validation Batch {batch_idx + 1}: {batch}")

dm.teardown(stage="fit")

dm.setup(stage="test")
for batch_idx, batch in enumerate(dm.test_dataloader()):
    print(f"Test Batch {batch_idx + 1}: {batch}")

dm.teardown(stage="test")

Train Batch 1: tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 1.0996e-03, -8.4078e-04, -1.2174e-03,  ..., -7.7706e-02,
           0.0000e+00,  0.0000e+00],
         [ 1.0729e-03, -6.8784e-05, -3.4416e-04,  ..., -4.7019e-02,
           0.0000e+00,  0.0000e+00],
         [ 1.9264e-03,  2.6798e-04,  1.3959e-04,  ..., -8.7400e-02,
           0.0000e+00,  0.0000e+00]]], dtype=torch.float64)
Train Batch 2: tensor([[[[[0.],
           [0.],
           [0.],
           ...,
           [0.],
           [0.],
           [0.]],

          [[0.],
           [0.],
           [0.],
           ...,
           [0.],
           [0.],
           [0.]],

          [[0.],
           [0.],
           [0.],
           .