##### Understand, in 3D context, torch-interpol and torch-functional affine matrix, grid and warping

#### Import

In [1]:
####![image_info](./dicomCT.png)

In [2]:
#Imports
import os, sys, json, pathlib, shutil, glob
import pandas as pd
import csv
import SimpleITK as sitk
import nibabel as nib
import random
import math
import numpy as np
from PIL import Image, ImageFont, ImageDraw

import scipy
from scipy.io import loadmat
from scipy import  signal

import matplotlib.pyplot as plt
import plotly.graph_objects as go

import ipywidgets as widgets
from ipywidgets import interactive,interact, interact_manual, HBox, Layout,VBox
from IPython.display import display, clear_output

import interpol
from interpol.api import affine_grid

from scipy import ndimage

from functools import partial

import torch
import torch.nn.functional as F
from torch.nn import MSELoss
from viu.io import volume
from viu.io.volume import read_volume
from viu.torch.deformation.fields import DVF, set_identity_mapping_cache
from viu.torch.io.deformation import *
from viu.util.body_mask import seg_body

# import ipywidgets as ipyw
import ipywidgets as widgets
from ipywidgets import interactive,interact, interact_manual, HBox, Layout,VBox
from IPython.display import display, clear_output

In [3]:
from exampleUtils import *

In [4]:
runTestCode_sec1=True
runTestCode_sec2=True
runTestCode_sec3=True

##### Read test volume

In [5]:
# experimentFolder = pathlib.Path('/mnt/data/supratik/diaphragm_detection/data/')
experimentFolder = pathlib.Path('/home/wd974888/Downloads/workingFolder/DeformationExperiment/')
#experimentFolder = pathlib.Path('/media/data/supratik/workingFolder/DeformationExperiment/demoDataDvfAndPCA/')

patientMRN='Patient11AB_bin19'
dcmFolder = experimentFolder / patientMRN
print(f'dcmFolder {dcmFolder}')
vol, res, pos = read_volume(str(dcmFolder))
print(f'vol type {(type(vol))} shape_ZYX {vol.shape} \
      min {np.min(vol)} max {np.max(vol)} ')
print(f'res_XYZ type {type(res)} value {res} ')
print(f'pos_XYZ type {type(pos)} value {pos}')

#Make the dimensions unequal to better manifest effect of rotation about
# different axis. Truncated volume in the width direction from one side
# (not ceneterd).
truncation_w = 24
vol = vol[:,:,:-truncation_w]
#With truncation in X direction, the pos matching the center of the volume
# in patient co-ordinate also shifts
pos[0] = pos[0] - 0.5*(truncation_w)*res[0]
print(f'truncated vol shape_ZYX {vol.shape}')
print(f'truncated vol pos_VYZ {pos}')
depth, height, width = vol.shape


dcmFolder /home/wd974888/Downloads/workingFolder/DeformationExperiment/Patient11AB_bin19
vol type <class 'numpy.ndarray'> shape_ZYX (452, 512, 512)       min -1890.0 max 36208.0 
res_XYZ type <class 'numpy.ndarray'> value [0.9765625 0.9765625 1.       ] 
pos_XYZ type <class 'numpy.ndarray'> value [ 1.71875000e-03 -2.09498281e+02 -9.00000000e-01]
truncated vol shape_ZYX (452, 512, 488)
truncated vol pos_VYZ [ -11.71703125 -209.49828125   -0.9       ]


