<a href="https://colab.research.google.com/github/surajsrivathsa/image_registration/blob/main/FIRE_deformable_transformation_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Reference Paper -

https://arxiv.org/pdf/1907.05062.pdf

In [None]:
!pip install SimpleITK

Collecting SimpleITK
[?25l  Downloading https://files.pythonhosted.org/packages/9c/6b/85df5eb3a8059b23a53a9f224476e75473f9bcc0a8583ed1a9c34619f372/SimpleITK-2.0.2-cp37-cp37m-manylinux2010_x86_64.whl (47.4MB)
[K     |████████████████████████████████| 47.4MB 70kB/s 
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.0.2


In [None]:
import warnings
import os
warnings.filterwarnings("ignore")
import numpy as np
import torch
print(torch.__version__)
import torchvision
print(torchvision.__version__)
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
# !pip install --upgrade nibabel
import nibabel as nb
import os, sys, glob
import SimpleITK as sitk
from google.colab import drive
drive.mount('/content/drive')
print()

1.8.0+cu101
0.9.0+cu101
Mounted at /content/drive



In [None]:
data_path_t1 = "/content/drive/My Drive/Image_Registration_Project/dataset_ants_resampled/T1_Train_200_Reg_downsampled_znm/"
data_path_t2 = "/content/drive/My Drive/Image_Registration_Project/dataset_ants_resampled/T1_Train_200_Reg_downsampled_znm/"
file_names_t1 = sorted(glob.glob(os.path.join(data_path_t1, "*.nii.gz")))
file_names_t2 = sorted(glob.glob(os.path.join(data_path_t2, "*.nii.gz")))

In [None]:
img_nb1 = nb.load(file_names_t1[0])
print(img_nb1.shape)
img_nb2 = nb.load(file_names_t2[0])
print(img_nb2.shape)

(91, 109, 91)
(91, 109, 91)


In [None]:
def load_4D(name):
        model_np = np.zeros(shape=(128, 128, 128))
        resamplng_shape = (128, 128, 128)
        X_nb = nb.load(name)
        #print(X_nb)
        X_np = X_nb.dataobj
        #print(X_np)
        x_dim, y_dim, z_dim = X_np.shape
        #print(x_dim, y_dim, z_dim)
        x_ltail = (resamplng_shape[0] - x_dim)//2 
        y_ltail = (resamplng_shape[1] - y_dim)//2
        z_ltail = (resamplng_shape[2] - z_dim)//2
        #print(x_ltail,y_ltail,z_ltail)

        x_rtail = resamplng_shape[0] - x_ltail - 1
        y_rtail = resamplng_shape[1] - y_ltail - 1
        z_rtail = resamplng_shape[2] - z_ltail - 1
        #print(x_rtail,y_rtail,z_rtail)
        model_np[x_ltail:x_rtail, y_ltail:y_rtail, z_ltail:z_rtail] = X_np[:, :, :]
        #print(model_np)
        return model_np

def imgnorm(N_I,index1=0.0001,index2=0.0001):
    I_sort = np.sort(N_I.flatten())
    I_min = I_sort[int(index1*len(I_sort))]
    I_max = I_sort[-int(index2*len(I_sort))]
    N_I =1.0*(N_I-I_min)/(I_max-I_min)
    N_I[N_I>1.0]=1.0
    N_I[N_I<0.0]=0.0
    N_I2 = N_I.astype(np.float32)
    return N_I2

def Norm_Zscore(img):
    img= (img-np.mean(img))/np.std(img) 
    return img

def save_img(I_img,savename):
    I2 = sitk.GetImageFromArray(I_img,isVector=False)
    sitk.WriteImage(I2,savename)

In [None]:
class Dataset(Data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, names_t1,names_t2,iterations =1,norm=True):
        'Initialization'
        self.names_t1 = names_t1
        self.names_t2 = names_t2
        self.norm = norm
        self.iterations = iterations
  def __len__(self):
        'Denotes the total number of samples'
        return len(self.names_t1) * self.iterations

  def __getitem__(self, step):
        'Generates one sample of data'
        # index_pair = np.random.permutation(len(self.names)) [0:4]
        img_A = load_4D(self.names_t1[step])
        img_B = load_4D(self.names_t2[step])     
        
        if self.norm:
            # return  Norm_Zscore(imgnorm(img_A)) , Norm_Zscore(imgnorm(img_B))
            return  imgnorm(img_A) , imgnorm(img_B)
        else:
            return img_A, img_B


In [None]:
training_generator = Data.DataLoader(Dataset(file_names_t1,file_names_t2,True), batch_size=2,shuffle=False)

In [None]:
for  X,Y in training_generator:
  print(torch.max(X))
  print(torch.min(Y))
  break

tensor(1.)
tensor(0.)


In [None]:
sitk_t1 = sitk.ReadImage(file_names_t1[0])
print(sitk_t1.GetSize())
print(load_4D(file_names_t1[0]).shape)

(91, 109, 91)
(128, 128, 128)


# Residual Block

In [None]:
def conv3x3x3(in_channels, out_channels, stride=1):
    return nn.Conv3d(in_channels,out_channels,kernel_size=3,stride=stride,padding=1,bias=False)
    
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.downsample = downsample
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

In [None]:
ResidualBlock(in_channels = 2, out_channels = 512)

ResidualBlock(
  (conv1): Conv3d(2, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
  (bn1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
  (bn2): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

# Transformation deformable network
 ### Please uncomment cdb variables as they are intended to be used for fullsized images Xa and not Ga, Gb ######
  ### Once commented take care of the same in the forward layer ########  

  In actual code it is better to hardcode number of channels instead of using start channels

In [None]:
class Transformation_Deformable_Network(nn.Module):
  def __init__(self,start_channel):
        # self.in_channel = in_channel
        self.start_channel = start_channel

        ## Declarations ##### 
        ### Please uncomment cdb variables as they are intended to be used for fullsized images Xa and not Ga, Gb ######
        ### Once commented take care of the same in the forward layer ########     
        super(Transformation_Deformable_Network, self).__init__()
        self.cdb_1_1 = self.convdownsampleblock(1, 16)
        self.cdb_1_2 = self.convdownsampleblock(1, 16)
        self.cdb_2_1 = self.convdownsampleblock(16, 64)
        self.cdb_2_2 = self.convdownsampleblock(16, 64)

        #self.convblock1 = self.convblock(self.start_channel * 32, self.start_channel * 16)
        #self.convblock2 = self.convblock(self.start_channel * 16, self.start_channel * 4)
        self.convblock1 = self.convblock(self.start_channel * 32, 8)
        self.convblock2 = self.convblock(self.start_channel * 32, 8)

        self.rb1 = ResidualBlock(16, 16, 1)

        ## Harcoded to get the output channels to 3 as deformable field has 3 fields ##
        self.convblock3 = self.convblock(16, 3)
        self.lkrelublock1 = self.leakyrelublock()
        self.lkrelublock2 = self.leakyrelublock()
        self.lkrelublock3 = self.leakyrelublock()

        #self.inb1 = self.instancenormblock(self.start_channel * 3)
        #self.inb2 = self.instancenormblock(self.start_channel * 3)

        self.inb1 = self.instancenormblock(3)
        self.inb2 = self.instancenormblock(3)


        self.tb1 = self.tanhblock()

        return;


  def convblock(self, in_channels, out_channels, kernel_size=3, bias=False, batchnorm=False):
    layer = nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size, bias=bias, padding=1),)
    return layer

  def convdownsampleblock(self, in_channels , out_channels, kernel_size=3, stride=2, padding=1, bias=True):
    layer = nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
                                      nn.BatchNorm3d(out_channels),
                                      nn.ReLU())
    return layer
          
  def leakyrelublock(self):
    layer = nn.LeakyReLU()
    return layer
          
  def instancenormblock(self, out_channels):
    layer = nn.InstanceNorm3d(out_channels)
    return layer

  def tanhblock(self):
    layer = nn.Tanh()
    return layer

  def forward(self, gx, gy):
    cdb11 = self.cdb_1_1(gx)
    cdb12 = self.cdb_1_2(gy)
    cdb21 = self.cdb_2_1(cdb11)
    cdb22 = self.cdb_2_2(cdb12)

    cb1 = self.convblock1(cdb21)
    cb1 = self.lkrelublock1(cb1)
    cb2 = self.convblock2(cdb22)
    cb2 = self.lkrelublock2(cb2)

    cat_in=torch.cat((cb1, cb2), 1)

    rb = self.rb1(cat_in)
    print(rb.shape)
    ib1 = self.inb1(rb)
    print(ib1.shape)
    lk = self.lkrelublock3(ib1)
    cb3 = self.convblock3(lk)
    ib2 = self.inb2(cb3)
    tanhb1 = self.tb1(ib2)
    return tanhb1;

In [None]:
mymodel = Transformation_Deformable_Network(2).to("cuda")

In [None]:
x = torch.randn(size=(2, 1, 128, 128, 128)).to("cuda")
y = torch.randn(size=(2, 1, 128, 128, 128)).to("cuda")



In [None]:
tanhbo = mymodel(x, y)

torch.Size([2, 16, 32, 32, 32])
torch.Size([2, 16, 32, 32, 32])


# Spatial Transformer

In [None]:
class SpatialTransformer(nn.Module):
    """
    N-D Spatial Transformer
    """

    def __init__(self, size, is_affine=False, theta = None, mode='bilinear', affine_image_size =  (2, 1, 128, 128, 128)):
        super().__init__()

        self.mode = mode
        self.isaffine = is_affine
        self.theta = theta
        self.affine_image_size =  affine_image_size
        # create sampling grid
        # registering the grid as a buffer cleanly moves it to the GPU, but it also
        # adds it to the state dict. this is annoying since everything in the state dict
        # is included when saving weights to disk, so the model files are way bigger
        # than they need to be. so far, there does not appear to be an elegant solution.
        # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict

        if (self.isaffine):
          grid = F.affine_grid(self.theta, self.affine_image_size, align_corners=False)
          #grid = grid.permute(0, 4, 1, 2, 3)
          self.register_buffer('grid', grid)
        else:
          vectors = [torch.arange(0, s) for s in size]
          grids = torch.meshgrid(vectors)
          grid = torch.stack(grids)
          grid = torch.unsqueeze(grid, 0)
          grid = grid.type(torch.FloatTensor)
          self.register_buffer('grid', grid)

    def forward(self, src, flow=None):      
      if (self.isaffine):
        grid = F.affine_grid(self.theta, self.affine_image_size)        
        warped_image = F.grid_sample(src, grid)
        #warped_image = warped_image.permute(0, 4, 1, 2, 3)
        return warped_image
      else:
        # new locations
        print(self.grid.shape)
        print(flow.shape)
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        # need to normalize grid values to [-1, 1] for resampler
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        # move channels dim to last position
        # also not sure why, but the channels need to be reversed
        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode)


In [None]:
spatial_transformer_deformable = SpatialTransformer(size=(32, 32, 32), is_affine=False).to("cuda")
print(spatial_transformer_deformable.grid.shape)
print(spatial_transformer_deformable.isaffine)
print("========= =========== ======")
print()

torch.Size([1, 3, 32, 32, 32])
False



In [None]:
gx_affine = torch.randn(size=(2, 128, 32, 32, 32)).to("cuda")

# Output of deformable

In [None]:
stdef_op = spatial_transformer_deformable(gx_affine, tanhbo)
print(stdef_op.shape)

torch.Size([1, 3, 32, 32, 32])
torch.Size([2, 3, 32, 32, 32])
torch.Size([2, 128, 32, 32, 32])


In [None]:
stdef_op = spatial_transformer_deformable(x, tanhbo)
print(stdef_op.shape)

torch.Size([1, 3, 32, 32, 32])
torch.Size([2, 3, 32, 32, 32])
torch.Size([2, 1, 32, 32, 32])
