In [1]:
import h5py
import os
import torch
import pandas as pd
import numpy as np
from einops import rearrange
from torch.utils.data import DataLoader, TensorDataset, Dataset
from tqdm import tqdm




class AWD_Dataset(torch.utils.data.Dataset):
    def __init__(self, hdf5_file, label_type, transform=None):
        self.csv_file = hdf5_file
        self.label_type = label_type
        self.transform_mask = None
        self.transform = transform
        df = pd.read_csv(self.csv_file)
        self.image_ids = df['image_id'].tolist()
        self.hdf5_file_paths = df.set_index('image_id')['hdf5_file_path'].to_dict()
        self.num_images = len(self.image_ids)

    def __getitem__(self, index):
        image_id = self.image_ids[index]
        hdf5_file_path = self.hdf5_file_paths[image_id]
        image, mask = self.load_and_process(image_id, hdf5_file_path)
        return image, mask

    def __len__(self):
        return len(self.image_ids)

    def load_and_process(self, image_id, hdf5_file_path):
        with h5py.File(hdf5_file_path, 'r', libver='latest', swmr=True) as hdf5:
            image = hdf5['image'][...]
            labels = hdf5['label'][self.label_type][...]

        image = torch.tensor(image).float()
        image = image / 10000

        if self.label_type in ["binary_seg_maps"]:
            mask = torch.tensor(labels)
            mask[mask > 0] = 1
            mask = torch.nn.functional.one_hot(mask, 2)
            mask = mask[:, :, 1].float().unsqueeze(0)  # Add channel dimension
            
        elif self.label_type in ["multi_class_seg_maps"]:
            mask = torch.tensor(labels)
            mask = torch.nn.functional.one_hot(mask, num_classes=4).permute(2, 0, 1).float()

        
        if self.transform:
            image = self.transform(image).squeeze(0)
            mask = self.transform_mask(mask).squeeze(0)


        return image.float(), mask.float()

    def _adjust_shapes(self, image, mask):
        if image.shape[2] != 4 or mask.shape[0] != 1:
            if image.shape[0] == 4:
                image = image.transpose(2, 0, 1)
            elif image.shape[1] == 4:
                image = image.transpose(0, 2, 1)
                if mask.shape[0] == 1:
                    if image.shape[0] != mask.shape[1]:
                        image = image.transpose(1, 0, 2)
            if mask.shape[0] == 1:
                mask = mask.transpose(2, 0, 1)

        if (image.shape[1] != mask.shape[1] or image.shape[2] != mask.shape[2]):
            raise ValueError(f'The shape of object {image.shape} is incorrect')

        return image, mask

In [2]:
train_ds = AWD_Dataset(hdf5_file="train_csv_file.csv", 
                        label_type="binary_seg_maps", 
                        transform=None)

train_loader = DataLoader(train_ds, 
                            batch_size=1,
                            num_workers=8, 
                            pin_memory=True,
                            prefetch_factor=8,
                            persistent_workers=False)


In [3]:
mean = [0.0,0.0,0.0,0.0]
channels_sum = [0.0,0.0,0.0,0.0]
channels_sqrd_sum = [0.0,0.0,0.0,0.0]
num_batches = [0,0,0,0]
std = [0.0,0.0,0.0,0.0]
max_channel = [0,0,0,0]
min_channel = [0,0,0,0]


for image, mask in tqdm(train_loader):       
    for channel in range(4):
        channel_data = image[:, channel, :, :]
        channels_sum[channel] += torch.mean(channel_data)
        channels_sqrd_sum[channel] += torch.mean(channel_data**2)
        num_batches[channel] += 1
        if(torch.max(channel_data).item()>max_channel[channel]):
            max_channel[channel] = torch.max(channel_data).item()
        if(torch.min(channel_data).item()<min_channel[channel]):
            min_channel[channel] = torch.min(channel_data).item()


for channel in range(4):
    mean[channel] = channels_sum[channel].item() / num_batches[channel]
    std[channel] = (channels_sqrd_sum[channel].item() / num_batches[channel] - mean[channel]**2) ** 0.5
        
print(mean)
print(std)
print(max_channel)
print(min_channel)

100%|██████████| 167436/167436 [12:08<00:00, 229.71it/s] 

[0.03978963475689965, 0.06304591754461406, 0.060595331014536895, 0.23040860575085406]
[0.03658732531954657, 0.045488653216840556, 0.05400182458998847, 0.11936749497700323]
[1.9407999515533447, 1.9943000078201294, 2.7534000873565674, 4.175000190734863]
[0, 0, 0, 0]



