In [None]:
% load_ext autoreload
% autoreload 2

## Test convolutional PCA reconstruction

In [3]:
import os

import torch
import vedo

from headrecbaselines.models.PCAH import PCAH_Net
from headrecbaselines.utils.datasetHeads import MeshHeadsDataset

### Load test images

In [4]:
img_path = os.path.abspath('../datasets/cq500mesh')
path_test = os.path.join(img_path, 'test')
path_sim = os.path.join(path_test, 'sim_defects')
images_test = open("partitions/test.txt", 'r').read().splitlines()
images_sim = open("partitions/test_sim.txt", 'r').read().splitlines()

dtset_tst = MeshHeadsDataset(images_test, path_test, test=True)
dtset_tst_sim = MeshHeadsDataset(images_sim, path_sim, test=True)
print(
    f"Test images: {len(dtset_tst)}. Simulated defect skulls: {len(dtset_tst_sim)}")

Test images: 36. Simulated defect skulls: 36


In [5]:
interp_factor = 0.85
config = {
    'device': 'cuda:0',
    'latents': 60,
    'inputsize': 1024,
    'interpolate': True,
    'h': int(512 * interp_factor),
    'w': int(512 * interp_factor),
    'slices': int(233 * interp_factor),
}

pcaNet = PCAH_Net(config.copy()).to(config['device'])
pcaNet.load_state_dict(torch.load('trained/PCAH/bestMSE.pt'))
pcaNet.eval()
print('Model loaded')

Model loaded


### Predict on test and simulated imgs

In [6]:
# I can use any img since all .faces() are the same
ex_img = os.path.join(
    img_path, 'CQ500-CT-0_CT PLAIN THIN_decimated_1perc_dfm.vtk'
)
faces = vedo.Mesh(ex_img).faces()

out_pth_tst = os.path.join(path_test, f'PCA_im2mesh')  # Predictions subfolder
out_pth_sim = os.path.join(path_sim, f'PCA_im2mesh')  # Predictions subfolder
os.makedirs(out_pth_tst, exist_ok=True)
os.makedirs(out_pth_sim, exist_ok=True)

for samples, out_path, fnames in zip([dtset_tst, dtset_tst_sim],
                                     [out_pth_tst, out_pth_sim],
                                     [images_test, images_sim]):
    print(f"saving in: {out_path}.")
    for i, sample in enumerate(samples):
        test_im = sample['image'].unsqueeze(0)
        out_img = pcaNet(test_im.to(config['device']))
        points = out_img.reshape(-1, 3).detach().cpu().numpy()

        f_name = fnames[i].replace('.nii.gz', '.stl')
        out_fpath = os.path.join(out_path, f_name)

        restored_mesh = vedo.Mesh([points, faces])
        restored_mesh.write(out_fpath)
        print(f'  saved mesh {f_name}')
        del test_im, out_img

saving in: /media/fmatzkin/data/franco/Code/datasets/cq500mesh/test/PCA_im2mesh.
  saved mesh CQ500-CT-370_CT BONE THIN.stl
  saved mesh CQ500-CT-356_CT PLAIN THIN.stl
  saved mesh CQ500-CT-469_CT 0.625mm.stl
  saved mesh CQ500-CT-241_CT PLAIN THIN.stl
  saved mesh CQ500-CT-228_CT 4cc sec 150cc D3D on-2.stl
  saved mesh CQ500-CT-328_CT 0.625mm.stl
  saved mesh CQ500-CT-105_CT I To S.stl
  saved mesh CQ500-CT-67_CT BONE.stl
  saved mesh CQ500-CT-122_CT PRE CONTRAST THIN.stl
  saved mesh CQ500-CT-33_CT 4cc sec 150cc D3D on.stl
  saved mesh CQ500-CT-287_CT Thin Plain.stl
  saved mesh CQ500-CT-365_CT Thin Plain.stl
  saved mesh CQ500-CT-373_CT Thin Plain.stl
  saved mesh CQ500-CT-485_CT PLAIN THIN.stl
  saved mesh CQ500-CT-80_CT 0.625mm.stl
  saved mesh CQ500-CT-103_CT Thin Plain.stl
  saved mesh CQ500-CT-39_CT PRE CONTRAST THIN.stl
  saved mesh CQ500-CT-453_CT PLAIN THIN.stl
  saved mesh CQ500-CT-384_CT PLAIN THIN.stl
  saved mesh CQ500-CT-49_CT PRE CONTRAST THIN.stl
  saved mesh CQ500-CT

## Convert meshes to volumes

