# Medical image registration

In [1]:
import numpy as np
from dipy.viz import regtools
import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow.keras.layers as layers
from dipy.io.image import load_nifti

2023-12-10 17:53:35.738542: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-12-10 17:53:35.758455: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-10 17:53:35.758478: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-10 17:53:35.759051: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-10 17:53:35.762588: I tensorflow/core/platform/cpu_feature_guar

In [2]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        tf.print(e)

# This is telling the default strategy for multi-GPUs
strategy = tf.distribute.MirroredStrategy()

# setting batch size. 1 per each gpu in this case.
BATCH_SIZE_PER_REPLICA = 1
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


2023-12-10 17:53:36.529329: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-12-10 17:53:36.548863: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-12-10 17:53:36.548955: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-

## Cross correlation loss for same modality registration

In [3]:
def ncc(y_true, y_pred):
    eps = tf.constant(1e-7, 'float32')
    ndim = len(tf.keras.backend.int_shape(y_true))

    y_true_mean = tf.reduce_mean(y_true, axis=range(1, ndim-1),
                                  keepdims=True)
    y_pred_mean = tf.reduce_mean(y_pred, axis=range(1, ndim-1),
                                  keepdims=True)

    y_true_std = tf.math.reduce_std(y_true, axis=range(1, ndim-1),
                                    keepdims=True)
    y_pred_std = tf.math.reduce_std(y_pred, axis=range(1, ndim-1),
                                    keepdims=True)

    y_true_hat = (y_true - y_true_mean) / (y_true_std + eps)
    y_pred_hat = (y_pred - y_pred_mean) / (y_pred_std + eps)

    return -tf.reduce_mean(y_true_hat * y_pred_hat)

In [4]:
def mutualInformation(bin_centers,
                      sigma_ratio=0.5,    # soft binning
                      max_clip=1,
                      crop_background=False, 
                      local_mi=False,
                      patch_size=1):
    
    if local_mi:
        return localMutualInformation(bin_centers, vol_size=(128, 128, 128), sigma_ratio=sigma_ratio, max_clip=max_clip, patch_size=patch_size)

    else:
        return globalMutualInformation(bin_centers, sigma_ratio, max_clip, crop_background)

