# 3D image matching with LDDMM

## import required libraries

In [None]:
import numpy as np
%matplotlib notebook
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.ion() # for drawing in real time
import nibabel as nib
import tensorflow as tf

## load images to register

In [None]:
# input image filenames here
# we will deform the atlas to match the target
atlas_image_fname = 'Adt27-55_02_Adt27-55_02_MNI.img'
target_image_fname = 'Adt27-55_03_Adt27-55_03_MNI.img'

In [None]:
# load them
fnames = [atlas_image_fname,target_image_fname]
img = [nib.load(fname) for fname in fnames]

In [None]:
# get info about domains
# we assume for this example that we have the same voxel size and same voxel spacing
if '.img' == atlas_image_fname[-4:]:    
    nx = img[0].header['dim'][1:4]
    dx = img[0].header['pixdim'][1:4]
else:
    # I'm only working with analyze for now
    raise ValueError('Only Analyze images supported for now')
x = [np.arange(nxi)*dxi for nxi,dxi in zip(nx,dx)]

In [None]:
# get the images, note they also include a fourth axis for time that I don't want
I = img[0].get_data()[:,:,:,0]
J = img[1].get_data()[:,:,:,0]

In [None]:
# simple function for drawing 3 slices
def draw_slices(x,I,axlist,**kwargs):
    ''' Draw three slices through the middle of an image'''
    axlist[0].imshow(np.squeeze(I[:,:,I.shape[2]//2]),**kwargs)
    axlist[1].imshow(np.squeeze(I[:,I.shape[1]//2,:]),**kwargs)
    axlist[2].imshow(np.squeeze(I[I.shape[0]//2,:,:]),**kwargs)
        

In [None]:
# draw my images
f,ax = plt.subplots(2,3)
draw_slices(x,I,ax[0],aspect='equal',cmap='gray')
draw_slices(x,J,ax[1],aspect='equal',cmap='gray')

## Now we need to define a linear interpolation function in tensorflow for 3d data

In [None]:
def interp3(x0,x1,x2,I,phi0,phi1,phi2):
    ''' 
    Linear interpolation
    Interpolate a 3D tensorflow image I
    with voxels corresponding to locations in x0, x1 (1d arrays)
    at the points phi0, phi1 (2d arrays)
    '''
    # get the size
    dx = [x0[1]-x0[0], x1[1]-x1[0], x2[1]-x2[0]]
    nx = [len(x0), len(x1), len(x2)]
    #convert to index
    phi0_index = (phi0 - x0[0])/dx[0]
    phi1_index = (phi1 - x1[0])/dx[1]
    phi2_index = (phi2 - x2[0])/dx[2]
    # take the floor to get integers
    phi0_index_floor = tf.floor(phi0_index)
    phi1_index_floor = tf.floor(phi1_index)
    phi2_index_floor = tf.floor(phi2_index)
    # get the fraction to the next pixel
    phi0_p = phi0_index - phi0_index_floor
    phi1_p = phi1_index - phi1_index_floor
    phi2_p = phi2_index - phi2_index_floor
    # get the next samples
    phi0_index_floor_1 = phi0_index_floor+1
    phi1_index_floor_1 = phi1_index_floor+1
    phi2_index_floor_1 = phi2_index_floor+1
    # and apply boundary conditions
    phi0_index_floor = tf.minimum(phi0_index_floor,nx[0]-1)
    phi0_index_floor = tf.maximum(phi0_index_floor,0)
    phi0_index_floor_1 = tf.minimum(phi0_index_floor_1,nx[0]-1)
    phi0_index_floor_1 = tf.maximum(phi0_index_floor_1,0)
    phi1_index_floor = tf.minimum(phi1_index_floor,nx[1]-1)
    phi1_index_floor = tf.maximum(phi1_index_floor,0)
    phi1_index_floor_1 = tf.minimum(phi1_index_floor_1,nx[1]-1)
    phi1_index_floor_1 = tf.maximum(phi1_index_floor_1,0)
    phi2_index_floor = tf.minimum(phi2_index_floor,nx[2]-1)
    phi2_index_floor = tf.maximum(phi2_index_floor,0)
    phi2_index_floor_1 = tf.minimum(phi2_index_floor_1,nx[2]-1)
    phi2_index_floor_1 = tf.maximum(phi2_index_floor_1,0)
    # then we will need to vectorize everything to use scalar indices
    phi0_index_floor_flat = tf.reshape(phi0_index_floor,[-1])
    phi0_index_floor_flat_1 = tf.reshape(phi0_index_floor_1,[-1])
    phi1_index_floor_flat = tf.reshape(phi1_index_floor,[-1])
    phi1_index_floor_flat_1 = tf.reshape(phi1_index_floor_1,[-1])
    phi2_index_floor_flat = tf.reshape(phi2_index_floor,[-1])
    phi2_index_floor_flat_1 = tf.reshape(phi2_index_floor_1,[-1])
    I_flat = tf.reshape(I,[-1])
    # indices recall that the LAST INDEX IS CONTIGUOUS
    phi_index_floor_flat_000 = nx[2]*nx[1]*phi0_index_floor_flat + nx[1]*phi1_index_floor_flat + phi2_index_floor_flat
    phi_index_floor_flat_001 = nx[2]*nx[1]*phi0_index_floor_flat + nx[1]*phi1_index_floor_flat + phi2_index_floor_flat_1
    phi_index_floor_flat_010 = nx[2]*nx[1]*phi0_index_floor_flat + nx[1]*phi1_index_floor_flat_1 + phi2_index_floor_flat
    phi_index_floor_flat_011 = nx[2]*nx[1]*phi0_index_floor_flat + nx[1]*phi1_index_floor_flat_1 + phi2_index_floor_flat_1
    phi_index_floor_flat_100 = nx[2]*nx[1]*phi0_index_floor_flat_1 + nx[1]*phi1_index_floor_flat + phi2_index_floor_flat
    phi_index_floor_flat_101 = nx[2]*nx[1]*phi0_index_floor_flat_1 + nx[1]*phi1_index_floor_flat + phi2_index_floor_flat_1
    phi_index_floor_flat_110 = nx[2]*nx[1]*phi0_index_floor_flat_1 + nx[1]*phi1_index_floor_flat_1 + phi2_index_floor_flat
    phi_index_floor_flat_111 = nx[2]*nx[1]*phi0_index_floor_flat_1 + nx[1]*phi1_index_floor_flat_1 + phi2_index_floor_flat_1
    
    # now slice the image
    I000_flat = tf.gather(I_flat, tf.cast(phi_index_floor_flat_000, dtype=tf.int64))
    I001_flat = tf.gather(I_flat, tf.cast(phi_index_floor_flat_001, dtype=tf.int64))
    I010_flat = tf.gather(I_flat, tf.cast(phi_index_floor_flat_010, dtype=tf.int64))
    I011_flat = tf.gather(I_flat, tf.cast(phi_index_floor_flat_011, dtype=tf.int64))
    I100_flat = tf.gather(I_flat, tf.cast(phi_index_floor_flat_100, dtype=tf.int64))
    I101_flat = tf.gather(I_flat, tf.cast(phi_index_floor_flat_101, dtype=tf.int64))
    I110_flat = tf.gather(I_flat, tf.cast(phi_index_floor_flat_110, dtype=tf.int64))
    I111_flat = tf.gather(I_flat, tf.cast(phi_index_floor_flat_111, dtype=tf.int64))
    
    # reshape it
    I000 = tf.reshape(I000_flat,nx)
    I001 = tf.reshape(I001_flat,nx)
    I010 = tf.reshape(I010_flat,nx)
    I011 = tf.reshape(I011_flat,nx)
    I100 = tf.reshape(I100_flat,nx)
    I101 = tf.reshape(I101_flat,nx)
    I110 = tf.reshape(I110_flat,nx)
    I111 = tf.reshape(I111_flat,nx)

    # combine them!
    Il = I000*(1.0-phi0_p)*(1.0-phi1_p)*(1.0-phi2_p)\
        + I001*(1.0-phi0_p)*(1.0-phi1_p)*(    phi2_p)\
        + I010*(1.0-phi0_p)*(    phi1_p)*(1.0-phi2_p)\
        + I011*(1.0-phi0_p)*(    phi1_p)*(    phi2_p)\
        + I100*(    phi0_p)*(1.0-phi1_p)*(1.0-phi2_p)\
        + I101*(    phi0_p)*(1.0-phi1_p)*(    phi2_p)\
        + I110*(    phi0_p)*(    phi1_p)*(1.0-phi2_p)\
        + I111*(    phi0_p)*(    phi1_p)*(    phi2_p)
    return Il

## and we need a gradient function

In [None]:
def grad3(I,dx):
    I_0 = (tf.manip.roll(I,shift=-1,axis=0) - tf.manip.roll(I,shift=1,axis=0))/2.0/dx[0]
    I_1 = (tf.manip.roll(I,shift=-1,axis=1) - tf.manip.roll(I,shift=1,axis=1))/2.0/dx[1]
    I_2 = (tf.manip.roll(I,shift=-1,axis=2) - tf.manip.roll(I,shift=1,axis=2))/2.0/dx[2]
    
    #out[0,:] = out[1,:]-out[0,:] # this doesn't work in tensorflow
    # generally you cannot assign to a tensor
    return I_0, I_1, I_2

## now we set some parameters


In [None]:
# gradient descent stepsize this will be a placeholder
# each time we will take a gradient descent step
# we may want to change it from time to time
epsilon = 1e-1
epsilonph = tf.placeholder(tf.float32, shape=()) 
niter = 20 # number of optimization iterations

# flow parameters
nt = 5 # number of timesteps
dt = 1.0/nt # increment in time for each step
alpha = dx[0]*2.0 # spatial scale of smoothing operator
power = 2.0 # power of identity - alpha*Laplacian

# cost parameters
sigmaM = (np.max(J) - np.min(J))*0.1 # matching cost standard deviation, smaller means larger cost
sigmaM = 1e1
sigmaR = 1e0 # regularization cost standard deviation

## build a Fourier domain and differential operators

In [None]:
f0 = np.arange(nx[0])/dx[0]/nx[0]
f1 = np.arange(nx[1])/dx[1]/nx[1]
f2 = np.arange(nx[2])/dx[2]/nx[2]
F0,F1,F2 = np.meshgrid(f0, f1, f2, indexing='ij')
# identity minus laplacian, in fourier domain
# AI[i,j] = I[i,j] - alpha^2( (I[i+1,j] - 2I[i,j] + I[i-1,j])/dx^2 + (I[i,j+1] - 2I[i,j] + I[i,j-1])/dy^2  )
Lhat = (1.0 - alpha**2*((-2.0 + 2.0*np.cos(2*np.pi*dx[0]*F0))/dx[0]**2 
    + (-2.0 + 2.0*np.cos(2*np.pi*dx[1]*F1))/dx[1]**2
    + (-2.0 + 2.0*np.cos(2*np.pi*dx[2]*F2))/dx[2]**2))**power
# for real ffts we only half of this, TODO
LLhat = Lhat**2
Khat = 1.0/LLhat
# convert to tensorflow
Khattf = tf.complex(tf.constant(Khat,dtype=tf.float32),0.)
#LLhattf = tf.complex(tf.constant(LLhat,dtype=tf.float32),0.)
LLhattf = tf.constant(LLhat,dtype=tf.float32)

## initialize my optimization variables

Each iteration there will be an old version and a new version


In [None]:
# NOTE this cell can only be run once
# if you run it again it will generate errors
vt0 = tf.get_variable('vt0',shape=[nx[0],nx[1],nx[2],nt],dtype=tf.float32,trainable=False,initializer=tf.zeros_initializer())
vt1 = tf.get_variable('vt1',shape=[nx[0],nx[1],nx[2],nt],dtype=tf.float32,trainable=False,initializer=tf.zeros_initializer())
vt2 = tf.get_variable('vt2',shape=[nx[0],nx[1],nx[2],nt],dtype=tf.float32,trainable=False,initializer=tf.zeros_initializer())

vt0new = tf.get_variable('vt0new',shape=[nx[0],nx[1],nx[2],nt],dtype=tf.float32,trainable=False,initializer=tf.zeros_initializer())
vt1new = tf.get_variable('vt1new',shape=[nx[0],nx[1],nx[2],nt],dtype=tf.float32,trainable=False,initializer=tf.zeros_initializer())
vt2new = tf.get_variable('vt2new',shape=[nx[0],nx[1],nx[2],nt],dtype=tf.float32,trainable=False,initializer=tf.zeros_initializer())

## Implement the tensorflow graph for one iteration of gradient descent

In [None]:
# initialize some variables
Itf = tf.constant(I,dtype=tf.float32)
Jtf = tf.constant(J,dtype=tf.float32)
x0=x[0]
x1=x[1]
x2=x[2]
X0,X1,X2 = np.meshgrid(x0,x1,x2,indexing='ij')
X0tf = tf.constant(X0,dtype=tf.float32)
X1tf = tf.constant(X1,dtype=tf.float32)
X2tf = tf.constant(X2,dtype=tf.float32)

In [None]:
# flow forwards
It = [Itf]
phiinv0 = X0tf
phiinv1 = X1tf
phiinv2 = X2tf
ER = 0
for t in range(nt):
    v0 = vt0[:,:,:,t]
    v1 = vt1[:,:,:,t]
    v2 = vt2[:,:,:,t]
    X0s = X0 - v0*dt
    X1s = X1 - v1*dt
    X2s = X2 - v2*dt
    
    # update diffeomorphism with nice boundary conditions
    phiinv0 = interp3(x0,x1,x2,phiinv0-X0tf,X0s,X1s,X2s)+X0s
    phiinv1 = interp3(x0,x1,x2,phiinv1-X1tf,X0s,X1s,X2s)+X1s
    phiinv2 = interp3(x0,x1,x2,phiinv2-X2tf,X0s,X1s,X2s)+X2s
    
    # deform the image
    It.append(interp3(x0,x1,x2,Itf,phiinv0,phiinv1,phiinv2))
    
    # get regularization energy
    # this is probably the fastest way to compute energy, note the normalizer 1/(number of elemetns)
    v0hat = tf.fft2d(tf.cast(v0,tf.complex64))
    v1hat = tf.fft2d(tf.cast(v1,tf.complex64))
    v2hat = tf.fft2d(tf.cast(v2,tf.complex64))
    ER = ER + tf.reduce_sum( ( tf.abs(v0hat)**2 + tf.abs(v1hat)**2 + tf.abs(v2hat)**2 ) * LLhattf )
ER = ER*dt*dx[0]*dx[1]*dx[2]/sigmaR**2/2.0/nx[0]/nx[1]/nx[2]

# now compute the error
lambda1 = (It[-1] - J)/sigmaM**2

# get matching energy 
EM = tf.reduce_sum((It[-1] - J)**2)/sigmaM**2*dx[0]*dx[1]*dx[2]/2.0
E = EM + ER

# flow the error backwards
phiinv0 = X0tf
phiinv1 = X1tf
phiinv2 = X2tf
vt0new_ = []
vt1new_ = []
vt2new_ = []
for t in range(nt-1,-1,-1):
    v0 = vt0[:,:,:,t]
    v1 = vt1[:,:,:,t]
    v2 = vt2[:,:,:,t]
    X0s = X0 + v0*dt
    X1s = X1 + v1*dt
    X2s = X2 + v2*dt
    phiinv0 = interp3(x0,x1,x2,phiinv0-X0tf,X0s,X1s,X2s) + X0s
    phiinv1 = interp3(x0,x1,x2,phiinv1-X1tf,X0s,X1s,X2s) + X1s
    phiinv2 = interp3(x0,x1,x2,phiinv2-X2tf,X0s,X1s,X2s) + X2s
    
    # compute the gradient of the image at this time
    I_0,I_1,I_2 = grad3(It[t],dx)
    
    # compute the determinanat of jacobian
    phiinv0_0,phiinv0_1,phiinv0_2 = grad3(phiinv0,dx)
    phiinv1_0,phiinv1_1,phiinv1_2 = grad3(phiinv1,dx)
    phiinv2_0,phiinv2_1,phiinv2_2 = grad3(phiinv2,dx)
    detjac = phiinv0_0*(phiinv1_1*phiinv2_2 - phiinv1_2*phiinv2_1)\
        - phiinv0_1*(phiinv1_0*phiinv2_2 - phiinv1_2*phiinv2_0)\
        + phiinv0_2*(phiinv1_0*phiinv2_1 - phiinv1_1*phiinv2_0)
    
    # get the lambda for this time
    lambda_ = interp3(x0,x1,x2,lambda1,phiinv0,phiinv1,phiinv2)*detjac
    
    # set up the gradient
    grad0 = -lambda_*I_0
    grad1 = -lambda_*I_1
    grad2 = -lambda_*I_2
    
    # smooth it
    grad0 = tf.real(tf.ifft2d(tf.fft2d(tf.cast(grad0,tf.complex64))*Khattf))
    grad1 = tf.real(tf.ifft2d(tf.fft2d(tf.cast(grad1,tf.complex64))*Khattf))
    grad2 = tf.real(tf.ifft2d(tf.fft2d(tf.cast(grad2,tf.complex64))*Khattf))
    
    # add the regularization
    grad0 = grad0 + v0/sigmaR**2
    grad1 = grad1 + v1/sigmaR**2
    grad2 = grad2 + v2/sigmaR**2
    
    # and calculate the new v
    vt0new_.append(v0 - epsilonph*grad0)
    vt1new_.append(v1 - epsilonph*grad1)
    vt2new_.append(v2 - epsilonph*grad2)

# stack
vt0new = tf.stack(vt0new_[::-1],axis=3)
vt1new = tf.stack(vt1new_[::-1],axis=3)
vt2new = tf.stack(vt2new_[::-1],axis=3)

# define a graph operation
step = tf.group(
  vt0.assign(vt0new),
  vt1.assign(vt1new))

In [None]:
EMall = []
ERall = []
Eall = []
f,ax = plt.subplots(2,2)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(niter):
        # take a step of gradient descent
        step.run({epsilonph:epsilon})
        
        Idnp = It[-1].eval()
        ax[0][0].imshow(Idnp,**imopts)
        ax[0][0].set_title('deformed image'.format(i))
        
        lambda1np = lambda1.eval()
        ax[0][1].imshow(lambda1np,**imopts)
        ax[0][1].set_title('error')
        
        EMall.append(EM.eval())
        ERall.append(ER.eval())
        Eall.append(E.eval())
        ax[1][0].cla()
        ax[1][0].plot(list(zip(Eall,EMall,ERall)))
        xlim = ax[1][0].get_xlim()
        ylim = ax[1][0].get_ylim()
        ax[1][0].set_aspect((xlim[1]-xlim[0])/(ylim[1]-ylim[0]))
        ax[1][0].legend(['Etot','Ematch','Ereg'])
        
        ax[1][1].imshow(J,**imopts)
        
        f.canvas.draw()