In [10]:
import SimpleITK as sitk
from DataHandler import DataHandler
import pickle
import numpy as np
import os
import pathlib

In [11]:
def get_landmarks(fixed_image_path: str, indexing: str = 'zyx', preResample = False):
    model_name = os.path.basename(fixed_image_path).replace('_atn_3.nrrd', '')
    if preResample: 
        loaded_points = pickle.load(open(f'/home/cschellenberger/Documents/vectorPickles/{model_name}_vec_frame1_to_frame2.p', "rb"))
        moving_landmarks = [(float(loaded_points[idx]['1X']), float(loaded_points[idx]['1Y']), float(loaded_points[idx]['1Z'])) for (idx, _) in enumerate(loaded_points)]
        fixed_landmarks = [(float(loaded_points[idx]['2X']), float(loaded_points[idx]['2Y']), float(loaded_points[idx]['2Z'])) for (idx, _) in enumerate(loaded_points)]
    else: 
        loaded_points = pickle.load(open(f'/home/cschellenberger/Documents/vectorPickles/CT_points_t1_t3_withRegion_Continuous/{model_name}_idx.p', "rb"))
        moving_landmarks = loaded_points['t1']
        fixed_landmarks = loaded_points['t3']
    regions = np.array(loaded_points['Region'])
    if indexing == 'xzy':
        # swap columns because numpy and vxm use zyx indexing and the data uses xyz indexing
        moving_landmarks[:, [1, 2]] = moving_landmarks[:, [2, 1]]
        fixed_landmarks[:, [1, 2]] = fixed_landmarks[:, [2, 1]]
    elif indexing == 'zyx':
        # swap columns because numpy and vxm use zyx indexing and the data uses xyz indexing
        moving_landmarks[:, [0, 2]] = moving_landmarks[:, [2, 0]]
        fixed_landmarks[:, [0, 2]] = fixed_landmarks[:, [2, 0]]
    elif indexing == 'yxz':
        # swap columns because numpy and vxm use zyx indexing and the data uses xyz indexing
        moving_landmarks[:, [0, 1]] = moving_landmarks[:, [1, 0]]
        fixed_landmarks[:, [0, 1]] = fixed_landmarks[:, [1, 0]]
    else: assert indexing == 'xyz', f'indexing can only be xyz or zyx. Got: {indexing}'
    return moving_landmarks, fixed_landmarks, regions

In [None]:
from TrainVoxelmorph import VoxelmorphTF
def get_moved_points(points: np.array, displacement: sitk.Image) -> np.array:
    displacement_copy = displacement.__copy__()
    displacement_transform = sitk.DisplacementFieldTransform(displacement_copy)
    moved_points = [displacement_transform.TransformPoint(point) for point in points]
    return moved_points
    
weights_path = '/home/cschellenberger/Documents/scripts/models/synthetic/localmi_thirdTrain_intsteps3_reg002_1000_st14_lr1e-05_bat3/bestWeights.h5'
downsize = 1
dh = DataHandler(val_images=12)
dh.get_synthetic_data(
    fixed_path='/home/cschellenberger/datam2olie/synthetic/orig/t3/Synthetic_CT/',
    moving_path='/home/cschellenberger/datam2olie/synthetic/orig/t1/Synthetic_MR/')
i = 0
resErr = 0
dists = {}
moving_image_paths = dh.x_val
fixed_image_paths = dh.y_val
nb_features = [[16, 16, 32, 32], [32, 32, 32, 32, 32, 16, 16]]
device = '/cpu:0'
imgReg = VoxelmorphTF(weights_path, sitk.ReadImage(fixed_image_paths[0]), nb_features, downsize)
for i in range(1):
    fixed_image = sitk.ReadImage(fixed_image_paths[i])
    moving_image = sitk.ReadImage(moving_image_paths[i])
    moving_landmarks, fixed_landmarks, regions = get_landmarks(fixed_image_paths[i], indexing='xyz')
    moved_img, displacement_np, time = imgReg.register_images(moving_image, fixed_image, device)
    displacement = displacement_np.squeeze()
    displacement = sitk.GetImageFromArray(displacement.astype(np.float64), isVector=True)
    displacement.SetSpacing(fixed_image.GetSpacing())
    displacement.SetOrigin(fixed_image.GetOrigin())
    moved_landmarks = get_moved_points(fixed_landmarks, displacement)

In [19]:
for i, point in enumerate(moving_landmarks):
    point = moving_image.TransformPhysicalPointToIndex(moving_image.TransformContinuousIndexToPhysicalPoint(point))
    displacement.SetPixel(point[0], point[1], point[2], disp_np[i] * 1.8)

In [10]:
dh = DataHandler(val_images=0)
dh.get_synthetic_data(
    fixed_path='/home/cschellenberger/datam2olie/synthetic/orig/t3/Synthetic_CT/',
    moving_path='/home/cschellenberger/datam2olie/synthetic/orig/t1/Synthetic_CT/')
moving_image_paths = dh.x_train
fixed_image_paths = dh.y_train
for idx in range(1):
    print(idx)
    moving_image = sitk.ReadImage(moving_image_paths[idx])
    moving_landmarks, fixed_landmarks, regions = get_landmarks(fixed_image_paths[idx], indexing='xyz')
    newT3 = sitk.ReadImage(fixed_image_paths[idx])
    for i in range(len(fixed_landmarks)):
        pointFixed = moving_image.TransformPhysicalPointToIndex(moving_image.TransformContinuousIndexToPhysicalPoint(fixed_landmarks[i]))
        pointMoving = moving_image.TransformPhysicalPointToIndex(moving_image.TransformContinuousIndexToPhysicalPoint(moving_landmarks[i]))
        try: newT3.SetPixel(pointFixed[0], pointFixed[1], pointFixed[2], moving_image.GetPixel(pointMoving[0], pointMoving[1], pointMoving[2]))
        except: continue
    path = pathlib.Path(fixed_image_paths[idx])
    parts = list(path.parts)
    #sitk.WriteImage(newT3, f'/home/cschellenberger/Documents/newT3/{parts[-1]}')

