<a href="https://colab.research.google.com/github/surajsrivathsa/image_registration/blob/main/ADMIR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Preamble

In [1]:
!pip install --upgrade nibabel

Collecting nibabel
[?25l  Downloading https://files.pythonhosted.org/packages/d7/7f/d3c29792fae50ef4f1f8f87af8a94d5d9fe76550b86ebcf8a251110169d8/nibabel-3.2.0-py3-none-any.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 2.7MB/s 
Installing collected packages: nibabel
  Found existing installation: nibabel 3.0.2
    Uninstalling nibabel-3.0.2:
      Successfully uninstalled nibabel-3.0.2
Successfully installed nibabel-3.2.0


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import numpy as np
from matplotlib import pyplot as plt
import nibabel as nb
import os, sys, glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms

print("nibabel version: {}".format(nb.__version__))
print("pytorch version: {}".format(torch.__version__))

nibabel version: 3.2.0
pytorch version: 1.7.0+cu101


In [4]:
t1_fn = '/content/drive/My Drive/Colab Notebooks/image_registration/T1_ixi/IXI002-Guys-0828-T1.nii.gz'
t2_fn = '/content/drive/My Drive/Colab Notebooks/image_registration/T2_ixi/IXI002-Guys-0828-T2.nii.gz'  

In [5]:
data_path = "/content/drive/My Drive/Colab Notebooks/image_registration/resampled_mri/admir_data"
file_names = glob.glob(os.path.join(data_path, "*.nii.gz"))

In [6]:
len(file_names)

5

# Image Processing

In [7]:
img_nb1 = nb.load(file_names[0])
img_nb1.shape
img_nb2 = nb.load(file_names[1])
img_nb2.shape

(128, 128, 128)

In [8]:
def load_4D(name):
    X_nb = nb.load(name)
    X_np = X_nb.dataobj
    X_np = np.reshape(X_np, (1,)+ X_np.shape)
    return X_np

def imgnorm(N_I,index1=0.0001,index2=0.0001):
    I_sort = np.sort(N_I.flatten())
    I_min = I_sort[int(index1*len(I_sort))]
    I_max = I_sort[-int(index2*len(I_sort))]
    N_I =1.0*(N_I-I_min)/(I_max-I_min)
    N_I[N_I>1.0]=1.0
    N_I[N_I<0.0]=0.0
    N_I2 = N_I.astype(np.float32)
    return N_I2

def Norm_Zscore(img):
    img= (img-np.mean(img))/np.std(img) 
    return img

In [9]:
class Dataset(Data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, names,iterations,norm=True):
        'Initialization'
        self.names = names
        self.norm = norm
        self.iterations = iterations
  def __len__(self):
        'Denotes the total number of samples'
        return len(self.names) * 2

  def __getitem__(self, step):
        'Generates one sample of data'
        # Select sample
        # print(self.names)
        index_pair = np.random.permutation(len(self.names)) [0:4]
        img_A = load_4D(self.names[index_pair[0]])
        img_B = load_4D(self.names[index_pair[1]])     
        
        if self.norm:
            return  Norm_Zscore(imgnorm(img_A)) , Norm_Zscore(imgnorm(img_B))
        else:
            return img_A, img_B



In [10]:
training_generator = Data.DataLoader(Dataset(file_names,iterations=2,norm=True), batch_size=2, shuffle=True)

In [11]:
ex1 = torch.rand(2, 40, 4, 4, 4)
ex2 = ex1.flatten(start_dim=1, end_dim=4)
ex2.shape

torch.Size([2, 2560])

In [12]:
for X,Y in training_generator:
  print(X.shape)
  print(Y.shape)

torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])


# Building Affine Model

In [13]:
class Admir_Affine_Encoder(nn.Module):
    def __init__(self, in_channel, start_channel, num_conv_blocks=6):
        self.in_channel = in_channel
        self.start_channel = start_channel
        self.num_conv_blocks = num_conv_blocks
        self.encoder_layer_list = []
        super(Admir_Affine_Encoder, self).__init__()
        self.create_model()

    def affine_conv_block(self, in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1, bias=True ):
      layer = nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
                            nn.BatchNorm3d(out_channels),
                            nn.LeakyReLU(negative_slope=0.1))
      return layer
    

    def create_model(self):
      for i in range(self.num_conv_blocks):
          if(i == 0):
            lyr = self.affine_conv_block(in_channels = self.in_channel, out_channels = self.start_channel)
            self.encoder_layer_list.append(lyr)
          else:
            lyr = self.affine_conv_block(in_channels= self.start_channel * i, out_channels = self.start_channel * (i+1))
            self.encoder_layer_list.append(lyr)

    def forward(self, x, y):
      # print("x,y", x.shape, "  ", y.shape)
      x_in=torch.cat((x, y), 1)
      e0 = self.encoder_layer_list[0](x_in)
      e1 = self.encoder_layer_list[1](e0)
      e2 = self.encoder_layer_list[2](e1)
      e3 = self.encoder_layer_list[3](e2)
      e4 = self.encoder_layer_list[4](e3)
      return e4


In [14]:
affine_conv_model = Admir_Affine_Encoder(in_channel=2, start_channel=8, num_conv_blocks=5)

