In [1]:
import argparse
from argparse import Namespace

import os.path

import numpy as np
import argparse
import nibabel as nib

from scipy import signal
from torchmetrics import MeanSquaredError
from viu.io import volume
from viu.torch.deformation.fields import DVF
from viu.torch.math import torch_affine_to_vol_mat
from viu.util.body_mask import seg_body

from pamomo.registration.deformable import reg
import torch
import torch.nn.functional as tt

from exampleUtils import *

In [2]:
args = Namespace(\
    src=f'/home/wd974888/Downloads/workingFolder/DeformationExperiment/readFormat/Patient09PB_bin_01.nii.gz',\
    dst=f'/home/wd974888/Downloads/workingFolder/DeformationExperiment/readFormat/Patient09PB_bin_11.nii.gz',\
    body_seg=True,\
    air_threshold=-300,\
    alpha=50,\
    verboseMode=False,\
    maxNumberOfIterations=100,\
    numLevels=3,\
    saveMovedAsNIIFlag=True)
print(args)

Namespace(src='/home/wd974888/Downloads/workingFolder/DeformationExperiment/readFormat/Patient09PB_bin_01.nii.gz', dst='/home/wd974888/Downloads/workingFolder/DeformationExperiment/readFormat/Patient09PB_bin_11.nii.gz', body_seg=True, air_threshold=-300, alpha=50, verboseMode=False, maxNumberOfIterations=100, numLevels=3, saveMovedAsNIIFlag=True)


In [3]:
runFreshRegistration=False

In [4]:
#registrationCommand(args)
dst_path = os.path.dirname(args.dst)

src_vol: np.ndarray
if args.src.endswith('.nii.gz') and args.dst.endswith('.nii.gz'):
    src_vol, src_res, src_pos = ImportFromNII(args.src)
    dst_vol, dst_res, dst_pos = ImportFromNII(args.dst)
else:
    src_vol, src_res, src_pos = volume.read_volume(args.src)
    dst_vol, dst_res, dst_pos = volume.read_volume(args.dst)

print(f'src_vol dtype {src_vol.dtype},  shape ZYX {src_vol.shape}, res XYZ {src_res}, pos XYZ {src_pos}')
# src_vol dtype float64,  shape ZYX (590, 512, 512), res XYZ [0.9765625 0.9765625 1.       ], pos XYZ [-249.50999451 -469.51000977 -352.3999939 ]
print(f'dst_vol dtype {dst_vol.dtype},  shape ZYX {dst_vol.shape}, res XYZ {dst_res}, pos XYZ {dst_pos}')
# dst_vol dtype float64,  shape ZYX (590, 512, 512), res XYZ [0.9765625 0.9765625 1.       ], pos XYZ [-249.50999451 -469.51000977 -352.3999939 ]

cachedPath = os.path.join(dst_path, f'Patient09PB_cachedDVF.npz')
if True==runFreshRegistration:

    additional_args = {}
    if args.body_seg:
        print('Find connected components...')
        #Find connected components...
        msk = seg_body(dst_vol, air_threshold=args.air_threshold)
        additional_args.update({'dst_seg': {'SM_bodymask': msk}, 'similarityMaskMultilevelStrategy': 'STRICTINTERIOR'})

    print('Start registration...')
    vol_min = -1200
    dvf, dvf_res, dvf_pos = reg(src_vol.clip(min=vol_min), src_res,
                                dst_vol.clip(min=vol_min), dst_res,
                                alpha=args.alpha,#alpha=20,
                                numLevels=args.numLevels,
                                # finestLevelReference=0,
                                # finestLevelTemplate=0,
                                maxNumberOfIterations=args.maxNumberOfIterations,
                                **additional_args) #verboseMode=args.verboseMode, **additional_args #verboseMode='true', **additional_args
    np.savez_compressed(cachedPath, dvf, dvf_res, dvf_pos)
    # Start registration...

    # This is the Fraunhofer MEVIS cuda registration library. Version: VarianDeformableRegistrationDLLCUDA -- v1.6.2, built Dec 11 2023, 20:49:53
    # numSourceSegments <= 0. Running without mask alignment. Continuing.
    # Started registration...
    # Optimization on level 1 / 3 needs: 3.39014 s
    # Optimization on level 2 / 3 needs: 8.03293 s
    # Optimization on level 3 / 3 needs: 84.1789 s
    # Finished multilevel registration.
    # Total runtime: 103.207s.
    # Done.
