In [None]:
import vtk
import numpy as np
import pandas as pd
import open3d as o3d
import copy

from CardioMesh.CardiacMesh import Cardiac3DMesh
from scipy import sparse as sp
import ipywidgets as widgets
from ipywidgets import interact
from trimesh import Trimesh

___

## Global registration

1. Load source and target meshes, see which format is necessary.
2. Compute features.
3. Align shapes

In [None]:
def draw_registration_result(source, target, transformation):
    
    source_temp = copy.deepcopy(source)
    target_temp = copy.deepcopy(target)
    source_temp.paint_uniform_color([1, 0.706, 0])
    target_temp.paint_uniform_color([0, 0.651, 0.929])
    source_temp.transform(transformation)
    
    o3d.visualization.draw_geometries(
        [source_temp, target_temp],
        zoom=0.4559,
        front=[0.6452, -0.3036, -0.7011],
        lookat=[1.9892, 2.0208, 1.8945],
        up=[-0.2779, -0.9482, 0.1556]
    )

In [None]:
def preprocess_point_cloud(pcd, voxel_size):
    
    print(":: Downsample with a voxel size %.3f." % voxel_size)
    
    pcd_down = pcd.voxel_down_sample(voxel_size)

    radius_normal = voxel_size * 2
    print(":: Estimate normal with search radius %.3f." % radius_normal)
    pcd_down.estimate_normals(
        o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30))

    radius_feature = voxel_size * 5
    
    print(":: Compute FPFH feature with search radius %.3f." % radius_feature)    
    pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature(
        pcd_down,
        o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=100)
    )
    
    return pcd_down, pcd_fpfh

In [None]:
def execute_global_registration(source, target, source_fpfh, target_fpfh, voxel_size):

    distance_threshold = voxel_size * 1.5

    checkers = [
      o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
      o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold)
    ]

    result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching(
         source=source,
         target=target,
         source_feature=source_fpfh,
         target_feature=target_fpfh,
         mutual_filter=True,
         max_correspondence_distance=distance_threshold,
         estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(False),
         ransac_n=3,
         checkers=checkers,
         criteria=o3d.pipelines.registration.RANSACConvergenceCriteria(100000, 0.999)
    )
    
    return result

In [None]:
# @interact
# def show_partition(partition=widgets.Select(options=partitions.keys(), value="left_ventricle")):
#     
#     partition = partitions[partition]
#     print(partition)
#     partition_mesh = fhm_mesh[partition]
#     
#     return Trimesh(partition_mesh.v, partition_mesh.f).show()

### Build target mesh

In [None]:
ID = "1000215"

fhm_mesh = Cardiac3DMesh(
    filename=f"/home/rodrigo/01_repos/CardiacCOMA/data/cardio/meshes/by_id/{ID}/models/FHM_time001.npy",
    faces_filename="/home/rodrigo/01_repos/CardioMesh/data/faces_fhm_10pct_decimation.csv",
    subpart_id_filename="/home/rodrigo/01_repos/CardioMesh/data/subpartIDs_FHM_10pct.txt"
)

subpart_df = pd.read_csv("/home/rodrigo/01_repos/CardioMesh/data/subpartIDs_FHM_10pct.txt", header=None)
subpart_df.columns = ["partition"]

small_subpart_names = ["LV", "AVP", "LA", "MVP", "RV", "PVP", "PV1", "PV2", "PV3", "PV4", "PV5", "RA", "TVP", "PV6", "PV7", "aorta"]
for subpart_name in small_subpart_names:
    subpart_df[subpart_name] = subpart_df.partition == subpart_name
    
lv_subparts = ("LV", "AVP") # , "RV", "PVP", "TVP")    
lvrv_subparts = ("LV", "AVP", "RV", "PVP", "TVP")    

lv_subparts = subpart_df[list(lv_subparts)].apply(any, axis=1)
lvrv_subparts = subpart_df[list(lvrv_subparts)].apply(any, axis=1)    

In [None]:
# LV
col_ind = lv_subparts.index[lv_subparts].to_list()
row_ind = list(range(len(col_ind)))

