In [1]:
import argparse
from argparse import Namespace

import os.path

import numpy as np
import argparse
import nibabel as nib

from scipy import signal
from torchmetrics import MeanSquaredError
from viu.io import volume
from viu.torch.deformation.fields import DVF
from viu.torch.math import torch_affine_to_vol_mat
from viu.util.body_mask import seg_body

from pamomo.registration.deformable import reg
import torch
import torch.nn.functional as tt

from exampleUtils import *

In [2]:
args = Namespace(\
    src=f'/home/wd974888/Downloads/workingFolder/DeformationExperiment/readFormat/Patient09PB_bin_01.nii.gz',\
    dst=f'/home/wd974888/Downloads/workingFolder/DeformationExperiment/readFormat/Patient09PB_bin_11.nii.gz',\
    body_seg=True,\
    air_threshold=-300,\
    alpha=50,\
    verboseMode=False,\
    maxNumberOfIterations=100,\
    numLevels=3,\
    saveMovedAsNIIFlag=True)
print(args)

Namespace(src='/home/wd974888/Downloads/workingFolder/DeformationExperiment/readFormat/Patient09PB_bin_01.nii.gz', dst='/home/wd974888/Downloads/workingFolder/DeformationExperiment/readFormat/Patient09PB_bin_11.nii.gz', body_seg=True, air_threshold=-300, alpha=50, verboseMode=False, maxNumberOfIterations=100, numLevels=3, saveMovedAsNIIFlag=True)


In [3]:
runFreshRegistration=False

In [5]:
#registrationCommand(args)
dst_path = os.path.dirname(args.dst)

src_vol: np.ndarray
if args.src.endswith('.nii.gz') and args.dst.endswith('.nii.gz'):
    src_vol, src_res, src_pos = ImportFromNII(args.src)
    dst_vol, dst_res, dst_pos = ImportFromNII(args.dst)
else:
    src_vol, src_res, src_pos = volume.read_volume(args.src)
    dst_vol, dst_res, dst_pos = volume.read_volume(args.dst)

print(f'src_vol dtype {src_vol.dtype},  shape ZYX {src_vol.shape}, res XYZ {src_res}, pos XYZ {src_pos}')
# src_vol dtype float64,  shape ZYX (590, 512, 512), res XYZ [0.9765625 0.9765625 1.       ], pos XYZ [-249.50999451 -469.51000977 -352.3999939 ]
print(f'dst_vol dtype {dst_vol.dtype},  shape ZYX {dst_vol.shape}, res XYZ {dst_res}, pos XYZ {dst_pos}')
# dst_vol dtype float64,  shape ZYX (590, 512, 512), res XYZ [0.9765625 0.9765625 1.       ], pos XYZ [-249.50999451 -469.51000977 -352.3999939 ]

cachedPath = os.path.join(dst_path, f'Patient09PB_cachedDVF.npz')
if True==runFreshRegistration:

    additional_args = {}
    if args.body_seg:
        print('Find connected components...')
        #Find connected components...
        msk = seg_body(dst_vol, air_threshold=args.air_threshold)
        additional_args.update({'dst_seg': {'SM_bodymask': msk}, 'similarityMaskMultilevelStrategy': 'STRICTINTERIOR'})

    print('Start registration...')
    vol_min = -1200
    dvf, dvf_res, dvf_pos = reg(src_vol.clip(min=vol_min), src_res,
                                dst_vol.clip(min=vol_min), dst_res,
                                alpha=args.alpha,#alpha=20,
                                numLevels=args.numLevels,
                                # finestLevelReference=0,
                                # finestLevelTemplate=0,
                                maxNumberOfIterations=args.maxNumberOfIterations,
                                **additional_args) #verboseMode=args.verboseMode, **additional_args #verboseMode='true', **additional_args
    np.savez_compressed(cachedPath, dvf, dvf_res, dvf_pos)
    # Start registration...

    # This is the Fraunhofer MEVIS cuda registration library. Version: VarianDeformableRegistrationDLLCUDA -- v1.6.2, built Dec 11 2023, 20:49:53
    # numSourceSegments <= 0. Running without mask alignment. Continuing.
    # Started registration...
    # Optimization on level 1 / 3 needs: 3.39014 s
    # Optimization on level 2 / 3 needs: 8.03293 s
    # Optimization on level 3 / 3 needs: 84.1789 s
    # Finished multilevel registration.
    # Total runtime: 103.207s.
    # Done.
