# Imports

In [None]:
# imports
import os, sys

# third party imports
import numpy as np
import tensorflow as tf
assert tf.__version__.startswith('2.'), 'This tutorial assumes Tensorflow 2.0+'

# local imports
import voxelmorph as vxm
import neurite as ne

# U-net Keras Backbone Model (2D Weakly Supervised)

In [None]:
#model

from tensorflow.keras import layers
from tensorflow import keras

def get_model(moving_image_shape, fixed_image_shape, with_label_inputs=True, up_filters=[64, 128, 256], down_filters=[256, 128, 64, 32]):

    input_moving_image = keras.Input(moving_image_shape)
    input_fixed_image = keras.Input(fixed_image_shape)

    if with_label_inputs:
        input_moving_label = keras.Input(moving_image_shape)
        input_fixed_label = keras.Input(fixed_image_shape)

    concatenate_layer = layers.Concatenate(axis=-1)([input_moving_image, input_fixed_image])

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(32, 3, strides=2, padding="same")(concatenate_layer)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
    for filters in up_filters:
        x = layers.Activation("relu")(x)
        x = layers.Conv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in down_filters:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # Project residual
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    out_ddf = layers.Conv2D(2, 3, activation="linear", padding="same")(x)

    # Define the model
    if with_label_inputs:
        model = keras.Model(inputs=[input_moving_image, input_fixed_image, input_moving_label, input_fixed_label], outputs=[out_ddf])
    else:
        model = keras.Model(inputs=[input_moving_image, input_fixed_image], outputs=[out_ddf])
    return model

# Data Generator (2D, Weakly Supervised)

In [None]:
#data generator

from skimage.transform import resize

import nibabel as nib

def resize_2d_image(image, shape):
    image= image.squeeze()
    resized_image = resize(image, output_shape=shape)
    if np.amax(resized_image) == np.amin(resized_image):
        normalised_image = resized_image
    else:
        normalised_image = (resized_image - np.amin(resized_image)) / (np.amax(resized_image) - np.amin(resized_image))
    return normalised_image

def train_generator_(f_path, batch_size, moving_image_shape, fixed_image_shape, with_label_inputs=True):
    moving_images_path = os.path.join(f_path, 'us_images')
    fixed_images_path = os.path.join(f_path, 'mr_images')

    if with_label_inputs:
        moving_labels_path = os.path.join(f_path, 'us_labels')
        fixed_labels_path = os.path.join(f_path, 'mr_labels')

    all_names = np.array(os.listdir(fixed_images_path))

    while True:

        batch_names = all_names[np.random.permutation(len(all_names))[:batch_size]]

        moving_images_batch = np.zeros((batch_size, *moving_image_shape))
        fixed_images_batch = np.zeros((batch_size, *fixed_image_shape))

        if with_label_inputs:
            moving_labels_batch = np.zeros((batch_size, *moving_image_shape))
            fixed_labels_batch = np.zeros((batch_size, *fixed_image_shape))

        for i, f_name in enumerate(batch_names):
            moving_image = nib.load(os.path.join(moving_images_path, f_name)).get_fdata()
            fixed_image = nib.load(os.path.join(fixed_images_path, f_name)).get_fdata()
            # TAKE A SLICE FROM THE 3D IMAGE
            slice_index_moving = moving_image.shape[0] // 2
            slice_index_fixed = fixed_image.shape[0] // 2
            moving_image = moving_image[slice_index_moving,:,:]
            fixed_image = fixed_image[slice_index_fixed,:,:]

            if with_label_inputs:
                moving_label = nib.load(os.path.join(moving_labels_path, f_name)).get_fdata()
                fixed_label = nib.load(os.path.join(fixed_labels_path, f_name)).get_fdata()

                label_to_select = np.random.randint(6) #pick one label randomly for training

            moving_images_batch[i] = resize_2d_image(moving_image, moving_image_shape)
            fixed_images_batch[i] = resize_2d_image(fixed_image, fixed_image_shape)

            if with_label_inputs:
                slice_index_moving = moving_label.shape[0] // 2
                slice_index_fixed = fixed_label.shape[0] // 2
                moving_labels_batch[i] = resize_2d_image(moving_label[slice_index_moving, :, :, label_to_select], moving_image_shape)
                fixed_labels_batch[i] = resize_2d_image(fixed_label[slice_index_fixed, :, :, label_to_select], fixed_image_shape)

        zero_phis = np.zeros([batch_size, *moving_image_shape[:-1], 2])

        if with_label_inputs:
            inputs = (
                tf.convert_to_tensor(moving_images_batch, dtype=tf.float32),
                tf.convert_to_tensor(fixed_images_batch, dtype=tf.float32),
                tf.convert_to_tensor(moving_labels_batch, dtype=tf.float32),
                tf.convert_to_tensor(fixed_labels_batch, dtype=tf.float32)
            )
            outputs = (
                tf.convert_to_tensor(fixed_images_batch, dtype=tf.float32),
                tf.convert_to_tensor(zero_phis, dtype=tf.float32),
                tf.convert_to_tensor(fixed_labels_batch, dtype=tf.float32),
            )

        else:
            inputs = (
                tf.convert_to_tensor(moving_images_batch, dtype=tf.float32),
                tf.convert_to_tensor(fixed_images_batch, dtype=tf.float32)
            )
            outputs = (
                tf.convert_to_tensor(fixed_images_batch, dtype=tf.float32),
                tf.convert_to_tensor(zero_phis, dtype=tf.float32)
            )

        yield inputs, outputs

