In [None]:
import pyvista as pv
import numpy as np
import os
import sys; sys.path.append("..")

In [None]:
from utils.VTKHelpers.CardiacMesh import Cardiac3DMesh, Cardiac4DMesh

In [None]:
import vedo

In [None]:
import vtk
import numpy as np
import os
import meshio  # tested with 2.3.0

import sys
from Constants import *
from tqdm import tqdm

import logging
import random

from IPython import embed  # For debugging

"""
This module is aimed to simplify the implementation of common tasks on VTK triangular meshes,
that result overly convoluted if the usual VTK Python wrapper for C++ is used,
and render the code difficult to follow.
"""


def set_logger(logger):
    return logger if logger is not None else logging.getLogger()

In [None]:
class Cardiac4DMesh:

    """
    Class representing a collection of cardiac meshes for one individual, across the cardiac cycle.
    Public attributes:
      meshes:
      triangles:
      time_frames: list of integers
      subjectID
      LVEDV, LVESV, LVEF, LVSV, LVM
    """

    def __init__(self, root_folder, time_frames=None, logger=None):

        """
        root_folder: path to the PyCardioX output for the given individual
        time_frames: list|tuple containing 1-50 and/or "ED"/"ES"
        """

        self._root_folder = root_folder
        self._logger = set_logger(logger)

        if time_frames is None:
            self.time_frames = [i + 1 for i in range(50)]
        else:
            self.time_frames = time_frames
        self._time_frames_to_int()

        self._vtk_paths = self._get_vtk_paths()
        self._load_meshes()

    # def _set_logger(self, logger):
    #    return logger if logger is not None else logging.getLogger()

    @property
    def subjectID(self):
        """
        This method assumes that the subject ID is the name of the rightmost folder.
        """
        self._subjectID = os.path.basename(self._root_folder.strip("/"))
        return self._subjectID

    def _load_meshes(self, load_connectivity_flag=True):
        self.meshes = []
        for i, vtk_path in enumerate(self._vtk_paths):
            if i == 0 and load_connectivity_flag:
                # For efficiency reasons, connectivity is loaded only for the first mesh and copied (as a reference) to the rest.
                mesh = Cardiac3DMesh(vtk_path, load_connectivity_flag=True)
                self.meshes.append(mesh)
            else:
                mesh = Cardiac3DMesh(vtk_path, load_connectivity_flag=False)
                self.meshes.append(mesh)
                if load_connectivity_flag:
                    self.meshes[i].triangles = self.meshes[0].triangles

        if load_connectivity_flag:
            self.triangles = self.meshes[0].triangles

    def as_numpy_array(self):
        """ """
        kk = [x.points for x in self.meshes]
        try:
            kk = np.stack(kk, axis=0)
        except:
            # Handle this error better
            # embed()
            # raise ValueError(
            self.logger.error(
                """
            Not possible to create Numpy array for individual {id}. \
            The folder is likely to be incomplete.
            """.format(
                    id=self.subjectID
                )
            )
            raise ValueError

        return kk

    
    def _get_vtk_paths(self):

        # TODO: provide file pattern as argument to constructor
        # The current (hardcoded) file pattern is the one used by the SPASM output.
        fp = os.path.join(
            self._root_folder, "models/fhm_time{time_frame}.vtk"
        )
        return [fp.format(time_frame=x) for x in self._time_frames_as_path]

    
    def _time_frames_to_int(self):
        """
        Convert cardiac phases like "ED" and "ES" (for end-diastole and end-systole) to the corresponding integer indices
        :return: None
        """
        self._time_frames_dict = {t: t for t in self.time_frames}
        self._time_frames_dict["ED"] = 1
        self._time_frames_dict["ES"] = self.ES_time_frame
        self.time_frames = [self._time_frames_dict[t] for t in self.time_frames]

    def __repr__(self):
        return "Time series of {} meshes (class {}) for subject {}".format(
            len(self.meshes), self.meshes[0].__class__.__name__, self.subjectID
        )

    def __getitem__(self, timeframe):
        return self.meshes[self._time_frames_dict[timeframe]]

    @property
    # TODO: TEST
    def ES_time_frame(self):
        try:
            with open(os.path.join(self._root_folder, "ES_time_step.csv"), "rt") as ff:
                self._ES_time_frame = int(ff.read().strip())
            return self._ES_time_frame
        except:
            return None

    @property
    # TODO: TEST
    def LVEF(self):
        try:
            with open(os.path.join(self._root_folder, "Ejection_fraction.csv"), "rt") as ff:
                self._LVEF = float(ff.read().strip())
            return self._LVEF
        except:
            return None

    @property
    # TODO: TEST
    def LVSV(self):
        try:
            with open(os.path.join(self._root_folder, "Stroke_volume.csv"), "rt") as ff:
                self._LVSV = float(ff.read().strip())
            return self._LVSV
        except:
            return None

    @property
    def _time_frames_as_path(self):
        return [
            "0" * (3 - len(str(t))) + str(t) for t in self.time_frames
        ]  # "001", "002", ..., "050"

    def generate_gif(self, gif_path, paraview_config):
        """
        Generate a GIF file showing the moving mesh, using Paraview.
        gif_path: path to the GIF output file
        paraview_config: object representing the Paraview config (specify what's needed)
        """
        raise NotImplementedError