In [15]:
class Admir_Affine_Output(nn.Module):
  def __init__(self, in_units, out_units=128, dropout_prob = 0.3):
    
    self.in_units = in_units
    self.out_units = out_units
    self.dropout_prob = dropout_prob
    super(Admir_Affine_Output, self).__init__()
    self.trns_ob = self.translation_output_block(self.in_units, self.out_units)
    self.rss_ob = self.rot_scale_shear_output_block(self.in_units, self.out_units)
    return;
  
  def translation_output_block(self, in_units, out_units):
    layer = nn.Sequential(
          nn.Linear(in_features = in_units, out_features= out_units),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units, out_features= out_units//2),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units//2, out_features= out_units//4),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units//4, out_features= out_units//8),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units//8, out_features= 3))
    return layer

  def rot_scale_shear_output_block(self, in_units, out_units):
    layer = nn.Sequential(
          nn.Linear(in_features = in_units, out_features= out_units),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units, out_features= out_units//2),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units//2, out_features= out_units//4),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units//4, out_features= out_units//8),
          nn.Dropout(p=self.dropout_prob),
          nn.Linear(in_features=out_units//8, out_features= 9),
          nn.Tanh())
    return layer
  
  def forward(self, input_tnsr):
    ip = input_tnsr.flatten(start_dim=1, end_dim=4)
    #print(ip.shape)
    translation_output = self.trns_ob(ip)
    rotate_scale_shear_output = self.rss_ob(ip)
    return [translation_output, rotate_scale_shear_output]

In [16]:
affine_output_model = Admir_Affine_Output( in_units= 2560)

In [26]:
# Suraj: Added spatial transformation code for warping
for X,Y in training_generator:
  print(X.shape)
  print(Y.shape)
  conv_out = affine_conv_model(X, Y)
  print(conv_out.shape)
  output_out = affine_output_model(conv_out)
  print(output_out[0].shape)
  print(output_out[1].shape)
  affine_tnsr = torch.cat((output_out[1], output_out[0]), 1)
  theta = torch.reshape(affine_tnsr, (2, 3, 4))
  print("========== ============== =============")
  print(theta.shape)
  print(theta)
  print("========== ============== =============")
  grid = F.affine_grid(theta, (2, 1, 128, 128, 128))
  print(grid.shape)
  #print(grid)
  warped_image = F.grid_sample(Y, grid)
  print("========== ============== =============")
  print()

torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 40, 4, 4, 4])
torch.Size([2, 3])
torch.Size([2, 9])
torch.Size([2, 3, 4])
tensor([[[-0.1220, -0.0160,  0.0155,  0.2409],
         [-0.1245,  0.1718,  0.3042,  0.2036],
         [ 0.3949,  0.1187,  0.0862,  0.1176]],

        [[ 0.1460, -0.1198,  0.0899,  0.1257],
         [-0.0433, -0.0366,  0.3735, -0.1034],
         [ 0.3608, -0.0035,  0.1467,  0.1008]]], grad_fn=<ViewBackward>)
torch.Size([2, 128, 128, 128, 3])





torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 40, 4, 4, 4])
torch.Size([2, 3])
torch.Size([2, 9])
torch.Size([2, 3, 4])
tensor([[[-0.0956,  0.1057,  0.1511,  0.1364],
         [ 0.0248,  0.1241,  0.1382,  0.1282],
         [ 0.2615, -0.0539, -0.1083,  0.1467]],

        [[ 0.0624,  0.0283, -0.0429,  0.3442],
         [ 0.0137,  0.1154,  0.1751,  0.0942],
         [ 0.2517,  0.2129,  0.2051, -0.0048]]], grad_fn=<ViewBackward>)
torch.Size([2, 128, 128, 128, 3])

torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 1, 128, 128, 128])
torch.Size([2, 40, 4, 4, 4])
torch.Size([2, 3])
torch.Size([2, 9])
torch.Size([2, 3, 4])
tensor([[[ 0.0196, -0.0208,  0.0681,  0.1238],
         [-0.0448,  0.1987,  0.2191, -0.0853],
         [ 0.2400,  0.0952,  0.1087,  0.1642]],

        [[ 0.1690, -0.0879, -0.0913,  0.1474],
         [ 0.1242,  0.1627,  0.4899,  0.1611],
         [ 0.5175,  0.1139,  0.0516, -0.1211]]], grad_fn=<ViewBackward>)
torch.Size([2, 128, 128, 128, 

In [19]:
class SpatialTransformer(nn.Module):
    """
    N-D Spatial Transformer
    """

    def __init__(self, size, mode='bilinear'):
        super().__init__()

        self.mode = mode

        # create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.stack(grids)
        grid = torch.unsqueeze(grid, 0)
        grid = grid.type(torch.FloatTensor)

        # registering the grid as a buffer cleanly moves it to the GPU, but it also
        # adds it to the state dict. this is annoying since everything in the state dict
        # is included when saving weights to disk, so the model files are way bigger
        # than they need to be. so far, there does not appear to be an elegant solution.
        # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
        self.register_buffer('grid', grid)

    def forward(self, src, flow):
        # new locations
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        # need to normalize grid values to [-1, 1] for resampler
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        # move channels dim to last position
        # also not sure why, but the channels need to be reversed
        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode)

In [20]:
spatial_transformer = SpatialTransformer(size=(128, 28, 128))

In [21]:
print(spatial_transformer.grid.shape)
print("========= =========== ======")
print()
#print(spatial_transformer.grid)

torch.Size([1, 3, 128, 28, 128])

