In [None]:
% load_ext autoreload
% autoreload 2

## Test convolutional PCA reconstruction

In [3]:
import os

import torch
import vedo

from headrecbaselines.models.PCAH import PCAH_Net
from headrecbaselines.utils.datasetHeads import MeshHeadsDataset

### Load test images

In [4]:
img_path = os.path.abspath('../datasets/cq500mesh')
path_test = os.path.join(img_path, 'test')
path_sim = os.path.join(path_test, 'sim_defects')
images_test = open("partitions/test.txt", 'r').read().splitlines()
images_sim = open("partitions/test_sim.txt", 'r').read().splitlines()

dtset_tst = MeshHeadsDataset(images_test, path_test, test=True)
dtset_tst_sim = MeshHeadsDataset(images_sim, path_sim, test=True)
print(
    f"Test images: {len(dtset_tst)}. Simulated defect skulls: {len(dtset_tst_sim)}")

Test images: 36. Simulated defect skulls: 36


In [5]:
interp_factor = 0.85
config = {
    'device': 'cuda:0',
    'latents': 60,
    'inputsize': 1024,
    'interpolate': True,
    'h': int(512 * interp_factor),
    'w': int(512 * interp_factor),
    'slices': int(233 * interp_factor),
}

pcaNet = PCAH_Net(config.copy()).to(config['device'])
pcaNet.load_state_dict(torch.load('trained/PCAH/bestMSE.pt'))
pcaNet.eval()
print('Model loaded')

Model loaded


### Predict on test and simulated imgs

In [6]:
# I can use any img since all .faces() are the same
ex_img = os.path.join(
    img_path, 'CQ500-CT-0_CT PLAIN THIN_decimated_1perc_dfm.vtk'
)
faces = vedo.Mesh(ex_img).faces()

out_pth_tst = os.path.join(path_test, f'PCA_im2mesh')  # Predictions subfolder
out_pth_sim = os.path.join(path_sim, f'PCA_im2mesh')  # Predictions subfolder
os.makedirs(out_pth_tst, exist_ok=True)
os.makedirs(out_pth_sim, exist_ok=True)

for samples, out_path, fnames in zip([dtset_tst, dtset_tst_sim],
                                     [out_pth_tst, out_pth_sim],
                                     [images_test, images_sim]):
    print(f"saving in: {out_path}.")
    for i, sample in enumerate(samples):
        test_im = sample['image'].unsqueeze(0)
        out_img = pcaNet(test_im.to(config['device']))
        points = out_img.reshape(-1, 3).detach().cpu().numpy()

        f_name = fnames[i].replace('.nii.gz', '.stl')
        out_fpath = os.path.join(out_path, f_name)

        restored_mesh = vedo.Mesh([points, faces])
        restored_mesh.write(out_fpath)
        print(f'  saved mesh {f_name}')
        del test_im, out_img

saving in: /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/PCA_im2mesh.
  saved mesh CQ500-CT-370_CT BONE THIN.stl
  saved mesh CQ500-CT-356_CT PLAIN THIN.stl
  saved mesh CQ500-CT-469_CT 0.625mm.stl
  saved mesh CQ500-CT-241_CT PLAIN THIN.stl
  saved mesh CQ500-CT-228_CT 4cc sec 150cc D3D on-2.stl
  saved mesh CQ500-CT-328_CT 0.625mm.stl
  saved mesh CQ500-CT-105_CT I To S.stl
  saved mesh CQ500-CT-67_CT BONE.stl
  saved mesh CQ500-CT-122_CT PRE CONTRAST THIN.stl
  saved mesh CQ500-CT-33_CT 4cc sec 150cc D3D on.stl
  saved mesh CQ500-CT-287_CT Thin Plain.stl
  saved mesh CQ500-CT-365_CT Thin Plain.stl
  saved mesh CQ500-CT-373_CT Thin Plain.stl
  saved mesh CQ500-CT-485_CT PLAIN THIN.stl
  saved mesh CQ500-CT-80_CT 0.625mm.stl
  saved mesh CQ500-CT-103_CT Thin Plain.stl
  saved mesh CQ500-CT-39_CT PRE CONTRAST THIN.stl
  saved mesh CQ500-CT-453_CT PLAIN THIN.stl
  saved mesh CQ500-CT-384_CT PLAIN THIN.stl
  saved mesh CQ500-CT-49_CT PRE CONTRAST THIN.stl
  saved mesh CQ500-CT

## Convert meshes to volumes
Using as a reference the image CQ500-CT-0_CT PLAIN THIN.nii.gz of the training split (for dims/metadata), binarize the meshes and save them as volumes.

In [8]:
import vedo

mesh_pth = 'test.vtk'  # Mesh to binarize
refv_path = 'test.nii.gz'  # Reference volume

mesh = vedo.Mesh(mesh_pth)
volm = vedo.Volume(refv_path)

spacing = volm._data.GetSpacing()
# volm._data.GetOrigin() gives 0,0,0 origin when it's not this one :(
origin = (-110.5, -135.6, -4)
direction_matrix = (-1, 0, 0, 0, -1, 0, 0, 0, 1)
fg_val = 1
bg_val = 0
image_size = volm.tonumpy().shape

bin_vol = mesh.binarize(spacing, False, direction_matrix, image_size, origin, fg_val, bg_val)
bin_vol.write('bin_vol.nii')

In [3]:
origin

(0.0, 0.0, 0.0)

In [2]:
from headrecbaselines.utils.utils import mesh2vol
import os

absp = os.path.abspath
m_fld = absp('../datasets/cq500mesh/test/sim_defects/PCA_im2mesh/')
ref_im = absp('../datasets/cq500mesh/CQ500-CT-0_CT PLAIN THIN.nii.gz')
meshes_paths = [os.path.join(m_fld, f)
                for f in os.listdir(m_fld) if f.endswith('.stl')]

for mesh in meshes_paths:
    mesh2vol(mesh, ref_im)

Binarizing /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-47_CT PRE CONTRAST THIN0_sim.stl..  file /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-47_CT PRE CONTRAST THIN0_sim_binMesh.nii.gz already exists.
Binarizing /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-49_CT PRE CONTRAST THIN0_sim.stl..  file /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-49_CT PRE CONTRAST THIN0_sim_binMesh.nii.gz already exists.
Binarizing /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-462_CT Thin Plain0_sim.stl..  file /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-462_CT Thin Plain0_sim_binMesh.nii.gz already exists.
Binarizing /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-356_CT PLAIN THIN0_sim.stl..  file 

### Voxelize with trimesh (not working with our skulls)

In [38]:
import SimpleITK as sitk
import numpy as np
import trimesh

malla = trimesh.load_mesh('CQ500-CT-62_CT Thin Plain0_sim.stl')
src_Img = sitk.ReadImage('CQ500-CT-62_CT Thin Plain.nii.gz')

# Filling options: base, holes, orthographic
v = malla.voxelized(1).fill('base')
# Raises "Can only export binvox with uniform scale"
v.export('CQ500-CT-62_CT Thin Plain0_sim.binvox')

vol_sitk = sitk.GetImageFromArray(v.matrix.astype(np.uint8))
# o_im.CopyInformation(im_o)
sitk.WriteImage(vol_sitk, 'CQ500-CT-62_CT Thin Plain0_sim_base.nii.gz')