In [7]:
import cv2 as cv
import SimpleITK as sitk
import numpy as np
from numpy import inf
import pandas as pd
import glob
import os
import skimage.morphology as m
from skimage import measure

import seaborn as sns
import codecs, json 
from ipywidgets import IntProgress
from collections import defaultdict
from operator import itemgetter

import gudhi as gd
import gudhi.representations
import matplotlib.pyplot as plt
from tqdm import tqdm

from sklearn.svm import LinearSVC
import sklearn.linear_model as lm
from sklearn import metrics
from sklearn.model_selection import cross_val_score, train_test_split, ShuffleSplit

import sys
sys.path.append(r'../CT_scan_oa')

In [8]:
def load_meta_img(fn):
    reader =sitk.ImageFileReader()
    reader.SetImageIO("MetaImageIO")
    reader.SetFileName(fn)
    image = reader.Execute()
    img = list(np.int16(sitk.GetArrayViewFromImage(image)))
    return img

def mask_to_boundary_pts(mask):
    for i, slide in enumerate(mask):
        if len(np.unique(slide)) == 1: # check the whether the slide contains the patella mask  
            pass
        elif len(np.unique(slide)) == 2:
        boundary_pts = measure.find_contours(slide)[0]
    return boundary_pts

In [9]:
fn = '../CT_scan_oa/Mako 001/ST0/SE1_mask.mha'
mask = load_meta_img(fn)

In [15]:
len(np.unique(mask[250]))

2

In [13]:
len(mask)

390

In [None]:
def tda_feat_extract(shape_data, repre_model_params:dict):
    X_tda = []
    for ID in tqdm(list(shape_data)):
        
        # Calculating single tda representation vector
        def tda_single_feat(repre, multi_pers, kth_pers):
            acX = gd.AlphaComplex(points=shape_data[ID]).create_simplex_tree()
            dgmX = acX.persistence()
            gd.plot_persistence_diagram(dgmX)
            
            if repre == 'poly':
                CP = gd.representations.vector_methods.ComplexPolynomial()
            elif repre == 'sil':
                CP = gd.representations.vector_methods.Silhouette(resolution=100)
            elif repre == 'entropy':
                CP = gd.representations.vector_methods.Entropy(mode='scalar')
            elif repre == 'landscape':
                CP = gd.representations.vector_methods.Landscape() 
            elif repre == 'pi':
                CP = gd.representations.vector_methods.PersistenceImage(bandwidth=1.0)

            if multi_pers == True:
                persistence_0th = acX.persistence_intervals_in_dimension(0)
                persistence_1st = acX.persistence_intervals_in_dimension(1)
                persistence_0th[persistence_0th ==inf] = 0
                persistence_1st[persistence_1st ==inf] = 0

                # Representation of 0th persistence
                CP.fit([persistence_0th])
                cp_0 = CP.transform([persistence_0th])
                cp_0 = cp_0.real.flatten()

                # Representation of 1st persistence
                CP.fit([persistence_1st])
                cp_1 = CP.transform([persistence_1st])
                cp_1 = cp_1.real.flatten()

                single_tda = np.hstack((cp_0, cp_1)).flatten()

            else:
                persistence = acX.persistence_intervals_in_dimension(kth_pers)
                persistence[persistence == inf] = 0
                CP.fit([persistence])
                cp = CP.transform([persistence])
                single_tda = cp.real.flatten()
                
            return single_tda
        
        # Stack all tda representations horizontally to form a multi-tda vector
        if len(repre_model_params['representation']) > 1:
            b = []
            for repre in list(zip(*repre_model_params.values())):
                a = tda_single_feat(*repre)
                b.append(a)
            multi_tda = np.hstack(b)
            X_tda.append(multi_tda)
            
        else:
            single_tda = tda_single_feat(*list(zip(*repre_model_params.values()))[0])
            X_tda.append(single_tda)
            
    return X_tda

repre_model_params = {'representation': ['sil', 'entropy'],
                      'multi_pers': [False, False],
                      'kth_pers':[0, 1]}