## Load the data

In [1]:
import os
import glob
import SimpleITK as sitk
import numpy as np
import pydicom
import matplotlib.pyplot as plt
# %matplotlib widget
import nibabel as nib
import torch
from skimage import exposure
import seaborn as sns
from sewar.full_ref import mse
os.environ['VXM_BACKEND'] = 'pytorch'
import voxelmorph as vxm   # nopep8

from scripts.torch.utils import *
from scipy import ndimage
from skimage.morphology import area_closing, binary_dilation
from skimage.exposure import match_histograms


In [2]:
def minmax(img):
    img = (img - np.min(img)) / (np.max(img) - np.min(img))
    return img

def preprocess(array, new_shape, range=10):
    temp = resize_img(sitk.GetImageFromArray(array), new_shape, sitk.sitkLinear)
    min, max = float(sitk.GetArrayFromImage(temp).min()), float(sitk.GetArrayFromImage(temp).max())
    # print(f"Pixel range {min, max}")
    temp = sitk.Threshold(temp, lower=min+range, upper=max, outsideValue=0)
    return minmax(sitk.GetArrayFromImage(temp)) 
    # return sitk.GetArrayFromImage(temp)

def seg(img, new_shape):
    img = (img*255).astype(np.uint8)
    cnt = sorted(cv2.findContours(img, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)[-2], key=cv2.contourArea)[-1]
    mask = np.zeros(new_shape, np.uint8)
    masked = cv2.drawContours(mask, [cnt],-1, 255, -1)
    return masked

In [3]:
from sklearn.preprocessing import normalize

In [4]:
test_folder = '../data/PostconT1w_dataset/test'
test_files = glob.glob(os.path.join(test_folder, '*.npy'))
vols_all = []
for file in test_files:
    vols, fixed_affine = vxm.py.utils.load_volfile(file, add_batch_axis=True, add_feat_axis=True, ret_affine=True)
    vols_all.append(np.squeeze(vols).transpose(1, 2, 0))
vols_all = np.dstack(vols_all)
vols_all.shape

(192, 192, 48)

## Load the registration model

In [5]:
device = 'cpu'
model_path = '/Users/mona/Documents/data/DelftBlueDownloads/PostconT1w_ncc_0.1_l2_1000.pt'
# load and set up model
model = vxm.networks.VxmDense.load(model_path, device)
model.to(device)
model.eval()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


VxmDense(
  (unet_model): Unet(
    (encoder): ModuleList(
      (0): ModuleList(
        (0): ConvBlock(
          (main): Conv2d(2, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (activation): LeakyReLU(negative_slope=0.2)
        )
      )
      (1): ModuleList(
        (0): ConvBlock(
          (main): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (activation): LeakyReLU(negative_slope=0.2)
        )
      )
      (2): ModuleList(
        (0): ConvBlock(
          (main): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (activation): LeakyReLU(negative_slope=0.2)
        )
      )
      (3): ModuleList(
        (0): ConvBlock(
          (main): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (activation): LeakyReLU(negative_slope=0.2)
        )
      )
    )
    (decoder): ModuleList(
      (0): ModuleList(
        (0): ConvBlock(
          (main): Conv2d(32, 32, kernel_size=(3, 3

In [6]:
# fixed_array = match_histograms(fixed_array, reference, channel_axis=-1)
# fixed_array = preprocess(fixed_array, new_shape, 0)
fixed_array = vols_all[:, :, 0]
input_fixed = torch.from_numpy(fixed_array[None, None, ...].astype(np.float32))

# print(f"Mona: vols shape {vols.shape} and input fixed {input_fixed.shape}")

output_moved = [input_fixed.squeeze()]
output_warp = []
output_orig = [input_fixed.squeeze()]

input_fixed = input_fixed.to(device)
for slice in range(1, vols_all.shape[-1]):
    input_moving = torch.from_numpy(vols_all[None, None, :, :, slice].astype(np.float32))
    # print(f"The shape of moving is {input_moving.shape} and shape of fixed {input_fixed.shape}")
    moved, warp = model(input_moving, input_fixed, registration=True) # register all sequence to the first sequence

    output_moved.append(moved.detach().cpu().numpy().squeeze())
    output_warp.append(warp.detach().cpu().numpy().squeeze())
    output_orig.append(input_moving.detach().cpu().numpy().squeeze())


In [7]:
moved = np.stack(output_moved)
warp = np.stack(output_warp)
orig = np.stack(output_orig)

moved = exposure.equalize_hist(moved)
warp = exposure.equalize_hist(warp)
orig = exposure.equalize_hist(orig)

moved_path = 'results/Lisbon_ncc_0.01/'
warp_path = 'results/Lisbon_ncc_0.01/'
os.makedirs(moved_path, exist_ok=True)
os.makedirs(warp_path, exist_ok=True)
fixed_affine = None

vxm.py.utils.save_volfile(moved, os.path.join(moved_path, "test_registered.nii"), fixed_affine)

warp = warp.transpose(2, 3, 0, 1)
vxm.py.utils.save_volfile(warp, os.path.join(warp_path, "test_warp.nii"), fixed_affine)

orig = orig.transpose(1, 2, 0)
moved = moved.transpose(1, 2, 0)
warp = warp.transpose(3, 2, 1, 0)
warp = np.flip(warp, axis=2)
# print(f"Shape of orig {orig.shape} and moved {moved.shape} and warp {warp.shape}")
moved_gif_path = save_gif(moved, "test_intrasubject", moved_path, "registered")
orig_gif_path = save_gif(orig, "test_intrasubject", moved_path, "original")
# quiver_path = save_quiver(warp, 'test', warp_path)
morph_field_path = save_morphField(warp, 'test_intrasubject', moved_path)