lv_subsetting_mtx = sp.csc_matrix(
  (np.ones(len(col_ind)), (row_ind, col_ind)), 
  shape=(len(col_ind), subpart_df.shape[0])
)  

lv_precenter = lv_subsetting_mtx * fhm_mesh.v
lv_poscenter = lv_precenter - lv_precenter.mean(0)
lv_scaled = lv_poscenter / np.sqrt((lv_poscenter**2).sum(1).mean())

lv_target_pcd = o3d.geometry.PointCloud()
lv_target_pcd.points = o3d.utility.Vector3dVector(lv_scaled)

# LV+RV
col_ind = lvrv_subparts.index[lvrv_subparts].to_list()
row_ind = list(range(len(col_ind)))

lvrv_subsetting_mtx = sp.csc_matrix(
  (np.ones(len(col_ind)), (row_ind, col_ind)), 
  shape=(len(col_ind), subpart_df.shape[0])
)  

lvrv_target_pcd = o3d.geometry.PointCloud()
lvrv = lvrv_subsetting_mtx * fhm_mesh.v
lvrv -= lv_precenter.mean(0)
lvrv = lvrv / np.sqrt((lv_poscenter**2).sum(1).mean())
lvrv_target_pcd.points = o3d.utility.Vector3dVector(lvrv)

### Build source mesh

In [None]:
FILE = "/home/rodrigo/01_repos/CardioMesh/data/template/myo_ED_AHA17.vtk"

reader = vtk.vtkPolyDataReader()
reader.SetFileName(FILE)
reader.Update()
mesh = reader.GetOutput()

n_points = mesh.GetNumberOfPoints()
lv_source_points = np.array([mesh.GetPoint(i) for i in range(n_points)])
lv_source_points_precenter = copy.copy(lv_source_points)
lv_source_points -= lv_source_points.mean(0)
lv_source_points = lv_source_points / np.sqrt(((lv_source_points**2).sum(1).mean()))

# Set the point cloud's points using the NumPy array
lv_source_pcd = o3d.geometry.PointCloud()
lv_source_pcd.points = o3d.utility.Vector3dVector(lv_source_points)
lv_source_pcd

In [None]:
FILE = "/home/rodrigo/01_repos/CardioMesh/data/template/heart_ED.vtk" # myo_ED_AHA17.vtk"

reader = vtk.vtkPolyDataReader()
reader.SetFileName(FILE)
reader.Update()
mesh = reader.GetOutput()

n_points = mesh.GetNumberOfPoints()
lvrv_source_points = np.array([mesh.GetPoint(i) for i in range(n_points)])
lvrv_source_points -= lv_source_points_precenter.mean(0)
lvrv_source_points = lvrv_source_points / np.sqrt(((lvrv_source_points**2).sum(1).mean()))

# Set the point cloud's points using the NumPy array
lvrv_source_pcd = o3d.geometry.PointCloud()
lvrv_source_pcd.points = o3d.utility.Vector3dVector(lvrv_source_points)
lvrv_source_pcd

In [None]:
def prepare_dataset(source, target, voxel_size):
    
    # Print information about loading point clouds and initial pose disturbance
    print("Loading two point clouds and disturbing the initial pose.")

    # Load the point clouds and disturb the initial pose
    print(source.points)
    print(target.points)
    trans_init = np.identity(4)  # Identity transformation matrix
    # trans_init[:3, :3] = np.array([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
    source.transform(trans_init)
    
    # Visualize the registration result
    #  draw_registration_result(source, target, np.identity(4))

    # Preprocess the point clouds
    source_down, source_fpfh = preprocess_point_cloud(source, voxel_size)
    target_down, target_fpfh = preprocess_point_cloud(target, voxel_size)
    
    return source, target, source_down, target_down, source_fpfh, target_fpfh


In [None]:
voxel_size = 0.2 # means 5cm for this dataset

source, target = lv_source_pcd, lv_target_pcd

source, target, source_down, target_down, source_fpfh, target_fpfh = prepare_dataset(source, target, voxel_size)
result_ransac1 = execute_global_registration(source_down, target_down, source_fpfh, target_fpfh, voxel_size)