In [5]:
def localMutualInformation(bin_centers,
                      vol_size=(128, 128, 128),
                      sigma_ratio=0.5,
                      max_clip=1,
                      patch_size=1):

    vol_bin_centers = K.variable(bin_centers)
    num_bins = len(bin_centers)
    sigma = np.mean(np.diff(bin_centers))*sigma_ratio

    preterm = K.variable(1 / (2 * np.square(sigma)))

    def local_mi(y_true, y_pred):
        y_pred = K.clip(y_pred, 0, max_clip)
        y_true = K.clip(y_true, 0, max_clip)

        on = [1, 1, 1, 1, num_bins]
        vbc = K.reshape(vol_bin_centers, on)
        
        # compute padding sizes
        x, y, z = vol_size
        x_r = -x % patch_size
        y_r = -y % patch_size
        z_r = -z % patch_size
        pad_dims = [[0,0]]
        pad_dims.append([x_r//2, x_r - x_r//2])
        pad_dims.append([y_r//2, y_r - y_r//2])
        pad_dims.append([z_r//2, z_r - z_r//2])
        pad_dims.append([0,0])
        padding = tf.constant(pad_dims)

        I_a = K.exp(- preterm * K.square(tf.pad(y_true, padding, 'CONSTANT')  - vbc))
        I_a /= K.sum(I_a, -1, keepdims=True)

        I_b = K.exp(- preterm * K.square(tf.pad(y_pred, padding, 'CONSTANT')  - vbc))
        I_b /= K.sum(I_b, -1, keepdims=True)

        I_a_patch = tf.reshape(I_a, [(x+x_r)//patch_size, patch_size, (y+y_r)//patch_size, patch_size, (z+z_r)//patch_size, patch_size, num_bins])
        I_a_patch = tf.transpose(I_a_patch, [0, 2, 4, 1, 3, 5, 6])
        I_a_patch = tf.reshape(I_a_patch, [-1, patch_size**3, num_bins])

        I_b_patch = tf.reshape(I_b, [(x+x_r)//patch_size, patch_size, (y+y_r)//patch_size, patch_size, (z+z_r)//patch_size, patch_size, num_bins])
        I_b_patch = tf.transpose(I_b_patch, [0, 2, 4, 1, 3, 5, 6])
        I_b_patch = tf.reshape(I_b_patch, [-1, patch_size**3, num_bins])

        I_a_permute = K.permute_dimensions(I_a_patch, (0,2,1))
        pab = K.batch_dot(I_a_permute, I_b_patch) 
        pab /= patch_size**3
        pa = tf.reduce_mean(I_a_patch, 1, keepdims=True)
        pb = tf.reduce_mean(I_b_patch, 1, keepdims=True)
        
        papb = K.batch_dot(K.permute_dimensions(pa, (0,2,1)), pb) + K.epsilon()
        mi = K.mean(K.sum(K.sum(pab * K.log(pab/papb + K.epsilon()), 1), 1))

        return mi

    def loss(y_true, y_pred):
        return -local_mi(y_true, y_pred)

    return loss

def globalMutualInformation(bin_centers,
                      sigma_ratio=0.5,
                      max_clip=1,
                      crop_background=False):

    vol_bin_centers = K.variable(bin_centers)
    num_bins = len(bin_centers)
    sigma = np.mean(np.diff(bin_centers))*sigma_ratio

    preterm = K.variable(1 / (2 * np.square(sigma)))

    def mi(y_true, y_pred):

        y_pred = K.clip(y_pred, 0, max_clip)
        y_true = K.clip(y_true, 0, max_clip)

        if crop_background:
            thresh = 0.0001
            padding_size = 20
            filt = tf.ones([padding_size, padding_size, padding_size, 1, 1])

            smooth = tf.nn.conv3d(y_true, filt, [1, 1, 1, 1, 1], "SAME")
            mask = smooth > thresh
            y_pred = tf.boolean_mask(y_pred, mask)
            y_true = tf.boolean_mask(y_true, mask)
            y_pred = K.expand_dims(K.expand_dims(y_pred, 0), 2)
            y_true = K.expand_dims(K.expand_dims(y_true, 0), 2)

        else:
            y_true = K.reshape(y_true, (-1, K.prod(K.shape(y_true)[1:])))
            y_true = K.expand_dims(y_true, 2)
            y_pred = K.reshape(y_pred, (-1, K.prod(K.shape(y_pred)[1:])))
            y_pred = K.expand_dims(y_pred, 2)
        
        nb_voxels = tf.cast(K.shape(y_pred)[1], tf.float32)

        o = [1, 1, np.prod(vol_bin_centers.get_shape().as_list())]
        vbc = K.reshape(vol_bin_centers, o)

        I_a = K.exp(- preterm * K.square(y_true  - vbc))
        I_a /= K.sum(I_a, -1, keepdims=True)

        I_b = K.exp(- preterm * K.square(y_pred  - vbc))
        I_b /= K.sum(I_b, -1, keepdims=True)

        I_a_permute = K.permute_dimensions(I_a, (0,2,1))
        pab = K.batch_dot(I_a_permute, I_b)
        pab /= nb_voxels
        pa = tf.reduce_mean(I_a, 1, keepdims=True)
        pb = tf.reduce_mean(I_b, 1, keepdims=True)
        
        papb = K.batch_dot(K.permute_dimensions(pa, (0,2,1)), pb) + K.epsilon()
        mi = K.sum(K.sum(pab * K.log(pab/papb + K.epsilon()), 1), 1)

        return mi

    def loss(y_true, y_pred):
        return -mi(y_true, y_pred)

    return loss


def gradient_loss(phi, norm=2):
    di = tf.abs(phi[:, 1:, :, :, :] - phi[:, :-1, :, :, :])
    dj = tf.abs(phi[:, :, 1:, :, :] - phi[:, :, :-1, :, :])
    dk = tf.abs(phi[:, :, :, 1:, :] - phi[:, :, :, :-1, :])

    loss = tf.reduce_mean(di) + tf.reduce_mean(dj) + tf.reduce_mean(dk)
    if norm == 2:
        loss = tf.reduce_mean(di**2) + tf.reduce_mean(dj**2) + tf.reduce_mean(dk**2)    
    return loss

def regular_grid_3d(depth, height, width):
    i = tf.linspace(-1.0, 1.0, depth)
    j = tf.linspace(-1.0, 1.0, height)
    k = tf.linspace(-1.0, 1.0, width)

    I, J, K = tf.meshgrid(i, j, k, indexing='ij')

    grid = tf.stack([I, J, K], axis=-1)
    return grid

def grid_sample_3d(moving, grid):
    nb, nd, nh, nw, nc = tf.shape(moving)

    i = grid[..., 0]  # shape (N, D, H, W)
    j = grid[..., 1]
    k = grid[..., 2]
    i = tf.cast(i, 'float32')
    j = tf.cast(j, 'float32')
    k = tf.cast(k, 'float32')

    # Scale i, j and k from [-1.0, 1.0] to [0, D], [0, H] and [0, W] respectively.
    i = (i + 1.0) * 0.5 * tf.cast(nd-1, 'float32')
    j = (j + 1.0) * 0.5 * tf.cast(nh-1, 'float32')
    k = (k + 1.0) * 0.5 * tf.cast(nw-1, 'float32')

    i_max = tf.cast(nd - 1, 'int32')
    j_max = tf.cast(nh - 1, 'int32')
    k_max = tf.cast(nw - 1, 'int32')
    zero = tf.constant(0, 'int32')

    # The value at (i, j, k) is a weighted average of the values at the
    # eight nearest integer locations: (i0, j0, k0), (i0, j0, k1), (i0, j1, k0),
    # (i0, j1, k1), (i1, j0, k0), (i1, j0, k1), (i1, j1, k0) and (i1, j1, k1)
    # where i0 = floor(i), i1 = ceil(i).
    i0 = tf.cast(tf.floor(i), 'int32')
    i1 = i0 + 1
    j0 = tf.cast(tf.floor(j), 'int32')
    j1 = j0 + 1
    k0 = tf.cast(tf.floor(k), 'int32')
    k1 = k0 + 1

    # Make sure indices are within the boundaries of the image.
    i0 = tf.clip_by_value(i0, zero, i_max)
    i1 = tf.clip_by_value(i1, zero, i_max)
    j0 = tf.clip_by_value(j0, zero, j_max)
    j1 = tf.clip_by_value(j1, zero, j_max)
    k0 = tf.clip_by_value(k0, zero, k_max)
    k1 = tf.clip_by_value(k1, zero, k_max)

    # Collect indices of the four corners.
    b = tf.ones_like(i0) * tf.reshape(tf.range(nb), [nb, 1, 1, 1])
    idx_a = tf.stack([b, i1, j0, k0], axis=-1)  # all front-top-left corners
    idx_b = tf.stack([b, i1, j1, k0], axis=-1)  # all front-bottom-left corners
    idx_c = tf.stack([b, i1, j0, k1], axis=-1)  # all front-top-right corners
    idx_d = tf.stack([b, i1, j1, k1], axis=-1)  # all front-bottom-right corners
    idx_e = tf.stack([b, i0, j0, k0], axis=-1)  # all back-top-left corners
    idx_f = tf.stack([b, i0, j1, k0], axis=-1)  # all back-bottom-left corners
    idx_g = tf.stack([b, i0, j0, k1], axis=-1)  # all back-top-right corners
    idx_h = tf.stack([b, i0, j1, k1], axis=-1)  # all back-bottom-right corners
    # shape (N, D, H, W, 3)

    # Collect values at the corners.
    moving_a = tf.gather_nd(moving, idx_a)  # all front-top-left values
    moving_b = tf.gather_nd(moving, idx_b)  # all front-bottom-left values
    moving_c = tf.gather_nd(moving, idx_c)  # all front-top-right values
    moving_d = tf.gather_nd(moving, idx_d)  # all front-bottom-right values
    moving_e = tf.gather_nd(moving, idx_e)  # all back-top-left values
    moving_f = tf.gather_nd(moving, idx_f)  # all back-bottom-left values
    moving_g = tf.gather_nd(moving, idx_g)  # all back-top-right values
    moving_h = tf.gather_nd(moving, idx_h)  # all back-bottom-right values
    # shape (N, D, H, W, C)

    i0_f = tf.cast(i0, 'float32')
    i1_f = tf.cast(i1, 'float32')
    j0_f = tf.cast(j0, 'float32')
    j1_f = tf.cast(j1, 'float32')
    k0_f = tf.cast(k0, 'float32')
    k1_f = tf.cast(k1, 'float32')

    # Calculate the weights.
    wa = tf.expand_dims((i - i0_f) * (j1_f - j) * (k1_f - k), axis=-1)
    wb = tf.expand_dims((i - i0_f) * (j - j0_f) * (k1_f - k), axis=-1)
    wc = tf.expand_dims((i - i0_f) * (j1_f - j) * (k - k0_f), axis=-1)
    wd = tf.expand_dims((i - i0_f) * (j - j0_f) * (k - k0_f), axis=-1)
    we = tf.expand_dims((i1_f - i) * (j1_f - j) * (k1_f - k), axis=-1)
    wf = tf.expand_dims((i1_f - i) * (j - j0_f) * (k1_f - k), axis=-1)
    wg = tf.expand_dims((i1_f - i) * (j1_f - j) * (k - k0_f), axis=-1)
    wh = tf.expand_dims((i1_f - i) * (j - j0_f) * (k - k0_f), axis=-1)

    # Calculate the weighted sum.
    moved = tf.add_n([wa * moving_a, wb * moving_b, wc * moving_c,
                      wd * moving_d, we * moving_e, wf * moving_f,
                      wg * moving_g, wh * moving_h])
    return moved


In [6]:
def voxelmorph1(input_shape=(128, 128, 1)):
    moving = layers.Input(shape=input_shape, name='moving')
    static = layers.Input(shape=input_shape, name='static')
    x_in = layers.concatenate([static, moving], axis=-1)

    # encoder
    x1 = layers.Conv3D(16, kernel_size=3, strides=2, padding='same',
                        kernel_initializer='he_normal')(x_in)
    x1 = layers.LeakyReLU(alpha=0.2)(x1)  # 16

    x2 = layers.Conv3D(32, kernel_size=3, strides=2, padding='same',
                        kernel_initializer='he_normal')(x1)
    x2 = layers.LeakyReLU(alpha=0.2)(x2)  # 8

    x3 = layers.Conv3D(32, kernel_size=3, strides=2, padding='same',
                        kernel_initializer='he_normal')(x2)
    x3 = layers.LeakyReLU(alpha=0.2)(x3)  # 4

    x4 = layers.Conv3D(32, kernel_size=3, strides=2, padding='same',
                        kernel_initializer='he_normal')(x3)
    x4 = layers.LeakyReLU(alpha=0.2)(x4)  # 2

    # decoder [32, 32, 32, 32, 8, 8]
    x = layers.Conv3D(32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal')(x4)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.UpSampling3D(size=2)(x)  # 4
    x = layers.concatenate([x, x3], axis=-1)  # 4

    x = layers.Conv3D(32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.UpSampling3D(size=2)(x)  # 8
    x = layers.concatenate([x, x2], axis=-1)  # 8

    x = layers.Conv3D(32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.UpSampling3D(size=2)(x)  # 16
    x = layers.concatenate([x, x1], axis=-1)  # 16

    x = layers.Conv3D(32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.Conv3D(8, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)  # 16

    x = layers.UpSampling3D(size=2)(x)  # 32
    x = layers.concatenate([x, x_in], axis=-1)
    x = layers.Conv3D(8, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)  # 32

    kernel_initializer = tf.keras.initializers.RandomNormal(mean=0.0,
                                                            stddev=1e-5)
    deformation = layers.Conv3D(3, kernel_size=3, strides=1,
                                padding='same',
                                kernel_initializer=kernel_initializer)(x)

    nb, nd, nh, nw, nc = tf.shape(deformation)

    # Regular grid.
    grid = regular_grid_3d(nd, nh, nw)  # shape (D, H, W, 2)
    grid = tf.expand_dims(grid, axis=0)  # shape (1, D, H, W, 2)
    multiples = tf.stack([nb, 1, 1, 1, 1])
    grid = tf.tile(grid, multiples)

    # Compute the new sampling grid.
    grid_new = grid + deformation
    grid_new = tf.clip_by_value(grid_new, -1, 1)

    # Sample the moving image using the new sampling grid.
    moved = grid_sample_3d(moving, grid_new)

    model = tf.keras.Model(inputs=[static, moving],
                            outputs=[moved, deformation], name='voxelmorph1')
    return model

@tf.function
def train_step(model, moving, moving2, static, criterion, optimizer):
    nb, nd, nh, nw, nc = tf.keras.backend.int_shape(moving)  # moving.shape

    # # Repeat the static image along the batch dim.
    # multiples = tf.constant([nb, 1, 1, 1, 1], tf.int32)
    # static = tf.tile(static, multiples)

    # Define the GradientTape context for automatic differentiation.
    with tf.GradientTape() as tape:
        # Get the deformation field
        # inputs = tf.concat([moving, static], axis=-1)
        moved, deformation = model({'moving': moving, 'static': static})

        moved2, deformation2 = model({'moving': moving2, 'static': static})

        # Compute the loss.
        # loss = criterion(moved, static)
        loss = -1 * criterion(static, moved) + 1 * gradient_loss(deformation) - 1 * criterion(static, moved2) + 1 * gradient_loss(deformation2) - 1 * criterion(moved, moved2)
        
    # Compute gradients.
    grads = tape.gradient(loss, model.trainable_variables)
    # Update the trainable parameters.
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

@tf.function
def test_step(model, moving, moving2, static, criterion):
    nb, nd, nh, nw, nc = tf.keras.backend.int_shape(moving)  # moving.shape

    # Repeat the static image along the batch dim.
    # multiples = tf.constant([nb, 1, 1, 1, 1], tf.int32)
    # static = tf.tile(static, multiples)

    # Get the deformation field.
    # inputs = tf.concat([moving, static], axis=-1)
    moved, deformation = model({'moving': moving, 'static': static}, training=False)
    moved2, deformation2 = model({'moving': moving2, 'static': static}, training=False)

    # Compute the loss.
    # loss = criterion(moved, static)
    loss = -1 * criterion(static, moved) + 1 * gradient_loss(deformation) - 1 * criterion(static, moved2) + 1 * gradient_loss(deformation2) - 1 * criterion(moved, moved2)
    return loss

def plot_images(model, moving, moving2, static):
    nb, nd, nh, nw, nc = moving.shape

    # Repeat the static image along the batch dim.
    # multiples = tf.constant([nb, 1, 1, 1, 1], tf.int32)
    # static = tf.tile(static, multiples)

    moved, deformation = model({'moving': moving, 'static': static}, training=False)
    moved2, deformation2 = model({'moving': moving2, 'static': static}, training=False)

    tf.print(deformation.shape, tf.reduce_max(deformation), tf.reduce_min(deformation), tf.reduce_mean(deformation))
    tf.print(deformation2.shape, tf.reduce_max(deformation2), tf.reduce_min(deformation2), tf.reduce_mean(deformation2))

    deformation = deformation.numpy()
    moved = moved.numpy().squeeze(axis=-1) * 255.0
    # moved = moved.astype(np.uint8)[:,:,nh//2,...]
    moving = moving.numpy().squeeze(axis=-1) * 255.0
    # moving = moving.astype(np.uint8)[:,:,nh//2,...]
    static = static.numpy().squeeze(axis=-1) * 255.0
    # static = static.astype(np.uint8)[:,:,nh//2,...]

    deformation2 = deformation2.numpy()
    moved2 = moved2.numpy().squeeze(axis=-1) * 255.0
    # moved = moved.astype(np.uint8)[:,:,nh//2,...]
    moving2 = moving2.numpy().squeeze(axis=-1) * 255.0
    # moving = moving.astype(np.uint8)[:,:,nh//2,...]

    # # Plot images.
    # fig = plt.figure(figsize=(3 * 1.7, nb * 1.7))
    # titles_list = ['Static', 'Moved', 'Moving']
    # images_list = [static, moved, moving]

    # moved = moving
    for i in range(nb):
        # for j in range(3):
        #     ax = fig.add_subplot(nb, 3, i * 3 + j + 1)
        #     if i == 0:
        #         ax.set_title(titles_list[j], fontsize=20)
        #     ax.set_axis_off()
        #     ax.imshow(images_list[j][i], cmap='gray')
        regtools.overlay_slices(static[i], moved[i], None, 0,
                                "%d Static"%i, "%d Transformed"%i,
                                "%d_0.png" % (i))
        regtools.overlay_slices(static[i], moved[i], None, 1,
                                "%d Static"%i, "%d Transformed"%i,
                                "%d_1.png" % (i))
        regtools.overlay_slices(static[i], moved[i], None, 2,
                                "%d Static"%i, "%d Transformed"%i,
                                "%d_2.png" % (i))
        
        regtools.overlay_slices(static[i], moved2[i], None, 0,
                                "%d Static"%i, "%d Transformed"%i,
                                "%d_0.png" % (i))
        regtools.overlay_slices(static[i], moved2[i], None, 1,
                                "%d Static"%i, "%d Transformed"%i,
                                "%d_1.png" % (i))
        regtools.overlay_slices(static[i], moved2[i], None, 2,
                                "%d Static"%i, "%d Transformed"%i,
                                "%d_2.png" % (i))
        d = {'static': static[i], 'moved': moved[i], 'moving': moving[i], 'deformation': deformation[i],
             'moved2': moved2[i], 'moving2': moving2[i], 'deformation2': deformation2[i]}
        np.save('drive/My Drive/DIPY/t1_t2/results/sample%d.npy'%(i), d)

    # plt.tight_layout()
    # plt.show()

def train_gen():
    import random
    random.shuffle(idx_list)
    for file_idx in idx_list:
        ap, _ = load_nifti(AP_list[file_idx])
        ap = np.expand_dims(np.interp(ap, (ap.min(), ap.max()), (0, 1)), -1)
        pa, _ = load_nifti(PA_list[file_idx])
        pa = np.expand_dims(np.interp(pa, (pa.min(), pa.max()), (0, 1)), -1)
        t1, _ = load_nifti(T1_list[file_idx])
        t1 = np.interp(t1, (t1.min(), t1.max()), (0, 1))
        yield ap, pa, t1

def test_gen():
    for file_idx in range(len(AP_list)-test_n, len(AP_list)):
        ap, _ = load_nifti(AP_list[file_idx])
        ap = np.expand_dims(np.interp(ap, (ap.min(), ap.max()), (0, 1)), -1)
        pa, _ = load_nifti(PA_list[file_idx])
        pa = np.expand_dims(np.interp(pa, (pa.min(), pa.max()), (0, 1)), -1)
        t1, _ = load_nifti(T1_list[file_idx])
        t1 = np.interp(t1, (t1.min(), t1.max()), (0, 1))
        yield ap, pa, t1


In [7]:
def main(args):
    train_ds = tf.data.Dataset.from_generator(
    train_gen,
    output_types=(tf.float32, tf.float32, tf.float32),
    output_shapes=((128, 128, 128, 1), (128, 128, 128, 1), (128, 128, 128, 1))
    )
    train_ds = train_ds.batch(BATCH_SIZE)

    test_ds = tf.data.Dataset.from_generator(
    test_gen,
    output_types=(tf.float32, tf.float32, tf.float32),
    output_shapes=((128, 128, 128, 1), (128, 128, 128, 1), (128, 128, 128, 1))
    )

    test_ds = test_ds.batch(BATCH_SIZE)


    model = voxelmorph1(input_shape = (128, 128, 128, 1))
    # model.load_weights('vxm_dense_brain_T1_3D_mse.h5')

    # Select optimizer and loss function.
    optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)
    # criterion = ncc_new
    bin_centers = np.linspace(0, 1, 32)
    criterion = mutualInformation(bin_centers, sigma_ratio=0.5, max_clip=1, crop_background=False, local_mi=False, patch_size=12)

    # Define the metrics to track training and testing losses.
    m_train = tf.keras.metrics.Mean(name='loss_train')
    m_test = tf.keras.metrics.Mean(name='loss_test')
    
    # Train and evaluate the model.
    for epoch in range(args.epochs):
        m_train.reset_states()
        m_test.reset_states()

        for moving, moving2, static in train_ds:
            loss_train = train_step(model, moving, moving2, static, criterion,
                                    optimizer)
            m_train.update_state(loss_train)

        for moving, moving2, static in test_ds:
            loss_test = test_step(model, moving, moving2, static, criterion)
            m_test.update_state(loss_test)

        model.save_weights('%d.h5'%epoch)
        tf.print('Epoch: %3d/%d\tTrain Loss: %.6f\tTest Loss: %.6f'
              % (epoch + 1, args.epochs, m_train.result(), m_test.result()))
    tf.print('\n')

    # Save the trained model.
    if args.save_model:
        model.save_weights('voxelmorph-weights.h5')


In [8]:
"""# Moved vs. Static"""

if __name__ == '__main__':
    AP_list = []
    PA_list = []
    T1_list = []
    idx_list = []
    import os
    for subj_n in os.listdir('data/AP'):
        T1_list.append('data/T1/'+subj_n)
        AP_list.append('data/AP/'+subj_n)
        PA_list.append('data/PA/'+subj_n)

    test_n = len(AP_list) // 10
    idx_list = [i for i in range(len(AP_list) - test_n)]
    # tf.print(AP_list[-test_n:])
    class Args():
        batch_size = BATCH_SIZE
        epochs = 100
        lr = 0.001
        # label = 1  # which digit images to train on?
        # num_samples = 5  # number of sample results to show
        save_model = True
    
    args = Args()
    main(args)


2023-12-10 17:53:40.632777: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8904
2023-12-10 17:53:44.508504: I external/local_xla/xla/service/service.cc:168] XLA service 0x7fd382c134a0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-12-10 17:53:44.508525: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA RTX A6000, Compute Capability 8.6
2023-12-10 17:53:44.511065: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1702248824.574483  143256 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Epoch:   1/100	Train Loss: 0.007069	Test Loss: 0.001904
Epoch:   2/100	Train Loss: 0.000998	Test Loss: 0.000653
Epoch:   3/100	Train Loss: 0.758103	Test Loss: 0.019494
Epoch:   4/100	Train Loss: 0.008574	Test Loss: 0.005129
Epoch:   5/100	Train Loss: 0.003580	Test Loss: 0.002857
Epoch:   6/100	Train Loss: 0.002631	Test Loss: 0.002028
Epoch:   7/100	Train Loss: 0.001535	Test Loss: 0.000834
Epoch:   8/100	Train Loss: 0.001465	Test Loss: 0.000926
Epoch:   9/100	Train Loss: 0.000899	Test Loss: 0.000311
Epoch:  10/100	Train Loss: 0.000265	Test Loss: 0.000206
Epoch:  11/100	Train Loss: 0.000201	Test Loss: 0.000159
Epoch:  12/100	Train Loss: 0.000147	Test Loss: 0.000129
Epoch:  13/100	Train Loss: 0.000124	Test Loss: 0.000108
Epoch:  14/100	Train Loss: 0.000104	Test Loss: 0.000095
Epoch:  15/100	Train Loss: 0.000092	Test Loss: 0.000096
Epoch:  16/100	Train Loss: 0.000083	Test Loss: 0.000109
Epoch:  17/100	Train Loss: 0.000077	Test Loss: 0.000105
Epoch:  18/100	Train Loss: 0.000073	Test Loss: 0