else:
    cachedDVF = np.load(cachedPath)
    dvf, dvf_res, dvf_pos = cachedDVF['arr_0'], cachedDVF['arr_1'], cachedDVF['arr_2']

print(f'B4 DVFify dvf type {type(dvf)}  dtype {dvf.dtype},  shape ZYX {dvf.shape}')
# B4 DVFify dvf type <class 'numpy.ndarray'>  dtype float32,  shape ZYX (295, 256, 256, 3)
print(f'B4 tensorify dvf_res dtype {dvf_res.dtype},  shape {dvf_res.shape}, value XYZ  {dvf_res}')
# B4 tensorify dvf_res dtype float64,  shape (3,), value XYZ  [1.95695466 1.95695466 2.00340136]
print(f'B4 tensorify dvf_pos dtype {dvf_pos.dtype},  shape {dvf_pos.shape}, value XYZ  {dvf_pos}')
# B4 tensorify dvf_pos dtype float64,  shape (3,), value XYZ  [0. 0. 0.]


src_vol dtype float64,  shape ZYX (590, 512, 512), res XYZ [0.9765625 0.9765625 1.       ], pos XYZ [-249.50999451 -469.51000977 -352.3999939 ]
dst_vol dtype float64,  shape ZYX (590, 512, 512), res XYZ [0.9765625 0.9765625 1.       ], pos XYZ [-249.50999451 -469.51000977 -352.3999939 ]
B4 DVFify dvf type <class 'numpy.ndarray'>  dtype float32,  shape ZYX (295, 256, 256, 3)
B4 tensorify dvf_res dtype float64,  shape (3,), value XYZ  [1.95695466 1.95695466 2.00340136]
B4 tensorify dvf_pos dtype float64,  shape (3,), value XYZ  [0. 0. 0.]


In [5]:
dvf_res = torch.tensor(dvf_res, dtype=torch.float64)
print(f'After tensorify dvf_res dtype {dvf_res.dtype},  shape {dvf_res.shape}, value XYZ  {dvf_res}')
dvf_pos = torch.tensor(dvf_pos, dtype=torch.float64)
print(f'After tensorify dvf_pos dtype {dvf_pos.dtype},  shape {dvf_pos.shape}, value XYZ  {dvf_pos}')

dst_dim = torch.tensor(dst_vol.shape[::-1]) #From ZYX to XYZ
print(f'dst_dim in XYZ dtype {dst_dim.dtype},  shape {dst_dim.shape}, value XYZ  {dst_dim}')
dst_res = torch.tensor(dst_res, dtype=torch.float64)
print(f'dst_res in XYZ dtype {dst_res.dtype},  shape {dst_res.shape}, value XYZ  {dst_res}')

After tensorify dvf_res dtype torch.float64,  shape torch.Size([3]), value XYZ  tensor([1.9570, 1.9570, 2.0034], dtype=torch.float64)
After tensorify dvf_pos dtype torch.float64,  shape torch.Size([3]), value XYZ  tensor([0., 0., 0.], dtype=torch.float64)
dst_dim in XYZ dtype torch.int64,  shape torch.Size([3]), value XYZ  tensor([512, 512, 590])
dst_res in XYZ dtype torch.float64,  shape torch.Size([3]), value XYZ  tensor([0.9766, 0.9766, 1.0000], dtype=torch.float64)


In [6]:
#Create unbatched DVF object
# dvf = DVF(dvf).from_millimeter(dvf_res).to(torch.float32)
# Above line gives error: DVFs can be formed only of tensors of shape (B,H,W,2) or (B,D,H,W,3). Provided shape: torch.Size([295, 256, 256, 3]).
batched_dvf = DVF(dvf[None,...]).from_millimeter(dvf_res).to(torch.float32)
print(f'batched_dvf DVF object   dtype {batched_dvf.dtype},  shape ZYX3 {batched_dvf.shape}')

batched_dvf DVF object   dtype torch.float32,  shape ZYX3 torch.Size([1, 295, 256, 256, 3])