In [6]:
v1_volumeComparisonViewer3D(
    listVolumes=[vol],listLabels=['original'],
    maxZ0=vol.shape[0], maxZ1=vol.shape[1], maxZ2=vol.shape[2],
    figsize=(12,8), cmap='gray',
    displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

interactive(children=(Output(),), _dom_classes=('widget-interact',))

<exampleUtils.v1_volumeComparisonViewer3D at 0x7f2bc1d4c310>

In [7]:
align_corners=True
device=torch.device('cuda:0')
batchSize=1

In [8]:
#volume tensor with batch and channel
vol_tensor = torch.from_numpy(vol).to(device).unsqueeze(0).unsqueeze(0)
printTensor("vol_tensor", vol_tensor)


vol_tensor shape: torch.Size([1, 1, 452, 512, 488]) device: cuda:0 dtype: torch.float32


identity grid in 2D and in 3D in matrix (ij) and cartesian co-ordinate system

identity grid in 2D

In [9]:
if runTestCode_sec1:
    tmp_h, tmp_w = 5, 6
    print(f'tmp_h {tmp_h}, tmp_w {tmp_w}')
    #tmp_h 5, tmp_w 6
    tmp_IdGrid_matrix_2d = torch.stack(torch.meshgrid(*[torch.arange(s).float() for s in [tmp_h, tmp_w] ],indexing='ij'), -1).numpy().astype('float32')
    print(f'Pass shape [tmp_h {tmp_h}, tmp_w {tmp_w}]: tmp_IdGrid_matrix_2d type {(type(tmp_IdGrid_matrix_2d))} shape {tmp_IdGrid_matrix_2d.shape} dtype {tmp_IdGrid_matrix_2d.dtype}')

    tmp_IdGrid_cartesian_2d = torch.stack(torch.meshgrid(*[torch.arange(s).float() for s in [tmp_h, tmp_w] ],indexing='xy'), -1).numpy().astype('float32')
    print(f'Pass shape [tmp_h {tmp_h}, tmp_w {tmp_w}]: tmp_IdGrid_cartesian_2d type {(type(tmp_IdGrid_cartesian_2d))} shape {tmp_IdGrid_cartesian_2d.shape} dtype {tmp_IdGrid_cartesian_2d.dtype}')


tmp_h 5, tmp_w 6
Pass shape [tmp_h 5, tmp_w 6]: tmp_IdGrid_matrix_2d type <class 'numpy.ndarray'> shape (5, 6, 2) dtype float32
Pass shape [tmp_h 5, tmp_w 6]: tmp_IdGrid_cartesian_2d type <class 'numpy.ndarray'> shape (6, 5, 2) dtype float32


identity grid in 3D

In [10]:
if runTestCode_sec1:
    tmp_d, tmp_h, tmp_w = 4, 5, 6
    print(f'tmp_d {tmp_d}, tmp_h {tmp_h}, tmp_w {tmp_w}')
    #tmp_d 4, tmp_h 5, tmp_w 6
    tmp_IdGrid_matrix = torch.stack(torch.meshgrid(*[torch.arange(s).float() for s in [tmp_d, tmp_h, tmp_w] ],indexing='ij'), -1).numpy().astype('float32')
    print(f'Pass shape [tmp_d {tmp_d}, tmp_h {tmp_h}, tmp_w {tmp_w}]: tmp_IdGrid_matrix type {(type(tmp_IdGrid_matrix))} shape {tmp_IdGrid_matrix.shape} dtype {tmp_IdGrid_matrix.dtype}')
    # Pass shape [tmp_d 4, tmp_h 5, tmp_w 6]: tmp_IdGrid_matrix type <class 'numpy.ndarray'> shape (4, 5, 6, 3) dtype float32
    #tmp_IdGrid_matrix[k,:,:,0]: Channel 0 all elements filled with z-coordinate (namely k) of the slice: for all slice k=0,1,...,toy_d-1 
    #tmp_IdGrid_matrix[k,:,:,1]: Channel 1 columns filled with  (i) y-coordinates: all columns  are transpose([0,1,... toy_h-1]), for all slice k=0,1,...,toy_d-1 
    #tmp_IdGrid_matrix[k,:,:,2]: Channel 2 rows filled with  (j) x-coordinates:  all rows  are [0,1,... toy_w-1],, for all slice  k=0,1,...,toy_d-1 
    tmp_IdGrid_cartesian = torch.stack(torch.meshgrid(*[torch.arange(s).float() for s in [tmp_d, tmp_h, tmp_w] ],indexing='xy'), -1).numpy().astype('float32')
    print(f'Pass shape [tmp_d {tmp_d}, tmp_h {tmp_h}, tmp_w {tmp_w}]: tmp_IdGrid_cartesian type {(type(tmp_IdGrid_cartesian))} shape {tmp_IdGrid_cartesian.shape} dtype {tmp_IdGrid_cartesian.dtype}')
    # Pass shape [tmp_d 4, tmp_h 5, tmp_w 6]: tmp_IdGrid_cartesian type <class 'numpy.ndarray'> shape (5, 4, 6, 3) dtype float32
    print('If we assumed, 1st dim is depth then in generated cartesian mesgrid, 2nd dimension corresponds to depth.')
    #If we assumed, 1st dim is depth then in generated cartesian mesgrid, 2nd dimension corresponds to depth.
    #tmp_IdGrid_cartesian[:,k,:,0]: Channel 0 all elements filled with z-coordinate (namely k) of the slice: for all slice k=0,1,...,toy_d-1 
    #tmp_IdGrid_cartesian[:,k,:,1]: Channel 1 columns filled with  (i) y-coordinates: all columns  are transpose([0,1,... toy_h-1]), for all slice k=0,1,...,toy_d-1 
    #tmp_IdGrid_cartesian[:,k,:,2]: Channel 2 rows filled with  (j) x-coordinates:  all rows  are [0,1,... toy_w-1],, for all slice  k=0,1,...,toy_d-1 

tmp_d 4, tmp_h 5, tmp_w 6
Pass shape [tmp_d 4, tmp_h 5, tmp_w 6]: tmp_IdGrid_matrix type <class 'numpy.ndarray'> shape (4, 5, 6, 3) dtype float32
Pass shape [tmp_d 4, tmp_h 5, tmp_w 6]: tmp_IdGrid_cartesian type <class 'numpy.ndarray'> shape (5, 4, 6, 3) dtype float32
If we assumed, 1st dim is depth then in generated cartesian mesgrid, 2nd dimension corresponds to depth.


##### Test steps

For rotation around Z, Y, X axis passing through center of volume

    - Find affine matrix  for torch-interpol grid

    - Create torch-interpol push  grid and apply torch-interpol grid_push

    - Create torch-interpol pull  grid and apply torch-interpol grid_pull

    - Convert torch-interpol pull grid to torch functional pull grid and  apply torch functional grid_warp(pull)

    - Convert  torch functional pull grid  back to torch-interpol pull grid  and apply torch-interpol grid_pull

In [11]:
debug=True

Create affine matrix

In [12]:
# positive theta_deg <==> clockwise (axial, coronal) or counter-clockwise (sagittal)
# negative theta_deg <==> counter-clockwise (axial, coronal) or clockwise (sagittal)
theta_deg = 10.00
viewString='sagittal'
tra_z, tra_y, tra_x = 30, 10, 60

if (theta_deg >=0 and viewString in ["axial", "coronal"]) or\
   (theta_deg <0 and viewString in ["sagittal"]):
    rot_dir = 'clockwise'
else:
    rot_dir = 'counter-clockwise'
description = f"Expected {rot_dir} rotation by {abs(theta_deg)} degrees around "\
    + f"{viewString} axis followed by translation \n in  all rotated directions: "\
    + f" by {tra_z}  in z (F--> H), {tra_y} in  y (A-->P), {tra_x} in  x (R-->L)"
# print(description)

#Get the push affine matrix in XYZ order
pushAffine_np = getPushRotationMatrix(
    theta_deg=theta_deg,
    viewString=viewString,
    center_slice_z=depth/2.0,
    center_row_y=height/2.0,
    center_col_x=width/2.0,
    tra_z=tra_z, tra_y=tra_y, tra_x=tra_x)

if debug:
    print(f'Unnormalized pushAffine_np {pushAffine_np}')

pullAffine_np = np.linalg.inv(pushAffine_np)
if debug:
    print(f'Unnormalized pullAffine_np {pullAffine_np}')


Unnormalized pushAffine_np [[  0.9848077    0.17364818   0.         -11.020477  ]
 [ -0.17364818   0.9848077    0.          53.133713  ]
 [  0.           0.           1.          60.        ]
 [  0.           0.           0.           1.        ]]
Unnormalized pullAffine_np [[  0.9848078  -0.1736482   0.         20.079624 ]
 [  0.1736482   0.9848078   0.        -50.412807 ]
 [  0.          0.          1.        -60.       ]
 [  0.          0.          0.          1.       ]]


Create torch interpol push grid and apply  push warp

In [13]:
if runTestCode_sec1:
    pushAffine_unBatchedTensor = torch.from_numpy(pushAffine_np).to(device)
    if debug:
        printTensor("pushAffine_unBatchedTensor", pushAffine_unBatchedTensor)
    aRotGrid_push_unbatched = affine_grid(pushAffine_unBatchedTensor,
        [depth, height, width])
    if debug:
        printTensor("aRotGrid_push_unbatched", aRotGrid_push_unbatched)
    #Expand to batch
    aRotGrid_push_batched = aRotGrid_push_unbatched.expand(batchSize,
        *aRotGrid_push_unbatched.shape)
    if debug:
        printTensor("aRotGrid_push_batched", aRotGrid_push_batched)

pushAffine_unBatchedTensor shape: torch.Size([4, 4]) device: cuda:0 dtype: torch.float32
aRotGrid_push_unbatched shape: torch.Size([452, 512, 488, 3]) device: cuda:0 dtype: torch.float32
aRotGrid_push_batched shape: torch.Size([1, 452, 512, 488, 3]) device: cuda:0 dtype: torch.float32


In [14]:
if runTestCode_sec1:
   aVol_warp_interpol_push_tensor = interpol.grid_push(vol_tensor, aRotGrid_push_batched,
      interpolation='cubic', bound='zeros',prefilter=True)
   if debug:
      printTensor("aVol_warp_interpol_push_tensor", aVol_warp_interpol_push_tensor)
   aVol_warp_interpol_push_np = aVol_warp_interpol_push_tensor\
      .clone().squeeze(0).squeeze().cpu().numpy()
   if debug:
      print(f"aVol_warp_interpol_push_np shape {aVol_warp_interpol_push_np.shape} "
            f"dtype {aVol_warp_interpol_push_np.dtype}")
   v1_volumeComparisonViewer3D(
      listVolumes=[aVol_warp_interpol_push_np],
      listLabels=[f'Interpol_PushWarp_Rot:  {description}'],
      maxZ0=aVol_warp_interpol_push_np.shape[0],
      maxZ1=aVol_warp_interpol_push_np.shape[1],
      maxZ2=aVol_warp_interpol_push_np.shape[2],
      figsize=(12,8), cmap='gray',
      displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

  output = spline_coeff_nd(input, *opt)


aVol_warp_interpol_push_tensor shape: torch.Size([1, 1, 452, 512, 488]) device: cuda:0 dtype: torch.float32
aVol_warp_interpol_push_np shape (452, 512, 488) dtype float32


interactive(children=(Output(),), _dom_classes=('widget-interact',))

Create torch interpol pull grid and apply pull warp

In [15]:
if runTestCode_sec1:
    pullAffine_unBatchedTensor = torch.from_numpy(pullAffine_np).to(device)
    if debug:
        printTensor("pullAffine_unBatchedTensor", pullAffine_unBatchedTensor)
    aRotGrid_pull_unbatched = affine_grid(pullAffine_unBatchedTensor,
        [depth, height, width])
    if debug:
        printTensor("aRotGrid_pull_unbatched", aRotGrid_pull_unbatched)
    #Expand to batch
    aRotGrid_pull_batched = aRotGrid_pull_unbatched.expand(batchSize,
        *aRotGrid_pull_unbatched.shape)
    if debug:
        printTensor("aRotGrid_pull_batched", aRotGrid_pull_batched)

pullAffine_unBatchedTensor shape: torch.Size([4, 4]) device: cuda:0 dtype: torch.float32
aRotGrid_pull_unbatched shape: torch.Size([452, 512, 488, 3]) device: cuda:0 dtype: torch.float32
aRotGrid_pull_batched shape: torch.Size([1, 452, 512, 488, 3]) device: cuda:0 dtype: torch.float32


In [16]:
if runTestCode_sec1:
   aVol_warp_interpol_pull_tensor = interpol.grid_pull(vol_tensor,
      aRotGrid_pull_batched, interpolation='cubic', bound='zeros',prefilter=True)
   if debug:
      printTensor("aVol_warp_interpol_pull_tensor", aVol_warp_interpol_pull_tensor)
   aVol_warp_interpol_pull_np = aVol_warp_interpol_pull_tensor.clone().squeeze().squeeze().cpu().numpy()
   if debug:
      print(f"aVol_warp_interpol_pull_np shape {aVol_warp_interpol_pull_np.shape} "
            f"dtype {aVol_warp_interpol_pull_np.dtype}")
   v1_volumeComparisonViewer3D(
      listVolumes=[aVol_warp_interpol_pull_np],
      listLabels=[f'Interpol_pullWarp_Rot: {description}'],
      maxZ0=aVol_warp_interpol_pull_np.shape[0],
      maxZ1=aVol_warp_interpol_pull_np.shape[1],
      maxZ2=aVol_warp_interpol_pull_np.shape[2],
      figsize=(12,8), cmap='gray',
      displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

 does not have profile information (Triggered internally at /opt/conda/conda-bld/pytorch_1682343997789/work/third_party/nvfuser/csrc/graph_fuser.cpp:104.)
  output = spline_coeff_nd(input, *opt)


aVol_warp_interpol_pull_tensor shape: torch.Size([1, 1, 452, 512, 488]) device: cuda:0 dtype: torch.float32
aVol_warp_interpol_pull_np shape (452, 512, 488) dtype float32


interactive(children=(Output(),), _dom_classes=('widget-interact',))

#### Direct conversion from torch interpol pull grid to Torch functional pull grid by normalization

Create torch functional pull grid  from torch interpol pull grid and apply  PyTorch pull warp

In [17]:
if runTestCode_sec1:
    aTorchFunctionalPullGrid_batched =\
        convertGrid_interpol2functional(aRotGrid_pull_batched,depth, height, width)
    if debug:
        printTensor("aTorchFunctionalPullGrid_batched", aTorchFunctionalPullGrid_batched)
    aVol_warp_functional_pull_tensor = F.grid_sample(
        input=vol_tensor,
        grid=aTorchFunctionalPullGrid_batched,
        mode='bilinear',
        padding_mode='zeros',
        align_corners=align_corners)
    if debug:
        printTensor("aVol_warp_functional_pull_tensor", aVol_warp_functional_pull_tensor)
    aVol_warp_functional_pull_np = aVol_warp_functional_pull_tensor.clone().squeeze().squeeze().cpu().numpy()
    if debug:
        print(f"aVol_warp_functional_pull_np shape {aVol_warp_functional_pull_np.shape} "
              f"dtype {aVol_warp_functional_pull_np.dtype}")
    v1_volumeComparisonViewer3D(
        listVolumes=[aVol_warp_functional_pull_np],
        listLabels=[f'Fuctional_PullWarp_Rot:  {description}'],
        maxZ0=aVol_warp_functional_pull_np.shape[0],
        maxZ1=aVol_warp_functional_pull_np.shape[1],
        maxZ2=aVol_warp_functional_pull_np.shape[2],
        figsize=(12,8), cmap='gray',
        displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

aTorchFunctionalPullGrid_batched shape: torch.Size([1, 452, 512, 488, 3]) device: cuda:0 dtype: torch.float32
aVol_warp_functional_pull_tensor shape: torch.Size([1, 1, 452, 512, 488]) device: cuda:0 dtype: torch.float32
aVol_warp_functional_pull_np shape (452, 512, 488) dtype float32


interactive(children=(Output(),), _dom_classes=('widget-interact',))

#### Direct conversion from torch functional pull grid to torch interpol pull grid to  by de-normalization

Create torch interpol pull grid from  torch functional pull grid  and apply  torch interpol pull warp

In [18]:
if runTestCode_sec1:
   aRotGrid_pull_batched_from_converted =\
      convertGrid_functional2interpol(aTorchFunctionalPullGrid_batched, depth, height, width)
   aVol_warp_interpol_pull_from_converted = interpol.grid_pull(vol_tensor,
      aRotGrid_pull_batched_from_converted,
      interpolation='cubic', bound='zeros',prefilter=True)
   if debug:
      printTensor("aVol_warp_interpol_pull_from_converted",
         aVol_warp_interpol_pull_from_converted)
   aVol_warp_interpol_pull_from_converted_np = aVol_warp_interpol_pull_from_converted.clone().squeeze().squeeze().cpu().numpy()
   if debug:
      print(f"aVol_warp_interpol_pull_from_converted_np shape "
            f"{aVol_warp_interpol_pull_from_converted_np.shape} "
            f"dtype {aVol_warp_interpol_pull_from_converted_np.dtype}")
   v1_volumeComparisonViewer3D(
      listVolumes=[aVol_warp_interpol_pull_from_converted_np],
      listLabels=[f'F2InterpolConverted_pullWarp_Rot:  {description}'],
      maxZ0=aVol_warp_interpol_pull_from_converted_np.shape[0],
      maxZ1=aVol_warp_interpol_pull_from_converted_np.shape[1],
      maxZ2=aVol_warp_interpol_pull_from_converted_np.shape[2],
      figsize=(12,8), cmap='gray',
      displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

aVol_warp_interpol_pull_from_converted shape: torch.Size([1, 1, 452, 512, 488]) device: cuda:0 dtype: torch.float32
aVol_warp_interpol_pull_from_converted_np shape (452, 512, 488) dtype float32


interactive(children=(Output(),), _dom_classes=('widget-interact',))

##### Test steps

-   Create torch-functional affine mat from un-normalized  torch-interpol affine matrix. Create both push and pull type affine mat

-   Create Torch functional pull grid from pull affine Mat, apply warp and display. Note torch functional has only pull-warp.

-   Create VIU DVF push grid and pull grid from push and pull affine Mat.

-   Apply VIU DVF pull warp  without prefilter using pull grid and display. Internally torch  functional pull-warp will be used.

-   Apply VIU DVF  pull warp WITH prefilter using pull grid and display. Internally torch-interpol pull-warp will be used as prefilter=True.

-   Apply VIU DVF  push warp WITH prefilter using push grid and display. Internally torch-interpol pull-warp will be used as (push-warp= | prefilter=True).



Create torch-functional affine mat from un-normalized  torch-interpol affine matrix. Create both push and pull type affine mat

In [19]:
pyTorchPushAffineMat3x4 = getPyTorchAffineMatTensor(pushAffine_np,
    depth, height, width, depth, height, width, device)
if debug:
    printTensor('pyTorchPushAffineMat3x4', pyTorchPushAffineMat3x4)
    print(f'pyTorchPushAffineMat3x4 {pyTorchPushAffineMat3x4}')
pyTorchPullAffineMat3x4 = getPyTorchAffineMatTensor(pullAffine_np,
    depth, height, width, depth, height, width, device)
if debug:
    printTensor('pyTorchPullAffineMat3x4', pyTorchPullAffineMat3x4)
    print(f'pyTorchPullAffineMat3x4 {pyTorchPullAffineMat3x4}')

pyTorchPushAffineMat3x4 shape: torch.Size([3, 4]) device: cuda:0 dtype: torch.float32
pyTorchPushAffineMat3x4 tensor([[ 1.0000,  0.0000,  0.0000,  0.2459],
        [ 0.0000,  0.9848, -0.1533,  0.0391],
        [ 0.0000,  0.1967,  0.9848,  0.1327]], device='cuda:0')
pyTorchPullAffineMat3x4 shape: torch.Size([3, 4]) device: cuda:0 dtype: torch.float32
pyTorchPullAffineMat3x4 tensor([[ 1.0000,  0.0000,  0.0000, -0.2459],
        [ 0.0000,  0.9848,  0.1533, -0.0588],
        [ 0.0000, -0.1967,  0.9848, -0.1230]], device='cuda:0')


Create VIU DVF push grid and pull grid from push and pull affine Mat.

In [20]:
if debug:
    printTensor('pyTorchPushAffineMat3x4', pyTorchPushAffineMat3x4)
    printTensor('pyTorchPullAffineMat3x4', pyTorchPullAffineMat3x4)
nb_dim = pyTorchPullAffineMat3x4.shape[-1]-1
viuDVFGridSizeParam=(batchSize, 1, depth, height, width, nb_dim)
if debug:
    print(f"viuDVFGridSizeParam shape: (batch, channel, D, H, W, nbDim=3): "
          f"{viuDVFGridSizeParam}")

#While calling DVF.affine, in the affine matrix, the last row (0,0,1) for 2D or (0,0,0,1) for is removed. However no batch is added.
#Also  while passing the size, the  DVF  dimension  (2 or 3) is being passed  as the last elemet of the size  tuple.
#Further  line #393 of field.py  is to be modified.
#Another important difference  is that  the DVF returned by DVF.affine  includes a subtraction of the identity map
viu_grid_push = DVF.affine(pyTorchPushAffineMat3x4, size=viuDVFGridSizeParam)
printTensor('viu_grid_push', viu_grid_push)
viu_grid_pull = DVF.affine(pyTorchPullAffineMat3x4, size=viuDVFGridSizeParam)
printTensor('viu_grid_pull', viu_grid_pull)

pyTorchPushAffineMat3x4 shape: torch.Size([3, 4]) device: cuda:0 dtype: torch.float32
pyTorchPullAffineMat3x4 shape: torch.Size([3, 4]) device: cuda:0 dtype: torch.float32
viuDVFGridSizeParam shape: (batch, channel, D, H, W, nbDim=3): (1, 1, 452, 512, 488, 3)
viu_grid_push shape: torch.Size([1, 452, 512, 488, 3]) device: cuda:0 dtype: torch.float32
viu_grid_pull shape: torch.Size([1, 452, 512, 488, 3]) device: cuda:0 dtype: torch.float32


Create Torch functional pull grid from pull affine Mat, apply warp and display. Note torch functional has only pull-warp.

In [21]:
if runTestCode_sec2:
    #Torch functional affine grid generation requires affine matrix with batch dimension
    pyTorchPullAffineMat3x4_batchAdded =\
        pyTorchPullAffineMat3x4.expand(batchSize, *pyTorchPullAffineMat3x4.shape)
    if debug:
        printTensor('pyTorchPullAffineMat3x4_batchAdded', pyTorchPullAffineMat3x4_batchAdded)
    # print(f'pyTorchThetaForAffineGrid_lastRowRemoved_batchAdded type {type(pyTorchThetaForAffineGrid_lastRowRemoved_batchAdded)} shape {pyTorchThetaForAffineGrid_lastRowRemoved_batchAdded.shape}')
    # print(f'pyTorchThetaForAffineGrid_lastRowRemoved_batchAdded  {pyTorchThetaForAffineGrid_lastRowRemoved_batchAdded}')
    grid_size_F = torch.Size([batchSize, 1, depth, height, width]) #batch, channel, depth, height, width
    print(f'grid_size_F {grid_size_F}')
    grid_F = F.affine_grid(
        theta=pyTorchPullAffineMat3x4_batchAdded,
        size=grid_size_F,
        align_corners=align_corners)
    if debug:
        print(f'grid_F type {type(grid_F)} shape {grid_F.shape} dtype {grid_F.dtype}')
    aVol_warp_functional_pullUsingMat_tensor = F.grid_sample(
        input=vol_tensor,
        grid=grid_F,
        mode='bilinear',
        padding_mode='zeros',
        align_corners=align_corners)
    if debug:
        printTensor("aVol_warp_functional_pullUsingMat_tensor", aVol_warp_functional_pullUsingMat_tensor)
    aVol_warp_functional_pullUsingMat_np =\
        aVol_warp_functional_pullUsingMat_tensor.clone().squeeze().squeeze().cpu().numpy()
    if debug:
        print(f"aVol_warp_functional_pullUsingMat_np shape "
              f"{aVol_warp_functional_pullUsingMat_np.shape} "
              f"dtype {aVol_warp_functional_pullUsingMat_np.dtype}")
    v1_volumeComparisonViewer3D(
        listVolumes=[aVol_warp_functional_pullUsingMat_np],
        listLabels=[f'Fuctional_PullWarpUsingMat_Rot:  {description}'],
        maxZ0=aVol_warp_functional_pullUsingMat_np.shape[0],
        maxZ1=aVol_warp_functional_pullUsingMat_np.shape[1],
        maxZ2=aVol_warp_functional_pullUsingMat_np.shape[2],
        figsize=(12,8), cmap='gray',
        displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

pyTorchPullAffineMat3x4_batchAdded shape: torch.Size([1, 3, 4]) device: cuda:0 dtype: torch.float32
grid_size_F torch.Size([1, 1, 452, 512, 488])
grid_F type <class 'torch.Tensor'> shape torch.Size([1, 452, 512, 488, 3]) dtype torch.float32
aVol_warp_functional_pullUsingMat_tensor shape: torch.Size([1, 1, 452, 512, 488]) device: cuda:0 dtype: torch.float32
aVol_warp_functional_pullUsingMat_np shape (452, 512, 488) dtype float32


interactive(children=(Output(),), _dom_classes=('widget-interact',))

Apply VIU DVF pull warp  without prefilter using pull grid and display. Internally torch  functional pull-warp will be used.

In [22]:
if runTestCode_sec2:
    if debug:
        printTensor('viu_grid_pull', viu_grid_pull)
        printTensor('vol_tensor', vol_tensor)
    #sample(self, input, mode="bilinear", padding_mode="zeros", warpingModeString="pull",prefilter=False)
    aVol_viu_grid_pull_warped = viu_grid_pull.sample(
        input=vol_tensor,
        mode='bilinear',
        padding_mode='zeros',
        warpingModeString="pull",
        prefilter=False)
    if debug:
        printTensor('aVol_viu_grid_pull_warped', aVol_viu_grid_pull_warped)
    aVol_viu_grid_pull_warped_np = aVol_viu_grid_pull_warped.squeeze().squeeze().clone().cpu().numpy()
    v1_volumeComparisonViewer3D(
        listVolumes=[aVol_viu_grid_pull_warped_np],
        listLabels=[f'VIUDVF_PullWarp_Rot:  {description}'],
        maxZ0=aVol_viu_grid_pull_warped_np.shape[0],
        maxZ1=aVol_viu_grid_pull_warped_np.shape[1],
        maxZ2=aVol_viu_grid_pull_warped_np.shape[2],
        figsize=(12,8), cmap='gray',
        displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

viu_grid_pull shape: torch.Size([1, 452, 512, 488, 3]) device: cuda:0 dtype: torch.float32
vol_tensor shape: torch.Size([1, 1, 452, 512, 488]) device: cuda:0 dtype: torch.float32
aVol_viu_grid_pull_warped shape: torch.Size([1, 1, 452, 512, 488]) device: cuda:0 dtype: torch.float32


interactive(children=(Output(),), _dom_classes=('widget-interact',))

Apply VIU DVF  pull warp WITH prefilter using pull grid and display. Internally torch-interpol pull-warp will be used as prefilter=True.

In [23]:
if runTestCode_sec2:
    if debug:
        printTensor('viu_grid_pull', viu_grid_pull)
        printTensor('vol_tensor', vol_tensor)
    #sample(self, input, mode="bilinear", padding_mode="zeros", warpingModeString="pull",prefilter=False)
    aVol_viu_grid_pull_prefilter_warped = viu_grid_pull.sample(
        input=vol_tensor,
        mode='cubic',
        padding_mode='zeros',
        warpingModeString="pull",
        prefilter=True)
    if debug:
        printTensor('aVol_viu_grid_pull_prefilter_warped', aVol_viu_grid_pull_prefilter_warped)
    aVol_viu_grid_pull_prefilter_warped_np = aVol_viu_grid_pull_prefilter_warped.squeeze().squeeze().clone().cpu().numpy()
    v1_volumeComparisonViewer3D(
        listVolumes=[aVol_viu_grid_pull_prefilter_warped_np],
        listLabels=[f'VIUDVF_pull_prefilterWarp_Rot:  {description}'],
        maxZ0=aVol_viu_grid_pull_prefilter_warped_np.shape[0],
        maxZ1=aVol_viu_grid_pull_prefilter_warped_np.shape[1],
        maxZ2=aVol_viu_grid_pull_prefilter_warped_np.shape[2],
        figsize=(12,8), cmap='gray',
        displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

viu_grid_pull shape: torch.Size([1, 452, 512, 488, 3]) device: cuda:0 dtype: torch.float32
vol_tensor shape: torch.Size([1, 1, 452, 512, 488]) device: cuda:0 dtype: torch.float32
aVol_viu_grid_pull_prefilter_warped shape: torch.Size([1, 1, 452, 512, 488]) device: cuda:0 dtype: torch.float32


interactive(children=(Output(),), _dom_classes=('widget-interact',))

##### Apply VIU DVF  push warp WITH prefilter using push grid and display. Internally torch-interpol pull-warp will be used as (push-warp= | prefilter=True).

In [24]:
if runTestCode_sec2:
    if debug:
        printTensor('viu_grid_push', viu_grid_push)
        printTensor('vol_tensor', vol_tensor)
    #sample(self, input, mode="bilinear", padding_mode="zeros", warpingModeString="pull",prefilter=False)
    aVol_viu_grid_push_prefilter_warped = viu_grid_push.sample(
        input=vol_tensor,
        mode='cubic',
        padding_mode='zeros',
        warpingModeString="push",
        prefilter=True)
    if debug:
        printTensor('aVol_viu_grid_push_prefilter_warped', aVol_viu_grid_push_prefilter_warped)
    aVol_viu_grid_push_prefilter_warped_np = aVol_viu_grid_push_prefilter_warped.squeeze().squeeze().clone().cpu().numpy()
    v1_volumeComparisonViewer3D(
        listVolumes=[aVol_viu_grid_push_prefilter_warped_np],
        listLabels=[f'VIUDVF_push_prefilterWarp_Rot:  {description}'],
        maxZ0=aVol_viu_grid_push_prefilter_warped_np.shape[0], 
        maxZ1=aVol_viu_grid_push_prefilter_warped_np.shape[1], 
        maxZ2=aVol_viu_grid_push_prefilter_warped_np.shape[2],
        figsize=(12,8), cmap='gray',
        displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

viu_grid_push shape: torch.Size([1, 452, 512, 488, 3]) device: cuda:0 dtype: torch.float32
vol_tensor shape: torch.Size([1, 1, 452, 512, 488]) device: cuda:0 dtype: torch.float32
aVol_viu_grid_push_prefilter_warped shape: torch.Size([1, 1, 452, 512, 488]) device: cuda:0 dtype: torch.float32


interactive(children=(Output(),), _dom_classes=('widget-interact',))

#### Timing consideration

Execute each of interpol grid pull and  functional grid pull N times and compare time taken

In [25]:
numRepeat = 10
import time

In [26]:
if runTestCode_sec3:
    start = time.time()
    for count in range(numRepeat):
        aVol_viu_grid_pull_warped = viu_grid_pull.sample(
            input=vol_tensor,
            mode='bilinear',
            padding_mode='zeros',
            warpingModeString="pull",
            prefilter=False)
    end = time.time()
    print(f'Volume shape {vol.shape}, numRepeat {numRepeat} : '
          f'VIU DVF Pull + bilinear + no-PreFilter => (Internally F.grid_sample) =>: '
          f'{(end - start):.2f} seconds')

Volume shape (452, 512, 488), numRepeat 10 : VIU DVF Pull + bilinear + no-PreFilter => (Internally F.grid_sample) =>: 0.05 seconds


In [27]:
if runTestCode_sec3:
    start = time.time()
    for count in range(numRepeat):
        Vol_viu_grid_pull_prefilter_warped = viu_grid_pull.sample(
            input=vol_tensor,
            mode='cubic',
            padding_mode='zeros',
            warpingModeString="pull",
            prefilter=True)
    end = time.time()
    print(f'Volume shape {vol.shape}, numRepeat {numRepeat} : '
          f'VIU DVF Pull + cubic + PreFilter => (Internally preFilter + interpol.grid_pull) =>: '
          f'{(end - start):.2f} seconds')

Volume shape (452, 512, 488), numRepeat 10 : VIU DVF Pull + cubic + PreFilter => (Internally preFilter + interpol.grid_pull) =>: 11.86 seconds


In [28]:
if runTestCode_sec3:
    start = time.time()
    for count in range(numRepeat):
        aVol_viu_grid_push_prefilter_warped = viu_grid_push.sample(
            input=vol_tensor,
            mode='cubic',
            padding_mode='zeros',
            warpingModeString="push",
            prefilter=True)
    end = time.time()
    print(f'Volume shape {vol.shape}, numRepeat {numRepeat} : '
          f'VIU DVF Push + cubic + PreFilter => (Internally preFilter + interpol.grid_push) =>: '
          f'{(end - start):.2f} seconds')

Volume shape (452, 512, 488), numRepeat 10 : VIU DVF Push + cubic + PreFilter => (Internally preFilter + interpol.grid_push) =>: 14.31 seconds
