<a href="https://colab.research.google.com/github/surajsrivathsa/image_registration/blob/main/ADMIR_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Preamble

In [None]:
!pip install --upgrade nibabel

Collecting nibabel
[?25l  Downloading https://files.pythonhosted.org/packages/d7/7f/d3c29792fae50ef4f1f8f87af8a94d5d9fe76550b86ebcf8a251110169d8/nibabel-3.2.0-py3-none-any.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 4.8MB/s 
Installing collected packages: nibabel
  Found existing installation: nibabel 3.0.2
    Uninstalling nibabel-3.0.2:
      Successfully uninstalled nibabel-3.0.2
Successfully installed nibabel-3.2.0


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import numpy as np
from matplotlib import pyplot as plt
import nibabel as nb
import os, sys, glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms

print("nibabel version: {}".format(nb.__version__))
print("pytorch version: {}".format(torch.__version__))

nibabel version: 3.2.0
pytorch version: 1.7.0+cu101


In [None]:
t1_fn = '/content/drive/My Drive/Colab Notebooks/ImageRegistrationUsingDeepLearning/ADMIR/Dataset/IXI002-Guys-0828-T1_resampled.nii.gz'
t2_fn = '/content/drive/My Drive/Colab Notebooks/ImageRegistrationUsingDeepLearning/ADMIR/Dataset/IXI002-Guys-0828-T2_resampled.nii.gz'  

In [None]:
data_path = "/content/drive/My Drive/Colab Notebooks/ImageRegistrationUsingDeepLearning/ADMIR/Dataset/"
file_names = glob.glob(os.path.join(data_path, "*.nii.gz"))

In [None]:
len(file_names)

0

# Image Processing

In [None]:
img_nb1 = nb.load(file_names[0])
img_nb1.shape
img_nb2 = nb.load(file_names[1])
img_nb2.shape

IndexError: ignored

In [None]:
def load_4D(name):
    X_nb = nb.load(name)
    X_np = X_nb.dataobj
    X_np = np.reshape(X_np, (1,)+ X_np.shape)
    return X_np

def imgnorm(N_I,index1=0.0001,index2=0.0001):
    I_sort = np.sort(N_I.flatten())
    I_min = I_sort[int(index1*len(I_sort))]
    I_max = I_sort[-int(index2*len(I_sort))]
    N_I =1.0*(N_I-I_min)/(I_max-I_min)
    N_I[N_I>1.0]=1.0
    N_I[N_I<0.0]=0.0
    N_I2 = N_I.astype(np.float32)
    return N_I2

def Norm_Zscore(img):
    img= (img-np.mean(img))/np.std(img) 
    return img

In [None]:
class Dataset(Data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, names,iterations,norm=True):
        'Initialization'
        self.names = names
        self.norm = norm
        self.iterations = iterations
  def __len__(self):
        'Denotes the total number of samples'
        return len(self.names) * 2

  def __getitem__(self, step):
        'Generates one sample of data'
        # Select sample
        # print(self.names)
        index_pair = np.random.permutation(len(self.names)) [0:4]
        img_A = load_4D(self.names[index_pair[0]])
        img_B = load_4D(self.names[index_pair[1]])     
        
        if self.norm:
            return  Norm_Zscore(imgnorm(img_A)) , Norm_Zscore(imgnorm(img_B))
        else:
            return img_A, img_B



In [None]:
training_generator = Data.DataLoader(Dataset(file_names,iterations=2,norm=True), batch_size=2, shuffle=True)

In [None]:
ex1 = torch.rand(2, 40, 4, 4, 4)
ex2 = ex1.flatten(start_dim=1, end_dim=4)
ex2.shape

In [None]:
for X,Y in training_generator:
  print(X.shape)
  print(Y.shape)

# Building Affine Model

In [None]:
class Admir_Affine_Encoder(nn.Module):
    def __init__(self, in_channel, start_channel, num_conv_blocks=6):
        self.in_channel = in_channel
        self.start_channel = start_channel
        self.num_conv_blocks = num_conv_blocks
        self.encoder_layer_list = []
        super(Admir_Affine_Encoder, self).__init__()
        self.create_model()

    def affine_conv_block(self, in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1, bias=True ):
      layer = nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
                            nn.BatchNorm3d(out_channels),
                            nn.LeakyReLU(negative_slope=0.1))
      return layer
    

    def create_model(self):
      for i in range(self.num_conv_blocks):
          if(i == 0):
            lyr = self.affine_conv_block(in_channels = self.in_channel, out_channels = self.start_channel)
            self.encoder_layer_list.append(lyr)
          else:
            lyr = self.affine_conv_block(in_channels= self.start_channel * i, out_channels = self.start_channel * (i+1))
            self.encoder_layer_list.append(lyr)

    def forward(self, x, y):
      # print("x,y", x.shape, "  ", y.shape)
      x_in=torch.cat((x, y), 1)
      e0 = self.encoder_layer_list[0](x_in)
      e1 = self.encoder_layer_list[1](e0)
      e2 = self.encoder_layer_list[2](e1)
      e3 = self.encoder_layer_list[3](e2)
      e4 = self.encoder_layer_list[4](e3)
      return e4


