# Minimal working UNet

## 0. imports and data check

In [1]:
import os
import h5py # note: importing h5py multiple times can cause an error

In [2]:
import numpy as np
import pandas as pd

import torch as t
import torch.nn as nn
import torch.nn.functional as f
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [39]:
# these directories have a tiny amount of data for a minimal test
# they can be replaced with any other directories depending on OS and preferences

train_dir = 'data/h5/train'
val_dir = 'data/h5/val'

os.listdir(train_dir)
train_subjects = os.listdir(train_dir)
train_subject = train_subjects[0]

In [4]:
train_h5 = h5py.File(f'{train_dir}/{train_subject}', 'r')

train_h5_raw_np = np.array(train_h5.get('raw'))
train_h5_label_np = np.array(train_h5.get('label'))

raw_shape = train_h5_raw_np.shape
label_shape = train_h5_label_np.shape

raw_shape, label_shape

((1, 256, 256, 256), (102, 256, 256, 256))

## 1. Dataloader logic

In [5]:
ordered_subject_list = sorted(os.listdir(train_dir))

In [6]:
class HDF5Dataset(Dataset):

    """ A custom Dataset class to iterate over subjects.
        This Dataset assumes that the data take the following form:
            data_dir/
                -- subject0.hdf5 (file with two datasets)
                    -- x_name: 4D array
                    -- y_name: 4D array
                -- subject1.hdf5 (next file with two datasets)
                    -- ...
        Note also that this directory should not contain any other files
        besides h5 files for subjects intended to be included in this dataset.
        -----
        Arguments:
            data_dir
            x_name
            y_name
            ordered_subject_list
        -----       
        Returns:
            Pytorch index-based Dataset where each sample is an x, y pair of tensors
                corresponding to a 3D T1 scan and a 4D set of anatomical labels (one-hot)
        
    """
    
    def __init__(self, 
                 data_dir, 
                 x_name=None,
                 y_name=None,
                 ordered_subject_list=None):
        
        self.data_dir = data_dir

        # parse default args
        x_name = 'raw' if x_name is None else x_name
        y_name = 'label' if y_name is None else y_name
        self.x_name = x_name
        self.y_name = y_name
        
        # parse subject ordering, if specified
        if ordered_subject_list is None:
            ordered_subject_list = sorted(os.listdir(data_dir))
        self.subjects = ordered_subject_list
        

    def __len__(self):
        return len(self.subjects)
    

    def __getitem__(self, index):
        subject = self.subjects[index]  # Select the current datapoint (subject)    
        h5 = h5py.File(f'{self.data_dir}/{subject}', 'r')
        
        x_np = h5.get(self.x_name)
        y_np = h5.get(self.y_name)
        
        x = t.from_numpy(np.array(x_np))
        y = t.from_numpy(np.array(y_np))
        
        h5.close() # close the h5 file to avoid extra memory usage

        # If necessary, apply any preprocessing or transformations to the data
        # data = ...

        return x, y

In [40]:
ds_train = HDF5Dataset(data_dir=train_dir)
ds_val = HDF5Dataset(data_dir=val_dir)

In [41]:
batch_size = 1

dl_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
dl_val = DataLoader(ds_val, batch_size=batch_size, shuffle=False)

## 2. Model

In [11]:
# translated from imports of unet3d.model
from unet3d.buildingblocks import DoubleConv, ResNetBlock, ResNetBlockSE, \
    create_decoders, create_encoders
from unet3d.utils import get_class, number_of_features_per_level
from unet3d.model import UNet3D

In [12]:
in_channels = 1
out_channels = 102

model = UNet3D(in_channels=in_channels, out_channels=out_channels)

In [42]:
# test forward pass
di_train = iter(dl_train)
di_val = iter(dl_val)

xt, yt = next(di_train)
xv, yv = next(di_val)

In [26]:
model

UNet3D(
  (encoders): ModuleList(
    (0): Encoder(
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (groupnorm): GroupNorm(1, 1, eps=1e-05, affine=True)
          (conv): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
        (SingleConv2): SingleConv(
          (groupnorm): GroupNorm(8, 32, eps=1e-05, affine=True)
          (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
      )
    )
    (1): Encoder(
      (pooling): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (groupnorm): GroupNorm(8, 64, eps=1e-05, affine=True)
          (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
     

In [36]:
# specifying "float" may or may not be necessary on the GPU
# but it is required on CPU

if torch.cuda.is_available():
    # GPU is available
    device = torch.device("cuda")
    model.to(device) 
else:
    # GPU is not available, fall back to CPU
    device = torch.device("cpu")
    model.to(device, dtype=float) 

In [35]:
# if this cell doesn't run, training won't work either
output = model(x)

RuntimeError: [enforce fail at alloc_cpu.cpp:75] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 695784701952 bytes. Error code 12 (Cannot allocate memory)

## 3. Training loop

In [19]:
checkpoint_dir = './checkpoints'  # change this based on your OS and preferences
os.makedirs(checkpoint_dir, exist_ok=True)

In [17]:
import torch
import torch.nn as nn
import torch.optim as optim

optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

# Other training parameters
epochs = 10
lr_scheduler_patience = 3
lr_scheduler_factor = 0.1

def train(dl_train, 
          dl_val, 
          model, 
          optimizer,
          criterion,
          n_batches=1e3,
          lr_scheduler=None
         ):
    """ Function to wrap the main training loop.
    """
    
    # parse default parameters
    if lr_scheduler is None:
        default_lr_scheduler_patience = 3
        default_lr_scheduler_factor = 0.1
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                            patience=default_lr_scheduler_patience, 
                                                            factor=default_lr_scheduler_factor
                                                           )
  
    # Training loop
    best_val_loss = float('inf')
    
    for b in range(n_batches):
        # Training
        model.train()
        train_loss = 0.0
        for inputs, labels in dl_train:
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, labels in dl_val:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

        # Compute average loss
        train_loss /= len(dl_train)
        val_loss /= len(dl_val)

        # Update learning rate scheduler
        lr_scheduler.step(val_loss)

        # Print progress
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        # Check if current validation loss is the best so far
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            # Save the model checkpoint if desired
            

        # Check early stopping condition if desired
        # TODO

    # Training complete
    return model

In [18]:
train(dl_train, dl_val, model, optimizer, criterion)

RuntimeError: mixed dtype (CPU): expect input to have scalar type of BFloat16

## 4. Evaluation, visualizaitons, etc.