In [1]:
%load_ext autoreload
%autoreload 2

## Test convolutional PCA reconstruction

In [1]:
import os

import torch
import vedo

from headrecbaselines.models.PCAH import PCAH_Net
from headrecbaselines.utils.datasetHeads import MeshHeadsDataset

### Load test images

In [2]:
path_imgs = os.path.abspath('../datasets/cq500mesh')
path_test = os.path.join(path_imgs, 'test')
path_sim = os.path.join(path_test, 'sim_defects')
images_test = open("partitions/test.txt", 'r').read().splitlines()
images_sim = open("partitions/test_sim.txt", 'r').read().splitlines()

dtset_tst = MeshHeadsDataset(images_test, path_test, test=True)
dtset_tst_sim = MeshHeadsDataset(images_sim, path_sim, test=True)
print(f"Test images: {len(dtset_tst)}. "
      f"Simulated defect skulls: {len(dtset_tst_sim)}")

Test images: 36. Simulated defect skulls: 36


In [3]:
interp_factor = 0.85
config = {
    'device': 'cuda',
    'latents': 60,
    'inputsize': 1024,
    'interpolate': True,
    'h': int(512 * interp_factor),
    'w': int(512 * interp_factor),
    'slices': int(233 * interp_factor),
}

pcaNet = PCAH_Net(config.copy()).to(config['device'])
pcaNet.load_state_dict(torch.load('trained/PCAH/bestMSE.pt'))
pcaNet.eval()
print('Model loaded')

Model loaded


### Predict on test and simulated imgs

In [4]:
# I can use any img since all .faces() are the same
ex_img = os.path.join(
    path_imgs, 'CQ500-CT-0_CT PLAIN THIN_decimated_1perc_dfm.vtk'
)
faces = vedo.Mesh(ex_img).faces()

out_pth_tst = os.path.join(path_test, f'PCA_im2mesh')  # Predictions subfolder
out_pth_sim = os.path.join(path_sim, f'PCA_im2mesh')  # Predictions subfolder
os.makedirs(out_pth_tst, exist_ok=True)
os.makedirs(out_pth_sim, exist_ok=True)

for samples, out_path, fnames in zip([dtset_tst, dtset_tst_sim],
                                     [out_pth_tst, out_pth_sim],
                                     [images_test, images_sim]):
    print(f"saving in: {out_path}.")
    for i, sample in enumerate(samples):
        test_im = sample['image'].unsqueeze(0)
        out_img = pcaNet(test_im.to(config['device']))
        points = out_img.reshape(-1, 3).detach().cpu().numpy()

        f_name = fnames[i].replace('.nii.gz', '.stl')
        out_fpath = os.path.join(out_path, f_name)

        restored_mesh = vedo.Mesh([points, faces])
        restored_mesh.write(out_fpath)
        print(f'  saved mesh {f_name}')
        del test_im, out_img