else:
    cachedDVF = np.load(cachedPath)
    dvf, dvf_res, dvf_pos = cachedDVF['arr_0'], cachedDVF['arr_1'], cachedDVF['arr_2']

print(f'B4 DVFify dvf type {type(dvf)}  dtype {dvf.dtype},  shape ZYX {dvf.shape}')
# B4 DVFify dvf type <class 'numpy.ndarray'>  dtype float32,  shape ZYX (295, 256, 256, 3)
print(f'B4 tensorify dvf_res dtype {dvf_res.dtype},  shape {dvf_res.shape}, value XYZ  {dvf_res}')
# B4 tensorify dvf_res dtype float64,  shape (3,), value XYZ  [1.95695466 1.95695466 2.00340136]
print(f'B4 tensorify dvf_pos dtype {dvf_pos.dtype},  shape {dvf_pos.shape}, value XYZ  {dvf_pos}')
# B4 tensorify dvf_pos dtype float64,  shape (3,), value XYZ  [0. 0. 0.]


src_vol dtype float64,  shape ZYX (590, 512, 512), res XYZ [0.9765625 0.9765625 1.       ], pos XYZ [-249.50999451 -469.51000977 -352.3999939 ]
dst_vol dtype float64,  shape ZYX (590, 512, 512), res XYZ [0.9765625 0.9765625 1.       ], pos XYZ [-249.50999451 -469.51000977 -352.3999939 ]
B4 DVFify dvf type <class 'numpy.ndarray'>  dtype float32,  shape ZYX (295, 256, 256, 3)
B4 tensorify dvf_res dtype float64,  shape (3,), value XYZ  [1.95695466 1.95695466 2.00340136]
B4 tensorify dvf_pos dtype float64,  shape (3,), value XYZ  [0. 0. 0.]


In [6]:
dvf_res = torch.tensor(dvf_res, dtype=torch.float64)
print(f'After tensorify dvf_res dtype {dvf_res.dtype},  shape {dvf_res.shape}, value XYZ  {dvf_res}')
dvf_pos = torch.tensor(dvf_pos, dtype=torch.float64)
print(f'After tensorify dvf_pos dtype {dvf_pos.dtype},  shape {dvf_pos.shape}, value XYZ  {dvf_pos}')

dst_dim = torch.tensor(dst_vol.shape[::-1]) #From ZYX to XYZ
print(f'dst_dim in XYZ dtype {dst_dim.dtype},  shape {dst_dim.shape}, value XYZ  {dst_dim}')
dst_res = torch.tensor(dst_res, dtype=torch.float64)
print(f'dst_res in XYZ dtype {dst_res.dtype},  shape {dst_res.shape}, value XYZ  {dst_res}')

After tensorify dvf_res dtype torch.float64,  shape torch.Size([3]), value XYZ  tensor([1.9570, 1.9570, 2.0034], dtype=torch.float64)
After tensorify dvf_pos dtype torch.float64,  shape torch.Size([3]), value XYZ  tensor([0., 0., 0.], dtype=torch.float64)
dst_dim in XYZ dtype torch.int64,  shape torch.Size([3]), value XYZ  tensor([512, 512, 590])
dst_res in XYZ dtype torch.float64,  shape torch.Size([3]), value XYZ  tensor([0.9766, 0.9766, 1.0000], dtype=torch.float64)


In [7]:
#Create unbatched DVF object
# dvf = DVF(dvf).from_millimeter(dvf_res).to(torch.float32)
# Above line gives error: DVFs can be formed only of tensors of shape (B,H,W,2) or (B,D,H,W,3). Provided shape: torch.Size([295, 256, 256, 3]).
batched_dvf = DVF(dvf[None,...]).from_millimeter(dvf_res).to(torch.float32)
print(f'batched_dvf DVF object   dtype {batched_dvf.dtype},  shape ZYX3 {batched_dvf.shape}')

batched_dvf DVF object   dtype torch.float32,  shape ZYX3 torch.Size([1, 295, 256, 256, 3])