In [7]:
# #Old resample of unbatched DVF
# dvf = dvf.resample(
#     dst_dim,
#     dst_res,
#     #dst_pos: torch.Tensor = torch.zeros(3, dtype=torch.float64), <--- default
#     dvf_res=dvf_res,
#     dvf_pos=dvf_pos
#     # mode="bilinear",
#     # padding_mode="border"
#     )
# print(f'Resampled dvf type: {type(dvf)}')
# ######## Exception happening  otherwise ######
# if dvf.dtype !=torch.tensor(src_vol).dtype:
#     #Making both float32
#     # dvf = dvf.to(torch.tensor(src_vol).dtype)
#     dvf = dvf.to(torch.float32)
#     src_vol = src_vol.astype('float32')
# #########################################
# print(f'Resampled DVF object   dtype {dvf.dtype},  shape ZYX3 {dvf.shape}')
# #With change in sample() we need to have batch dimension added to DVF and batch and channel  dimension added to  input
# batched_dvf = dvf.unsqueeze(0)
# print(f'batched_dvf  object   dtype {batched_dvf.dtype},  shape ZYX3 {batched_dvf.shape}')

In [7]:
# Resample of batched DVF without preFilter
batched_dvf1 = batched_dvf.resample(
    dst_dim,
    dst_res,
    #dst_pos: torch.Tensor = torch.zeros(3, dtype=torch.float64), <--- default
    dvf_res=dvf_res,
    dvf_pos=dvf_pos,
    mode="bilinear",
    padding_mode="border",
    prefilter=False
    )
if batched_dvf1.dtype !=torch.tensor(src_vol).dtype:
    #Making both float32
    batched_dvf1 = batched_dvf1.to(torch.float32)
    src_vol = src_vol.astype('float32')
print(f'Resampled DVF object  type {type(batched_dvf1)}   dtype {batched_dvf1.dtype},  shape ZYX3 {batched_dvf1.shape}')

Resampled DVF object  type <class 'viu.torch.deformation.fields.DVF'>   dtype torch.float32,  shape ZYX3 torch.Size([1, 590, 512, 512, 3])


In [8]:
#Resample of batched DVF with preFilter
batched_dvf2 = batched_dvf.resample(
    dst_dim,
    dst_res,
    #dst_pos: torch.Tensor = torch.zeros(3, dtype=torch.float64), <--- default
    dvf_res=dvf_res,
    dvf_pos=dvf_pos,
    mode="cubic",
    padding_mode="border",
    prefilter=True
    )
if batched_dvf2.dtype !=torch.tensor(src_vol).dtype:
    #Making both float32
    batched_dvf2 = batched_dvf2.to(torch.float32)
    src_vol = src_vol.astype('float32')
print(f'Resampled DVF object  type: {type(batched_dvf2)} dtype{batched_dvf2.dtype} shape BZYX3 {batched_dvf2.shape}')

  output = spline_coeff_nd(input, *opt)


Resampled DVF object  type: <class 'viu.torch.deformation.fields.DVF'> dtypetorch.float32 shape BZYX3 torch.Size([1, 590, 512, 512, 3])


In [9]:
#Compare
print(f'batched_dvf1 min {torch.min(batched_dvf1)} max {torch.max(batched_dvf1)}')
print(f'batched_dvf2 min {torch.min(batched_dvf2)} max {torch.max(batched_dvf2)}')
print(f'batched_dvf1 - batched_dvf2 min {torch.min(batched_dvf1 - batched_dvf2)} max {torch.max(batched_dvf1 - batched_dvf2)} mean {torch.mean(batched_dvf1 - batched_dvf2)}')
print(f'|batched_dvf1 - batched_dvf2| min {torch.min(torch.abs(batched_dvf1 - batched_dvf2))} max {torch.max(torch.abs(batched_dvf1 - batched_dvf2))} mean {torch.mean(torch.abs(batched_dvf1 - batched_dvf2))}')

batched_dvf1 min DVF(-0.1099) max DVF(0.0270)
batched_dvf2 min DVF(-0.1099) max DVF(0.0271)
batched_dvf1 - batched_dvf2 min DVF(-0.0001) max DVF(0.0002) mean DVF(2.7087e-08)
|batched_dvf1 - batched_dvf2| min DVF(0.) max DVF(0.0002) mean DVF(3.5387e-06)


