In [124]:
import os
import torch
import numpy as np
import pickle as pkl

from utils.VTKHelpers.CardiacMesh import CardiacMeshPopulation, Cardiac3DMesh
from torch.utils.data import TensorDataset, DataLoader, random_split
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union, Dict

import pytorch_lightning as pl
import logging
from argparse import Namespace

In [126]:
def get_3d_mesh(ids, root_folder):
    
    for id in ids:
        
        try:
             npy_file = f"{root_folder}/{id}/models/LV_time001.npy"
             pc  = np.load(npy_file)
             yield id, pc
        except FileNotFoundError:
             continue        
        
        
def load_procrustes_transforms(filename):    
    return pkl.load(open(filename, "rb"))    

In [163]:
class CardiacMeshPopulationDataset(TensorDataset):

    def __init__(
       self, 
       meshes: Union[Mapping[str, np.array]],
       procrustes_transforms: Union[str, Mapping[str, Dict], None] = None,
       context=Namespace(logger=logging.getLogger())
    ):
           
       if not isinstance(meshes, dict):                       
           raise(f"Argument should be a dictionary but is a {type(meshes)}")
           
       self.ids = list(meshes.keys())
       self.meshes = np.array(list(meshes.values()))            
        
       if isinstance(procrustes_transforms, str):
           procrustes_transforms = load_procrustes_transforms(procrustes_transforms)
       
       procrustes_transforms = { id: procrustes_transforms.get(id, {"discard": True}) for id in ids }
                        
       for id in self.ids:
           idx = self.ids.index(id)
           self.meshes[idx] = self.transform_mesh(self.meshes[idx], **procrustes_transforms[id])
       
       self.meshes = torch.Tensor(self.meshes)
        
       self._data_dict = { 
            self.ids[i]:self.meshes[i] for i, _ in enumerate(self.meshes) 
       }
    
    def __getitem__(self, id):
       if isinstance(id, int):
           return self._data_dict[self.ids[id]]
       elif isinstance(id, str):
           try:
                return self._data_dict[id]
           except:
                return None
       
    def __len__(self):
       return len(self.ids)        
        
    
    def transform_mesh(self, mesh, rotation: Union[None, np.array] = None, traslation: Union[None, np.array] = None, discard=False):
        
        if discard:
            return None
        
        if traslation is not None:
            mesh = mesh - traslation
            
        if rotation is not None:
            centroid = mesh.mean(axis=0)
            mesh -= centroid
            mesh = mesh.dot(rotation)
            mesh += centroid
            
        return mesh 
    
    
class CardiacMeshPopulationDM(pl.LightningDataModule):    
    
    '''
    PyTorch datamodule wrapping the CardiacMeshPopulation class
    '''
    
    def __init__(self, 
#        data_dir: Union[None, str] = None, 
        cardiac_population: Union[Mapping[str, np.array], CardiacMeshPopulation, None] = None, 
        procrustes_transforms: Union[str, Mapping[str, Dict], None] = None,
        batch_size: int = 32,
        split_lengths: Union[None, List[int]]=None
    ):

        '''
        params:
            data_dir:
            batch_size:
            split_lengths:
        '''
        
        super().__init__()
        
        # self.data_dir = data_dir
        self.cardiac_population = cardiac_population
        self.procrustes_transforms = procrustes_transforms

        self.batch_size = batch_size
        self.split_lengths = None


    def setup(self, stage: Optional[str] = None):

        popu = CardiacMeshPopulationDataset(
            meshes=self.cardiac_population,
            procrustes_transforms=self.procrustes_transforms
        )

        if self.split_lengths is None:
            train_len = int(0.6 * len(popu))
            test_len = int(0.2 * len(popu))
            val_len = len(popu) - train_len - test_len
            self.split_lengths = [train_len, val_len, test_len]

        self.train_dataset, self.val_dataset, self.test_dataset = random_split(popu, self.split_lengths)        


    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=2, num_workers=8)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=1, num_workers=8)    

In [201]:
# root_folder = "data/cardio/Results"
root_folder = "/home/rodrigo/CISTIB/UKBB/data/meshes/FHM/Results"
PROCRUSTES_FILENAME = "utils/VTKHelpers/procrustes_transforms_35k.pkl"

ids = os.listdir(root_folder)[:1000]
meshes = {id: mesh for id, mesh in get_3d_mesh(ids, root_folder)}

In [176]:
# mm = CardiacMeshPopulationDataset(meshes, procrustes_transforms)

In [200]:
dm = CardiacMeshPopulationDM(
    cardiac_population = meshes,
    procrustes_transforms=PROCRUSTES_FILENAME
)

In [194]:
dm.setup()