In [8]:
# #Old resample of unbatched DVF
# dvf = dvf.resample(
#     dst_dim,
#     dst_res,
#     #dst_pos: torch.Tensor = torch.zeros(3, dtype=torch.float64), <--- default
#     dvf_res=dvf_res,
#     dvf_pos=dvf_pos
#     # mode="bilinear",
#     # padding_mode="border"
#     )
# print(f'Resampled dvf type: {type(dvf)}')
# ######## Exception happening  otherwise ######
# if dvf.dtype !=torch.tensor(src_vol).dtype:
#     #Making both float32
#     # dvf = dvf.to(torch.tensor(src_vol).dtype)
#     dvf = dvf.to(torch.float32)
#     src_vol = src_vol.astype('float32')
# #########################################
# print(f'Resampled DVF object   dtype {dvf.dtype},  shape ZYX3 {dvf.shape}')
# #With change in sample() we need to have batch dimension added to DVF and batch and channel  dimension added to  input
# batched_dvf = dvf.unsqueeze(0)
# print(f'batched_dvf  object   dtype {batched_dvf.dtype},  shape ZYX3 {batched_dvf.shape}')

In [9]:
# Resample of batched DVF without preFilter
batched_dvf1 = batched_dvf.resample(
    dst_dim,
    dst_res,
    #dst_pos: torch.Tensor = torch.zeros(3, dtype=torch.float64), <--- default
    dvf_res=dvf_res,
    dvf_pos=dvf_pos,
    mode="bilinear",
    padding_mode="border",
    prefilter=False
    )
if batched_dvf1.dtype !=torch.tensor(src_vol).dtype:
    #Making both float32
    batched_dvf1 = batched_dvf1.to(torch.float32)
    src_vol = src_vol.astype('float32')
print(f'Resampled DVF object  type {type(batched_dvf1)}   dtype {batched_dvf1.dtype},  shape ZYX3 {batched_dvf1.shape}')

Resampled DVF object  type <class 'viu.torch.deformation.fields.DVF'>   dtype torch.float32,  shape ZYX3 torch.Size([1, 590, 512, 512, 3])


In [10]:
from __future__ import annotations
from typing import Any
import torch
import torch.nn.functional as F
from torch import nn, optim
from viu.torch.math import torch_affine_to_vol_mat, st_mat_ex
# from .utils import permute_input, permute_output, ensure_dimensions
import os, sys
import interpol

