In [None]:
import numpy as np
import itk
import matplotlib.pyplot as plt
target_path = "RegLib_C01_1.nrrd"
source_path = "RegLib_C01_2.nrrd"

target_itk = itk.imread(target_path)
target_meta = dict(target_itk)
target = np.asarray(target_itk)
print(f"Target shape: {target.shape}")
print(f"Target spacing: {target_meta['spacing']}")
print(f"Target direction: {target_meta['direction']}")

source_itk = itk.imread(source_path)
source_meta = dict(source_itk)
source = np.asarray(source_itk)
print(f"Source shape: {source.shape}")
print(f"Source spacing: {source_meta['spacing']}")
print(f"Source direction: {source_meta['direction']}")

# Check whether the orientation of the images are the same.
assert np.array_equal(dict(target_itk)["direction"], dict(source_itk)["direction"]), "The orientation of source and target images need to be the same."

fig, axes = plt.subplots(1,2)
axes[0].imshow(source[100])
axes[0].set_title('Source')
axes[1].imshow(target[100])
axes[1].set_title('Target')
plt.show()


In [None]:
# Processing images
import torch
import torch.nn.functional as F
def preprocess(img, type="ct"):
  if type == "ct":
    clamp = [-1000, 1000]
    img = (torch.clamp(img, clamp[0], clamp[1]) - clamp[0])/(clamp[1]-clamp[0])
    return F.interpolate(img, [175, 175, 175], mode="trilinear", align_corners=False)
  elif type == "mri":
    im_min, im_max = torch.min(img), torch.quantile(img.view(-1), 0.99)
    img = torch.clip(img, im_min, im_max)
    img = (img-im_min) / (im_max-im_min)
    return F.interpolate(img, [175, 175, 175], mode="trilinear", align_corners=False)
  else:
    print(f"Error: Do not support the type {type}")
    return img

target = preprocess(torch.Tensor(np.array(target)).unsqueeze(0).unsqueeze(0), type="mri")
source = preprocess(torch.Tensor(np.array(source)).unsqueeze(0).unsqueeze(0), type="mri")

In [None]:
from unigradicon import get_unigradicon
net = get_unigradicon()
net.cuda()
net.eval()
print()

In [None]:
with torch.no_grad():
  net(source.cuda(), target.cuda())

In [None]:
def show_as_grid_contour(ax, phi, linewidth=1, stride=8, flip=False):
    data_size = phi.size()[1:]
    plot_phi = phi.cpu() - 0.5
    N = plot_phi.size()[-1]
    ax.contour(plot_phi[1], np.linspace(0, N, int(N/stride)), linewidths=linewidth, alpha=0.8)
    ax.contour(plot_phi[0], np.linspace(0, N, int(N/stride)), linewidths=linewidth, alpha=0.8)
    if flip:
        ax.set_ylim([0, data_size[0]])

def show_pair(source, target, warped, phi, axes, idx, flip=False):
    phi_scaled = phi * (torch.tensor((175, 175, 175), dtype=torch.float32).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)-1)
    if flip:
        origin = 'lower'
    else:
        origin = 'upper'
    axes[0].imshow(source.cpu()[0,0,idx], cmap="gray", origin=origin)
    axes[1].imshow(target.cpu()[0,0,idx], cmap="gray", origin=origin)
    axes[2].imshow(warped.cpu()[0,0,idx], cmap="gray", origin=origin)
    axes[3].imshow(target.cpu()[0,0,idx], cmap="gray", origin=origin)
    show_as_grid_contour(axes[3], phi_scaled[0, [1,2], idx], linewidth=0.6, stride=4, flip=flip)
    axes[4].imshow(target.cpu()[0,0,idx]-source.cpu()[0,0,idx], origin=origin)
    axes[5].imshow(target.cpu()[0,0,idx]-warped.cpu()[0,0,idx], origin=origin)

In [None]:
fig, axes = plt.subplots(1, 6, figsize=(12,2))
show_pair(source, target, net.warped_image_A, net.phi_AB_vectorfield.cpu(), axes, 100)

font_size = 10
axes[0].set_title('Source', fontsize=font_size)
axes[1].set_title('Target', fontsize=font_size)
axes[2].set_title('Warped', fontsize=font_size)
axes[3].set_title('Target+Grids', fontsize=font_size)
axes[4].set_title('Difference Before', fontsize=font_size)
axes[5].set_title('Difference After', fontsize=font_size)
for axe in axes:
  axe.set_xticks([])
  axe.set_yticks([])
plt.show()