In [1]:
# Imports
import os
import numpy as np
import dipy
import nibabel as nib
import matplotlib.pyplot as plt
from dipy.io.image import load_nifti, save_nifti
from dipy.viz import regtools
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow.keras.backend as K

2023-12-10 16:56:14.306313: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-12-10 16:56:14.325871: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-10 16:56:14.325888: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-10 16:56:14.326419: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-10 16:56:14.329878: I tensorflow/core/platform/cpu_feature_guar

In [2]:
HCP_folder = '/Users/sreekar/mounts/GRG/data/HCP/'
HCP_folder = 'data'

In [3]:
os.listdir(HCP_folder)

['PA', 'T1', '.DS_Store', 'AP']

In [4]:
# Data files
T1_file = HCP_folder + "/T1/102109.nii.gz"

In [5]:
T1, T1_affine = dipy.io.image.load_nifti(T1_file)

In [6]:
T1.shape

(128, 128, 128, 1)

In [7]:
def mutualInformation(bin_centers,
                      sigma_ratio=0.5,    # sigma for soft MI. If not provided, it will be half of a bin length
                      max_clip=1,
                      crop_background=False, # crop_background should never be true if local_mi is True
                      local_mi=False,
                      patch_size=1):
    """
    mutual information for image-image pairs.
    Author: Courtney Guo. See thesis https://dspace.mit.edu/handle/1721.1/123142
    """
    # print("vxm:mutual information loss is experimental.", file=sts.stderr)
    
    if local_mi:
        return localMutualInformation(bin_centers, vol_size=(128, 128, 128), sigma_ratio=sigma_ratio, max_clip=max_clip, patch_size=patch_size)

    else:
        return globalMutualInformation(bin_centers, sigma_ratio, max_clip, crop_background)