def test_generator(f_path, batch_size, moving_image_shape, fixed_image_shape, start_index, end_index, label_num, with_label_inputs=True):
    moving_images_path = os.path.join(f_path, 'us_images')
    fixed_images_path = os.path.join(f_path, 'mr_images')

    if with_label_inputs:
        moving_labels_path = os.path.join(f_path, 'us_labels')
        fixed_labels_path = os.path.join(f_path, 'mr_labels')

    all_names = np.array(os.listdir(fixed_images_path))[start_index: end_index]

    if start_index and end_index is not None:
        n_steps = int(np.floor((end_index - start_index) / batch_size))
    else:
        start_index = 0
        end_index = len(all_names)
        n_steps =int( np.floor((end_index - start_index) / batch_size))

    for step in range(n_steps):

        batch_names = all_names[step*batch_size:(step*batch_size)+batch_size]

        moving_images_batch = np.zeros((batch_size, *moving_image_shape))
        fixed_images_batch = np.zeros((batch_size, *fixed_image_shape))

        if with_label_inputs:
            moving_labels_batch = np.zeros((batch_size, *moving_image_shape))
            fixed_labels_batch = np.zeros((batch_size, *fixed_image_shape))

        for i, f_name in enumerate(batch_names):
            moving_image = nib.load(os.path.join(moving_images_path, f_name)).get_fdata()
            fixed_image = nib.load(os.path.join(fixed_images_path, f_name)).get_fdata()
            # TAKE A SLICE FROM THE 3D IMAGE
            slice_index_moving = moving_image.shape[0] // 2
            slice_index_fixed = fixed_image.shape[0] // 2
            moving_image = moving_image[slice_index_moving,:,:]
            #print('*', moving_image.shape)
            fixed_image = fixed_image[slice_index_fixed,:,:]

            if with_label_inputs:
                moving_label = nib.load(os.path.join(moving_labels_path, f_name)).get_fdata() # if label not available, just pass zeros
                fixed_label = nib.load(os.path.join(fixed_labels_path, f_name)).get_fdata() # if label not available, just pass zeros

                label_to_select = label_num #pick one label randomly for training

            moving_images_batch[i] = resize_2d_image(moving_image, moving_image_shape)
            fixed_images_batch[i] = resize_2d_image(fixed_image, fixed_image_shape)

            if with_label_inputs:
                slice_index_moving = moving_label.shape[0] // 2
                slice_index_fixed = fixed_label.shape[0] // 2
                moving_labels_batch[i] = resize_2d_image(moving_label[slice_index_moving, :, :, label_to_select], moving_image_shape)
                fixed_labels_batch[i] = resize_2d_image(fixed_label[slice_index_fixed, :, :, label_to_select], fixed_image_shape)

        zero_phis = np.zeros([batch_size, *moving_image_shape[:-1], 2])

        if with_label_inputs:
            inputs = [moving_images_batch, fixed_images_batch, moving_labels_batch, fixed_labels_batch]
            outputs = [fixed_images_batch, zero_phis, fixed_labels_batch]
        else:
            inputs = [moving_images_batch, fixed_images_batch]
            outputs = [fixed_images_batch, zero_phis]

        yield (inputs, outputs)

# Training Loop (2D, Weakly Supervised, NCC, US-MRI)

In [None]:
#train voxel morph
os.environ['CUDA_VISIBLE_DEVICES']='1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import matplotlib.pyplot as plt

