# Training a 3D U-Net

TensorFlow 2 code to train a 3D U-Net on the brain tumor segmentation ([BraTS](https://www.med.upenn.edu/sbia/brats2017.html)) subset of the [Medical Segmentation Decathlon dataset](http://medicaldecathlon.com/) dataset. 

This model can achieve a [Dice coefficient](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC1415224/) of > 0.80 on the whole tumor using just the [FLAIR](https://en.wikipedia.org/wiki/Fluid-attenuated_inversion_recovery) channel.

In [1]:
import tensorflow as tf
from tensorflow import keras as K

import nibabel as nib
import numpy as np
import os
import datetime
    
import matplotlib.pyplot as plt
%matplotlib inline

## Determine if we are using Intel-optimized TensorFlow (DNNL)

In [2]:
def test_intel_tensorflow():
    """
    Check if Intel version of TensorFlow is installed
    """
    import tensorflow as tf

    print("We are using Tensorflow version {}".format(tf.__version__))

    major_version = int(tf.__version__.split(".")[0])
    if major_version >= 2:
        from tensorflow.python.util import _pywrap_util_port
        print("Intel-optimizations (DNNL) enabled:",
              _pywrap_util_port.IsMklEnabled())
    else:
        print("Intel-optimizations (DNNL) enabled:",
              tf.pywrap_tensorflow.IsMklEnabled())


test_intel_tensorflow()  # Prints if Intel-optimized TensorFlow is used.

We are using Tensorflow version 2.17.0
Intel-optimizations (DNNL) enabled: True


In [3]:
test_intel_tensorflow()

print("\nDevice being used:", tf.config.list_physical_devices())


We are using Tensorflow version 2.17.0
Intel-optimizations (DNNL) enabled: True

Device being used: [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]


## Define the settings

In [4]:
data_path= r"C:\Users\basur\Music\ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData"
train_val_split = 0.80
val_test_split = 0.50
bz_train=8
bz_val=4
bz_test=1
num_epochs=3

crop_dim = (240, 240, 155, 1)
number_output_classes = 3

filters=8
saved_model_name = "3d_unet_decathlon"

seed=816


## Define a data loader

We'll use `tf.data` to define a way to load the BraTS dataset at runtime whenever a new batch of 3D images and masks are requested.

In [5]:
class DataLoader:
    def __init__(self, data_path):
        self.data_path = data_path
        self.file_list = self._create_file_list()

    def _create_file_list(self):
        """
        Create a list of file paths for the dataset.
        """
        file_list = []
        for root, dirs, files in os.walk(self.data_path):
            for file in files:
                if file.endswith(".nii.gz"):
                    file_list.append(os.path.join(root, file))
        return file_list

    def load_data(self, idx):
        """
        Load the data and label for a given index.
        """
        img_file = self.file_list[idx]
        img = nib.load(img_file).get_fdata()
        img = np.expand_dims(img, axis=-1)  # Add channel dimension
        return img

    def __len__(self):
        return len(self.file_list)

# Example usage
data_loader = DataLoader(data_path)
print(f"Number of files: {len(data_loader)}")
sample_img = data_loader.load_data(0)
print(f"Sample image shape: {sample_img.shape}")

Number of files: 6255
Sample image shape: (240, 240, 155, 1)


In [6]:
brats_datafiles = DataLoader(data_path)
for file_path in brats_datafiles.file_list:
    print(file_path)

C:\Users\basur\Music\ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData\BraTS-GLI-00000-000\BraTS-GLI-00000-000-seg.nii.gz
C:\Users\basur\Music\ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData\BraTS-GLI-00000-000\BraTS-GLI-00000-000-t1c.nii.gz
C:\Users\basur\Music\ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData\BraTS-GLI-00000-000\BraTS-GLI-00000-000-t1n.nii.gz
C:\Users\basur\Music\ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData\BraTS-GLI-00000-000\BraTS-GLI-00000-000-t2f.nii.gz
C:\Users\basur\Music\ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData\BraTS-GLI-00000-000\BraTS-GLI-00000-000-t2w.nii.gz
C:\Users\basur\Music\ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData\BraTS-GLI-00002-000\BraTS-GLI-00002-000-seg.nii.gz
C:\Users\basur\Music\ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData\BraTS-GLI-00002-000\BraTS-GLI-00002-000-t1c.nii.gz
C:\Users\basur\Music\ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData\BraTS-GLI-00002-000\BraTS-GLI-00002-000-t1n.nii.gz
C:\Users\basur\Music\ASNR-MICCAI

## Data preprocessing

Here we preprocess the 3D MRI scans. We'll normalize the images, crop the images, and do random flips/rotations.

In [7]:
def z_normalize_img(img):
    """
    Normalize the image so that the mean value for each image
    is 0 and the standard deviation is 1.
    """
    if not isinstance(img, np.ndarray):
        raise TypeError("Input must be a numpy array")
    
    if img.size == 0:
        raise ValueError("Input array is empty")
        
    for channel in range(img.shape[-1]):
        img_temp = img[..., channel]
        std = np.std(img_temp)
        
        # Avoid division by zero
        if std == 0:
            img[..., channel] = img_temp - np.mean(img_temp)
        else:
            img[..., channel] = (img_temp - np.mean(img_temp)) / std

    return img


In [8]:
    
def crop(img, msk, randomize):
    """
    Randomly crop the image and mask
    """

    slices = []
    
    # Do we randomize?
    is_random = randomize and np.random.rand() > 0.5

    for idx in range(len(img.shape)-1):  # Go through each dimension

        cropLen = crop_dim[idx]
        imgLen = img.shape[idx]

        start = (imgLen-cropLen)//2

        ratio_crop = 0.20  # Crop up this this % of pixels for offset
        # Number of pixels to offset crop in this dimension
        offset = int(np.floor(start*ratio_crop))

        if offset > 0:
            if is_random:
                start += np.random.choice(range(-offset, offset))
                if ((start + cropLen) > imgLen):  # Don't fall off the image
                    start = (imgLen-cropLen)//2
        else:
            start = 0

        slices.append(slice(start, start+cropLen))

    return img[tuple(slices)], msk[tuple(slices)]


In [9]:
    
def augment_data(img, msk, crop_dim):
    """
    Data augmentation
    Flip image and mask. Rotate image and mask.
    """
    
    # Determine if axes are equal and can be rotated
    # If the axes aren't equal then we can't rotate them.
    equal_dim_axis = []
    for idx in range(0, len(crop_dim)-1):  # Exclude last dimension (channels)
        for jdx in range(idx+1, len(crop_dim)-1):
            if crop_dim[idx] == crop_dim[jdx]:
                equal_dim_axis.append([idx, jdx])  # Valid rotation axes
    dim_to_rotate = equal_dim_axis

    if np.random.rand() > 0.5:
        # Random 0,1 (axes to flip)
        ax = np.random.choice(np.arange(len(crop_dim)-1))
        img = np.flip(img, ax)
        msk = np.flip(msk, ax)

    elif (len(dim_to_rotate) > 0) and (np.random.rand() > 0.5):
        rot = np.random.choice([1, 2, 3])  # 90, 180, or 270 degrees

        # This will choose the axes to rotate
        # Axes must be equal in size
        random_axis = dim_to_rotate[np.random.choice(len(dim_to_rotate))]
        
        img = np.rot90(img, rot, axes=random_axis)  # Rotate axes 0 and 1
        msk = np.rot90(msk, rot, axes=random_axis)  # Rotate axes 0 and 1

    return img, msk


In [10]:
    
def read_nifti_file(idx, crop_dim, randomize=False):
    """
    Read Nifti file and corresponding mask
    """
    
    idx = idx.numpy()
    # Get the image file path
    img_file = brats_datafiles.file_list[idx]
    
    # Load image
    img = nib.load(img_file).get_fdata()
    
    # Create corresponding mask path by replacing 't2w' with 'seg'
    msk_file = img_file.replace('t2w.nii.gz', 'seg.nii.gz')
    
    # Load mask if it exists, otherwise create zero mask
    if os.path.exists(msk_file):
        msk = nib.load(msk_file).get_fdata()
    else:
        msk = np.zeros_like(img)
    
    img = np.expand_dims(img, axis=-1)  # Add channel dimension
    img = np.rot90(img)
    
    msk = np.rot90(msk)

    """
    "labels": {
         "0": "background",
         "1": "edema",
         "2": "non-enhancing tumor",
         "3": "enhancing tumour"}
     """
    # Combine all masks but background
    if number_output_classes == 1:
        msk[msk > 0] = 1.0
        msk = np.expand_dims(msk, -1)
    else:
        msk_temp = np.zeros(list(msk.shape) + [number_output_classes])
        # Create one-hot encoded mask
        for channel in range(number_output_classes):
            if channel == 0:  # background
                msk_temp[..., channel] = (msk == 0).astype(np.float32)
            elif channel == 1:  # edema
                msk_temp[..., channel] = (msk == 1).astype(np.float32)
            elif channel == 2:  # non-enhancing tumor + enhancing tumor
                msk_temp[..., channel] = ((msk == 2) | (msk == 3)).astype(np.float32)
        msk = msk_temp
    
    imgFilename = os.path.basename(img_file).split(".nii.gz")[0]
    
    # Crop
    img, msk = crop(img, msk, randomize)
    
    # Normalize
    img = z_normalize_img(img)
    
    # Randomly rotate
    if randomize:
        img, msk = augment_data(img, msk, crop_dim)
    
    return img, msk
    

In [11]:
import os
import nibabel as nib
import numpy as np

def is_nifti_empty(nifti_path):
    """Check if a NIfTI file is empty based on multiple conditions."""
    try:
        if not os.path.exists(nifti_path):
            print(f"File does not exist: {nifti_path}")
            return True
            
        img = nib.load(nifti_path)
        data = img.get_fdata()

        # Condition 1: All values are zero
        if np.all(data == 0):
            return True

        # Condition 2: Data has no shape (invalid file)
        if data.size == 0 or data.shape == ():
            return True

        # Condition 3: All values are NaN
        if np.isnan(data).all():
            return True

        # Condition 4: All values are close to zero (e.g., due to precision errors)
        if np.all(np.abs(data) < 1e-6):
            return True

        return False  # File contains meaningful data
    except Exception as e:
        print(f"Error processing {nifti_path}: {str(e)}")
        return True

def discard_empty_nifti(directory):
    """Delete all empty NIfTI files in a given directory."""
    if not os.path.exists(directory):
        print(f"Directory does not exist: {directory}")
        return
        
    count = 0
    for file in os.listdir(directory):
        if file.endswith(".nii") or file.endswith(".nii.gz"):
            file_path = os.path.join(directory, file)
            if is_nifti_empty(file_path):
                os.remove(file_path)
                count += 1
                print(f"Deleted empty NIfTI file: {file}")
    print(f"Total empty files deleted: {count}")

# Use the data_path that was already defined
print(f"Checking for empty NIfTI files in: {data_path}")
if os.path.exists(data_path):
    discard_empty_nifti(data_path)
else:
    print(f"Data path does not exist: {data_path}")

Checking for empty NIfTI files in: C:\Users\basur\Music\ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData
Total empty files deleted: 0


## tf.data

Define the training, testing, and validation data loaders.

In [12]:
# Calculate dataset sizes
numFiles = len(brats_datafiles)
numTrain = int(numFiles * train_val_split)
numValTest = numFiles - numTrain

# Create and shuffle the dataset
ds = tf.data.Dataset.range(numFiles).shuffle(numFiles, seed=seed)

# Split into train, validation, and test sets
ds_train = ds.take(numTrain)
ds_val_test = ds.skip(numTrain)
ds_val = ds_val_test.take(int(numValTest * val_test_split))
ds_test = ds_val_test.skip(int(numValTest * val_test_split))

# Define output shapes for the py_function
output_shapes = (
    tf.TensorSpec(shape=crop_dim, dtype=tf.float32),
    tf.TensorSpec(shape=(crop_dim[0], crop_dim[1], crop_dim[2], number_output_classes), dtype=tf.float32)
)

# Map the data loading function
ds_train = ds_train.map(
    lambda x: tf.py_function(read_nifti_file, [x, crop_dim, True], [tf.float32, tf.float32]),
    num_parallel_calls=tf.data.AUTOTUNE
).map(lambda x, y: (tf.ensure_shape(x, crop_dim), tf.ensure_shape(y, (crop_dim[0], crop_dim[1], crop_dim[2], number_output_classes))))

ds_val = ds_val.map(
    lambda x: tf.py_function(read_nifti_file, [x, crop_dim, False], [tf.float32, tf.float32]),
    num_parallel_calls=tf.data.AUTOTUNE
).map(lambda x, y: (tf.ensure_shape(x, crop_dim), tf.ensure_shape(y, (crop_dim[0], crop_dim[1], crop_dim[2], number_output_classes))))

ds_test = ds_test.map(
    lambda x: tf.py_function(read_nifti_file, [x, crop_dim, False], [tf.float32, tf.float32]),
    num_parallel_calls=tf.data.AUTOTUNE
).map(lambda x, y: (tf.ensure_shape(x, crop_dim), tf.ensure_shape(y, (crop_dim[0], crop_dim[1], crop_dim[2], number_output_classes))))

# Configure datasets for performance
ds_train = ds_train.cache().shuffle(100).batch(bz_train).prefetch(tf.data.AUTOTUNE)
ds_val = ds_val.cache().batch(bz_val).prefetch(tf.data.AUTOTUNE)
ds_test = ds_test.cache().batch(bz_test).prefetch(tf.data.AUTOTUNE)

# Print dataset information
print(f"Number of training batches: {tf.data.experimental.cardinality(ds_train)}")
print(f"Number of validation batches: {tf.data.experimental.cardinality(ds_val)}")
print(f"Number of test batches: {tf.data.experimental.cardinality(ds_test)}")

Number of training batches: 626
Number of validation batches: 157
Number of test batches: 626


## Define the loss and metrics

In [15]:
import tensorflow as tf

def dice_coef(target, prediction, axis=(1, 2, 3), smooth=0.0001):

    """

    Sorenson (Soft) Dice coefficient for comparing the similarity 

    of two batch of data, specifically for binary mask evaluation.

    

    Args:

        target (Tensorflow tensor): ground truth

        prediction (Tensorflow tensor): prediction

        axis (tuple): axes to calculate dice coefficient 

        smooth (float): small constant to avoid division by zero

    

    Returns:

        float: dice coefficient

    """

    prediction = tf.round(prediction)  # Round predictions to 0 or 1

    intersection = tf.reduce_sum(target * prediction, axis=axis)

    union = tf.reduce_sum(target + prediction, axis=axis)

    dice = (2.0 * intersection + smooth) / (union + smooth)

    return tf.reduce_mean(dice)



def soft_dice_coef(target, prediction, axis=(1, 2, 3), smooth=0.0001):

    """

    Sorenson (Soft) Dice coefficient without rounding predictions.

    

    Args:

        target (Tensorflow tensor): ground truth

        prediction (Tensorflow tensor): prediction

        axis (tuple): axes to calculate dice coefficient

        smooth (float): small constant to avoid division by zero

    

    Returns:

        float: soft dice coefficient

    """

    intersection = tf.reduce_sum(target * prediction, axis=axis)

    union = tf.reduce_sum(target + prediction, axis=axis)

    dice = (2.0 * intersection + smooth) / (union + smooth)

    return tf.reduce_mean(dice)



def dice_loss(target, prediction, axis=(1, 2, 3), smooth=0.0001):

    """

    Sorenson (Soft) Dice loss function.

    Using -log(Dice) to obtain smooth optimization surface.

    

    Args:

        target (Tensorflow tensor): ground truth

        prediction (Tensorflow tensor): prediction

        axis (tuple): axes to calculate dice coefficient

        smooth (float): small constant to avoid division by zero

    

    Returns:

        float: dice loss

    """

    intersection = tf.reduce_sum(prediction * target, axis=axis)

    p = tf.reduce_sum(prediction, axis=axis)

    t = tf.reduce_sum(target, axis=axis)

    numerator = tf.reduce_mean(intersection + smooth)

    denominator = tf.reduce_mean(t + p + smooth)

    dice_loss = -tf.math.log(2.*numerator) + tf.math.log(denominator)



    return dice_loss


## Define the 3D U-Net

Create a TensorFlow model.

In [16]:
# Keras is already imported as 'K' in cell 1 with:
# from tensorflow import keras as K

In [31]:
# Define default dimensions - assuming you want 128x128x128 images with 1 channel
crop_dim = (240, 240, 160, 1)

import tensorflow as tf
from tensorflow import keras as K

def unet_3d(fms=32, input_dim=crop_dim, use_upsampling=False, concat_axis=-1):
    """
    3D U-Net
    """
    
    def ConvolutionBlock(x, name, fms, params):
        """
        Convolutional block of layers
        Per the original paper this is back to back 3D convs
        with batch norm and then ReLU.
        """

        x = K.layers.Conv3D(filters=fms, **params, name=name+"_conv0")(x)
        x = K.layers.BatchNormalization(name=name+"_bn0")(x)
        x = K.layers.Activation("relu", name=name+"_relu0")(x)

        x = K.layers.Conv3D(filters=fms, **params, name=name+"_conv1")(x)
        x = K.layers.BatchNormalization(name=name+"_bn1")(x)
        x = K.layers.Activation("relu", name=name)(x)

        return x

    inputs = K.layers.Input(shape=input_dim, name="MRImages")

    params = dict(kernel_size=(3, 3, 3), activation=None,
                  padding="same", 
                  kernel_initializer="he_uniform")

    # Transposed convolution parameters
    params_trans = dict(kernel_size=(2, 2, 2), strides=(2, 2, 2),
                        padding="same")


    # BEGIN - Encoding path
    encodeA = ConvolutionBlock(inputs, "encodeA", fms, params)
    poolA = K.layers.MaxPooling3D(name="poolA", pool_size=(2, 2, 2))(encodeA)

    encodeB = ConvolutionBlock(poolA, "encodeB", fms*2, params)
    poolB = K.layers.MaxPooling3D(name="poolB", pool_size=(2, 2, 2))(encodeB)

    encodeC = ConvolutionBlock(poolB, "encodeC", fms*4, params)
    poolC = K.layers.MaxPooling3D(name="poolC", pool_size=(2, 2, 2))(encodeC)

    encodeD = ConvolutionBlock(poolC, "encodeD", fms*8, params)
    poolD = K.layers.MaxPooling3D(name="poolD", pool_size=(2, 2, 2))(encodeD)

    encodeE = ConvolutionBlock(poolD, "encodeE", fms*16, params)
    # END - Encoding path

    # BEGIN - Decoding path
    if use_upsampling:
        up = K.layers.UpSampling3D(name="upE", size=(2, 2, 2),
                                   interpolation="bilinear")(encodeE)
    else:
        up = K.layers.Conv3DTranspose(name="transconvE", filters=fms*8,
                                      **params_trans)(encodeE)
    concatD = K.layers.concatenate(
        [up, encodeD], axis=concat_axis, name="concatD")

    decodeC = ConvolutionBlock(concatD, "decodeC", fms*8, params)

    if use_upsampling:
        up = K.layers.UpSampling3D(name="upC", size=(2, 2, 2),
                                   interpolation="bilinear")(decodeC)
    else:
        up = K.layers.Conv3DTranspose(name="transconvC", filters=fms*4,
                                      **params_trans)(decodeC)
    concatC = K.layers.concatenate(
        [up, encodeC], axis=concat_axis, name="concatC")

    decodeB = ConvolutionBlock(concatC, "decodeB", fms*4, params)

    if use_upsampling:
        up = K.layers.UpSampling3D(name="upB", size=(2, 2, 2),
                                   interpolation="bilinear")(decodeB)
    else:
        up = K.layers.Conv3DTranspose(name="transconvB", filters=fms*2,
                                      **params_trans)(decodeB)
    concatB = K.layers.concatenate(
        [up, encodeB], axis=concat_axis, name="concatB")

    decodeA = ConvolutionBlock(concatB, "decodeA", fms*2, params)

    if use_upsampling:
        up = K.layers.UpSampling3D(name="upA", size=(2, 2, 2),
                                   interpolation="bilinear")(decodeA)
    else:
        up = K.layers.Conv3DTranspose(name="transconvA", filters=fms,
                                      **params_trans)(decodeA)
    concatA = K.layers.concatenate(
        [up, encodeA], axis=concat_axis, name="concatA")

    # END - Decoding path

    convOut = ConvolutionBlock(concatA, "convOut", fms, params)

    prediction = K.layers.Conv3D(name="PredictionMask",
                                 filters=number_output_classes, kernel_size=(1, 1, 1),
                                 activation="sigmoid")(convOut)

    model = K.models.Model(inputs=[inputs], outputs=[prediction], name="3d_unet_decathlon")

    model.summary()

    return model


In [None]:
# Ensure dimensions are divisible by 16 (2^4 for 4 max pooling layers)
# Calculate the padding needed to make dimensions divisible by 16
def pad_to_divisible(dim, factor=16):
	return ((dim + factor - 1) // factor) * factor

height = pad_to_divisible(240)
width = pad_to_divisible(240)
depth = pad_to_divisible(155)

# Update the dimensions
input_shape = (None, height, width, depth, 1)
crop_dim = (height, width, depth, 1)

print(f"Using padded dimensions: {crop_dim}")

# Create the model with the padded dimensions
model = unet_3d(fms=filters, input_dim=crop_dim)

In [19]:
# Model compilation is handled in cell 31

## Define the training callbacks

This includes model checkpoints and TensorBoard logs.

In [25]:
import os
import datetime

# Add .keras extension to the model name
model_filepath = f"{saved_model_name}.keras"

checkpoint = K.callbacks.ModelCheckpoint(
    filepath=model_filepath,
    monitor='val_dice_coef',  # Monitor validation dice coefficient
    mode='max',              # We want to maximize the dice coefficient
    save_best_only=True,     # Only save the best model
    verbose=1
)

# TensorBoard
logs_dir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tb_logs = K.callbacks.TensorBoard(log_dir=logs_dir)

callbacks = [checkpoint, tb_logs]


In [23]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir logs 

## Train the model

In [38]:
# Calculate dimensions that are divisible by 16 (2^4 for 4 max pooling layers)
def pad_to_divisible(dim, factor=16):
    return ((dim + factor - 1) // factor) * factor

# Pad all dimensions to be divisible by 16
padded_height = pad_to_divisible(240)  # 240 is already divisible by 16
padded_width = pad_to_divisible(240)   # 240 is already divisible by 16
padded_depth = 155  # Depth remains 155

# Update crop_dim with the new padded dimensions
crop_dim = (padded_height, padded_width, padded_depth, 1)
print(f"Using padded dimensions: {crop_dim}")

# Clear any previous models from memory
tf.keras.backend.clear_session()

# Create the model with the padded dimensions
model = unet_3d(fms=filters, input_dim=crop_dim)

# Create optimizer with specified learning rate
optimizer = K.optimizers.Adam(learning_rate=0.0001)

# Compile the model
model.compile(loss=dice_loss, 
             metrics=[dice_coef, soft_dice_coef], 
             optimizer=optimizer)

# Print model summary to verify the dimensions
model.summary()

Using padded dimensions: (240, 240, 155, 1)



ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concatenation axis. Received: input_shape=[(None, 30, 30, 18, 64), (None, 30, 30, 19, 64)]

In [36]:
# Update crop_dim to match the actual data dimensions
crop_dim = (240, 240, 155, 1)  # Use the original dimensions from the data

# Recreate the model with the correct dimensions
model = unet_3d(fms=filters, input_dim=crop_dim)

# Recompile the model
optimizer = K.optimizers.Adam(learning_rate=0.0001)
model.compile(loss=dice_loss, 
             metrics=[dice_coef, soft_dice_coef], 
             optimizer=optimizer)

# Train the model
history = model.fit(
    ds_train,
    epochs=num_epochs,
    validation_data=ds_val,
    callbacks=callbacks,
    verbose=2
)

ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concatenation axis. Received: input_shape=[(None, 30, 30, 18, 64), (None, 30, 30, 19, 64)]

In [83]:
# Calculate dimensions that are divisible by 16 (2^4 for 4 max pooling layers)
def pad_to_divisible(dim, factor=16):
    return ((dim + factor - 1) // factor) * factor

# Pad dimensions to be divisible by 16
padded_height = pad_to_divisible(240)
padded_width = pad_to_divisible(240)
padded_depth = pad_to_divisible(155)

# Update crop_dim to match padded dimensions
crop_dim = (padded_height, padded_width, padded_depth, 1)
print(f"Using padded dimensions: {crop_dim}")

# Recreate the model with padded dimensions
model = unet_3d(fms=filters, input_dim=crop_dim)

# Create a fresh optimizer instance
optimizer = K.optimizers.Adam(learning_rate=0.0001)

# Compile the model with the new optimizer
model.compile(loss=dice_loss, 
             metrics=[dice_coef, soft_dice_coef], 
             optimizer=optimizer)


Using padded dimensions: (240, 240, 160, 1)


## Evaluate the model

Evaluate the final model on the test dataset. This gives us an idea of how the model should perform on data it has never seen.

In [None]:
loss, dice_coef, soft_dice_coef = model.evaluate(ds_test)

print("Average Dice Coefficient on test dataset = {:.4f}".format(dice_coef))

*Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. SPDX-License-Identifier: EPL-2.0*

*Copyright (c) 2019-2020 Intel Corporation*