In [None]:
import pyssam
from pathlib import Path
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt
import pyvista as pv
import point_cloud_utils as pcu
from copy import copy

In [None]:
basepath = Path('/home/simoneponcioni/Documents/01_PHD/03_Methods/HFE/01_DATA/TIBIA')
landmarkpath = Path('/home/simoneponcioni/Documents/01_PHD/03_Methods/HFE/QMSKI/SSAM/') / 'landmarks_ext'
landmarkpath = Path('/home/simoneponcioni/Documents/01_PHD/03_Methods/HFE/QMSKI/SSAM/') / 'landmarks_int'
# for all subdirs, get all '_CORTMASK' files
cort_list = []
trab_list = []
for file in basepath.iterdir():
    if file.is_file() and '_CORTMASK' in file.name and 'mhd' in file.suffix:
        trab_list.append(file)

# for all subdirs, get all '_TRABMASK' files
trab_list = []
for file in basepath.iterdir():
    if file.is_file() and '_TRABMASK' in file.name and 'mhd' in file.suffix:
        trab_list.append(file)

# # for each file, create a numpy array and save it as a .npy file in a temp dir
# landmark_files = []
# for file in cort_list:
#     sitk_image = sitk.ReadImage(str(file))
#     np_image = sitk.GetArrayFromImage(sitk_image)
#     landmark_files.append(landmarkpath / file.with_suffix('.npy').name)
#     np.save(landmarkpath / file.with_suffix('.npy').name, np_image)

In [None]:
num_voxels_per_axis = 4096
num_landmarks = 1024

landmarks_list = []
for file in trab_list:
    print(file)
    img = pv.read(str(file))
    point_items = img.point_data.items()
    print(point_items)
    pv_threshold = img.threshold(value=1,
                                 scalars='MetaImage')

    surf = pv.DataSetFilters.extract_surface(pv_threshold)
    points = surf.points
    print(len(points))
    normals = np.array(surf.point_normals, dtype=np.float32)
    values = np.ones(np.shape(points)[0], dtype=np.float32)
    # values = np.array(surf.point_data['vtkOriginalPointIds'], dtype=np.float32)
    plt.figure()
    plt.scatter(points[:, 0], points[:, 1], s=1)
    plt.show()
    plt.close()


#     points = np.array(points, dtype=np.float32)
#     v, n, c = points, normals, values

#     # Size of the axis aligned bounding box of the point cloud
#     bbox_size = v.max(0) - v.min(0)

#     # The size per-axis of a single voxel
#     sizeof_voxel = bbox_size / num_voxels_per_axis

#     # Downsample a point cloud on a voxel grid so there is at most one point per voxel.
#     # Any arguments after the points are treated as attribute arrays and get averaged within each voxel
#     v_sampled, n_sampled, c_sampled = pcu.downsample_point_cloud_on_voxel_grid(sizeof_voxel, v, n, c)

#     # Ensure the sampled points have the shape (num_landmarks, 3)
#     if v_sampled.shape[0] > num_landmarks:
#         indices = np.random.choice(v_sampled.shape[0], num_landmarks, replace=False)
#         v_sampled = v_sampled[indices]
#     elif v_sampled.shape[0] < num_landmarks:
#         padding = np.zeros((num_landmarks - v_sampled.shape[0], 3), dtype=np.float32)
#         v_sampled = np.vstack((v_sampled, padding))

#     landmarks_list.append(v_sampled)

# landmark_coordinates = np.array(landmarks_list)

In [None]:
ssm_obj = pyssam.SSM(landmark_coordinates)
ssm_obj.create_pca_model(ssm_obj.landmarks_columns_scale)
mean_shape_columnvector = ssm_obj.compute_dataset_mean()
mean_shape = mean_shape_columnvector.reshape(-1, 3)
shape_model_components = ssm_obj.pca_model_components

In [None]:
# Define some plotting functions

def plot_cumulative_variance(explained_variance, target_variance=-1):
    number_of_components = np.arange(0, len(explained_variance))+1
    fig, ax = plt.subplots(1,1)
    color = "blue"
    ax.plot(number_of_components, explained_variance*100.0, marker="o", ms=2, color=color, mec=color, mfc=color)
    if target_variance > 0.0:
        ax.axhline(target_variance*100.0)
    
    ax.set_ylabel("Variance [%]")
    ax.set_xlabel("Number of components")
    ax.grid(axis="x")
    plt.show()
    
def plot_shape_modes(
  mean_shape_columnvector, 
  mean_shape, 
  original_shape_parameter_vector,
  shape_model_components,
  mode_to_plot,
):
  weights = [-2, 0, 2]
  fig, ax = plt.subplots(1, 3)
  for j, weights_i in enumerate(weights):
    shape_parameter_vector = copy(original_shape_parameter_vector)
    shape_parameter_vector[mode_to_plot] = weights_i
    mode_i_coords = ssm_obj.morph_model(
        mean_shape_columnvector, 
        shape_model_components, 
        shape_parameter_vector
    ).reshape(-1, 3)

    offset_dist = pyssam.utils.euclidean_distance(
      mean_shape, 
      mode_i_coords
    )
    # colour points blue if closer to point cloud centre than mean shape
    mean_shape_dist_from_centre = pyssam.utils.euclidean_distance(
      mean_shape,
      np.zeros(3),
    )
    mode_i_dist_from_centre = pyssam.utils.euclidean_distance(
      mode_i_coords,
      np.zeros(3),
    )
    offset_dist = np.where(
        mode_i_dist_from_centre<mean_shape_dist_from_centre,
        offset_dist*-1,
        offset_dist,
    )
    if weights_i == 0:
      ax[j].scatter(
        mode_i_coords[:, 0],
        mode_i_coords[:, 2],
        c="gray",
        s=1,
      )
      ax[j].set_title("mean shape")
    else:
      ax[j].scatter(
        mode_i_coords[:, 0],
        mode_i_coords[:, 2],
        c=offset_dist,
        cmap="seismic",
        vmin=-1,
        vmax=1,
        s=1,
      )
      ax[j].set_title(f"mode {mode_to_plot} \nweight {weights_i}")
    ax[j].axis('off')
    ax[j].margins(0,0)
    ax[j].xaxis.set_major_locator(plt.NullLocator())
    ax[j].yaxis.set_major_locator(plt.NullLocator())

  plt.show()

In [None]:
print(f"To obtain {ssm_obj.desired_variance*100}% variance, {ssm_obj.required_mode_number} modes are required")
plot_cumulative_variance(np.cumsum(ssm_obj.pca_object.explained_variance_ratio_), 0.9)

In [None]:
mode_to_plot = 1
print(f"explained variance is {ssm_obj.pca_object.explained_variance_ratio_[mode_to_plot]}")

plot_shape_modes(
    mean_shape_columnvector, 
    mean_shape, 
    ssm_obj.model_parameters,
    ssm_obj.pca_model_components,
    mode_to_plot,
)

In [None]:
# convert mean shape to pyvista mesh
mean_shape_mesh = pv.PolyData(mean_shape)
mean_shape_mesh.plot(notebook=False)