In [None]:
def binarize_franco(self, source_img, invert=False):
    if type(source_img) != sitk.Image:
        print('source_img must be a SimpleITK image')
        return

    pd = self.polydata()

    whiteImage = vtk.vtkImageData()
    whiteImage.SetDirectionMatrix(-1,0,0,0,-1,0,0,0,1)
    bounds = pd.GetBounds()

    whiteImage.SetSpacing(source_img.GetSpacing())

    dim = source_img.GetSize()
    whiteImage.SetDimensions(dim)
    whiteImage.SetExtent(0, dim[0] - 1, 0, dim[1] - 1, 0, dim[2] - 1)

    whiteImage.SetOrigin(source_img.GetOrigin())
    whiteImage.AllocateScalars(vtk.VTK_UNSIGNED_CHAR, 1)

    # fill the image with foreground voxels:
    if invert:
        inval = 0
    else:
        inval = 1
    count = whiteImage.GetNumberOfPoints()
    for i in range(count):
        whiteImage.GetPointData().GetScalars().SetTuple1(i, inval)

    # polygonal data --> image stencil:
    pol2stenc = vtk.vtkPolyDataToImageStencil()
    pol2stenc.SetInputData(pd)
    pol2stenc.SetOutputOrigin(whiteImage.GetOrigin())
    pol2stenc.SetOutputSpacing(whiteImage.GetSpacing())
    pol2stenc.SetOutputWholeExtent(whiteImage.GetExtent())
    pol2stenc.Update()

    # cut the corresponding white image and set the background:
    if invert:
        outval = 1
    else:
        outval = 0
    imgstenc = vtk.vtkImageStencil()
    imgstenc.SetInputData(whiteImage)
    imgstenc.SetStencilConnection(pol2stenc.GetOutputPort())
    imgstenc.SetReverseStencil(invert)
    imgstenc.SetBackgroundValue(outval)
    imgstenc.Update()
    return vedo.Volume(imgstenc.GetOutput())

In [1]:
import vedo
import SimpleITK as sitk
import numpy as np
from headctools.preprocessing.utils import fixed_pad_sitk
import os

jp = os.path.join
im_p = os.path.abspath('../datasets/cq500mesh/test/')
preds_sim_fld = jp(im_p, 'sim_defects/PCA_im2mesh/')

src_Img = sitk.ReadImage('/home/fmatzkin/Code/datasets/cq500mesh/test/CQ500-CT-62_CT Thin Plain.nii.gz')

inp_mesh = vedo.Mesh(jp(preds_sim_fld, 'CQ500-CT-12_CT Thin Plain0_sim.stl'))
binarized_vol = inp_mesh.binarize_franco(src_Img)

In [5]:
array = binarized_vol.tonumpy().astype(np.uint8)
array = np.transpose(array, (2, 1, 0))
array.shape

(233, 512, 512)

In [6]:
vol_sitk = sitk.GetImageFromArray(array)

vol_sitk.CopyInformation(src_Img)
sitk.WriteImage(vol_sitk, '/home/fmatzkin/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-12_CT Thin Plain0_sim.nii')
# binarized_vol.write('/home/fmatzkin/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-62_CT Thin Plain0_sim.nii')

In [63]:
sitkimg = sitk.GetImageFromArray(binarized_vol.tonumpy())
sitkimg.SetSpacing(src_Img.GetSpacing())
sitkimg.SetOrigin(src_Img.GetOrigin())
sitk.Show(sitkimg)

In [None]:
vol_sitk = fixed_pad_sitk(vol_sitk, tuple(np.roll(src_Img.GetSize(), 1)))
vol_sitk.CopyInformation(src_Img)
sitk.WriteImage(
    vol_sitk,
    '/home/fmatzkin/Code/datasets/cq500mesh/test/sim_defects/PCA_im2mesh/CQ500-CT-62_CT Thin Plain0_sim_vedo.nii.gz'
)

In [None]:
vol_sitk = fixed_pad_sitk(vol_sitk, tuple(np.roll(src_Img.GetSize(), 1)))

In [66]:
from headctools.tools.img_processing import FillDefects

ch = FillDefects.convex_hull

src_ImgCh = ch(src_Img)

In [67]:
sitk.Show(src_Img)
sitk.Show(src_ImgCh)

MetaImageIO (0x55f6a5d18800): Unsupported or empty metaData item ITK_FileNotes of type Ssfound, won't be written to image file

MetaImageIO (0x55f6a5d18800): Unsupported or empty metaData item aux_file of type Ssfound, won't be written to image file

MetaImageIO (0x55f6a5d18800): Unsupported or empty metaData item descrip of type Ssfound, won't be written to image file

MetaImageIO (0x55f6a5d18800): Unsupported or empty metaData item intent_name of type Ssfound, won't be written to image file



In [38]:
import SimpleITK as sitk
import numpy as np
import trimesh

malla = trimesh.load_mesh('CQ500-CT-62_CT Thin Plain0_sim.stl')
src_Img = sitk.ReadImage('CQ500-CT-62_CT Thin Plain.nii.gz')

# Filling options: base, holes, orthographic
v = malla.voxelized(1).fill('base')
# Raises "Can only export binvox with uniform scale"
v.export('CQ500-CT-62_CT Thin Plain0_sim.binvox')

vol_sitk = sitk.GetImageFromArray(v.matrix.astype(np.uint8))
# o_im.CopyInformation(im_o)
sitk.WriteImage(vol_sitk, 'CQ500-CT-62_CT Thin Plain0_sim_base.nii.gz')