def resample_try(self_dvf,
    dst_dim: torch.Tensor,
    dst_res: torch.Tensor,
    dst_pos: torch.Tensor = torch.zeros(3, dtype=torch.float64),
    dvf_res: torch.Tensor = None,
    dvf_pos: torch.Tensor = torch.zeros(3, dtype=torch.float64),
    mode="bilinear",
    padding_mode="border",
    prefilter=False,
    attempt = 'attempt3'): # # 'attempt1', 'attempt2', 'attempt3' #NOTE , align_corners=True
    """
    NOTE to self: Parameters dst_dim, dst_pos, dvf_res, dvf_pos are all specified  in XYZ order
    while the actual DVF is in ZYX order
    """
    #NOTE hardcoded behaviour
    align_corners=True
    field_shape = self_dvf.shape
    assert  ((4==len(field_shape) and 2==field_shape[-1] ) or (5==len(field_shape) and 3==field_shape[-1] ) ),\
            f'Expected field.shape == (B,H,W,2) or (B,D,H,W,3). Provied field.shape {field_shape}.'
    if dvf_res is None:
        print(f'Warning DVF resolution not given;  assuming 1.')
        dvf_res = torch.ones(len(dst_dim), dtype=torch.float64)
    if 4==len(field_shape):
        assert (2==len(dst_dim) and 2==len(dst_res) and 2==len(dst_pos) and 2==len(dvf_res) and 2==len(dvf_pos)),\
            f'If 4==len(self.shape) all of dst_dim, dst_res, dst_pos, and original dvf_res, dvf_pos should be 2-element tensor.\
                Provided values: dst_dim {dst_dim} dst_res {dst_res} dst_pos {dst_pos} dvf_res {dvf_res} dvf_pos {dvf_pos}'
        nb_dim=2
    if 5==len(field_shape):
        assert (3==len(dst_dim) and 3==len(dst_res) and 3==len(dst_pos) and 3==len(dvf_res) and 3==len(dvf_pos)),\
            f'If 5==len(self.shape) all of dst_dim, dst_res, dst_pos, and original dvf_res, dvf_pos should be 3-element tensor.\
                Provided values: dst_dim {dst_dim} dst_res {dst_res} dst_pos {dst_pos} dvf_res {dvf_res} dvf_pos {dvf_pos}'
        nb_dim=3
    # print(f'nb_dim {nb_dim}')

    assert prefilter in [True, False], f'prefilter should be a boolean True or False. Passed value : {prefilter}.'
    if (False==prefilter):
        behaviour="torch_Functional"
        assert mode in ['nearest', 'bilinear', 'bicubic'],f'In torch_Functional library, mode should be one of nearest, bilinear, bicubic. Passed value: {mode}'
    else:
            behaviour="torch_interpol"
            assert mode in ['nearest', 'linear', 'quadratic', 'cubic'],f'With prefilter=True, using torch_interpol library and  mode should be one of nearest, linear, quadratic, cubic. Passed value: {mode}'

    dvf = self_dvf
    batch_Size=self_dvf.shape[0] #We expect batch dimension to be present.
    dvf_dim = torch.tensor(self_dvf.shape[1:-1][::-1]) #Earlier self.shape[:-1][::-1] was used. But now we expect batch dimension to be present.`
    assert nb_dim==len(dst_dim) and nb_dim==len(dvf_dim),\
        f'Expected nb_dim==len(dst_dim)==len(dvf_dim). Provided: nb_dim {nb_dim} len(dst_dim) {len(dst_dim)} len(dvf_dim) {len(dvf_dim)}'

    # Current st_mat method in math.py only supports 4x4 homogeneous matrix for 3D dvf. Therefore instead of using torch_affine_to_vol_mat from 
    # math.py we explicitly support both 3x3 homogeneous matrix for 2D dvf and 4x4  homogeneous matrix for 3D DVF.
    # mat_dvf = torch_affine_to_vol_mat(dvf_dim, dvf_res, dvf_pos).to(dvf.device)
    # mat_vol = torch_affine_to_vol_mat(dst_dim, dst_res, dst_pos).to(dvf.device)
    mat_dvf=st_mat_ex(s=0.5 * dvf_res * (dvf_dim - 1), t=dvf_pos, nb_dim=nb_dim).to(dvf.device)
    mat_vol=st_mat_ex(s=0.5 * dst_res * (dst_dim - 1), t=dst_pos, nb_dim=nb_dim).to(dvf.device)
    mat = mat_dvf.inverse().matmul(mat_vol) #Create affine matrix using (vol_to_affine(destination)) * (afffine_to_vol(dvf))
    # print(f'mat shape {mat.shape}')
    assert (2==nb_dim and torch.Size([3, 3])==mat.shape) or (3==nb_dim and torch.Size([4, 4])==mat.shape),\
        f'With nb_dim {nb_dim} expected mat.shape: [{nb_dim+1}, {nb_dim+1}] but found {mat.shape}'
    assert (3==nb_dim and torch.Size([4, 4])==mat.shape),\
        f'With nb_dim {nb_dim} expected mat.shape: [{nb_dim+1}, {nb_dim+1}] but found {mat.shape}'

    #Currently torch_interpol behaviour can only be supported if re-sample behaviour is a simple resize behaviour.
    #This is because #Attempt-1 and Attempt-2 of converting  the torch-functional grid  into torch-interpol grid
    # is not giving the intended  result after warping. Therefore if the affine matrix mat is close to identity matrix, 
    # we will allow torch-interpol resample to continue; otherwise, we will override user's choice of using torch-interpol 
    # and use the torch functional behaviour.
    identityMatTensor=torch.eye(nb_dim+1).to(torch.float64).to(dvf.device)
    if "torch_interpol"==behaviour and not torch.isclose(mat, identityMatTensor).all():
        behaviour= "torch_Functional"
        if 'nearest' != mode:
            mode='bilinear'
        prefilter=False


    if not torch.isclose(dvf_dim, dst_dim).all() or \
            not torch.isclose(dvf_res, dst_res).all() or \
            not torch.isclose(dvf_pos, dst_pos).all():

        vol_grid_shape = [batch_Size, 1] + dst_dim.tolist()[::-1] #Intended grid shape NOTE earlier it was [1, 1] instead of [batch_Size, 1]
        if 2==nb_dim:
            dvf = dvf.permute(0, 3, 1, 2) #Treat the source dvf as image by permuting dvf values as image channels
        if 3==nb_dim:
            dvf = dvf.permute(0, 4, 1, 2, 3) #Treat the source dvf as image by permuting dvf values as image channels
        # print(f'dvf shape {dvf.shape}')

        pyTorchTheta=mat[:nb_dim, :].to(torch.float32)
        pyTorchTheta=pyTorchTheta.expand(batch_Size, *pyTorchTheta.shape)
        # print(f'pyTorchTheta shape {pyTorchTheta.shape}')
        grid = F.affine_grid(pyTorchTheta,size=vol_grid_shape, align_corners=align_corners) #Generate affine grid
        # print(f'grid shape {grid.shape}')

        if "torch_Functional"==behaviour:
            out = F.grid_sample(dvf, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners)

        else: #behaviour=="torch_interpol"
            [des_depth, des_height, des_width] = dst_dim.tolist()[::-1]
            [src_depth, src_height, src_width] = dvf_dim.tolist()[::-1]

            assert attempt in ['attempt1', 'attempt2', 'attempt3'], 'Undefined attempt.'

            if 'attempt1'==attempt:
                #####################################
                # #Attempt 1 : # Not working when the resampled grid was used for warping - Why??
                # mat_interpol = DVF.getUnNormalizedAffineMatTensorInImageCoord(mat,  src_depth, src_height, src_width,  des_depth, des_height, des_width)
                pyTorchAffineMatTensorInNormalizedCoord_a2b=mat
                depth_a, height_a, width_a = src_depth, src_height, src_width
                depth_b, height_b, width_b = des_depth, des_height, des_width

                device = pyTorchAffineMatTensorInNormalizedCoord_a2b.device
                print(f'pyTorchAffineMatTensorInNormalizedCoord_a2b {pyTorchAffineMatTensorInNormalizedCoord_a2b}')
                block3x3Flipped = pyTorchAffineMatTensorInNormalizedCoord_a2b[0:3, 0:3].flip([0,1])
                tmpAffineMat=pyTorchAffineMatTensorInNormalizedCoord_a2b.clone()
                tmpAffineMat[0:3, 0:3]=block3x3Flipped
                print(f'tmpAffineMat {tmpAffineMat}')

                T_regular2normalized_a_yx = torch.tensor([[2./(depth_a), 0, 0, -1.],[ 0, 2./(height_a), 0, -1.],[0, 0, 2./(width_a), -1.], [0, 0, 0, 1.]], dtype=torch.float64).to(device)
                T_normalized2regular_b_yx = torch.tensor([[2./(depth_b), 0, 0, -1.],[ 0, 2./(height_b), 0, -1.],[0, 0, 2./(width_b), -1.], [0, 0, 0, 1.]], dtype=torch.float64).to(device).inverse()
                #Normalize, convert into tensor, use 1st 3 rows.
                unNormalizedAffineMatTensorInImageCoord = T_normalized2regular_b_yx.matmul(tmpAffineMat.matmul(T_regular2normalized_a_yx))
                mat_interpol=unNormalizedAffineMatTensorInImageCoord

                grid_interpol = interpol.api.affine_grid(mat_interpol.to(torch.float32), [des_depth, des_height, des_width]).to(dvf.device)
                grid_interpol_batched = grid_interpol.expand(batch_Size, *grid_interpol.shape)
                out=interpol.grid_pull(dvf, grid_interpol_batched, interpolation=mode, bound=padding_mode,prefilter=prefilter)
                #####################################

            if 'attempt2'==attempt:
                # #####################################
                # #Attempt 2 # Not working when the resampled grid was used for warping - Why??
                # grid_ij_batched_deormalized = DVF._convertGrid_functional2interpol(grid, des_depth, des_height, des_width)
                functional_grid_batched, depth, height, width = grid, des_depth, des_height, des_width
                # printTensor("functional_grid_batched", functional_grid_batched)
                batchSize=functional_grid_batched.shape[0]
                #xy to ij
                field_ij_batched_normalized = torch.flip(functional_grid_batched, [-1])
                #deNormalization matrix : With align_corner=True, we will use depth-1, height-1, width-1
                deNormalizationMat = torch.linalg.inv(torch.tensor([[2./(depth), 0, 0, -1.],[ 0, 2./(height), 0, -1.],[0, 0, 2./(width), -1.], [0, 0, 0, 1.]],
                    dtype=torch.float32, device=functional_grid_batched.device))
                nb_dim = deNormalizationMat.shape[-1] - 1
                deNormalizationMat_rot = deNormalizationMat[:nb_dim, :nb_dim]
                deNormalizationMat_tr = deNormalizationMat[:nb_dim, -1]
                #Expand deNormalization matrix by batchSize
                deNormalizationMat_rot = deNormalizationMat_rot.expand(batchSize, 1, 1, 1, *deNormalizationMat_rot.shape)
                deNormalizationMat_tr =   deNormalizationMat_tr.expand(batchSize, 1, 1, 1, *deNormalizationMat_tr.shape)
                # Add dimension (in-place) in the end to support  matmul with normalizationMat_rot.
                # Then remove that dimension before adding  with normalizationMat_tr
                field_ij_batched_denormalized = torch.matmul(deNormalizationMat_rot, field_ij_batched_normalized.unsqueeze(-1)).squeeze(-1) + deNormalizationMat_tr
                grid_ij_batched_deormalized = field_ij_batched_denormalized# return field_ij_batched_denormalized

                out=interpol.grid_pull(dvf, grid_ij_batched_deormalized, interpolation=mode, bound=padding_mode,prefilter=prefilter)
                ########################################

            if 'attempt3'==attempt:
                # Attempt 3 : directly using interpol.resize() which does not make use of  dst_res, dst_pos, dvf_res, dvf_pos
                out=interpol.resize(
                    image=dvf,
                    factor=None,
                    shape=dst_dim.tolist()[::-1],
                    anchor='c' if True==align_corners else 'e',
                    interpolation=mode,
                    prefilter=True)

        #Put DVF values in last channel
        if 2==nb_dim:
            out = out.permute(0, 2, 3, 1)
        if 3==nb_dim:
            out = out.permute(0,2,3,4,1)
        #NOTE Are we ensuring that the original DVF object's data got resampled with all gradients and other properties maintained?
        if not isinstance(out, DVF):
            out=DVF(out)
    else:
        out=self_dvf
    return out