In [10]:
diff_dvf = (batched_dvf1 - batched_dvf2).cpu().numpy()
v1_volumeComparisonViewer3D(
    listVolumes=[diff_dvf[0,:,:,:,0], diff_dvf[0,:,:,:,1], diff_dvf[0,:,:,:,2]],listLabels=['diff_ch0', 'diff_ch1', 'diff_ch2'],
    maxZ0=diff_dvf[0,:,:,:,0].shape[0], maxZ1=diff_dvf[0,:,:,:,0].shape[1], maxZ2=diff_dvf[0,:,:,:,0].shape[2],
    figsize=(12,8), cmap='coolwarm',
    displayColorbar=True, useExternalWindowCenter=True, wMin=-0.05, wMax=0.05)

interactive(children=(Output(),), _dom_classes=('widget-interact',))

<exampleUtils.v1_volumeComparisonViewer3D at 0x7f5c93ca5280>

In [11]:
vol_tensor = torch.tensor(src_vol).unsqueeze(0).unsqueeze(0)
print(f'vol_tensor  object   dtype {vol_tensor.dtype},  shape ZYX3 {vol_tensor.shape}')

vol_tensor  object   dtype torch.float32,  shape ZYX3 torch.Size([1, 1, 590, 512, 512])


In [12]:
#_call_ method => pull warping, no prefilter
# warped_vol_F_noPreFilter = batched_dvf1(vol_tensor).squeeze(0).squeeze(0).numpy()
#Instead of call method use explicit sample method
warped_vol_F_noPreFilter = batched_dvf1.sample(vol_tensor, mode="bilinear", padding_mode="zeros", warpingModeString="pull",prefilter=False).squeeze(0).squeeze(0).numpy()
print(f'warped_vol_F_noPreFilter  object   dtype {warped_vol_F_noPreFilter.dtype},  shape ZYX3 {warped_vol_F_noPreFilter.shape}')

# #Display
# v1_volumeComparisonViewer3D(
#     listVolumes=[src_vol, dst_vol, warped_vol_F_noPreFilter],listLabels=['src', 'dst', 'warped_vol_F_noPreFilter'],
#     maxZ0=src_vol.shape[0], maxZ1=src_vol.shape[1], maxZ2=src_vol.shape[2],
#     figsize=(12,8), cmap='gray',
#     displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

v1_volumeComparisonViewer3D(
    listVolumes=[src_vol-dst_vol, warped_vol_F_noPreFilter-dst_vol],listLabels=['F-M', 'M*_NoPreFilter-M'],
    maxZ0=src_vol.shape[0], maxZ1=src_vol.shape[1], maxZ2=src_vol.shape[2],
    figsize=(12,8), cmap='coolwarm',
    displayColorbar=False, useExternalWindowCenter=True, wMin=-100, wMax=100)

warped_vol_F_noPreFilter  object   dtype float32,  shape ZYX3 (590, 512, 512)


interactive(children=(Output(),), _dom_classes=('widget-interact',))

<exampleUtils.v1_volumeComparisonViewer3D at 0x7f5c93c9bca0>

In [13]:
#_call_ method => pull warping, no prefilter
# warped_vol_interpol_preFilter = batched_dvf2(vol_tensor).squeeze(0).squeeze(0).numpy()
# Instead of call method use explicit sample method
# warped_vol_interpol_preFilter = batched_dvf2.sample(vol_tensor, mode="cubic", padding_mode="zeros", warpingModeString="pull",prefilter=True).squeeze(0).squeeze(0).numpy()
warped_vol_interpol_preFilter = batched_dvf2.sample(vol_tensor, mode="bilinear", padding_mode="zeros", warpingModeString="pull",prefilter=False).squeeze(0).squeeze(0).cpu().numpy()
print(f'warped_vol_interpol_preFilter  object   dtype {warped_vol_interpol_preFilter.dtype},  shape ZYX3 {warped_vol_interpol_preFilter.shape}')

