# Minimal working UNet

## 0. imports and data check

In [1]:
import os
import h5py  # note: importing h5py multiple times can cause an error

In [2]:
import numpy as np
import pandas as pd

import torch as t
import torch.nn as nn
import torch.nn.functional as f
import torch.optim as optim

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [3]:
# these directories have a tiny amount of data for a minimal test
# they can be replaced with any other directories depending on OS and preferences

train_dir = 'data/h5/train'
val_dir = 'data/h5/val'

os.listdir(train_dir)
train_subjects = os.listdir(train_dir)
train_subject = train_subjects[0]

In [53]:
def split_volume(np_arr, chunk_size, stride_size):
    """ Splits an np array of shape (nchannels, nx, ny, nz)
    """
    shape = np_arr.shape
    
    assert len(shape) == 4
    assert shape[1] == shape[2] == shape[3]
    
    num_chunks = ((shape[1] - chunk_size) // stride_size) + 1

    chunks_lvl0 = []
    
    # triple loop over x, y, z dimensions of volume 
    for xc in range(num_chunks):
        x0 = xc * stride_size
        x1 = x0 + chunk_size
        
        chunks_lvl1 = []
        for yc in range(num_chunks):
            y0 = yc * stride_size
            y1 = y0 + chunk_size
            
            chunks_lvl2 = []
            for zc in range(num_chunks):
                z0 = zc * stride_size
                z1 = z0 + chunk_size
        
                chunk = np_arr[:, x0:x1, y0:y1, z0:z1]
                chunks_lvl2.append(chunk)
            
            # zoom out 
            chunks_lvl1.append(chunks_lvl2)
        chunks_lvl0.append(chunks_lvl1)

    return chunks_lvl0

In [54]:
chunk_size = 84  # a little under 256 // 3
stride_size = 56  # with a chunk_size of 84, a stride size of 64 allows for some overlap of chunks.

raw_chunks = split_volume(raw_np, chunk_size=chunk_size, stride_size=stride_size)

In [55]:
# (x_chunks, y_chunks, z_chunks, channels, xsize, ysize, size)
np.array(raw_chunks).shape

(4, 4, 4, 1, 84, 84, 84)

In [106]:
def split_datapoint_and_save(h5_dataset, 
                             chunk_size, 
                             stride_size, 
                             save_directory,
                             save_filename_prefix):
    
    """ Splits an h5 file with compatible-shaped
        datasets 'raw' and 'label';
        Returns None.
        
        Its effect is to save out a list of datasets by chunking 'raw' and 'label'
        into a smaller new h5 files.
        
        The newly created h5 files are saved with a filename suffix indicating which
        x, y, z chunk of the original data they contain.
    """    
    
    raw_np = np.array(h5_dataset.get('raw'))
    label_np = np.array(h5_dataset.get('label'))
    
    split_raw_volumes = split_volume(raw_np, chunk_size, stride_size)
    split_label_volumes = split_volume(label_np, chunk_size, stride_size)
    
    nchunks_1d = len(split_raw_volumes[0])  # assumes x, y, z dimensions are equal
    for x in range(nchunks_1d):
        for y in range(nchunks_1d):
            for z in range(nchunks_1d):
                # save out a new h5 file for each chunk of the original volume
                print('chunk:', x, y, z)
                v1 = split_raw_volumes[x][y][z]
                v2 = split_label_volumes[x][y][z]
                
                file = h5py.File(f'{save_directory}/{save_filename_prefix}_chunk_{x}_{y}_{z}.h5', 'w')
                file.create_dataset('raw', data=v1)
                file.create_dataset('label', data=v2)
                file.close()
    
    return

In [107]:
# make directory for new chunked dataset
# in which each original volume (h5 file)
# is saved out as several distinct chunks (h5 files)
test_chunk_dir = '.data/test_chunk_dir'  # may need to change for your own tests
os.makedirs(test_chunk_dir, exist_ok=True)

chunk_size = 84  # a little under 256 // 3
stride_size = 56  # with a chunk_size of 84, a stride size of 64 allows for some overlap of chunks.

h5_volume = h5py.File(f'{train_dir}/{train_subject}', 'r')
h5_prefix = train_subject.split('.')[0] # filter off filename extension .h5

split_datapoint_and_save(h5_volume, 
                         chunk_size, 
                         stride_size, 
                         save_directory=test_chunk_dir,
                         save_filename_prefix=h5_prefix
                        )

chunk: 0 0 0
chunk: 0 0 1
chunk: 0 0 2
chunk: 0 0 3
chunk: 0 1 0
chunk: 0 1 1
chunk: 0 1 2
chunk: 0 1 3
chunk: 0 2 0
chunk: 0 2 1
chunk: 0 2 2
chunk: 0 2 3
chunk: 0 3 0
chunk: 0 3 1
chunk: 0 3 2
chunk: 0 3 3
chunk: 1 0 0
chunk: 1 0 1
chunk: 1 0 2
chunk: 1 0 3
chunk: 1 1 0
chunk: 1 1 1
chunk: 1 1 2
chunk: 1 1 3
chunk: 1 2 0
chunk: 1 2 1
chunk: 1 2 2
chunk: 1 2 3
chunk: 1 3 0
chunk: 1 3 1
chunk: 1 3 2
chunk: 1 3 3
chunk: 2 0 0
chunk: 2 0 1
chunk: 2 0 2
chunk: 2 0 3
chunk: 2 1 0
chunk: 2 1 1
chunk: 2 1 2
chunk: 2 1 3
chunk: 2 2 0
chunk: 2 2 1
chunk: 2 2 2
chunk: 2 2 3
chunk: 2 3 0
chunk: 2 3 1
chunk: 2 3 2
chunk: 2 3 3
chunk: 3 0 0
chunk: 3 0 1
chunk: 3 0 2
chunk: 3 0 3
chunk: 3 1 0
chunk: 3 1 1
chunk: 3 1 2
chunk: 3 1 3
chunk: 3 2 0
chunk: 3 2 1
chunk: 3 2 2
chunk: 3 2 3
chunk: 3 3 0
chunk: 3 3 1
chunk: 3 3 2
chunk: 3 3 3


In [115]:
def chunk_entire_dataset(data_in_dir, 
                         data_out_dir, 
                         chunk_size, 
                         stride_size
                        ):
    """
        Takes in a dataset with one h5 volume's worth of data ('raw' and 'label')
        for each subject in 'data_in_dir'
        Populates a directory 'data_out_dir' with a several chunk volumes that are
        obtained from the original volumes. E.g., 8 octants, each their own h5 file,
        for each 1 original h5 file.
    """
    
    h5_filenames = sorted(os.listdir(data_in_dir))
    for h5_filename in h5_filenames:

        chunk_size = 84  # a little under 256 // 3
        stride_size = 56  # with a chunk_size of 84, a stride size of 64 allows for some overlap of chunks.

        h5_in_file = h5py.File(f'{data_in_dir}/{h5_filename}', 'r')
        h5_prefix = h5_filename.split('.')[0] # filter off filename extension .h5

        split_datapoint_and_save(h5_in_file, 
                                 chunk_size, 
                                 stride_size=stride_size,
                                 save_directory=data_out_dir,
                                 save_filename_prefix=h5_prefix
                                )
        
    return

In [116]:
os.listdir('data/h5')

['train_chunked', 'train', '.DS_Store', 'val']

In [117]:
chunk_size = 84
stride_size = 56

data_in_dir = train_dir
data_out_dir = 'data/h5/train_chunked'  # may need to change for your own tests
os.makedirs(data_out_dir, exist_ok=True)

chunk_entire_dataset(data_in_dir=data_in_dir,
                     data_out_dir=data_out_dir, 
                     chunk_size=chunk_size, 
                     stride_size=stride_size)

chunk: 0 0 0
chunk: 0 0 1
chunk: 0 0 2
chunk: 0 0 3
chunk: 0 1 0
chunk: 0 1 1
chunk: 0 1 2
chunk: 0 1 3
chunk: 0 2 0
chunk: 0 2 1
chunk: 0 2 2
chunk: 0 2 3
chunk: 0 3 0
chunk: 0 3 1
chunk: 0 3 2
chunk: 0 3 3
chunk: 1 0 0
chunk: 1 0 1
chunk: 1 0 2
chunk: 1 0 3
chunk: 1 1 0
chunk: 1 1 1
chunk: 1 1 2
chunk: 1 1 3
chunk: 1 2 0
chunk: 1 2 1
chunk: 1 2 2
chunk: 1 2 3
chunk: 1 3 0
chunk: 1 3 1
chunk: 1 3 2
chunk: 1 3 3
chunk: 2 0 0
chunk: 2 0 1
chunk: 2 0 2
chunk: 2 0 3
chunk: 2 1 0
chunk: 2 1 1
chunk: 2 1 2
chunk: 2 1 3
chunk: 2 2 0
chunk: 2 2 1
chunk: 2 2 2
chunk: 2 2 3
chunk: 2 3 0
chunk: 2 3 1
chunk: 2 3 2
chunk: 2 3 3
chunk: 3 0 0
chunk: 3 0 1
chunk: 3 0 2
chunk: 3 0 3
chunk: 3 1 0
chunk: 3 1 1
chunk: 3 1 2
chunk: 3 1 3
chunk: 3 2 0
chunk: 3 2 1
chunk: 3 2 2
chunk: 3 2 3
chunk: 3 3 0
chunk: 3 3 1
chunk: 3 3 2
chunk: 3 3 3
chunk: 0 0 0
chunk: 0 0 1
chunk: 0 0 2
chunk: 0 0 3
chunk: 0 1 0
chunk: 0 1 1
chunk: 0 1 2
chunk: 0 1 3
chunk: 0 2 0
chunk: 0 2 1
chunk: 0 2 2
chunk: 0 2 3
chunk: 0 3 0

## 1. Dataloader logic

In [5]:
class HDF5Dataset(Dataset):

    """ A custom Dataset class to iterate over subjects.
        This Dataset assumes that the data take the following form:
            data_dir/
                -- subject0.hdf5 (file with two datasets)
                    -- x_name: 4D array
                    -- y_name: 4D array
                -- subject1.hdf5 (next file with two datasets)
                    -- ...
        Note also that this directory should not contain any other files
        besides h5 files for subjects intended to be included in this dataset.
        -----
        Arguments:
            data_dir
            x_name
            y_name
            ordered_subject_list
        -----       
        Returns:
            Pytorch index-based Dataset where each sample is an x, y pair of tensors
                corresponding to a 3D T1 scan (padded to 4D with a dummy dimension)
                and a 4D set of anatomical labels (one-hot encoded classes)
        
    """
    
    def __init__(self, 
                 data_dir, 
                 x_name=None,
                 y_name=None,
                 ordered_subject_list=None):
        
        self.data_dir = data_dir

        # parse default args
        x_name = 'raw' if x_name is None else x_name
        y_name = 'label' if y_name is None else y_name
        self.x_name = x_name
        self.y_name = y_name
        
        # parse subject ordering, if specified
        if ordered_subject_list is None:
            ordered_subject_list = sorted(os.listdir(data_dir))
        self.subjects = ordered_subject_list
        

    def __len__(self):
        return len(self.subjects)
    

    def __getitem__(self, index):
        subject = self.subjects[index]  # Select the current datapoint (subject)    
        h5 = h5py.File(f'{self.data_dir}/{subject}', 'r')
        
        x_np = h5.get(self.x_name)
        y_np = h5.get(self.y_name)
        
        x = t.from_numpy(np.array(x_np))
        y = t.from_numpy(np.array(y_np))
        
        h5.close() # close the h5 file to avoid extra memory usage

        # If necessary, apply any preprocessing or transformations to the data
        # data = ...

        return x, y

In [6]:
ds_train = HDF5Dataset(data_dir=train_dir)
ds_val = HDF5Dataset(data_dir=val_dir)

In [7]:
batch_size = 1

dl_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
dl_val = DataLoader(ds_val, batch_size=batch_size, shuffle=False)

## 2. Model

In [8]:
# translated from imports of unet3d.model
from unet3d.buildingblocks import DoubleConv, ResNetBlock, ResNetBlockSE, \
    create_decoders, create_encoders
from unet3d.utils import get_class, number_of_features_per_level
from unet3d.model import UNet3D

In [9]:
in_channels = 1
out_channels = 102

model = UNet3D(in_channels=in_channels, out_channels=out_channels)

In [44]:
# specifying "float" may or may not be necessary on the GPU
# but it is required on CPU
def get_device():
    device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
    return device
    
model.to(device, dtype=float) 

UNet3D(
  (encoders): ModuleList(
    (0): Encoder(
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (groupnorm): GroupNorm(1, 1, eps=1e-05, affine=True)
          (conv): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
        (SingleConv2): SingleConv(
          (groupnorm): GroupNorm(8, 32, eps=1e-05, affine=True)
          (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
      )
    )
    (1): Encoder(
      (pooling): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (groupnorm): GroupNorm(8, 64, eps=1e-05, affine=True)
          (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
     

In [39]:
device

device(type='cuda')

In [40]:
# test forward pass
di_train = iter(dl_train)
di_val = iter(dl_val)

xt, yt = next(di_train)
xv, yv = next(di_val)

In [41]:
# if this cell doesn't run, training won't work either
xt = xt.to(device)

In [47]:
model.cuda()

UNet3D(
  (encoders): ModuleList(
    (0): Encoder(
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (groupnorm): GroupNorm(1, 1, eps=1e-05, affine=True)
          (conv): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
        (SingleConv2): SingleConv(
          (groupnorm): GroupNorm(8, 32, eps=1e-05, affine=True)
          (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
      )
    )
    (1): Encoder(
      (pooling): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (groupnorm): GroupNorm(8, 64, eps=1e-05, affine=True)
          (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
     

In [43]:
output = model(xt)

RuntimeError: GET was unable to find an engine to execute this computation

## 3. Training loop

In [None]:
checkpoint_dir = './checkpoints'  # change this based on your OS and preferences
os.makedirs(checkpoint_dir, exist_ok=True)

In [46]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

# Other training parameters
epochs = 10
lr_scheduler_patience = 3
lr_scheduler_factor = 0.1

def train(dl_train, 
          dl_val, 
          model, 
          optimizer,
          criterion,
          n_batches=1e3,
          lr_scheduler=None,
         ):
    """ Function to wrap the main training loop.
    """
    
    device = get_device()  # defined in above section
    
    # parse default parameters
    if lr_scheduler is None:
        default_lr_scheduler_patience = 3
        default_lr_scheduler_factor = 0.1
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                            patience=default_lr_scheduler_patience, 
                                                            factor=default_lr_scheduler_factor
                                                           )
  
    # Training loop
    best_val_loss = float('inf')
    
    for b in range(n_batches):
        # Training
        model.train()
        train_loss = 0.0
        for inputs, labels in dl_train:
            optimizer.zero_grad()

            # Forward pass
            inputs = inputs.to(device)
            outputs = outputs.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, labels in dl_val:
                inputs = inputs.to(device)
                outputs = outputs.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

        # Compute average loss
        train_loss /= len(dl_train)
        val_loss /= len(dl_val)

        # Update learning rate scheduler
        lr_scheduler.step(val_loss)

        # Print progress
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        # Check if current validation loss is the best so far
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            # Save the model checkpoint if desired
            

        # Check early stopping condition if desired
        # TODO

    # Training complete
    return model

In [18]:
train(dl_train, dl_val, model, optimizer, criterion)

RuntimeError: mixed dtype (CPU): expect input to have scalar type of BFloat16

## 4. Evaluation, visualizaitons, etc.