In [1]:
import h5py
filename = '/pscratch/sd/y/ypincha/mhd/3d-h5-sphere/torus.mhd_w_bcc.00181.athdf-spherical-64-100.0.h5'

with h5py.File(filename, "r") as f:
    print(f.keys())
    print(f[list(f.keys())[0]])
    for key in f.keys():
        if f[key].shape != (64, 64, 64):
            print(f'dimension {key}', f[key].shape)


<KeysViewHDF5 ['bcc1', 'bcc2', 'bcc3', 'beta', 'dens', 'edot_kinematic', 'edot_mag', 'edot_potential', 'edot_thermal', 'eint', 'mdot', 'mdotin', 'mdotout', 'phi', 'r', 'temp', 'theta', 'velx', 'vely', 'velz']>
<HDF5 dataset "bcc1": shape (64, 64, 64), type "<f4">
dimension phi (64,)
dimension r (64,)
dimension theta (64,)


In [2]:
f = h5py.File(filename, "r")
f['bcc1'][:].shape

(64, 64, 64)

In [None]:
from torch.utils.data import Dataset
import glob
import torch
import h5py

class SphericalDataset(Dataset):
    FIELDS = ['bcc1', 'bcc2', 'bcc3', 'dens', 'eint', 'velx', 'vely', 'velz']

    def __init__(self, pscratch_path, mode='train', train_ratio=0.8):

        self.pscratch_path = pscratch_path
        self.mode = mode
        self.train_ratio = train_ratio

        pattern = f"{self.pscratch_path}/mhd/3d-h5-sphere/*.h5"
        files = sorted(glob.glob(pattern))
        if not files:
            raise FileNotFoundError(f"no .h5 files found under {pattern}")

        split_idx = int(len(files) * train_ratio)
        if mode == 'train':
            self.files = files[:split_idx]
        else:
            self.files = files[split_idx:]

    def __len__(self):
        return len(self.files)

    def stack_tensor(self, fname):
        with h5py.File(fname, 'r') as f:
            arrays = [
                torch.tensor(f[field][:], dtype=torch.float32)
                for field in self.FIELDS
            ]
        return torch.stack(arrays).reshape(8, 64, 64, 64)

    def __getitem__(self, idx):
        return self.stack_tensor(self.files[idx])


In [11]:
dataset = SphericalDataset('/pscratch/sd/y/ypincha/', 'train', 0.2)
dataset.sorted_files()

['/pscratch/sd/y/ypincha//mhd/3d-h5-sphere/torus.mhd_w_bcc.00181.athdf-spherical-64-100.0.h5',
 '/pscratch/sd/y/ypincha//mhd/3d-h5-sphere/torus.mhd_w_bcc.00182.athdf-spherical-64-100.0.h5',
 '/pscratch/sd/y/ypincha//mhd/3d-h5-sphere/torus.mhd_w_bcc.00183.athdf-spherical-64-100.0.h5',
 '/pscratch/sd/y/ypincha//mhd/3d-h5-sphere/torus.mhd_w_bcc.00184.athdf-spherical-64-100.0.h5',
 '/pscratch/sd/y/ypincha//mhd/3d-h5-sphere/torus.mhd_w_bcc.00185.athdf-spherical-64-100.0.h5',
 '/pscratch/sd/y/ypincha//mhd/3d-h5-sphere/torus.mhd_w_bcc.00186.athdf-spherical-64-100.0.h5',
 '/pscratch/sd/y/ypincha//mhd/3d-h5-sphere/torus.mhd_w_bcc.00187.athdf-spherical-64-100.0.h5',
 '/pscratch/sd/y/ypincha//mhd/3d-h5-sphere/torus.mhd_w_bcc.00188.athdf-spherical-64-100.0.h5',
 '/pscratch/sd/y/ypincha//mhd/3d-h5-sphere/torus.mhd_w_bcc.00189.athdf-spherical-64-100.0.h5',
 '/pscratch/sd/y/ypincha//mhd/3d-h5-sphere/torus.mhd_w_bcc.00190.athdf-spherical-64-100.0.h5',
 '/pscratch/sd/y/ypincha//mhd/3d-h5-sphere/torus.m

In [12]:
dataset.sorted_files()[0]

'/pscratch/sd/y/ypincha//mhd/3d-h5-sphere/torus.mhd_w_bcc.00181.athdf-spherical-64-100.0.h5'

In [13]:
stacked = dataset.stack_tensor('/pscratch/sd/y/ypincha//mhd/3d-h5-sphere/torus.mhd_w_bcc.00181.athdf-spherical-64-100.0.h5')

In [14]:
stacked.shape

torch.Size([8, 64, 64, 64])

In [15]:
dataset.__getitem__(4).shape

torch.Size([8, 64, 64, 64])

In [16]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

In [21]:
for batch in dataloader:
    print(batch[0].shape)

torch.Size([8, 64, 64, 64])
torch.Size([8, 64, 64, 64])
torch.Size([8, 64, 64, 64])
torch.Size([8, 64, 64, 64])


KeyboardInterrupt: 