source, target = copy.copy(lvrv_source_pcd).transform(result_ransac1.transformation), lvrv_target_pcd
source, target, source_down, target_down, source_fpfh, target_fpfh = prepare_dataset(source, target, voxel_size)
result_ransac2 = execute_global_registration(source_down, target_down, source_fpfh, target_fpfh, voxel_size)

draw_registration_result(lv_source_pcd, lv_target_pcd, result_ransac1.transformation)
draw_registration_result(source, target, result_ransac2.transformation)
# print(result_ransac)

In [None]:
kk = copy.copy(lv_source_pcd)
kk.transform(result_ransac1.transformation)
kk.transform(result_ransac2.transformation)
np.asarray(kk.points).dump("lv_source.npy")

kk = copy.copy(lv_target_pcd)
# kk.transform(result_ransac2.transformation)
np.asarray(kk.points).dump("lv_target.npy")

kk = copy.copy(lvrv_source_pcd)
kk.transform(result_ransac1.transformation)
kk.transform(result_ransac2.transformation)
np.asarray(kk.points).dump("lvrv_source.npy")

kk = copy.copy(lvrv_target_pcd)
# kk.transform(result_ransac2.transformation)
np.asarray(kk.points).dump("lvrv_target.npy")

In [None]:
LV_MESH_SOURCE = "/home/rodrigo/01_repos/CardioMesh/data/template/myo_ED_AHA17.vtk"
LVRV_MESH_SOURCE = "/home/rodrigo/01_repos/CardioMesh/data/template/heart_ED.vtk"

LV_PARTITIONS = ("LV", "AVP")
LVRV_PARTITIONS = ("LV", "AVP", "RV", "PVP", "TVP")

In [None]:
BASEFN = "mallas/lv"
PARTITIONS = LV_PARTITIONS
SOURCEFILE = LV_MESH_SOURCE

# BASEFN = "mallas/lvrv"
# PARTITIONS = LVRV_PARTITIONS
# SOURCEFILE = LVRV_MESH_SOURCE

In [None]:
def map_elements_to_integers(elements):
    element_to_integer = {}
    integer_to_element = {}
    
    for i, element in enumerate(pd.Series(elements).unique()):
        element_to_integer[element] = i
        integer_to_element[i] = element
    
    return element_to_integer, integer_to_element

In [None]:
source_mesh = Cardiac3DMesh(SOURCEFILE)

source_mesh.points = np.load(open(f"{BASEFN}_source.npy", "rb"), allow_pickle=True)
target_mesh.points = np.load(open(f"{BASEFN}_target.npy", "rb"), allow_pickle=True)

target_mesh = Cardiac3DMesh(
    filename=f"/home/rodrigo/01_repos/CardiacCOMA/data/cardio/meshes/by_id/{ID}/models/FHM_time001.npy",
    faces_filename="/home/rodrigo/01_repos/CardioMesh/data/faces_fhm_10pct_decimation.csv",
    subpart_id_filename="/home/rodrigo/01_repos/CardioMesh/data/subpartIDs_FHM_10pct.txt"
)[PARTITIONS]

save_mesh_as_vtk(target_mesh.v - target_mesh.v.mean(0), target_mesh.f, "lv_target.vtk")

source_mesh.v.dump(f"{BASEFN}_source_verts.npy")
source_mesh.f.dump(f"{BASEFN}_source_faces.npy")

target_mesh.v.dump(f"{BASEFN}_target_verts.npy")
target_mesh.f.dump(f"{BASEFN}_target_faces.npy")

# save_mesh_as_vtk(source_mesh.v, source_mesh.f, f"{BASEFN}_target.vtk")
# source_mesh.save_to_vtk(f"{BASEFN}_source.vtk")

# element_to_integer, _ = map_elements_to_integers(list(target_mesh.subpartID))
# target_mesh.subpartID = pd.Series(target_mesh.subpartID).apply(lambda x: element_to_integer[x]).to_list()
# save_mesh_as_vtk(target_mesh.v, target_mesh.f, f"{BASEFN}_target.vtk")