# #Display
# v1_volumeComparisonViewer3D(
#     listVolumes=[src_vol, dst_vol, warped_vol_interpol_preFilter],listLabels=['src', 'dst', 'warped_vol_interpol_preFilter'],
#     maxZ0=src_vol.shape[0], maxZ1=src_vol.shape[1], maxZ2=src_vol.shape[2],
#     figsize=(12,8), cmap='gray',
#     displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

v1_volumeComparisonViewer3D(
    listVolumes=[src_vol-dst_vol, warped_vol_interpol_preFilter-dst_vol],listLabels=['F-M', 'M*_PreFilter-M'],
    maxZ0=src_vol.shape[0], maxZ1=src_vol.shape[1], maxZ2=src_vol.shape[2],
    figsize=(12,8), cmap='coolwarm',
    displayColorbar=False, useExternalWindowCenter=True, wMin=-100, wMax=100)

warped_vol_interpol_preFilter  object   dtype float32,  shape ZYX3 (590, 512, 512)


interactive(children=(Output(),), _dom_classes=('widget-interact',))

<exampleUtils.v1_volumeComparisonViewer3D at 0x7f5c7c26e5b0>

In [14]:
batched_dvf1Grid_to_interpol = convertGrid_functional2interpol(batched_dvf1.mapping(), src_vol.shape[0], src_vol.shape[1], src_vol.shape[2])
warped_vol_interpol = interpol.grid_pull(vol_tensor, batched_dvf1Grid_to_interpol,interpolation='cubic',bound='zero',prefilter=True).squeeze(0).squeeze(0).cpu().numpy()
print(f'warped_vol_interpol  object   dtype {warped_vol_interpol.dtype},  shape ZYX3 {warped_vol_interpol.shape}')

#Display
v1_volumeComparisonViewer3D(
    listVolumes=[src_vol, dst_vol, warped_vol_interpol],listLabels=['src', 'dst', 'warped_vol_interpol'],
    maxZ0=src_vol.shape[0], maxZ1=src_vol.shape[1], maxZ2=src_vol.shape[2],
    figsize=(12,8), cmap='gray',
    displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

 does not have profile information (Triggered internally at /opt/conda/conda-bld/pytorch_1682343997789/work/third_party/nvfuser/csrc/graph_fuser.cpp:104.)
  output = spline_coeff_nd(input, *opt)


warped_vol_interpol  object   dtype float32,  shape ZYX3 (590, 512, 512)


interactive(children=(Output(),), _dom_classes=('widget-interact',))

<exampleUtils.v1_volumeComparisonViewer3D at 0x7f5c7c28adf0>

In [15]:
warped_vol_resample_F_warp_interpol = batched_dvf1.sample(vol_tensor, mode="cubic", padding_mode="zeros", warpingModeString="pull",prefilter=True).squeeze(0).squeeze(0).numpy()
print(f'warped_vol_resample_F_warp_interpol  object   dtype {warped_vol_resample_F_warp_interpol.dtype},  shape ZYX3 {warped_vol_resample_F_warp_interpol.shape}')

#Display
v1_volumeComparisonViewer3D(
    listVolumes=[src_vol, dst_vol, warped_vol_F_noPreFilter],listLabels=['src', 'dst', 'warped_vol_resample_F_warp_interpol'],
    maxZ0=src_vol.shape[0], maxZ1=src_vol.shape[1], maxZ2=src_vol.shape[2],
    figsize=(12,8), cmap='gray',
    displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

# v1_volumeComparisonViewer3D(
#     listVolumes=[src_vol-dst_vol, warped_vol_resample_F_warp_interpol-dst_vol],listLabels=['F-M', 'M*_resamp_F_warp_interpol-M'],
#     maxZ0=src_vol.shape[0], maxZ1=src_vol.shape[1], maxZ2=src_vol.shape[2],
#     figsize=(12,8), cmap='coolwarm',
#     displayColorbar=False, useExternalWindowCenter=True, wMin=-100, wMax=100)

warped_vol_resample_F_warp_interpol  object   dtype float32,  shape ZYX3 (590, 512, 512)


interactive(children=(Output(),), _dom_classes=('widget-interact',))

<exampleUtils.v1_volumeComparisonViewer3D at 0x7f5c7c2ad790>