In [11]:
#Resample of batched DVF with preFilter
localMethod=False
if localMethod:
    batched_dvf2 = resample_try(batched_dvf, # batched_dvf.resample(
        dst_dim,
        dst_res,
        #dst_pos: torch.Tensor = torch.zeros(3, dtype=torch.float64), <--- default
        dvf_res=dvf_res,
        dvf_pos=dvf_pos,
        mode="cubic",
        padding_mode="border",
        prefilter=True,
        attempt='attempt3' #
        )
else:
    batched_dvf2 = batched_dvf.resample( # batched_dvf.resample(
        dst_dim,
        dst_res,
        #dst_pos: torch.Tensor = torch.zeros(3, dtype=torch.float64), <--- default
        dvf_res=dvf_res,
        dvf_pos=dvf_pos,
        mode="cubic",
        padding_mode="border",
        prefilter=True
        )
if batched_dvf2.dtype !=torch.tensor(src_vol).dtype:
    #Making both float32
    batched_dvf2 = batched_dvf2.to(torch.float32)
    src_vol = src_vol.astype('float32')
print(f'Resampled DVF object  type: {type(batched_dvf2)} dtype{batched_dvf2.dtype} shape BZYX3 {batched_dvf2.shape}')

#Compare
print(f'batched_dvf1 min {torch.min(batched_dvf1)} max {torch.max(batched_dvf1)}')
print(f'batched_dvf2 min {torch.min(batched_dvf2)} max {torch.max(batched_dvf2)}')
print(f'batched_dvf1 - batched_dvf2 min {torch.min(batched_dvf1 - batched_dvf2)} max {torch.max(batched_dvf1 - batched_dvf2)} mean {torch.mean(batched_dvf1 - batched_dvf2)}')
print(f'|batched_dvf1 - batched_dvf2| min {torch.min(torch.abs(batched_dvf1 - batched_dvf2))} max {torch.max(torch.abs(batched_dvf1 - batched_dvf2))} mean {torch.mean(torch.abs(batched_dvf1 - batched_dvf2))}')