saving in: /home/franco/Code/datasets/cq500mesh/test/PCA_im2mesh.


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.12 GiB (GPU 0; 1.96 GiB total capacity; 561.36 MiB already allocated; 934.38 MiB free; 574.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

## Convert meshes to volumes
Using as a reference the image CQ500-CT-0_CT PLAIN THIN.nii.gz of the training split (for dims/metadata), binarize the meshes and save them as volumes.

In [8]:
import vedo

mesh_pth = 'test.vtk'  # Mesh to binarize
refv_path = 'test.nii.gz'  # Reference volume

mesh = vedo.Mesh(mesh_pth)
volm = vedo.Volume(refv_path)

spacing = volm._data.GetSpacing()
# volm._data.GetOrigin() gives 0,0,0 origin when it's not this one :(
origin = (-110.5, -135.6, -4)
direction_matrix = (-1, 0, 0, 0, -1, 0, 0, 0, 1)
fg_val = 1
bg_val = 0
image_size = volm.tonumpy().shape

bin_vol = mesh.binarize(spacing, False, direction_matrix, image_size, origin, fg_val, bg_val)
bin_vol.write('bin_vol.nii')

In [2]:
from headrecbaselines.utils.utils import mesh2vol
import os

absp = os.path.abspath
m_fld = absp('../datasets/cq500mesh/test/sim_defects/PCA_im2mesh/')
ref_im = absp('../datasets/cq500mesh/CQ500-CT-0_CT PLAIN THIN.nii.gz')
meshes_paths = [os.path.join(m_fld, f)
                for f in os.listdir(m_fld) if f.endswith('.stl')]

for mesh in meshes_paths:
    mesh2vol(mesh, ref_im)

Binarizing /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-47_CT PRE CONTRAST THIN0_sim.stl..  file /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-47_CT PRE CONTRAST THIN0_sim_binMesh.nii.gz already exists.
Binarizing /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-49_CT PRE CONTRAST THIN0_sim.stl..  file /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-49_CT PRE CONTRAST THIN0_sim_binMesh.nii.gz already exists.
Binarizing /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-462_CT Thin Plain0_sim.stl..  file /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-462_CT Thin Plain0_sim_binMesh.nii.gz already exists.
Binarizing /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-356_CT PLAIN THIN0_sim.stl..  file 

### Align the predictions to the inputs

The prections may be disaligned in the edges, so with this I warp the output mesh to the input

In [1]:
import os
from headctools.preprocessing import utils

preds_folder = os.path.abspath('../datasets/cq500mesh/test/sim_defects/PCA_im2mesh/')
vols_folder = os.path.abspath('../datasets/cq500mesh/test/sim_defects/')

In [2]:
# Convert the input volumes to meshes so I can align the output meshes later
utils.nii_to_stl_marching_cubes(vols_folder, vols_folder, False, '_sim.nii.gz')
paths_inp_mshs = [os.path.join(vols_folder, f) for f in os.listdir(vols_folder)
                  if f.endswith('_sim.stl')]

    Saved mesh in /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/CQ500-CT-392_CT PLAIN THIN0_sim.stl
    Saved mesh in /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/CQ500-CT-49_CT PRE CONTRAST THIN0_sim.stl
    Saved mesh in /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/CQ500-CT-287_CT Thin Plain0_sim.stl
    Saved mesh in /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/CQ500-CT-195_CT PRE CONTRAST THIN0_sim.stl
    Saved mesh in /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/CQ500-CT-373_CT Thin Plain0_sim.stl
    Saved mesh in /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/CQ500-CT-444_CT Thin Plain0_sim.stl
    Saved mesh in /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/CQ500-CT-103_CT Thin Plain0_sim.stl
    Saved mesh in /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/CQ500-CT-384_CT PLAIN THIN0_sim.stl
   

In [10]:
import vedo
import SimpleITK as sitk
from headrecbaselines.utils.VedoMorpher3 import fit_mesh_dmap
import os

basef = os.path.expanduser('~/Code/datasets/cq500mesh/')
pred = basef + 'test/sim_defects/PCA_im2mesh/CQ500-CT-444_CT Thin Plain0_sim.stl'
inp = basef + 'test/sim_defects/CQ500-CT-444_CT Thin Plain0_sim_decimated_1perc.vtk'
# reg = basef + 'reg_CQ500-CT-444_CT Thin Plain0_sim_decimated_1perc/DeterministicAtlas__Reconstruction__skull__subject_CQ500-CT-444_CT Thin Plain0_sim_decimated_1perc.vtk'
distmap = basef + 'test/sim_defects/CQ500-CT-444_CT Thin Plain0_sim_distmap.nii.gz'

predmesh = vedo.Mesh(pred)
inpmesh = vedo.Mesh(inp)
# regmesh = vedo.Mesh(reg).decimate(.1)
distmap = sitk.ReadImage(distmap)

In [11]:
predmesh.write('predmesh_b.vtk')
predmesh.align_to(inpmesh, use_centroids=True)
predmesh.write('predmesh_a_c.vtk')
inpmesh.write('inpmesh.vtk')

<Mesh(0x561526619860) at 0x7ff233f45580>

In [7]:
pm = {
        'dm_threshold': 0,  # distance map threshold
        'save_path': pred.replace('.stl', f'_morphed.stl'),  # path to save the morphed mesh
        'no_angle' : 35,  # Angle threshold for normals
        'n_points' : 100,  # Number of points to use for the closest point search
        
    }

mr = fit_mesh_dmap(predmesh, inpmesh, distmap, pm)

No near point found point 648
No near point found point 1057
No near point found point 1397
No near point found point 1908
Morphed mesh saved to /home/franco/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-444_CT Thin Plain0_sim_morphed.stl


In [10]:
# Draw an arrow for each normal of the inpmesh
arrs = []
for i, pt in enumerate(inpmesh.points()):
    arrs.append([pt, pt + inpmesh.normals()[i] * 10])
v_arrs = vedo.Arrows(arrs)
vedo.show(inpmesh, v_arrs).close()

In [8]:
# Draw an arrow for each normal of the prediction meshq
arrs = []
for i, pt in enumerate(predmesh.points()):
    arrs.append([pt, pt + predmesh.normals()[i] * 10])
v_arrs = vedo.Arrows(arrs)
vedo.show(predmesh, v_arrs).close()

In [9]:
v_arrs.save(basef + 'test/sim_defects/PCA_im2mesh/CQ500-CT-444_CT Thin Plain0_sim_normals.vtk')

AttributeError: 'Arrows' object has no attribute 'save'

In [4]:
# Draw an arrow for each normal of the morphed mesh
arrs = []
for i, pt in enumerate(mr.morphed.points()):
    arrs.append([pt, pt + mr.morphed.normals()[i] * 10])
v_arrs = vedo.Arrows(arrs)
vedo.show(mr.morphed, v_arrs).close()