# Minimal working UNet

## 0. imports and data check

In [2]:
import os
import h5py # note: importing h5py multiple times can cause an error

import numpy as np
import pandas as pd

import torch as t
import torch.nn.functional as f

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [3]:
train_dir = 'data/h5/test'
os.listdir(train_dir)
train_subjects = os.listdir(train_dir)
train_subject = train_subjects[0]

In [4]:
train_h5 = h5py.File(f'{train_dir}/{train_subject}', 'r')

train_h5_raw_np = np.array(train_h5.get('raw'))
train_h5_label_np = np.array(train_h5.get('label'))

raw_shape = train_h5_raw_np.shape
label_shape = train_h5_label_np.shape

raw_shape, label_shape

((1, 256, 256, 256), (102, 256, 256, 256))

## 1. Dataloader logic

In [5]:
ordered_subject_list = sorted(os.listdir(train_dir))

In [12]:
class HDF5Dataset(Dataset):

    """ A custom Dataset class to iterate over subjects.
        This Dataset assumes that the data take the following form:
            data_dir/
                -- subject0.hdf5 (file with two datasets):
                    -- x_name: 4D array
                    -- y_name: 4D array
                -- subject1.hdf5 ...
                    -- ...
        Note also that this directory should not contain any other files
        besides h5 files for subjects intended to be included in this dataset.
        -----
        Arguments:
            data_dir
            x_name
            y_name
            ordered_subject_list
        -----       
        Returns:
            Pytorch index-based Dataset where each sample is an x, y pair of tensors
                corresponding to a 3D T1 scan and a 4D set of anatomical labels (one-hot)
        
    """
    
    def __init__(self, 
                 data_dir, 
                 x_name=None,
                 y_name=None,
                 ordered_subject_list=None):
        
        self.data_dir = data_dir

        # parse default args
        x_name = 'raw' if x_name is None else x_name
        y_name = 'label' if y_name is None else y_name
        self.x_name = x_name
        self.y_name = y_name
        
        # parse subject ordering, if specified
        if ordered_subject_list is None:
            ordered_subject_list = sorted(os.listdir(data_dir))
        self.subjects = ordered_subject_list
        

    def __len__(self):
        return len(self.subjects)
    

    def __getitem__(self, index):
        subject = self.subjects[index]  # Select the current datapoint (subject)    
        h5 = h5py.File(f'{self.data_dir}/{subject}', 'r')
        
        x_np = h5.get(self.x_name)
        y_np = h5.get(self.y_name)
        
        x = t.from_numpy(np.array(x_np))
        y = t.from_numpy(np.array(y_np))
        
        h5.close() # close the h5 file to avoid extra memory usage

        # If necessary, apply any preprocessing or transformations to the data
        # data = ...

        return x, y

In [13]:
ds = HDF5Dataset(data_dir=train_dir)

In [None]:
xt, yt = ds[0]

## 2. Model

## 3. Training loop

## 4. Evaluation, visualizaitons, etc.