def localMutualInformation(bin_centers,
                      vol_size=(128, 128, 128),
                      sigma_ratio=0.5,
                      max_clip=1,
                      patch_size=1):
    """
    Local Mutual Information for image-image pairs
    # vol_size is something like (160, 192, 224)  
    This function assumes that y_true and y_pred are both (batch_sizexheightxwidthxdepthxchan)
    Author: Courtney Guo. See thesis at https://dspace.mit.edu/handle/1721.1/123142
    """
    # print("vxm:mutual information loss is experimental.", file=sts.stderr)

    """ prepare MI. """
    vol_bin_centers = K.variable(bin_centers)
    num_bins = len(bin_centers)
    sigma = np.mean(np.diff(bin_centers))*sigma_ratio

    preterm = K.variable(1 / (2 * np.square(sigma)))

    def local_mi(y_true, y_pred):
        y_pred = K.clip(y_pred, 0, max_clip)
        y_true = K.clip(y_true, 0, max_clip)

        # reshape bin centers to be (1, 1, B)
        o = [1, 1, 1, 1, num_bins]
        vbc = K.reshape(vol_bin_centers, o)
        
        # compute padding sizes
        x, y, z = vol_size
        x_r = -x % patch_size
        y_r = -y % patch_size
        z_r = -z % patch_size
        pad_dims = [[0,0]]
        pad_dims.append([x_r//2, x_r - x_r//2])
        pad_dims.append([y_r//2, y_r - y_r//2])
        pad_dims.append([z_r//2, z_r - z_r//2])
        pad_dims.append([0,0])
        padding = tf.constant(pad_dims)

        # compute image terms
        # num channels of y_true and y_pred must be 1
        I_a = K.exp(- preterm * K.square(tf.pad(y_true, padding, 'CONSTANT')  - vbc))
        I_a /= K.sum(I_a, -1, keepdims=True)

        I_b = K.exp(- preterm * K.square(tf.pad(y_pred, padding, 'CONSTANT')  - vbc))
        I_b /= K.sum(I_b, -1, keepdims=True)

        I_a_patch = tf.reshape(I_a, [(x+x_r)//patch_size, patch_size, (y+y_r)//patch_size, patch_size, (z+z_r)//patch_size, patch_size, num_bins])
        I_a_patch = tf.transpose(I_a_patch, [0, 2, 4, 1, 3, 5, 6])
        I_a_patch = tf.reshape(I_a_patch, [-1, patch_size**3, num_bins])

        I_b_patch = tf.reshape(I_b, [(x+x_r)//patch_size, patch_size, (y+y_r)//patch_size, patch_size, (z+z_r)//patch_size, patch_size, num_bins])
        I_b_patch = tf.transpose(I_b_patch, [0, 2, 4, 1, 3, 5, 6])
        I_b_patch = tf.reshape(I_b_patch, [-1, patch_size**3, num_bins])

        # compute probabilities
        I_a_permute = K.permute_dimensions(I_a_patch, (0,2,1))
        pab = K.batch_dot(I_a_permute, I_b_patch)  # should be the right size now, nb_labels x nb_bins
        pab /= patch_size**3
        pa = tf.reduce_mean(I_a_patch, 1, keepdims=True)
        pb = tf.reduce_mean(I_b_patch, 1, keepdims=True)
        
        papb = K.batch_dot(K.permute_dimensions(pa, (0,2,1)), pb) + K.epsilon()
        mi = K.mean(K.sum(K.sum(pab * K.log(pab/papb + K.epsilon()), 1), 1))

        return mi

    def loss(y_true, y_pred):
        return -local_mi(y_true, y_pred)

    return loss

def globalMutualInformation(bin_centers,
                      sigma_ratio=0.5,
                      max_clip=1,
                      crop_background=False):
    """
    Mutual Information for image-image pairs
    Building from neuron.losses.MutualInformationSegmentation()    
    This function assumes that y_true and y_pred are both (batch_size x height x width x depth x nb_chanels)
    Author: Courtney Guo. See thesis at https://dspace.mit.edu/handle/1721.1/123142
    """
    # print("vxm:mutual information loss is experimental.", file=sts.stderr)

    """ prepare MI. """
    vol_bin_centers = K.variable(bin_centers)
    num_bins = len(bin_centers)
    sigma = np.mean(np.diff(bin_centers))*sigma_ratio

    preterm = K.variable(1 / (2 * np.square(sigma)))

    def mi(y_true, y_pred):
        """ soft mutual info """
        y_pred = K.clip(y_pred, 0, max_clip)
        y_true = K.clip(y_true, 0, max_clip)

        if crop_background:
            # does not support variable batch size
            thresh = 0.0001
            padding_size = 20
            filt = tf.ones([padding_size, padding_size, padding_size, 1, 1])

            smooth = tf.nn.conv3d(y_true, filt, [1, 1, 1, 1, 1], "SAME")
            mask = smooth > thresh
            # mask = K.any(K.stack([y_true > thresh, y_pred > thresh], axis=0), axis=0)
            y_pred = tf.boolean_mask(y_pred, mask)
            y_true = tf.boolean_mask(y_true, mask)
            y_pred = K.expand_dims(K.expand_dims(y_pred, 0), 2)
            y_true = K.expand_dims(K.expand_dims(y_true, 0), 2)

        else:
            # reshape: flatten images into shape (batch_size, heightxwidthxdepthxchan, 1)
            y_true = K.reshape(y_true, (-1, K.prod(K.shape(y_true)[1:])))
            y_true = K.expand_dims(y_true, 2)
            y_pred = K.reshape(y_pred, (-1, K.prod(K.shape(y_pred)[1:])))
            y_pred = K.expand_dims(y_pred, 2)
        
        nb_voxels = tf.cast(K.shape(y_pred)[1], tf.float32)

        # reshape bin centers to be (1, 1, B)
        o = [1, 1, np.prod(vol_bin_centers.get_shape().as_list())]
        vbc = K.reshape(vol_bin_centers, o)
        
        # compute image terms
        I_a = K.exp(- preterm * K.square(y_true  - vbc))
        I_a /= K.sum(I_a, -1, keepdims=True)

        I_b = K.exp(- preterm * K.square(y_pred  - vbc))
        I_b /= K.sum(I_b, -1, keepdims=True)

        # compute probabilities
        I_a_permute = K.permute_dimensions(I_a, (0,2,1))
        pab = K.batch_dot(I_a_permute, I_b)  # should be the right size now, nb_labels x nb_bins
        pab /= nb_voxels
        pa = tf.reduce_mean(I_a, 1, keepdims=True)
        pb = tf.reduce_mean(I_b, 1, keepdims=True)
        
        papb = K.batch_dot(K.permute_dimensions(pa, (0,2,1)), pb) + K.epsilon()
        mi = K.sum(K.sum(pab * K.log(pab/papb + K.epsilon()), 1), 1)

        return mi

    def loss(y_true, y_pred):
        return -mi(y_true, y_pred)

    return loss


In [8]:
def gradient_loss(phi, norm=2):
    di = tf.abs(phi[:, 1:, :, :, :] - phi[:, :-1, :, :, :])
    dj = tf.abs(phi[:, :, 1:, :, :] - phi[:, :, :-1, :, :])
    dk = tf.abs(phi[:, :, :, 1:, :] - phi[:, :, :, :-1, :])

    loss = tf.reduce_mean(di) + tf.reduce_mean(dj) + tf.reduce_mean(dk)
    if norm == 2:
        loss = tf.reduce_mean(di**2) + tf.reduce_mean(dj**2) + tf.reduce_mean(dk**2)    
    return loss

In [9]:
def regular_grid_3d(depth, height, width):
    i = tf.linspace(-1.0, 1.0, depth)
    j = tf.linspace(-1.0, 1.0, height)
    k = tf.linspace(-1.0, 1.0, width)

    I, J, K = tf.meshgrid(i, j, k, indexing='ij')

    grid = tf.stack([I, J, K], axis=-1)
    return grid

In [10]:
def grid_sample_3d(moving, grid):
    nb, nd, nh, nw, nc = tf.shape(moving)

    i = grid[..., 0]  # shape (N, D, H, W)
    j = grid[..., 1]
    k = grid[..., 2]
    i = tf.cast(i, 'float32')
    j = tf.cast(j, 'float32')
    k = tf.cast(k, 'float32')

    # Scale i, j and k from [-1.0, 1.0] to [0, D], [0, H] and [0, W] respectively.
    i = (i + 1.0) * 0.5 * tf.cast(nd-1, 'float32')
    j = (j + 1.0) * 0.5 * tf.cast(nh-1, 'float32')
    k = (k + 1.0) * 0.5 * tf.cast(nw-1, 'float32')

    i_max = tf.cast(nd - 1, 'int32')
    j_max = tf.cast(nh - 1, 'int32')
    k_max = tf.cast(nw - 1, 'int32')
    zero = tf.constant(0, 'int32')

    # The value at (i, j, k) is a weighted average of the values at the
    # eight nearest integer locations: (i0, j0, k0), (i0, j0, k1), (i0, j1, k0),
    # (i0, j1, k1), (i1, j0, k0), (i1, j0, k1), (i1, j1, k0) and (i1, j1, k1)
    # where i0 = floor(i), i1 = ceil(i).
    i0 = tf.cast(tf.floor(i), 'int32')
    i1 = i0 + 1
    j0 = tf.cast(tf.floor(j), 'int32')
    j1 = j0 + 1
    k0 = tf.cast(tf.floor(k), 'int32')
    k1 = k0 + 1

    # Make sure indices are within the boundaries of the image.
    i0 = tf.clip_by_value(i0, zero, i_max)
    i1 = tf.clip_by_value(i1, zero, i_max)
    j0 = tf.clip_by_value(j0, zero, j_max)
    j1 = tf.clip_by_value(j1, zero, j_max)
    k0 = tf.clip_by_value(k0, zero, k_max)
    k1 = tf.clip_by_value(k1, zero, k_max)

    # Collect indices of the four corners.
    b = tf.ones_like(i0) * tf.reshape(tf.range(nb), [nb, 1, 1, 1])
    idx_a = tf.stack([b, i1, j0, k0], axis=-1)  # all front-top-left corners
    idx_b = tf.stack([b, i1, j1, k0], axis=-1)  # all front-bottom-left corners
    idx_c = tf.stack([b, i1, j0, k1], axis=-1)  # all front-top-right corners
    idx_d = tf.stack([b, i1, j1, k1], axis=-1)  # all front-bottom-right corners
    idx_e = tf.stack([b, i0, j0, k0], axis=-1)  # all back-top-left corners
    idx_f = tf.stack([b, i0, j1, k0], axis=-1)  # all back-bottom-left corners
    idx_g = tf.stack([b, i0, j0, k1], axis=-1)  # all back-top-right corners
    idx_h = tf.stack([b, i0, j1, k1], axis=-1)  # all back-bottom-right corners
    # shape (N, D, H, W, 3)

    # Collect values at the corners.
    moving_a = tf.gather_nd(moving, idx_a)  # all front-top-left values
    moving_b = tf.gather_nd(moving, idx_b)  # all front-bottom-left values
    moving_c = tf.gather_nd(moving, idx_c)  # all front-top-right values
    moving_d = tf.gather_nd(moving, idx_d)  # all front-bottom-right values
    moving_e = tf.gather_nd(moving, idx_e)  # all back-top-left values
    moving_f = tf.gather_nd(moving, idx_f)  # all back-bottom-left values
    moving_g = tf.gather_nd(moving, idx_g)  # all back-top-right values
    moving_h = tf.gather_nd(moving, idx_h)  # all back-bottom-right values
    # shape (N, D, H, W, C)

    i0_f = tf.cast(i0, 'float32')
    i1_f = tf.cast(i1, 'float32')
    j0_f = tf.cast(j0, 'float32')
    j1_f = tf.cast(j1, 'float32')
    k0_f = tf.cast(k0, 'float32')
    k1_f = tf.cast(k1, 'float32')

    # Calculate the weights.
    wa = tf.expand_dims((i - i0_f) * (j1_f - j) * (k1_f - k), axis=-1)
    wb = tf.expand_dims((i - i0_f) * (j - j0_f) * (k1_f - k), axis=-1)
    wc = tf.expand_dims((i - i0_f) * (j1_f - j) * (k - k0_f), axis=-1)
    wd = tf.expand_dims((i - i0_f) * (j - j0_f) * (k - k0_f), axis=-1)
    we = tf.expand_dims((i1_f - i) * (j1_f - j) * (k1_f - k), axis=-1)
    wf = tf.expand_dims((i1_f - i) * (j - j0_f) * (k1_f - k), axis=-1)
    wg = tf.expand_dims((i1_f - i) * (j1_f - j) * (k - k0_f), axis=-1)
    wh = tf.expand_dims((i1_f - i) * (j - j0_f) * (k - k0_f), axis=-1)

    # Calculate the weighted sum.
    moved = tf.add_n([wa * moving_a, wb * moving_b, wc * moving_c,
                      wd * moving_d, we * moving_e, wf * moving_f,
                      wg * moving_g, wh * moving_h])
    return moved

In [11]:
def voxelmorph1(input_shape=(320, 320, 1)):
    in_channels = 1
    out_channels = 3
    input_shape = input_shape + (in_channels,)
    moving = layers.Input(shape=input_shape, name='moving')
    static = layers.Input(shape=input_shape, name='static')
    x_in = layers.concatenate([static, moving], axis=-1)

    # encoder
    x1 = layers.Conv3D(16, kernel_size=3, strides=2, padding='same',
                        kernel_initializer='he_normal')(x_in)
    x1 = layers.LeakyReLU(alpha=0.2)(x1)  # 16

    x2 = layers.Conv3D(32, kernel_size=3, strides=2, padding='same',
                        kernel_initializer='he_normal')(x1)
    x2 = layers.LeakyReLU(alpha=0.2)(x2)  # 8

    x3 = layers.Conv3D(32, kernel_size=3, strides=2, padding='same',
                        kernel_initializer='he_normal')(x2)
    x3 = layers.LeakyReLU(alpha=0.2)(x3)  # 4

    x4 = layers.Conv3D(32, kernel_size=3, strides=2, padding='same',
                        kernel_initializer='he_normal')(x3)
    x4 = layers.LeakyReLU(alpha=0.2)(x4)  # 2

    # decoder [32, 32, 32, 32, 8, 8]
    x = layers.Conv3D(32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal')(x4)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.UpSampling3D(size=2)(x)  # 4
    x = layers.concatenate([x, x3], axis=-1)  # 4

    x = layers.Conv3D(32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.UpSampling3D(size=2)(x)  # 8
    x = layers.concatenate([x, x2], axis=-1)  # 8

    x = layers.Conv3D(32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.UpSampling3D(size=2)(x)  # 16
    x = layers.concatenate([x, x1], axis=-1)  # 16

    x = layers.Conv3D(32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.Conv3D(8, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)  # 16

    x = layers.UpSampling3D(size=2)(x)  # 32
    x = layers.concatenate([x, x_in], axis=-1)
    x = layers.Conv3D(8, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)  # 32

    kernel_initializer = tf.keras.initializers.RandomNormal(mean=0.0,
                                                            stddev=1e-5)
    deformation = layers.Conv3D(out_channels, kernel_size=3, strides=1,
                                padding='same',
                                kernel_initializer=kernel_initializer)(x)

    nb, nd, nh, nw, nc = tf.shape(deformation)

    # Regular grid.
    grid = regular_grid_3d(nd, nh, nw)  # shape (D, H, W, 2)
    grid = tf.expand_dims(grid, axis=0)  # shape (1, D, H, W, 2)
    multiples = tf.stack([nb, 1, 1, 1, 1])
    grid = tf.tile(grid, multiples)

    # Compute the new sampling grid.
    grid_new = grid + deformation
    grid_new = tf.clip_by_value(grid_new, -1, 1)

    # Sample the moving image using the new sampling grid.
    moved = grid_sample_3d(moving, grid_new)

    model = tf.keras.Model(inputs=[static, moving],
                            outputs=[moved, deformation], name='voxelmorph1')
    return model

In [12]:
@tf.function
def train_step(model, moving, static, criterion, optimizer):
    nb, nd, nh, nw, nc = tf.keras.backend.int_shape(moving)  # moving.shape

    # Repeat the static image along the batch dim.
    multiples = tf.constant([nb, 1, 1, 1, 1], tf.int32)
    static = tf.tile(static, multiples)

    # Define the GradientTape context for automatic differentiation.
    with tf.GradientTape() as tape:
        # Get the deformation field
        # inputs = tf.concat([moving, static], axis=-1)
        moved, deformation = model({'moving': moving, 'static': static})
        # Compute the loss.
        # loss = criterion(moved, static)
        # loss = criterion(static, moved) + 2*gradient_loss(deformation)
        loss = 0.5*criterion(static, moved) + 0.5*criterion(moved, static) + 2*gradient_loss(deformation)
    # Compute gradients.
    grads = tape.gradient(loss, model.trainable_variables)
    # Update the trainable parameters.
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

In [13]:
@tf.function
def test_step(model, moving, static, criterion):
    nb, nd, nh, nw, nc = tf.keras.backend.int_shape(moving)  # moving.shape

    # Repeat the static image along the batch dim.
    multiples = tf.constant([nb, 1, 1, 1, 1], tf.int32)
    static = tf.tile(static, multiples)

    # Get the deformation field.
    # inputs = tf.concat([moving, static], axis=-1)
    moved, deformation = model({'moving': moving, 'static': static}, training=False)

    # Compute the loss.
    # loss = criterion(moved, static)
    # loss = criterion(static, moved) + 2*gradient_loss(deformation)
    loss = 0.5*criterion(static, moved) + 0.5*criterion(moved, static) + 2*gradient_loss(deformation)
    return loss

In [14]:
def plot_images(model, moving, static):
    nb, nd, nh, nw, nc = moving.shape

    # Repeat the static image along the batch dim.
    multiples = tf.constant([nb, 1, 1, 1, 1], tf.int32)
    static = tf.tile(static, multiples)

    moved, deformation = model({'moving': moving, 'static': static}, training=False)
    print(deformation.shape, tf.reduce_max(deformation), tf.reduce_min(deformation), tf.reduce_mean(deformation))

    deformation = deformation.numpy()
    moved = moved.numpy().squeeze(axis=-1) * 255.0
    # moved = moved.astype(np.uint8)[:,:,nh//2,...]
    moving = moving.numpy().squeeze(axis=-1) * 255.0
    # moving = moving.astype(np.uint8)[:,:,nh//2,...]
    static = static.numpy().squeeze(axis=-1) * 255.0
    # static = static.astype(np.uint8)[:,:,nh//2,...]


    # # Plot images.
    # fig = plt.figure(figsize=(3 * 1.7, nb * 1.7))
    # titles_list = ['Static', 'Moved', 'Moving']
    # images_list = [static, moved, moving]

    # moved = moving
    dd = []
    for i in range(nb):
        # for j in range(3):
        #     ax = fig.add_subplot(nb, 3, i * 3 + j + 1)
        #     if i == 0:
        #         ax.set_title(titles_list[j], fontsize=20)
        #     ax.set_axis_off()
        #     ax.imshow(images_list[j][i], cmap='gray')
        regtools.overlay_slices(static[i], moved[i], None, 0,
                                "%d Static"%i, "%d Transformed"%i,
                                "%d_0.png" % (i))
        regtools.overlay_slices(static[i], moved[i], None, 1,
                                "%d Static"%i, "%d Transformed"%i,
                                "%d_1.png" % (i))
        regtools.overlay_slices(static[i], moved[i], None, 2,
                                "%d Static"%i, "%d Transformed"%i,
                                "%d_2.png" % (i))
        d = {'static': static[i].astype(np.uint8), 'moved': moved[i].astype(np.uint8), 'moving': moving[i].astype(np.uint8)}
        dd.append(d)
    print(moved.min(), moved.max(), static.min(), static.max(), moving.min(), moving.max())
    np.save('/content/drive/My Drive/DIPY/T1_T2_new/results_200.npy', dd)

    # plt.tight_layout()
    # plt.show()

In [15]:
# Load 10 images from dataset
T1_dir = HCP_folder + "/T1/"

T1_files = os.listdir(T1_dir)
T1_files.sort()
# T1_files = T1_files[:200]


In [16]:
def read_dataset(file_list):
    images = []
    for file in file_list:
        image, _ = load_nifti(T1_dir + file)
        # image = np.expand_dims(image, axis=-1)
        images.append(image)
    images = np.array(images)
    return images

In [17]:
# imgs = read_dataset(T1_files)

In [18]:
# imgs.shape

In [19]:
def main(args):

    # brain = np.load('/content/drive/My Drive/DIPY/t1_t2/t1_t2_affine_128.npy')
    # static = np.load('/content/drive/My Drive/DIPY/t1_t2/static_128.npy')
    dataset = read_dataset(T1_files)
    print("dataset:",dataset.shape)
    brain = dataset[:50,...]
    print("brain:",brain.shape)
    static = dataset[50::,...]
    print("static:",static.shape)
    del dataset
    brain = brain.astype(np.float32)/255.0
    x_train = brain[:25,...][...,None].copy()
    print("x_train:",x_train.shape)
    # x_train = x_train.astype(np.float32)/255.0
    x_test = brain[25:,...][...,None].copy()
    print("x_test:",x_test.shape)
    # x_test = x_test.astype(np.float32)/255.0
    x_sample = x_test.copy()
    # static = brain[-1,...][None,...,None]
    static = static[None,...,None]
    print("static:",static.shape)
    static = static.astype(np.float32)/255.0



    min_val = min(static.min(), brain.min())
    max_val = max(static.max(), brain.max())
    print(min_val, max_val)
    bin_centers = np.linspace(min_val, max_val, 32)

    del brain


    x_train = tf.convert_to_tensor(x_train, dtype='float32')
    x_test = tf.convert_to_tensor(x_test, dtype='float32')
    x_sample = tf.convert_to_tensor(x_sample, dtype='float32')
    static = tf.convert_to_tensor(static, dtype='float32')
    # print(x_train.shape, x_test.shape, x_sample.shape, static.shape)

    from_tensor_slices = tf.data.Dataset.from_tensor_slices
    x_train = from_tensor_slices(x_train).shuffle(10000).batch(args.batch_size)
    x_test = from_tensor_slices(x_test).shuffle(10000).batch(args.batch_size)

    # S = 128
    # # Create a model instance.
    # model = voxelmorph1(input_shape=(S, S, S))
    model = voxelmorph1(input_shape=(128, 128, 128))

    # replace
    # model.load_weights('drive/My Drive/DIPY/t1_t2/results/checkpoints/59.h5')
    # # or
    # tf.keras.models.load_model('name of .h5')


    # Select optimizer and loss function.
    optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)
    # criterion = ncc_new
    criterion = mutualInformation(bin_centers, sigma_ratio=0.5, max_clip=1, crop_background=False, local_mi=True, patch_size=12)

    # Define the metrics to track training and testing losses.
    m_train = tf.keras.metrics.Mean(name='loss_train')
    m_test = tf.keras.metrics.Mean(name='loss_test')

    # Train and evaluate the model.
    for epoch in range(args.epochs):
        m_train.reset_states()
        m_test.reset_states()
        for i, moving in enumerate(x_train):
            loss_train = train_step(model, moving, static, criterion,
                                    optimizer)
            m_train.update_state(loss_train)

        for i, moving in enumerate(x_test):
            loss_test = test_step(model, moving, static, criterion)
            m_test.update_state(loss_test)

        model.save_weights('data/checkpoints/%d_2.h5'%epoch)
        print('Epoch: %3d/%d\tTrain Loss: %.6f\tTest Loss: %.6f'
              % (epoch + 1, args.epochs, m_train.result(), m_test.result()))
    print('\n')

    # Show sample results.
    plot_images(model, x_sample, static)

    # Save the trained model.
    if args.save_model:
        model.save_weights('data/results/voxelmorph1-weights.h5')

In [20]:
if __name__ == '__main__':

    class Args():
        batch_size = 1
        epochs = 60
        lr = 0.001
        # label = 1  # which digit images to train on?
        # num_samples = 5  # number of sample results to show
        save_model = True

    args = Args()
    main(args)

dataset: (1065, 128, 128, 128, 1)
brain: (50, 128, 128, 128, 1)
static: (1015, 128, 128, 128, 1)
x_train: (25, 128, 128, 128, 1, 1)
x_test: (25, 128, 128, 128, 1, 1)
static: (1, 1015, 128, 128, 128, 1, 1)
0.0 0.003921569


2023-12-10 16:57:16.350261: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-12-10 16:57:16.368718: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2256] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


ValueError: in user code:

    File "/tmp/ipykernel_141993/1463614979.py", line 3, in train_step  *
        nb, nd, nh, nw, nc = tf.keras.backend.int_shape(moving)  # moving.shape

    ValueError: too many values to unpack (expected 5)