In [None]:
def render_mesh_as_png(mesh3D, faces, filename, camera_position='xy', show_edges=False, **kwargs):

        '''  
        Produces a png file representing a static 3D mesh.
        - params
        ::mesh3D:: a sequence of Trimesh mesh objects. 
        ::faces:: array of F x 3 containing the indices of the mesh's triangular faces.
        ::filename:: the name of the output png file. 
        ::camera_position:: camera position for pyvista plotter (check relevant docs)
        
        - return:  
        None, only produces the png file. 
        '''

        pv.set_plot_theme("document")
        plotter = pv.Plotter(off_screen=True, notebook=False)
        connectivity = np.c_[np.ones(faces.shape[0]) * 3, faces].astype(int)

        try:
            # if mesh3D is torch.Tensor, this your should run OK
            mesh3D = mesh3D.cpu().numpy()
        except:
            pass

        mesh = pv.PolyData(mesh3D, connectivity)
        actor = plotter.add_mesh(mesh, show_edges=show_edges)
        plotter.camera_position = camera_position
        plotter.screenshot(filename if filename.endswith("png") else filename + ".png")


def generate_gif(mesh4D, faces, filename, camera_position='xy', show_edges=False, **kwargs):
    
    '''
    Produces a gif file representing the motion of the input mesh.
    
    params:
      ::mesh4D:: a sequence of Trimesh mesh objects.
      ::faces:: array of F x 3 containing the indices of the mesh's triangular faces.
      ::filename:: the name of the output gif file.
      ::camera_position:: camera position for pyvista plotter (check relevant docs)
      
    return:
      None, only produces the gif file.
    '''

    connectivity = np.c_[np.ones(faces.shape[0]) * 3, faces].astype(int)

    pv.set_plot_theme("document")
    os.makedirs(os.path.dirname("./"+filename) , exist_ok=True)
    
    # plotter = pv.Plotter(shape=(1, len(camera_positions)), notebook=False, off_screen=True)
    plotter = pv.Plotter(notebook=False, off_screen=True)
        
    # Open a gif
    plotter.open_gif(filename)

    try:
        # if mesh3D is torch.Tensor, this your should run OK
        mesh4D = mesh4D.cpu().numpy()[0]
    except:
        pass

    kk = pv.PolyData(mesh4D.meshes[0].v, connectivity)
    # plotter.add_mesh(kk, smooth_shading=True, opacity=0.5 )#, show_edges=True)
    plotter.add_mesh(kk, show_edges=show_edges) 
    
    for t in range(1, len(mesh4D.meshes)):        
        kk = pv.PolyData(mesh4D.meshes[t].v, connectivity)                  
        plotter.camera_position = camera_position
        plotter.update_coordinates(kk.points, render=False)
        plotter.render()             
        plotter.write_frame()
    
    plotter.close()

In [None]:
ID = "1000215"
mesh = Cardiac3DMesh(f"/home/rodrigo/01_repos/CardiacMotion/data/cached/cardio/Results/{ID}/models/fhm_time010.vtk")
# mesh4D = Cardiac4DMesh(root_folder=f"/home/rodrigo/01_repos/CardiacMotion/data/cached/cardio/Results/{ID}")

In [None]:
# render_mesh_as_png(mesh.v, mesh.f, filename=f"{ID}.png")

In [None]:
# generate_gif(mesh4D, mesh4D.triangles, f"{ID}.gif")

In [None]:
n_vert = {part: mesh[part].v.shape[0] for part in mesh.distinct_subparts}
n_vert = {x: n_vert[x] for x in sorted(n_vert, key=lambda x: n_vert[x], reverse=True)}

In [None]:
mesh_vedo = vedo.mesh.Mesh([mesh["LV"].v, mesh["LV"].f])
mesh_decimated = mesh_vedo.decimate(fraction=0.05)
mesh_decimated.points()