# Ideal output
# batched_dvf1 min DVF(-0.1099) max DVF(0.0270)
# batched_dvf2 min DVF(-0.1099) max DVF(0.0271)
# batched_dvf1 - batched_dvf2 min DVF(-0.0001) max DVF(0.0002) mean DVF(2.7087e-08)
# |batched_dvf1 - batched_dvf2| min DVF(0.) max DVF(0.0002) mean DVF(3.5387e-06)

  output = spline_coeff_nd(input, *opt)


Resampled DVF object  type: <class 'viu.torch.deformation.fields.DVF'> dtypetorch.float32 shape BZYX3 torch.Size([1, 590, 512, 512, 3])
batched_dvf1 min DVF(-0.1099) max DVF(0.0270)
batched_dvf2 min DVF(-0.1099) max DVF(0.0271)
batched_dvf1 - batched_dvf2 min DVF(-0.0001) max DVF(0.0002) mean DVF(2.7087e-08)
|batched_dvf1 - batched_dvf2| min DVF(0.) max DVF(0.0002) mean DVF(3.5387e-06)


In [12]:
diff_dvf = (batched_dvf1 - batched_dvf2).cpu().numpy()
v1_volumeComparisonViewer3D(
    listVolumes=[diff_dvf[0,:,:,:,0], diff_dvf[0,:,:,:,1], diff_dvf[0,:,:,:,2]],listLabels=['diff_ch0', 'diff_ch1', 'diff_ch2'],
    maxZ0=diff_dvf[0,:,:,:,0].shape[0], maxZ1=diff_dvf[0,:,:,:,0].shape[1], maxZ2=diff_dvf[0,:,:,:,0].shape[2],
    figsize=(12,8), cmap='coolwarm',
    displayColorbar=True, useExternalWindowCenter=True, wMin=-0.05, wMax=0.05)