def train_model(last_trial=None, latest_weights=None, Verbose=False):
    
    for lambda_param in [0.05]:

        # =============================================================================
        # Build the backbone model
        # =============================================================================
    
        moving_image_shape = (80, 80, 1)
        fixed_image_shape = (80, 80, 1)
        
        model = get_model(moving_image_shape, fixed_image_shape, with_label_inputs=False)
        
        print('\nBackbone model inputs and outputs:')
        
        print('    input shape: ', ', '.join([str(t.shape) for t in model.inputs]))
        print('    output shape:', ', '.join([str(t.shape) for t in model.outputs]))
        
        # =============================================================================
        # Build the registration network
        # =============================================================================
        
        # build transformer layer
        spatial_transformer = vxm.layers.SpatialTransformer(name='transformer')
        
        # extract the moving image
        moving_image = model.input[0]
        fixed_image = model.inputs[1]
        input_moving_label = keras.Input(moving_image_shape, name="moving_label")
        input_fixed_label  = keras.Input(fixed_image_shape,  name="fixed_label")
        inputs=[moving_image, fixed_image, input_moving_label, input_fixed_label]
        
        # extract ddf
        ddf = model.outputs[0]
        
        # warp the moving image with the transformer using network-predicted ddf
        moved_image = spatial_transformer([moving_image, ddf])
        moved_label = spatial_transformer([input_moving_label, ddf])
        fixed_label = spatial_transformer([input_moving_label, ddf])*0 + input_fixed_label
        
        outputs = [moved_image, ddf, moved_label]
        
        registration_model = keras.Model(inputs=inputs, outputs=outputs)
        
        print('\nRegistration network inputs and outputs:')
        
        print('    input shape: ', ', '.join([str(t.shape) for t in registration_model.inputs]))
        print('    output shape:', ', '.join([str(t.shape) for t in registration_model.outputs]))
    
        losses = [vxm.losses.NCC().loss, vxm.losses.Grad('l2').loss, vxm.losses.Dice().loss]
        loss_weights = [0, lambda_param, 1]
        
        registration_model.compile(optimizer='Adam', loss=losses, loss_weights=loss_weights)
        
        # =============================================================================
        # Training loop
        # =============================================================================
        
        f_path = 'nifti_data/train'#r'nifti_data/train'
        #f_path = '/content/drive/MyDrive/nifti_data_smol/train'
        
        val_path = 'nifti_data/val'#r'nifti_data/val'
        #val_path = '/content/drive/MyDrive/nifti_data_smol/val'
        
        model_save_path = 'DICE_2D_WS_US-MRI_NCC_checkpoints'#r'voxelmorph_model_checkpoints'
        if not os.path.exists(model_save_path):
            os.mkdir(model_save_path)
        
        batch_size = 32 # decrease this is you are running out of RAM
    
        # ------------------------------
        # Resume from last checkpoint
        # ------------------------------
        if latest_weights is not None:
            print(f"Resuming from {latest_weights} (trial {last_trial})")
            registration_model.load_weights(latest_weights)
            start_trial = last_trial + 1
            trial_num = last_trial
            # Load arrays
            val_dice = np.load(os.path.join(model_save_path, "val_dice.npy")).tolist()
            transformer_losses = np.load(os.path.join(model_save_path, "transformer_losses.npy")).tolist()
            losses = np.load(os.path.join(model_save_path, "losses.npy")).tolist()
            conv2d_losses = np.load(os.path.join(model_save_path, "conv2d_losses.npy")).tolist()
            
            # Index from 0 to trial_num
            val_dice = val_dice[:trial_num + 1]
            transformer_losses = transformer_losses[:trial_num + 1]
            losses = losses[:trial_num + 1]
            conv2d_losses = conv2d_losses[:trial_num + 1]
        else:
            print("No previous weights found, starting fresh")
            start_trial = 0
            val_dice = []
            losses = []
            transformer_losses = []
            conv2d_losses = []
        
        train_gen = train_generator_(f_path, batch_size, moving_image_shape, fixed_image_shape, with_label_inputs=True)
        
        num_trials = 1024
         
        for trial in range(start_trial, num_trials):
            
            print(f'\nTrial {trial} / {num_trials-1}:')
        
            hist = registration_model.fit(train_gen, epochs=1, steps_per_epoch=32, verbose=1);#32, verbose=1);
        
            dice_scores = []
        
            count = 0
            label_num = 0
            val_gen = test_generator(val_path, 4, moving_image_shape, fixed_image_shape, start_index=None, end_index=None, label_num=label_num, with_label_inputs=True)
            while True:
                try:
                    (val_inputs, val_outputs) = next(val_gen)
                    moving_images_val, fixed_images_val, moving_labels_val, fixed_labels_val = val_inputs
                    fixed_images_val, zero_phis_val, fixed_labels_val = val_outputs
                    _, ddf_val, _ = registration_model.predict((moving_images_val, fixed_images_val, moving_labels_val, fixed_labels_val), verbose=0)
        
                    moved_labels_val = spatial_transformer([moving_labels_val, ddf_val])
                    moved_images_val = spatial_transformer([moving_images_val, ddf_val])
    
                    if count ==0 and Verbose:
                        # VISUALLY CHECK FIXED, MOVING MOVED IMAGES
                        print(moving_images_val.shape)
                        plt.subplot(1,3,1)
                        plt.imshow(tf.squeeze(moved_images_val[0]), cmap='gray')
                        plt.title('Moved Image')
        
                        plt.subplot(1,3,2)
                        plt.imshow(tf.squeeze(moving_images_val[0]), cmap='gray')
                        plt.title('Moving Image')
        
                        plt.subplot(1,3,3)
                        plt.imshow(tf.squeeze(fixed_images_val[0]), cmap='gray')
                        plt.title('Fixed Image')
                        plt.show()
    
                        # VISUALLY CHECK LABELS
                        plt.subplot(1,3,1)
                        plt.imshow(tf.squeeze(moved_labels_val[0]), cmap='gray')
                        plt.title('Moved Label')
                        plt.subplot(1,3,2)
                        plt.imshow(tf.squeeze(moving_labels_val[0]), cmap='gray')
                        plt.title('Moving Label')
                        plt.subplot(1,3,3)
                        plt.imshow(tf.squeeze(fixed_labels_val[0]), cmap='gray')
                        plt.title('Fixed Label')
                        plt.show()
    
                        val_pred = tuple(ddf_val)
                        flow = val_pred[1].squeeze()[::3,::3]
                        print("ddf_val shape: ", ddf_val.shape)
                        
                        ddf_flow = np.stack([-ddf_val[0][..., 1], -ddf_val[0][..., 0]], axis=-1)
                        ne.plot.flow([ddf_flow], width=5);
                        count += 1
        
                    dice_score = np.array(-1.0 * vxm.losses.Dice().loss(tf.convert_to_tensor(moved_labels_val, dtype='float32'), tf.convert_to_tensor(fixed_labels_val, dtype='float32')))
                    dice_scores.append(dice_score)
                    print(".", end='')
                except (IndexError, StopIteration) as e:
                      break  
                    
            val_dice.append(np.mean(dice_scores))
            if trial == 0:
                losses.append(0)
            else:
                losses.append(hist.history["loss"][0])
            transformer_losses.append(hist.history["transformer_loss"][0])
            conv_loss_key = next(
                (key for key in hist.history.keys() 
                 if key.startswith('conv2d_') and key.endswith('_loss')),
                None
            )
            if trial == 0:
                conv2d_losses.append(0)
            else:
                conv2d_losses.append(hist.history[conv_loss_key][0])
            if Verbose:
                plt.figure(figsize=(12, 4))
                plt.subplot(1, 3, 1)
                plt.subplots_adjust(wspace=0.5) 
                
                plt.plot(losses, label="Loss")
                plt.plot(transformer_losses, label="Transformer Loss")
                plt.xlabel('Trials')
                plt.ylabel('Losses')
                plt.legend()
        
                plt.subplot(1, 3, 2)
                plt.plot(conv2d_losses, label="Conv2D Loss")
                plt.xlabel('Trials')
                plt.ylabel('Losses')
                plt.legend()
        
                plt.subplot(1, 3, 3)
                plt.plot(val_dice, 'r')
                plt.xlabel('Trials')
                plt.ylabel('Dice')
                plt.savefig(r'voxelmorph_val_dice_1.png')
                plt.show()
            print('    Validation Dice: ', np.mean(dice_scores))
    
            np.save(os.path.join(model_save_path, "losses.npy"), np.array(losses))
            np.save(os.path.join(model_save_path, "transformer_losses.npy"), np.array(transformer_losses))
            np.save(os.path.join(model_save_path, "conv2d_losses.npy"), np.array(conv2d_losses))
            np.save(os.path.join(model_save_path, "val_dice.npy"), np.array(val_dice))
            print("Training history saved.")
            
            if trial % 8 == 0:
              print("Saving Weights:")
              save_path = os.path.join(model_save_path, f"weights_trial_{trial}.weights.h5")
              registration_model.save_weights(save_path)
              print(f"Weights saved to {save_path}")


In [None]:
train_model(Verbose=True)

In [None]:
train_model(0, '2D_WS_US-MRI_NCC_checkpoints/weights_trial_0.weights.h5', Verbose=True)

In [None]:
train_model(Verbose=True)