Skip to content

Commit

Permalink
Merge pull request #110 from wolny/fix_lazy_loader
Browse files Browse the repository at this point in the history
Fix lazy HDF5 loader
  • Loading branch information
wolny committed Feb 18, 2024
2 parents 4355dff + df8cf5c commit cca661e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 56 deletions.
9 changes: 4 additions & 5 deletions .github/workflows/conda-build.yml
Expand Up @@ -20,15 +20,14 @@ jobs:
channel-priority: false
- shell: bash -l {0}
run: conda info --envs
- name: Build pytorch-3dunet using Boa
- name: Build pytorch-3dunet
shell: bash -l {0}
run: |
conda install --yes -c conda-forge mamba
mamba install -q boa
conda mambabuild -c pytorch -c nvidia -c conda-forge conda-recipe
conda install -q conda-build
conda build -c pytorch -c nvidia -c conda-forge conda-recipe
- name: Create pytorch3dunet env
run: |
mamba create -n pytorch3dunet -c pytorch -c nvidia -c conda-forge pytorch-3dunet pytest
conda create -n pytorch3dunet -c pytorch -c nvidia -c conda-forge pytorch-3dunet pytest
- name: Run pytest
shell: bash -l {0}
run: |
Expand Down
63 changes: 16 additions & 47 deletions pytorch3dunet/datasets/hdf5.py
Expand Up @@ -34,9 +34,9 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r
self.phase = phase
self.file_path = file_path

input_file = self.create_h5_file(file_path)
input_file = h5py.File(file_path, 'r')

self.raw = self.load_dataset(input_file, raw_internal_path)
self.raw = self._load_dataset(input_file, raw_internal_path)

stats = calculate_stats(self.raw, global_normalization)

Expand All @@ -46,11 +46,11 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r
if phase != 'test':
# create label/weight transform only in train/val phase
self.label_transform = self.transformer.label_transform()
self.label = self.load_dataset(input_file, label_internal_path)
self.label = self._load_dataset(input_file, label_internal_path)

if weight_internal_path is not None:
# look for the weight map in the raw file
self.weight_map = self.load_dataset(input_file, weight_internal_path)
self.weight_map = self._load_dataset(input_file, weight_internal_path)
self.weight_transform = self.transformer.weight_transform()
else:
self.weight_map = None
Expand All @@ -70,10 +70,12 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r
self.patch_count = len(self.raw_slices)
logger.info(f'Number of patches: {self.patch_count}')

@staticmethod
def load_dataset(input_file, internal_path):
def load_dataset(self, input_file, internal_path):
raise NotImplementedError

def _load_dataset(self, input_file, internal_path):
assert internal_path in input_file, f"Internal path: {internal_path} not found in the H5 file"
ds = input_file[internal_path][:]
ds = self.load_dataset(input_file, internal_path)
assert ds.ndim in [3, 4], \
f"Invalid dataset dimension: {ds.ndim}. Supported dataset formats: (C, Z, Y, X) or (Z, Y, X)"
return ds
Expand Down Expand Up @@ -106,10 +108,6 @@ def __getitem__(self, idx):
def __len__(self):
return self.patch_count

@staticmethod
def create_h5_file(file_path):
raise NotImplementedError

@staticmethod
def _check_volume_sizes(raw, label):
def _volume_shape(volume):
Expand Down Expand Up @@ -182,9 +180,9 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config,
label_internal_path=label_internal_path, weight_internal_path=weight_internal_path,
global_normalization=global_normalization)

@staticmethod
def create_h5_file(file_path):
return h5py.File(file_path, 'r')
def load_dataset(self, input_file, internal_path):
# load the dataset from the H5 file into memory
return input_file[internal_path][:]


class LazyHDF5Dataset(AbstractHDF5Dataset):
Expand All @@ -198,37 +196,8 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config,
label_internal_path=label_internal_path, weight_internal_path=weight_internal_path,
global_normalization=global_normalization)

logger.info("Using modified HDF5Dataset!")

@staticmethod
def create_h5_file(file_path):
return LazyHDF5File(file_path)


class LazyHDF5File:
"""Implementation of the LazyHDF5File class for the LazyHDF5Dataset."""

def __init__(self, path, internal_path=None):
self.path = path
self.internal_path = internal_path
if self.internal_path:
with h5py.File(self.path, "r") as f:
self.ndim = f[self.internal_path].ndim
self.shape = f[self.internal_path].shape

def ravel(self):
with h5py.File(self.path, "r") as f:
data = f[self.internal_path][:].ravel()
return data

def __getitem__(self, arg):
if isinstance(arg, str) and not self.internal_path:
return LazyHDF5File(self.path, arg)

if arg == Ellipsis:
return LazyHDF5File(self.path, self.internal_path)

with h5py.File(self.path, "r") as f:
data = f[self.internal_path][arg]
logger.info("Using LazyHDF5Dataset")

return data
def load_dataset(self, input_file, internal_path):
# load the dataset from the H5 file lazily
return input_file[internal_path]
10 changes: 6 additions & 4 deletions pytorch3dunet/datasets/utils.py
Expand Up @@ -134,7 +134,7 @@ class FilterSliceBuilder(SliceBuilder):
Filter patches containing more than `1 - threshold` of ignore_index label
"""

def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, ignore_index=(0,),
def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, ignore_index=None,
threshold=0.6, slack_acceptance=0.01, **kwargs):
super().__init__(raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, **kwargs)
if label_dataset is None:
Expand All @@ -144,15 +144,17 @@ def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stri

def ignore_predicate(raw_label_idx):
label_idx = raw_label_idx[1]
patch = np.copy(label_dataset[label_idx])
for ii in ignore_index:
patch[patch == ii] = 0
patch = label_dataset[label_idx]
if ignore_index is not None:
patch = np.copy(patch)
patch[patch == ignore_index] = 0
non_ignore_counts = np.count_nonzero(patch != 0)
non_ignore_counts = non_ignore_counts / patch.size
return non_ignore_counts > threshold or rand_state.rand() < slack_acceptance

zipped_slices = zip(self.raw_slices, self.label_slices)
# ignore slices containing too much ignore_index
logger.info(f'Filtering slices...')
filtered_slices = list(filter(ignore_predicate, zipped_slices))
# unzip and save slices
raw_slices, label_slices = zip(*filtered_slices)
Expand Down

0 comments on commit cca661e

Please sign in to comment.