0


In [5]:
dh = DataHandler(val_images=0)
dh.get_synthetic_data(
    fixed_path='/home/cschellenberger/datam2olie/synthetic/orig/t3/Synthetic_CT/',
    moving_path='/home/cschellenberger/datam2olie/synthetic/orig/t1/Synthetic_CT/')
moving_image_paths = dh.x_train
fixed_image_paths = dh.y_train
for idx in range(1):
    print(idx)
    moving_image = sitk.ReadImage(moving_image_paths[idx])
    moving_landmarks, fixed_landmarks, regions = get_landmarks(fixed_image_paths[idx], indexing='xyz')
    newT3 = sitk.ReadImage(moving_image_paths[idx])
    for i in range(len(fixed_landmarks)):
        pointFixed = moving_image.TransformPhysicalPointToIndex(moving_image.TransformContinuousIndexToPhysicalPoint(fixed_landmarks[i]))
        pointMoving = moving_image.TransformPhysicalPointToIndex(moving_image.TransformContinuousIndexToPhysicalPoint(moving_landmarks[i]))
        try: newT3.SetPixel(pointFixed[0], pointFixed[1], pointFixed[2], moving_image.GetPixel(pointMoving[0], pointMoving[1], pointMoving[2]))
        except: continue
    path = pathlib.Path(fixed_image_paths[idx])
    parts = list(path.parts)
    #sitk.WriteImage(newT3, f'/home/cschellenberger/Documents/newT3/{parts[-1]}')

0


In [19]:
def findPixel(i, j, k, image):
    if k < image.GetSize()[2] - 1 and displacement.GetPixel(i, j, k + 1) != (0, 0, 0): return displacement.GetPixel(i, j, k + 1)
    if k > 0 and displacement.GetPixel(i, j, k - 1) != (0, 0, 0): return displacement.GetPixel(i, j, k - 1)
    return (0, 0, 0)

In [None]:

from matplotlib.pyplot import imsave
import matplotlib.cm as cm

In [20]:
dh = DataHandler(val_images=0)
dh.get_synthetic_data(
    fixed_path='/home/cschellenberger/datam2olie/synthetic/native/t3/Synthetic_CT/',
    moving_path='/home/cschellenberger/datam2olie/synthetic/native/t1/Synthetic_CT/')
moving_image_paths = dh.x_train
fixed_image_paths = dh.y_train
for idx in range(1):
    print(idx)
    moving_image = sitk.ReadImage(moving_image_paths[idx])
    fixed_image = sitk.ReadImage(fixed_image_paths[idx])
    moving_landmarks, fixed_landmarks, regions = get_landmarks(fixed_image_paths[idx], indexing='xyz', preResample=True)
    displacement_np = np.array(moving_landmarks) - np.array(fixed_landmarks)
    displacement = sitk.Image(moving_image.GetSize(), sitk.sitkVectorFloat64)
    displacement.SetSpacing(moving_image.GetSpacing())
    displacement.SetOrigin(moving_image.GetOrigin())
    for i, point in enumerate(moving_landmarks):
        pointFixed = moving_image.TransformPhysicalPointToIndex(moving_image.TransformContinuousIndexToPhysicalPoint(fixed_landmarks[i]))
        try: displacement.SetPixel(pointFixed[0], pointFixed[1], pointFixed[2], displacement_np[i] * 1.8)
        except: continue
    for i in range(moving_image.GetSize()[0]):
        for j in range(moving_image.GetSize()[1]):
            for k in range(moving_image.GetSize()[2]):
                if displacement.GetPixel(i, j, k) == (0, 0, 0):
                    displacement.SetPixel(i, j, k, findPixel(i, j, k, moving_image))
    displacement_transform = displacement.__copy__()
    displacement_transform = sitk.DisplacementFieldTransform(displacement_transform)
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(moving_image)
    resampler.SetTransform(displacement_transform)
    newT3 = resampler.Execute(moving_image)
    # for i in range(moving_image.GetSize()[0]):
    #     for j in range(moving_image.GetSize()[1]):
    #         for k in range(moving_image.GetSize()[2]):
    #             if -1 < newT3.GetPixel(i, j, k) < 1:
    #                 newT3.SetPixel(i, j, k, fixed_image.GetPixel(i, j, k))
    newT3 = sitk.GetArrayFromImage(newT3)[:, :, 256]
    imsave(f'./ValImgOrig/newT3v2{idx}.jpg', np.flip(newT3, 0), cmap=cm.gray)
    path = pathlib.Path(fixed_image_paths[idx])
    parts = list(path.parts)
    #sitk.WriteImage(newT3, f'/home/cschellenberger/Documents/newT3v2/{parts[-1]}')

0


In [5]:
newT3v1 = newT3

In [7]:
newT3v2 = newT3

In [9]:
newT3v3 = newT3

In [16]:
from pyM2aia import M2aiaOnlineHelper
M2aiaHelper = M2aiaOnlineHelper("ipynbViewer", "jtfc.de:5050/m2aia/m2aia-no-vnc:with_exit", "8899")
with M2aiaHelper as helper:
    helper.show({"t1": moving_image, "t3": fixed_image, "newT3v0": newT3}) #"disp": displacement "newT3v1": newT3v1, "newT3v2": newT3v2, "newT3v3": newT3v3, 

You can find your images @  http://141.19.142.80:8899