interactive(children=(Output(),), _dom_classes=('widget-interact',))

<exampleUtils.v1_volumeComparisonViewer3D at 0x7f28bc3e5bb0>

In [13]:
vol_tensor = torch.tensor(src_vol).unsqueeze(0).unsqueeze(0)
print(f'vol_tensor  object   dtype {vol_tensor.dtype},  shape ZYX3 {vol_tensor.shape}')

vol_tensor  object   dtype torch.float32,  shape ZYX3 torch.Size([1, 1, 590, 512, 512])


In [14]:
#_call_ method => pull warping, no prefilter
# warped_vol_F_noPreFilter = batched_dvf1(vol_tensor).squeeze(0).squeeze(0).numpy()
#Instead of call method use explicit sample method
warped_vol_F_noPreFilter = batched_dvf1.sample(vol_tensor, mode="bilinear", padding_mode="zeros", warpingModeString="pull",prefilter=False).squeeze(0).squeeze(0).numpy()
print(f'warped_vol_F_noPreFilter  object   dtype {warped_vol_F_noPreFilter.dtype},  shape ZYX3 {warped_vol_F_noPreFilter.shape}')

# #Display
# v1_volumeComparisonViewer3D(
#     listVolumes=[src_vol, dst_vol, warped_vol_F_noPreFilter],listLabels=['src', 'dst', 'warped_vol_F_noPreFilter'],
#     maxZ0=src_vol.shape[0], maxZ1=src_vol.shape[1], maxZ2=src_vol.shape[2],
#     figsize=(12,8), cmap='gray',
#     displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

v1_volumeComparisonViewer3D(
    listVolumes=[src_vol-dst_vol, warped_vol_F_noPreFilter-dst_vol],listLabels=['F-M', 'M*_NoPreFilter-M'],
    maxZ0=src_vol.shape[0], maxZ1=src_vol.shape[1], maxZ2=src_vol.shape[2],
    figsize=(12,8), cmap='coolwarm',
    displayColorbar=False, useExternalWindowCenter=True, wMin=-100, wMax=100)

warped_vol_F_noPreFilter  object   dtype float32,  shape ZYX3 (590, 512, 512)


interactive(children=(Output(),), _dom_classes=('widget-interact',))

<exampleUtils.v1_volumeComparisonViewer3D at 0x7f28bc4742b0>

In [15]:
#_call_ method => pull warping, no prefilter
# warped_vol_interpol_preFilter = batched_dvf2(vol_tensor).squeeze(0).squeeze(0).numpy()
# Instead of call method use explicit sample method
# warped_vol_interpol_preFilter = batched_dvf2.sample(vol_tensor, mode="cubic", padding_mode="zeros", warpingModeString="pull",prefilter=True).squeeze(0).squeeze(0).numpy()
warped_vol_interpol_preFilter = batched_dvf2.sample(vol_tensor, mode="bilinear", padding_mode="zeros", warpingModeString="pull",prefilter=False).squeeze(0).squeeze(0).cpu().numpy()
print(f'warped_vol_interpol_preFilter  object   dtype {warped_vol_interpol_preFilter.dtype},  shape ZYX3 {warped_vol_interpol_preFilter.shape}')