In [None]:
def save_mesh_as_vtk(vertices, faces, filename):
    # Create a vtkPoints object to store vertices
    points = vtk.vtkPoints()

    # Add vertices to vtkPoints
    for vertex in vertices:
        points.InsertNextPoint(vertex)

    # Create a vtkCellArray to store faces
    cells = vtk.vtkCellArray()

    # Add faces to vtkCellArray
    for face in faces:
        cell = vtk.vtkTriangle()
        for i in range(3):
            cell.GetPointIds().SetId(i, face[i])
        cells.InsertNextCell(cell)

    # Create a vtkPolyData object and set points and cells
    polydata = vtk.vtkPolyData()
    polydata.SetPoints(points)
    polydata.SetPolys(cells)

    # Create a writer to save the VTK file
    writer = vtk.vtkPolyDataWriter()
    writer.SetFileName(filename)
    writer.SetInputData(polydata)

    # Write the VTK file
    writer.Write()

    
def read_vtk_file(filename):
    
    # Create a VTK reader
    reader = vtk.vtkPolyDataReader()
    reader.SetFileName(filename)
    
    # Read the VTK file
    reader.Update()

    # Get the vtkPolyData
    polydata = reader.GetOutput()

    # Get the points (vertices)
    vtk_points = polydata.GetPoints()
    print(polydata)
    num_points = vtk_points.GetNumberOfPoints()
    vertices = np.zeros((num_points, 3))

    for i in range(num_points):
        vertices[i, :] = vtk_points.GetPoint(i)

    # Get the cells (faces)    
    vtk_cells = polydata.GetPolys()
    # print(vtk_cells)
    num_cells = vtk_cells.GetNumberOfCells()
    faces = []

    for i in range(num_cells):
        cell = vtk_cells.GetNextCell(3)  # Assuming triangles (3 vertices per cell)
        face = [cell.GetPointId(0), cell.GetPointId(1), cell.GetPointId(2)]
        faces.append(face)

    # Convert faces to a NumPy array
    faces = np.array(faces)

    return vertices, faces    

In [None]:
# Example usage:
vertices, faces = read_vtk_file(f"{BASEFN}_source.vtk")
print("Vertices:\n", vertices)
print("Faces:\n", faces)

In [None]:
reader = vtk.vtkUnstructuredGrid()
reader.SetFileName(f"{BASEFN}_target.vtk")
reader.Update()
mesh = reader.GetOutput()

n_points = mesh.GetNumberOfPoints()
n_points

___

In [None]:
# Load the two VTK files using the appropriate reader for the file format
reader1 = vtk.vtkPolyDataReader()
reader1.SetFileName("/home/rodrigo/01_repos/CardioMesh/data/template/endo_ED_AHA17.vtk")
reader1.Update()
mesh1 = reader1.GetOutput()

reader2 = vtk.vtkPolyDataReader()
reader2.SetFileName("/home/rodrigo/01_repos/CardioMesh/data/template/epi_ED_AHA17.vtk")
reader2.Update()
mesh2 = reader2.GetOutput()

# Create a vtkAppendPolyData object and add the two loaded meshes as inputs
appendFilter = vtk.vtkAppendPolyData()
appendFilter.AddInputData(mesh1)
appendFilter.AddInputData(mesh2)

# Create a vtkPointData object for the output mesh
outputPointData = vtk.vtkPointData()

# Loop through each input mesh and add its point data to the output mesh using the vtkPointData object
for i in range(appendFilter.GetNumberOfInputPorts()):
    inputPointData = appendFilter.GetInputDataObject(i, 0).GetPointData()
    for j in range(inputPointData.GetNumberOfArrays()):
        array = inputPointData.GetArray(j)
        outputPointData.AddArray(array)

# Update the vtkAppendPolyData object to generate the merged output mesh
appendFilter.Update()
mergedMesh = appendFilter.GetOutput()

# Write the merged mesh to a VTK file using the appropriate writer for the desired file format
writer = vtk.vtkPolyDataWriter()
writer.SetFileName("mergedMesh.vtk")
writer.SetInputData(mergedMesh)
writer.Write()

In [None]:
pp = mesh1.GetPointData().GetAbstractArray(0)
subpartID = [pp.GetValue(i) for i in range(n_points)]