In [None]:
affine_conv_model = Admir_Affine_Encoder(in_channel=2, start_channel=8, num_conv_blocks=5)

In [None]:
class Admir_Affine_Output(nn.Module):
  def __init__(self, in_units, out_units=128, dropout_prob = 0.3):
    
    self.in_units = in_units
    self.out_units = out_units
    self.dropout_prob = dropout_prob
    super(Admir_Affine_Output, self).__init__()
    self.trns_ob = self.translation_output_block(self.in_units, self.out_units)
    self.rss_ob = self.rot_scale_shear_output_block(self.in_units, self.out_units)
    return;
  
  def translation_output_block(self, in_units, out_units):
    layer = nn.Sequential(
          nn.Linear(in_features = in_units, out_features= out_units),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units, out_features= out_units//2),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units//2, out_features= out_units//4),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units//4, out_features= out_units//8),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units//8, out_features= 3))
    return layer

  def rot_scale_shear_output_block(self, in_units, out_units):
    layer = nn.Sequential(
          nn.Linear(in_features = in_units, out_features= out_units),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units, out_features= out_units//2),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units//2, out_features= out_units//4),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units//4, out_features= out_units//8),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units//8, out_features= 9),
          nn.Tanh())
    return layer
  
  def forward(self, input_tnsr):
    ip = input_tnsr.flatten(start_dim=1, end_dim=4)
    #print(ip.shape)
    translation_output = self.trns_ob(ip)
    rotate_scale_shear_output = self.rss_ob(ip)
    return [translation_output, rotate_scale_shear_output]

In [None]:
affine_output_model = Admir_Affine_Output( in_units= 2560)

In [None]:
for X,Y in training_generator:
  print(X.shape)
  print(Y.shape)
  conv_out = affine_conv_model(X, Y)
  print(conv_out.shape)
  output_out = affine_output_model(conv_out)
  print(output_out[0].shape)
  print(output_out[1].shape)
  print("========== ============== =============")
  print()

torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 40, 4, 4, 4])
torch.Size([2, 3])
torch.Size([2, 9])

torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 40, 4, 4, 4])
torch.Size([2, 3])
torch.Size([2, 9])

torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 40, 4, 4, 4])
torch.Size([2, 3])
torch.Size([2, 9])

torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 40, 4, 4, 4])
torch.Size([2, 3])
torch.Size([2, 9])

torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 40, 4, 4, 4])
torch.Size([2, 3])
torch.Size([2, 9])



# Spatial Transform

In [None]:
class SpatialTransformer(nn.Module):
    """
    N-D Spatial Transformer
    """

    def __init__(self, size, is_affine=False, theta = None, mode='bilinear'):
        super().__init__()

        self.mode = mode
        self.isaffine = is_affine
        self.theta = theta
        # create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.stack(grids)
        grid = torch.unsqueeze(grid, 0)
        grid = grid.type(torch.FloatTensor)
        # registering the grid as a buffer cleanly moves it to the GPU, but it also
        # adds it to the state dict. this is annoying since everything in the state dict
        # is included when saving weights to disk, so the model files are way bigger
        # than they need to be. so far, there does not appear to be an elegant solution.
        # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
        self.register_buffer('grid', grid)

    def forward(self, src, flow):      
      if (self.isaffine):
        grid = F.affine_grid(self.theta, (2, 1, 128, 128, 128))
        warped_image = F.grid_sample(src, grid)
        return warped_image
      else:
        # new locations
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        # need to normalize grid values to [-1, 1] for resampler
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        # move channels dim to last position
        # also not sure why, but the channels need to be reversed
        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode)