# #Display
# v1_volumeComparisonViewer3D(
#     listVolumes=[src_vol, dst_vol, warped_vol_interpol_preFilter],listLabels=['src', 'dst', 'warped_vol_interpol_preFilter'],
#     maxZ0=src_vol.shape[0], maxZ1=src_vol.shape[1], maxZ2=src_vol.shape[2],
#     figsize=(12,8), cmap='gray',
#     displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

v1_volumeComparisonViewer3D(
    listVolumes=[src_vol-dst_vol, warped_vol_interpol_preFilter-dst_vol],listLabels=['F-M', 'M*_PreFilter-M'],
    maxZ0=src_vol.shape[0], maxZ1=src_vol.shape[1], maxZ2=src_vol.shape[2],
    figsize=(12,8), cmap='coolwarm',
    displayColorbar=False, useExternalWindowCenter=True, wMin=-100, wMax=100)

warped_vol_interpol_preFilter  object   dtype float32,  shape ZYX3 (590, 512, 512)


interactive(children=(Output(),), _dom_classes=('widget-interact',))

<exampleUtils.v1_volumeComparisonViewer3D at 0x7f27a4fbd8e0>

In [16]:
batched_dvf1Grid_to_interpol = convertGrid_functional2interpol(batched_dvf1.mapping(), src_vol.shape[0], src_vol.shape[1], src_vol.shape[2])
warped_vol_interpol = interpol.grid_pull(vol_tensor, batched_dvf1Grid_to_interpol,interpolation='cubic',bound='zero',prefilter=True).squeeze(0).squeeze(0).cpu().numpy()
print(f'warped_vol_interpol  object   dtype {warped_vol_interpol.dtype},  shape ZYX3 {warped_vol_interpol.shape}')

#Display
v1_volumeComparisonViewer3D(
    listVolumes=[src_vol, dst_vol, warped_vol_interpol],listLabels=['src', 'dst', 'warped_vol_interpol'],
    maxZ0=src_vol.shape[0], maxZ1=src_vol.shape[1], maxZ2=src_vol.shape[2],
    figsize=(12,8), cmap='gray',
    displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

 does not have profile information (Triggered internally at /opt/conda/conda-bld/pytorch_1682343997789/work/third_party/nvfuser/csrc/graph_fuser.cpp:104.)
  output = spline_coeff_nd(input, *opt)


warped_vol_interpol  object   dtype float32,  shape ZYX3 (590, 512, 512)


interactive(children=(Output(),), _dom_classes=('widget-interact',))

<exampleUtils.v1_volumeComparisonViewer3D at 0x7f27a32a9760>

In [17]:
warped_vol_resample_F_warp_interpol = batched_dvf1.sample(vol_tensor, mode="cubic", padding_mode="zeros", warpingModeString="pull",prefilter=True).squeeze(0).squeeze(0).numpy()
print(f'warped_vol_resample_F_warp_interpol  object   dtype {warped_vol_resample_F_warp_interpol.dtype},  shape ZYX3 {warped_vol_resample_F_warp_interpol.shape}')

#Display
v1_volumeComparisonViewer3D(
    listVolumes=[src_vol, dst_vol, warped_vol_F_noPreFilter],listLabels=['src', 'dst', 'warped_vol_resample_F_warp_interpol'],
    maxZ0=src_vol.shape[0], maxZ1=src_vol.shape[1], maxZ2=src_vol.shape[2],
    figsize=(12,8), cmap='gray',
    displayColorbar=False, useExternalWindowCenter=True, wMin=-500, wMax=500)

# v1_volumeComparisonViewer3D(
#     listVolumes=[src_vol-dst_vol, warped_vol_resample_F_warp_interpol-dst_vol],listLabels=['F-M', 'M*_resamp_F_warp_interpol-M'],
#     maxZ0=src_vol.shape[0], maxZ1=src_vol.shape[1], maxZ2=src_vol.shape[2],
#     figsize=(12,8), cmap='coolwarm',
#     displayColorbar=False, useExternalWindowCenter=True, wMin=-100, wMax=100)

warped_vol_resample_F_warp_interpol  object   dtype float32,  shape ZYX3 (590, 512, 512)


interactive(children=(Output(),), _dom_classes=('widget-interact',))

<exampleUtils.v1_volumeComparisonViewer3D at 0x7f27a33a8e20>