In [None]:
spatial_transformer = SpatialTransformer(size=(128, 28, 128), is_affine=True, theta=torch.randn(size=(3, 4)))

In [None]:
print(spatial_transformer.grid.shape)
print(spatial_transformer.isaffine)
print("========= =========== ======")
print()
#print(spatial_transformer.grid)

torch.Size([1, 3, 128, 28, 128])
True



# Deformable ConvNet

In [None]:
class Admir_Deformable_UNet(nn.Module):
  def __init__(self,in_channel  , n_classes,start_channel):
        self.in_channel = in_channel
        self.n_classes = n_classes
        self.start_channel = start_channel
        super(Admir_Deformable_UNet, self).__init__()
        self.eninput = self.encoder(self.in_channel, self.start_channel, bias=False)

        self.ec1 = self.encoder(self.start_channel, self.start_channel, bias=False)
        self.ec2 = self.encoder(self.start_channel, self.start_channel*2, stride=2, bias=False)

        self.ec3 = self.encoder(self.start_channel*2, self.start_channel*2, bias=False)
        self.ec4 = self.encoder(self.start_channel*2, self.start_channel*4, stride=2, bias=False)

        self.ec5 = self.encoder(self.start_channel*4, self.start_channel*4, bias=False)
        self.ec6 = self.encoder(self.start_channel*4, self.start_channel*8, stride=2, bias=False)

       
    
        self.dc1 = self.encoder(self.start_channel*8, self.start_channel*8, kernel_size=3, stride=1, bias=False) 
        self.dc2 = self.encoder(self.start_channel*4, self.start_channel*4, kernel_size=3, stride=1, bias=False)          
        self.dc3 = self.encoder(self.start_channel*2, self.start_channel*2, kernel_size=3, stride=1, bias=False)

        self.up1 = self.decoder(self.start_channel*8, self.start_channel*4)
        self.up2 = self.decoder(self.start_channel*4, self.start_channel*2)
        self.up3 = self.decoder(self.start_channel*2, self.start_channel)

        self.dc4 = self.output(self.start_channel, self.n_classes,kernel_size=1,bias=False)

  def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
                bias=True):
    layer = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
                nn.BatchNorm3d(out_channels),
                nn.LeakyReLU(negative_slope=0.1))
    return layer

  def decoder(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0,
                output_padding=0, bias=True):
    layer = nn.Sequential(
                nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,
                               padding=padding, output_padding=output_padding, bias=bias),
            nn.LeakyReLU(negative_slope=0.1))
    return layer
       
  def output(self, in_channels, out_channels, kernel_size=3, 
                bias=False, batchnorm=False):
    layer = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size, bias=bias),
               )
    return layer

  def forward(self, x,y):
        # print("x,y", x.shape, "  ", y.shape)
        x_in=torch.cat((x, y), 1)  
        e0 = self.eninput(x_in)

        # print("e0", e0.shape)

        e0 = self.ec1(e0)
        es1 = self.ec2(e0)   #strided
        # print("e0", e0.shape)
        # print("es1", es1.shape)

        e1 = self.ec3(es1)   
        es2 = self.ec4(e1)   #strided
        # print("e1", e1.shape)
        # print("es2", es2.shape)

        e2 = self.ec5(es2)
        es3 = self.ec6(e2)   #strided
        # print("e2", e2.shape)
        # print("es3", es3.shape)

        

        d0 = self.dc1(es3)
        # print("d0", d0.shape)

        d0 = torch.add(self.up1(d0), e2)
        # print("d0", d0.shape)

        d1 = self.dc2(d0)
        d1 = torch.add(self.up2(d1), e1)
        # print("d1", d1.shape)

        d2 = self.dc3(d1)
        d2 = torch.add(self.up3(d2), e0)
        print("d2", d2.shape)

        output = self.dc4(d2)
        return output

In [None]:
torch.cuda.empty_cache()
model = Admir_Deformable_UNet(2,3,16).cuda() # assining cuda to model

for X,Y in training_generator:
  X = X.cuda().float()
  Y = Y.cuda().float()
  print(X.shape)
  print(Y.shape)
  out = model(X, Y)
  print(out.shape)
  print("========== ============== =============")
  print()

torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
d2 torch.Size([2, 16, 128, 128, 128])
torch.Size([2, 3, 128, 128, 128])

torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
d2 torch.Size([2, 16, 128, 128, 128])
torch.Size([2, 3, 128, 128, 128])



# Loss Function NCC

Reference: https://github.com/yuta-hi/pytorch_similarity

In [None]:
def normalized_cross_correlation(x, y, return_map, reduction='mean', eps=1e-8):
    """ N-dimensional normalized cross correlation (NCC)
    Args:
        x (~torch.Tensor): Input tensor.
        y (~torch.Tensor): Input tensor.
        return_map (bool): If True, also return the correlation map.
        reduction (str, optional): Specifies the reduction to apply to the output:
            ``'mean'`` | ``'sum'``. Defaults to ``'sum'``.
        eps (float, optional): Epsilon value for numerical stability. Defaults to 1e-8.
    Returns:
        ~torch.Tensor: Output scalar
        ~torch.Tensor: Output tensor
    """

    shape = x.shape
    b = shape[0]

    # reshape
    x = x.view(b, -1)
    y = y.view(b, -1)

    # mean
    x_mean = torch.mean(x, dim=1, keepdim=True)
    y_mean = torch.mean(y, dim=1, keepdim=True)

    # deviation
    x = x - x_mean
    y = y - y_mean

    dev_xy = torch.mul(x,y)
    dev_xx = torch.mul(x,x)
    dev_yy = torch.mul(y,y)

    dev_xx_sum = torch.sum(dev_xx, dim=1, keepdim=True)
    dev_yy_sum = torch.sum(dev_yy, dim=1, keepdim=True)

    ncc = torch.div(dev_xy + eps / dev_xy.shape[1],
                    torch.sqrt( torch.mul(dev_xx_sum, dev_yy_sum)) + eps)
    ncc_map = ncc.view(b, *shape[1:])

    # reduce
    if reduction == 'mean':
        ncc = torch.mean(torch.sum(ncc, dim=1))
    elif reduction == 'sum':
        ncc = torch.sum(ncc)
    else:
        raise KeyError('unsupported reduction type: %s' % reduction)

    if not return_map:
        return ncc

    return ncc, ncc_map


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class NormalizedCrossCorrelation(nn.Module):
    """ N-dimensional normalized cross correlation (NCC)
    Args:
        eps (float, optional): Epsilon value for numerical stability. Defaults to 1e-8.
        return_map (bool, optional): If True, also return the correlation map. Defaults to False.
        reduction (str, optional): Specifies the reduction to apply to the output:
            ``'mean'`` | ``'sum'``. Defaults to ``'mean'``.
    """
    def __init__(self,
                 eps=1e-8,
                 return_map=False,
                 reduction='mean'):

        super(NormalizedCrossCorrelation, self).__init__()

        self._eps = eps
        self._return_map = return_map
        self._reduction = reduction

    def forward(self, x, y):

        return normalized_cross_correlation(x, y,self._return_map, self._reduction, self._eps)

In [None]:
#  Checking NCC loss

similarity_loss = NormalizedCrossCorrelation()
for X,Y in training_generator:
  X = X.cuda().float()
  Y = Y.cuda().float()
  print(X.shape)
  print(Y.shape)
  out = similarity_loss(X, Y)
  print(out)
  print("========== ============== =============")
  print()

torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
tensor(0.7159, device='cuda:0')

torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
tensor(0.7159, device='cuda:0')



# Regularizer - DVF edge smoothness

In [None]:
x = torch.randn(size=(4, 3, 128, 128, 128))
sobel = [[1, 2, 1], [0, 0, 0], [-1, -2, -1]]
depth=x.size()[1]
channels=x.size()[2]

print(depth)
print(channels)
print(sobel)
print()
sobel_kernel = torch.FloatTensor(sobel).expand(depth,channels,3,3).unsqueeze(0)
print(sobel_kernel.shape) 
malignacy = F.conv3d(x, sobel_kernel, stride=1, padding=1)
print(malignacy.shape)

3
128
[[1, 2, 1], [0, 0, 0], [-1, -2, -1]]

torch.Size([1, 3, 128, 3, 3])
torch.Size([4, 1, 3, 128, 128])
