In [None]:
import gudhi
import numpy as np
import pandas as pd
import scipy.io
import matplotlib.pyplot as plt
from gudhi.representations import Landscape
import gudhi.representations
from itertools import combinations
import seaborn as sns
from gudhi.hera import wasserstein_distance
from gudhi.hera import bottleneck_distance
from collections import defaultdict
from scipy.interpolate import interp1d
from scipy.stats import ks_2samp
from scipy.spatial.distance import pdist

### this currently does not handle the middle lobe, just the general process bc two inputs

In [None]:
# Data extractor functions
def connectivityExtractor(name, pruned):
    if pruned == 0:
        file_path = 'Networks/Network_Vessels_' + name +'.mat'
    elif pruned == 1: 
        file_path = 'Pruned/Pruned_Network_' + name +'.mat'
    matlab_data = scipy.io.loadmat(file_path)
    # Extract the 'connectivity' field from the 'Data' structured array
    data_structure = matlab_data['Data']
    connectivity_raw = data_structure['connectivity'][0, 0]  # Access the data (adjust indexing if needed)
    # Reshape or ensure it's a proper 2D array (if required)
    connectivity_data = connectivity_raw.squeeze()
    # Create a DataFrame from the connectivity data
    connectivity_df = pd.DataFrame(connectivity_data, columns=['Parent', 'Daughter1', 'Daughter2', 'Daughter3'])
    connectivity_df.replace(0, np.nan, inplace=True) #ensure all nonexistent vessels have NaN
    connectivity_df.at[0,'Parent']=0 #make sure first vessel is 0 (purposefully removed in last step for ease)
    # Save the DataFrame to inspect it
    return connectivity_df

def nodesExtractor(name, pruned): #extracts nodes and their corresponding information
    if pruned == 0:
        file_path = 'Networks/Network_Vessels_' + name +'.mat'
    elif pruned == 1: 
        file_path = 'Pruned/Pruned_Network_' + name +'.mat'
    matlab_data = scipy.io.loadmat(file_path)
    # Extract the 'connectivity' field from the 'Data' structured array
    data_structure = matlab_data['nodesC2']
    # Reshape or ensure it's a proper 2D array (if required)
    nodes_data = data_structure.squeeze()
    # Create a DataFrame from the connectivity data
    nodes_df = pd.DataFrame(nodes_data, columns=['NodeID', 'X', 'Y', 'Z', 'Degree'])
    # Save the DataFrame to inspect it
    return nodes_df

def edgesExtractor(name, pruned): #extracts segments to create a dataframe of from and to nodes
    if pruned == 0:
        file_path = 'Networks/Network_Vessels_' + name +'.mat'
    elif pruned == 1: 
        file_path = 'Pruned/Pruned_Network_' + name +'.mat'
    matlab_data = scipy.io.loadmat(file_path)
    # Extract the 'segments' field
    data_structure = matlab_data['segments']
    # Reshape or ensure it's a proper 2D array (if required)
    edges_data = data_structure.squeeze()
    # Create a DataFrame from the connectivity data
    edge_df = pd.DataFrame(edges_data, columns=['ID', 'From', 'To'])
    # Save the DataFrame to inspect it
    return edge_df
    
def findInputVessel(segments,fromnode,to):
    vessel = segments[((segments['From'] == fromnode)&(segments['To']==to))|((segments['From'] == to)&(segments['To']==fromnode))]
    return int(vessel['ID'].iloc[0])

def mapIDExtractor(name, pruned):
    if pruned == 0:
        file_path = 'Networks/Network_Vessels_' + name +'.mat'
    elif pruned == 1: 
        file_path = 'Pruned/Pruned_Network_' + name +'.mat'
    matlab_data = scipy.io.loadmat(file_path)
    # Extract the 'mapID' field from the 'Data' structured array
    data_structure = matlab_data['Data']
    map_raw = data_structure['mapIDs'][0, 0]  # Access the data (adjust indexing if needed)
    # Reshape or ensure it's a proper 2D array (if required)
    map_data = map_raw.squeeze()
    # Create a DataFrame from the connectivity data
    map_df = pd.DataFrame(map_data, columns=['New', 'Old'])
    # Save the DataFrame to inspect it
    return map_df

def lobeExtractor(name, vesID,pruned):
    data = connectivityExtractor(name,pruned)
    
    tree = defaultdict(list)
    for _,row in data.iterrows():
        parent = row['Parent']
        for daughter_col in ['Daughter1','Daughter2','Daughter3']:
            daughter = row[daughter_col]
            if pd.notna(daughter):
                tree[parent].append(daughter)

    visited = set()
    queue = deque([vesID])

    while queue:
        current = queue.popleft()
        if current not in visited:
            visited.add(current)
            queue.extend(tree.get(current,[]))
    
    visited.discard(vesID)  # Remove vesID from visited
    downstream_df = data[data['Parent'].isin(visited)]
    return downstream_df

def node_loc(name,lobe_nodes,pruned):
    nodes = nodesExtractor(name,pruned)
    lobe = nodes[nodes['NodeID'].isin(lobe_nodes)]
    return lobe[['X','Y','Z']]

def lobeTermLoc(name,fromnode,tonode,pruned):
    segments = edgesExtractor(name, pruned)
    maps = mapIDExtractor(name, pruned)
    vesID = findInputVessel(segments,fromnode,tonode, pruned)
    newID = int(maps[maps['Old']==vesID]['New'].iloc[0])
    lobe_ves = lobeExtractor(name,newID)
    new_lobe_ves_ID = lobe_ves['Parent'].to_numpy()
    oldID = maps[maps['New'].isin(new_lobe_ves_ID)]['Old'].to_numpy()
    fromnodes = segments[segments['ID'].isin(oldID)]['From'].to_numpy()
    tonodes = segments[segments['ID'].isin(oldID)]['To'].to_numpy()
    lobe_nodes = np.unique(np.concatenate((fromnodes,tonodes))).astype(int)
    lobe_node_loc = node_loc(name,lobe_nodes)/1000
    #term_nodes = term_nodes_loc(name,lobe_nodes)
    return lobe_node_loc

In [None]:
#TDA functions
def compute_persistence(points):
    """
    Compute persistence diagram and track the last death in each dimension.
    Excludes infinite values in H1 and H2.
    """
    # --- Build alpha complex ---
    alpha_complex = gudhi.AlphaComplex(points=points)
    simplex_tree = alpha_complex.create_simplex_tree()
    simplex_tree.compute_persistence()
    
    # --- Collect persistence pairs ---
    persistence_pairs = simplex_tree.persistence()
    
    # DIAGNOSTIC: Check for infinite bars
    for dim, (birth, death) in persistence_pairs:
        if np.isinf(death):
            print(f"WARNING: Infinite bar in H{dim}: birth={birth}, death=âˆž")
    
    # --- Filter infinite bars for H1/H2 ---
    diag = []
    for dim, (birth, death) in persistence_pairs:
        diag.append((dim, (birth, death)))
    
    # --- Track last death in each dimension ---
    last_deaths = {}
    for dim, (birth, death) in diag:
        if np.isfinite(death):
            if dim not in last_deaths:
                last_deaths[dim] = death
            else:
                last_deaths[dim] = max(last_deaths[dim], death)
    
    return diag, last_deaths

def average_curves(curve_dicts, resolution=100000):
    """
    Generic averaging function for any type of curve dicts (Betti, lifespan, etc.)
    curve_dicts: list of dicts {dim: (grid, curve)}
    Returns: dict {dim: (common_grid, avg_curve)}
    """
    average = {}

    dims = set().union(*(d.keys() for d in curve_dicts))

    for dim in dims:
        all_grids, all_curves = [], []
        for cdict in curve_dicts:
            if dim not in cdict:
                continue
            grid, curve = cdict[dim]
            if len(grid) == 0:
                continue
            all_grids.append(grid)
            all_curves.append(curve)

        if not all_grids:
            average[dim] = (np.array([]), np.zeros(resolution))
            continue

        global_min = min(g[0] for g in all_grids)
        global_max = max(g[-1] for g in all_grids)
        common_grid = np.linspace(global_min, global_max, resolution)

        interpolated = []
        for grid, curve in zip(all_grids, all_curves):
            f = interp1d(grid, curve, bounds_error=False, fill_value=0)
            interpolated.append(f(common_grid))

        avg_curve = np.mean(interpolated, axis=0)
        average[dim] = (common_grid, avg_curve)

    return average

def plot_curves(curves_dicts, labels, dimension, title='Betti Curves Comparison', limit = [0,100]):
    """
    curves_dicts: list of betti_curve dicts {dim: (grid, curve)}
    labels: list of labels for each curve
    dimension: the Betti number (int) to plot
    """
    plt.figure(figsize=(10, 6))

    colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'cyan', 'magenta']
    linestyles = ['-', '--', '-.', ':']

    for idx, (curve_dict, label) in enumerate(zip(curves_dicts, labels)):
        if dimension in curve_dict:
            grid, curve = curve_dict[dimension]
            if len(grid) > 0:
                color = colors[idx % len(colors)]
                linestyle = linestyles[idx % len(linestyles)]
                plt.plot(grid, curve, color=color, linestyle=linestyle, label=f'{label} (Dimension-{dimension})')
        else:
            print(f"Warning: {label} does not contain Betti-{dimension}")

    plt.title(f'{title} (Betti-{dimension})', fontsize=28)
    plt.xlabel('Filtration Value', fontsize=20)
    plt.ylabel(f'Betti-{dimension}', fontsize=20)
    plt.xlim(limit)
    plt.legend(fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.grid(False)
    plt.tight_layout()
    plt.show()

def save_curves(curves_dicts, labels, dimension, filename, title='Betti Curves Comparison', limit=[0, 100], dpi=300):
    """
    curves_dicts: list of betti_curve dicts {dim: (grid, curve)}
    labels: list of labels for each curve
    dimension: the Betti number (int) to plot
    filename: path to save the plot (e.g., 'output.png' or 'output.jpg')
    dpi: resolution of saved image (default 300)
    """
    plt.figure(figsize=(10, 6))
    colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'cyan', 'magenta']
    linestyles = ['-', '--', '-.', ':']
    
    for idx, (curve_dict, label) in enumerate(zip(curves_dicts, labels)):
        if dimension in curve_dict:
            grid, curve = curve_dict[dimension]
            if len(grid) > 0:
                color = colors[idx % len(colors)]
                linestyle = linestyles[idx % len(linestyles)]
                plt.plot(grid, curve, color=color, linestyle=linestyle, label=f'{label} (Dimension-{dimension})')
        else:
            print(f"Warning: {label} does not contain Betti-{dimension}")
    
    plt.title(f'{title} (Betti-{dimension})', fontsize=28)
    plt.xlabel('Filtration Value', fontsize=20)
    plt.ylabel(f'Betti-{dimension}', fontsize=20)
    plt.xlim(limit)
    plt.legend(fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.grid(False)
    plt.tight_layout()
    plt.savefig(filename, dpi=dpi, bbox_inches='tight')
    plt.close()  # Close the figure to free memory
    
    print(f"Saved plot to {filename}")

In [None]:
pressure = '1'
lobe = 'left'

middle
p1
m1053007 = [1841, 1945, 1841, 1964]
m2053007 = [2463, 2497, 2502, 2501]
m1053107 = [4179, 4304, 4018, 4379]
m2053107 = [2920, 2921, 3150, 3149]
m1060107 = [2766, 2744, 2992, 1332]
m1060407 = [2456, 2459, 2455, 1061]
m2060407 = [2963, 2568, 2964, 2995]
m3060407 = [2509, 2613, 2507, 408]
m1060507 = [1992, 1929, 2030, 2065]
m2060507 = [2390, 2227, 2284, 2283]
m3060507 = [2624, 2377, 2568, 1365]
m2060607 = [1916, 2052, 2057, 1991]
m3060607 = [2604, 2466, 2677, 2607]
p2
m1053007 = [1841, 1957, 1841, 2025]
m2053007 = [1831, 1854, 1908, 1919]
m1053107 = [2738, 2739, 2959, 1567]
m2053107 = [1645, 1644, 1642, 1641]
m1060107 = [2164, 2098, 2153, 927]
m1060407 = [1289, 1269, 1253, 1252]
m2060407 = [1464, 1466, 1465, 592]
m3060407 = [1907, 1962, 1939, 337]
m1060507 = [1322, 1400, 1320, 1425]
m2060507 = [1565, 1566, 1578, 1577]
m3060507 = [1732, 1589, 1749, 1109]
m2060607 = [1441, 1221, 1460, 221]
m3060607 = [1925, 1946, 2002, 2094]
p3
m1053007 = [1721, 1720, 1721, 95]
m2053007 = [1731, 1737, 1682, 1733]
m1053107 = [3103, 2859, 2892, 844]
m2053107 = [1621, 1622, 1714, 1709]
m1060107 = [1651, 1727, 1611, 1610]
m1060407 = [1110, 1109, 1140, 426]
m2060407 = [854, 283, 891, 50]
m3060407 = [968, 938, 967, 976]
m1060507 = [1470, 1475, 1469, 1370]
m2060507 = [1041, 1121, 1042, 1163]
m3060507 = [1000, 451, 1062, 1016]
m2060607 = [607, 595, 616, 640]
m3060607 = [10, 11]
p4
m1053007 = [10, 11]
m2053007 = [10, 11]
m1053107 = [10, 11]
m2053107 = [10, 11]
m1060107 = [10, 11]
m1060407 = [10, 11]
m2060407 = [10, 11]
m3060407 = [10, 11]
m1060507 = [10, 11]
m2060507 = [10, 11]
m3060507 = [10, 11]
m2060607 = [10, 11]
m3060607 = [10, 11]

In [None]:
if pressure == '1':
    if lobe == 'left':
        m1053007 = [1831, 1858]
        m2053007 = [2367, 2368]
        m1053107 = [3954, 2900]
        m2053107 = [2866, 2868]
        m1060107 = [190, 2722]
        m1060407 = [2272, 2273]
        m2060407 = [2776, 2774]
        m3060407 = [2473, 2472]
        m1060507 = [2121, 2040]
        m2060507 = [2257, 2258]
        m3060507 = [692, 2524]
        m2060607 = [1475, 1997]
        m3060607 = [53, 2576]
    if lobe == 'superior':
        m1053007 = [1836, 1835]
        m2053007 = [2464, 2406]
        m1053107 = [4071, 685]
        m2053107 = [2867, 2979]
        m1060107 = [2780, 2716], [2780, 2781] #won't run
        m1060407 = [2274, 2283]
        m2060407 = [2742, 2598]
        m3060407 = [2418, 2419], [2420, 2491]
        m1060507 = [1993, 1997]
        m2060507 = [2259, 2330]
        m3060507 = [2392, 2398]
        m2060607 = [1914, 1915]
        m3060607 = [2603, 2518]
    if lobe == 'inferior':
        m1053007 = [1841, 1839]
        m2053007 = [2502, 2501]
        m1053107 = [4018, 4019]
        m2053107 = [3150, 3170]
        m1060107 = [2992, 2895]
        m1060407 = [2455, 2394]
        m2060407 = [2964, 2727]
        m3060407 = [2507, 2508]
        m1060507 = [2030, 2027]
        m2060507 = [2391, 2423]
        m3060507 = [2537, 1693]
        m2060607 = [2057, 2056]
        m3060607 = [2677, 2676]
    if lobe == 'postcaval':
        m1053007 = [1864, 35]
        m2053007 = [2465, 2558]
        m1053107 = [4296, 1851]
        m2053107 = [2915, 2913]
        m1060107 = [2766, 2808]
        m1060407 = [2454, 2602]
        m2060407 = [2886, 2962]
        m3060407 = [2421, 2469]
        m1060507 = [1994, 2178]
        m2060507 = [2284, 2454]
        m3060507 = [2628, 2355]
        m2060607 = [1916, 1970]
        m3060607 = [2701, 2688]

In [None]:
if pressure == '2':
    if lobe == 'left':
        m1053007 = [1075, 1960]
        m2053007 = [217, 1868]
        m1053107 = [1506, 2897]
        m2053107 = [44, 1616]
        m1060107 = [2089, 2169]
        m1060407 = [1225, 2265]
        m2060407 = [1371, 1392]
        m3060407 = [1882, 1924]
        m1060507 = [118, 1381]
        m2060507 = [1509, 1519]
        m3060507 = [51, 1591]
        m2060607 = [1495, 1419]
        m3060607 = [535, 1916]
    if lobe == 'superior':
        m1053007 = [1829, 1828]
        m2053007 = [1829, 1830]
        m1053107 = [2909, 2895]
        m2053107 = [1640, 1693]
        m1060107 = [2162, 2163]
        m1060407 = [1266, 1239]
        m2060407 = [1387, 1405]
        m3060407 = [1883, 1890], [1880, 1881] #won't run
        m1060507 = [1468, 1454]
        m2060507 = [1528, 1552]
        m3060507 = [1706, 1705]
        m2060607 = [1440, 1426]
        m3060607 = [1923, 1924]
    if lobe == 'inferior':
        m1053007 = [1842, 2020]
        m2053007 = [1908, 1979]
        m1053107 = [2959, 2957]
        m2053107 = [1642, 1755]
        m1060107 = [2153, 2151]
        m1060407 = [1253, 1252]
        m2060407 = [1465, 1515]
        m3060407 = [1939, 1982]
        m1060507 = [1320, 1424]
        m2060507 = [1620, 1602]
        m3060507 = [1749, 1809]
        m2060607 = [1460, 1375]
        m3060607 = [2002, 1903]
    if lobe == 'postcaval':
        m1053007 = [1862, 2003]
        m2053007 = [1907, 1909]
        m1053107 = [2056, 1829]
        m2053107 = [1724, 1722]
        m1060107 = [2164, 2190]
        m1060407 = [1290, 1323]
        m2060407 = [1386, 1410]
        m3060407 = [1907, 1796]
        m1060507 = [1319, 1321]
        m2060507 = [1578, 1650]
        m3060507 = [1739, 1689]
        m2060607 = [1441, 1509]
        m3060607 = [1951, 1950]

In [None]:
if pressure == '3':
    if lobe == 'left':
        m1053007 = [1630, 1628]
        m2053007 = [1619, 1621]
        m1053107 = [246, 2864]
        m2053107 = [717, 1751]
        m1060107 = [318, 1645]
        m1060407 = [1122, 1121]
        m2060407 = [808, 814]
        m3060407 = [919, 921]
        m1060507 = [80, 1354]
        m2060507 = [1097, 1116]
        m3060507 = [191, 959]
        m2060607 = [653, 620]
        m3060607 = [10, 11]
    if lobe == 'superior':
        m1053007 = [1657, 1656]
        m2053007 = [1620, 1703]
        m1053107 = [2982, 2955]
        m2053107 = [1612, 1610]
        m1060107 = [1644, 1650]
        m1060407 = [1126, 1095]
        m2060407 = [862, 832]
        m3060407 = [969, 970], [920, 924] # won't run
        m1060507 = [1405, 1461]
        m2060507 = [1096, 1095]
        m3060507 = [952, 950]
        m2060607 = [615, 642]
        m3060607 = [10, 11]
    if lobe == 'inferior':
        m1053007 = [1721, 1624]
        m2053007 = [1682, 1683]
        m1053107 = [2892, 1313]
        m2053107 = [1714, 1715]
        m1060107 = [1611, 1757]
        m1060407 = [1140, 1157]
        m2060407 = [891, 844]
        m3060407 = [967, 975]
        m1060507 = [1469, 1488]
        m2060507 = [1040, 1043]
        m3060507 = [1062, 106]
        m2060607 = [616, 637]
        m3060607 = [10, 11]
    if lobe == 'postcaval':
        m1053007 = [1755, 1636]
        m2053007 = [1704, 1732]
        m1053107 = [3039, 1474]
        m2053107 = [1655, 1653]
        m1060107 = [1651, 29]
        m1060407 = [1134, 1141]
        m2060407 = [851, 853]
        m3060407 = [968, 1020]
        m1060507 = [1381, 1379]
        m2060507 = [1042, 1136]
        m3060507 = [10, 11]
        m2060607 = [607, 664]
        m3060607 = [10, 11]

In [None]:
if pressure == '4':
    if lobe == 'left':
        m1053007 = [10, 11]
        m2053007 = [10, 11]
        m1053107 = [10, 11]
        m2053107 = [10, 11]
        m1060107 = [10, 11]
        m1060407 = [10, 11]
        m2060407 = [10, 11]
        m3060407 = [10, 11]
        m1060507 = [10, 11]
        m2060507 = [10, 11]
        m3060507 = [10, 11]
        m2060607 = [10, 11]
        m3060607 = [10, 11]
    if lobe == 'superior':
        m1053007 = [10, 11]
        m2053007 = [10, 11]
        m1053107 = [10, 11]
        m2053107 = [10, 11]
        m1060107 = [10, 11]
        m1060407 = [10, 11]
        m2060407 = [10, 11]
        m3060407 = [10, 11]
        m1060507 = [10, 11]
        m2060507 = [10, 11]
        m3060507 = [10, 11]
        m2060607 = [10, 11]
        m3060607 = [10, 11]
    if lobe == 'inferior':
        m1053007 = [10, 11]
        m2053007 = [10, 11]
        m1053107 = [10, 11]
        m2053107 = [10, 11]
        m1060107 = [10, 11]
        m1060407 = [10, 11]
        m2060407 = [10, 11]
        m3060407 = [10, 11]
        m1060507 = [10, 11]
        m2060507 = [10, 11]
        m3060507 = [10, 11]
        m2060607 = [10, 11]
        m3060607 = [10, 11]
    if lobe == 'postcaval':
        m1053007 = [10, 11]
        m2053007 = [10, 11]
        m1053107 = [10, 11]
        m2053107 = [10, 11]
        m1060107 = [10, 11]
        m1060407 = [10, 11]
        m2060407 = [10, 11]
        m3060407 = [10, 11]
        m1060507 = [10, 11]
        m2060507 = [10, 11]
        m3060507 = [10, 11]
        m2060607 = [10, 11]
        m3060607 = [10, 11]

if middle

elif superior && pressure == 1

else

In [None]:
datasets = {
    'm1053007': lobeTermLoc('m1p'+pressure+'_053007',m1053007[0], m1053007[1], 1),
    'm2053007': lobeTermLoc('m2p'+pressure+'_053007',m2053007[0], m2053007[1], 1),
    'm1053107': lobeTermLoc('m1p'+pressure+'_053107',m1053107[0], m1053007[1], 1),
    'm2053107': lobeTermLoc('m2p'+pressure+'_053107',m2053107[0], m2053107[1], 1),
    'm1060107': lobeTermLoc('m1p'+pressure+'_060107',m1060107[0], m1060107[1], 1),
    'm1060407': lobeTermLoc('m1p'+pressure+'_060407',m1060407[0], m1060107[1], 0),
    'm2060407': lobeTermLoc('m2p'+pressure+'_060407',m2060407[0], m2060407[1], 0),
    'm3060407': lobeTermLoc('m3p'+pressure+'_060407',m3060407[0], m3060407[1], 0),
    'm1060507': lobeTermLoc('m1p'+pressure+'_060507',m1060507[0], m1060507[1], 0),
    'm2060507': lobeTermLoc('m2p'+pressure+'_060507',m2060507[0], m2060507[1], 0),
    'm3060507': lobeTermLoc('m3p'+pressure+'_060507',m3060507[0], m3060507[1], 0),
    'm2060607': lobeTermLoc('m2p'+pressure+'_060607',m2060607[0], m2060607[1], 0),
    'm3060607': lobeTermLoc('m3p'+pressure+'_060607',m3060607[0], m3060607[1], 0),
}

# Example usage across all datasets
persistence_results = {
    name: compute_persistence(points)
    for name, points in datasets.items()
}

# Separate results
persistence_diagrams = {name: res[0] for name, res in persistence_results.items()}
last_deaths = {name: res[1] for name, res in persistence_results.items()}

In [None]:
# Compute global max death across all datasets for each dimension
global_last_deaths = {}
for name, deaths_by_dim in last_deaths.items():
    for dim, death in deaths_by_dim.items():
        if dim not in global_last_deaths:
            global_last_deaths[dim] = death
        else:
            global_last_deaths[dim] = max(global_last_deaths[dim], death)

print("Global max deaths by dimension:")
for dim in sorted(global_last_deaths.keys()):
    print(f"  H{dim}: {global_last_deaths[dim]:.4f}")

# Betti Curve

In [None]:
def compute_betti_curve(diag, max_dim=2, resolution=100000, cutoff=None):
    betti_curves = {}
    
    # Determine global filtration range across ALL dimensions
    all_values = []
    for dim, (birth, death) in diag:
        all_values.append(birth)
        if np.isfinite(death):
            all_values.append(death)
    
    global_min = np.min(all_values) if all_values else 0
    global_max = np.max(all_values) if all_values else 1
    
    if cutoff is not None:
        global_max = cutoff
    
    for dim in range(max_dim + 1):
        diag_dim = np.array([pt[1] for pt in diag if pt[0] == dim])
        
        if len(diag_dim) == 0:
            betti_curves[dim] = (np.zeros(resolution), np.zeros(resolution))
            continue
        
        # Use GLOBAL range, not per-dimension range
        grid = np.linspace(global_min, global_max, resolution)
        curve = np.zeros_like(grid)
        
        for birth, death in diag_dim:
            if np.isinf(death):
                # Infinite bar: alive from birth onward
                curve += (grid >= birth)
            else:
                # Finite bar: alive in [birth, death]
                curve += (grid >= birth) & (grid <= death)
        
        betti_curves[dim] = (grid, curve)
    
    return betti_curves

In [None]:
betti_curves = {}
for name, diag in persistence_diagrams.items():
    points = datasets[name]
    betti_curves[name] = compute_betti_curve(diag, cutoff=global_last_deaths[dim])

In [None]:
# Get keys in insertion order
keys = list(betti_curves.keys())

hyper_betti_curves = [betti_curves[k] for k in keys[:5]]
control_betti_curves  = [betti_curves[k] for k in keys[-8:]]

average_betti_control = average_curves(control_betti_curves)
average_betti_hyper = average_curves(hyper_betti_curves)

In [None]:
#Plots Betti curves for one mouse
"""for dim in np.arange(3):    
    plot_curves(
        [hyper_betti_curves[4]],
        labels=["m1053007"], dimension = dim,
        title="Betti Curve", limit = [0,global_last_deaths[dim]]
    )"""

In [None]:
#Plots Average Betti Curves
"""for dim in np.arange(3):    
    plot_curves(
        [average_betti_hyper, average_betti_control],
        labels=["Hyper", "Control"], dimension = dim,
        title="Comparison of Average Betti Curves", limit = [0,global_last_deaths[dim]]
    )"""

# Persistence Landscapes

In [None]:
def compute_landscape_curves(
    persistence_diagram,
    t_min,
    t_max,
    max_dim=2,
    k=5,
    resolution=500
):
    """
    Compute persistence landscapes for all dimensions.
    Returns dict of {dim: (t_vals, landscapes)} where landscapes is shape (k, resolution)
    """
    landscape_curves = {}
    
    for dim in range(max_dim + 1):
        diagram = np.array([
            (birth, death)
            for d, (birth, death) in persistence_diagram
            if d == dim and np.isfinite(death)  # Filter out infinite bars
        ])
        
        if diagram.shape[0] == 0:
            # No features in this dimension
            t_vals = np.linspace(t_min, t_max, resolution)
            landscapes = np.zeros((k, resolution))
            landscape_curves[dim] = (t_vals, landscapes)
            continue
        
        landscape = Landscape(
            num_landscapes=k,
            resolution=resolution,
            sample_range=(t_min, t_max)
        )
        
        # IMPORTANT: input must be a list of NumPy arrays
        landscapes = landscape.fit_transform([diagram])[0]
        landscapes = landscapes.reshape(k, resolution)
        t_vals = np.linspace(t_min, t_max, resolution)
        
        landscape_curves[dim] = (t_vals, landscapes)
    
    return landscape_curves


def save_landscape(t_vals, landscapes, dim, filename, title=None, limit=[0, 100], dpi=300):
    """
    Save persistence landscape plot to file.
    landscapes: shape (k, resolution) array of k landscape functions
    """
    plt.figure(figsize=(10, 6))
    
    for i, layer in enumerate(landscapes):
        plt.plot(t_vals, layer, label=f"$\\lambda_{{{i+1}}}$")
    
    plt.xlabel("Filtration Value", fontsize=20)
    plt.ylabel("Landscape Value", fontsize=20)
    plt.title(title or f"Persistence Landscape (H_{dim})", fontsize=28)
    plt.xlim(limit)
    plt.legend(fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.grid(False)
    plt.tight_layout()
    
    plt.savefig(filename, dpi=dpi, bbox_inches='tight')
    plt.close()
    
    print(f"Saved landscape plot to {filename}")

def plot_landscape(t_vals, landscapes, dim, title=None):
    plt.figure(figsize=(8, 5))
    for i, layer in enumerate(landscapes):
        plt.plot(t_vals, layer, label=f"$\\lambda_{{{i+1}}}$")

    plt.xlabel("t")
    plt.ylabel("Landscape value")
    plt.title(title or f"Persistence Landscape (H_{dim})")
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
# Compute landscapes for all datasets
landscape_curves = {}
for name, diag in persistence_diagrams.items():
    landscape_curves[name] = compute_landscape_curves(
        diag,
        t_min=0,
        t_max=max(global_last_deaths.values()),
        max_dim=2,
        k=5,
        resolution=500
    )

# Lifespan Curves

In [None]:
def extract_intervals(persistence_diagram, dim):
    return np.array([
        (b, d)
        for d_dim, (b, d) in persistence_diagram
        if d_dim == dim and np.isfinite(d)
    ])

def compute_lifespan_curves(
    persistence_diagram,
    t_min,
    t_max,
    max_dim=2,
    resolution=500
):
    lifespan_curves = {}
    
    for dim in range(max_dim + 1):
        intervals = extract_intervals(persistence_diagram, dim)
        
        if intervals.size == 0:
            lifespan_curves[dim] = (np.zeros(resolution), np.zeros(resolution))
            continue
        
        births = intervals[:, 0]
        deaths = intervals[:, 1]
        lifespans = deaths - births
        
        # Use GLOBAL range, not per-dimension range
        t_vals = np.linspace(t_min, t_max, resolution)
        LC = np.zeros_like(t_vals)
        
        for i, t in enumerate(t_vals):
            alive = (births <= t) & (t < deaths)
            LC[i] = lifespans[alive].sum()
        
        lifespan_curves[dim] = (t_vals, LC)
    
    return lifespan_curves

In [None]:
lifespan_curves = {}
for name, diag in persistence_diagrams.items():
    points = datasets[name]
    lifespan_curves[name] = compute_lifespan_curves(diag, t_min = 0, t_max=global_last_deaths[dim], resolution = 100000)

In [None]:
# Get keys in insertion order
keys = list(lifespan_curves.keys())

hyper_lifespan_curves = [lifespan_curves[k] for k in keys[:5]]
control_lifespan_curves  = [lifespan_curves[k] for k in keys[-8:]]

average_lifespan_control = average_curves(control_lifespan_curves)
average_lifespan_hyper = average_curves(hyper_lifespan_curves)

In [None]:
# Plots lifespan curves for one mouse
"""for dim in np.arange(3):    
    plot_curves(
        [hyper_lifespan_curves[4]],
        labels=["m1053007"], dimension = dim,
        title="Lifespan Curve", limit = [0,global_last_deaths[dim]]
    )"""

In [None]:
#Plots average lifespan curves
"""for dim in np.arange(3):    
    plot_curves(
        [average_lifespan_hyper, average_lifespan_control],
        labels=["Hyper", "Control"], dimension = dim,
        title="Comparison of Average Lifespan Curves", limit = [0,global_last_deaths[dim]]
    )"""

# Norm Lifespan Curves

In [None]:
def compute_norm_lifespan_curves(
    persistence_diagram,
    t_min,
    t_max,
    max_dim=2,
    resolution=500
):
    lifespan_curves = {}
    
    for dim in range(max_dim + 1):
        intervals = extract_intervals(persistence_diagram, dim)
        
        if intervals.size == 0:
            lifespan_curves[dim] = (np.zeros(resolution), np.zeros(resolution))
            continue
        
        births = intervals[:, 0]
        deaths = intervals[:, 1]
        lifespans = deaths - births
        
        # Compute total lifespan for normalization
        total_lifespan = lifespans.sum()
        
        # Use GLOBAL range, not per-dimension range
        t_vals = np.linspace(t_min, t_max, resolution)
        LC = np.zeros_like(t_vals)
        
        for i, t in enumerate(t_vals):
            alive = (births <= t) & (t < deaths)
            # Normalized: sum of lifespans alive at t / total lifespan
            if total_lifespan > 0:
                LC[i] = lifespans[alive].sum() / total_lifespan
            else:
                LC[i] = 0
        
        lifespan_curves[dim] = (t_vals, LC)
    
    return lifespan_curves

In [None]:
norm_lifespan_curves = {}
for name, diag in persistence_diagrams.items():
    points = datasets[name]
    norm_lifespan_curves[name] = compute_norm_lifespan_curves(diag, t_min = 0, t_max=global_last_deaths[dim], resolution = 100000)

In [None]:
# Get keys in insertion order
keys = list(norm_lifespan_curves.keys())

hyper_norm_lifespan_curves = [norm_lifespan_curves[k] for k in keys[:5]]
control_norm_lifespan_curves  = [norm_lifespan_curves[k] for k in keys[-8:]]

average_norm_lifespan_control = average_curves(control_lifespan_curves)
average_norm_lifespan_hyper = average_curves(hyper_lifespan_curves)

In [None]:
#Plot all lifespan curves for one mouse
"""for dim in np.arange(3):    
    plot_curves(
        [hyper_norm_lifespan_curves[4]],
        labels=[keys[4]], dimension = dim,
        title="Norm Lifespan Curve", limit = [0,global_last_deaths[dim]]
    )"""

In [None]:
#Plot average lifespan curves across groups in all three dimensions
"""for dim in np.arange(3):    
    plot_curves(
        [average_lifespan_hyper, average_lifespan_control],
        labels=["Hyper", "Control"], dimension = dim,
        title="Comparison of Average Norm Lifespan Curves", limit = [0,global_last_deaths[dim]]
    )"""

# Save Figures

In [None]:
#saves all average curves
for dim in np.arange(3):
    filename = 'lobeTDAGraphs/Pressure'+str(pressure)+'/'+lobe+'/Betti_'+str(dim)+'_Average_P'+str(pressure)+'_'+lobe
    save_curves([average_betti_hyper, average_betti_control], labels=['Hyper','Control'], dimension=dim, filename=filename, 
        title=f"Comparison of Average Betti Curves Lobe {lobe}", limit = [0,global_last_deaths[dim]], dpi=300)

    filename = 'lobeTDAGraphs/Pressure'+str(pressure)+'/'+lobe+'/Lifespan_'+str(dim)+'_Average_P'+str(pressure)+'_'+lobe
    save_curves([average_lifespan_hyper, average_lifespan_control], labels=['Hyper','Control'], dimension=dim, filename=filename, 
        title=f"Comparison of Average Lifespan Curves Lobe {lobe}", limit = [0,global_last_deaths[dim]], dpi=300)

    filename = 'lobeTDAGraphs/Pressure'+str(pressure)+'/'+lobe+'/Norm_Lifespan_'+str(dim)+'_Average_P'+str(pressure)+'_'+lobe
    save_curves([average_norm_lifespan_hyper, average_norm_lifespan_control], labels=['Hyper','Control'], dimension=dim, filename=filename, 
        title="Comparison of Average Norm Lifespan Curves", limit = [0,global_last_deaths[dim]], dpi=300)

Saved plot to 2DComplementGraphs/Pressure3/postcaval/Betti_0_Average_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Lifespan_0_Average_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Norm_Lifespan_0_Average_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Betti_1_Average_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Lifespan_1_Average_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Norm_Lifespan_1_Average_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Betti_2_Average_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Lifespan_2_Average_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Norm_Lifespan_2_Average_P3_postcaval


## Betti Curve

In [None]:
keys = list(betti_curves.keys())
graph = 'Betti'
for dim in range(3):
        for i in range(13):
                mouse_name = keys[i]
                if i< 5:
                        filename = 'lobeTDAGraphs/Pressure'+str(pressure)+'/'+lobe+'/'+graph+'_'+str(dim)+'_Hyper_'+mouse_name+'_P'+str(pressure)+'_'+lobe
                else:
                        filename = 'lobeTDAGraphs/Pressure'+str(pressure)+'/'+lobe+'/'+graph+'_'+str(dim)+'_Control_'+mouse_name+'_P'+str(pressure)+'_'+lobe
                save_curves([betti_curves[mouse_name]], labels=[mouse_name], dimension=dim, filename=filename, 
                        title=f"Betti Curve Lobe {lobe}", limit = [0,global_last_deaths[dim]], dpi=300)

Saved plot to 2DComplementGraphs/Pressure3/postcaval/Betti_0_Hyper_m1053007_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Betti_0_Hyper_m2053007_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Betti_0_Hyper_m1053107_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Betti_0_Hyper_m2053107_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Betti_0_Hyper_m1060107_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Betti_0_Control_m1060407_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Betti_0_Control_m2060407_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Betti_0_Control_m3060407_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Betti_0_Control_m1060507_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Betti_0_Control_m2060507_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Betti_0_Control_m3060507_P3_postcaval
Saved plo

## Lifespan

In [None]:

keys = list(lifespan_curves.keys())
graph = 'Lifespan'
for dim in range(3):
        for i in range(13):
                mouse_name = keys[i]
                if i< 5:
                        filename = 'lobeTDAGraphs/Pressure'+str(pressure)+'/'+lobe+'/'+graph+'_'+str(dim)+'_Hyper_'+mouse_name+'_P'+str(pressure)+'_'+lobe
                else:
                        filename = 'lobeTDAGraphs/Pressure'+str(pressure)+'/'+lobe+'/'+graph+'_'+str(dim)+'_Control_'+mouse_name+'_P'+str(pressure)+'_'+lobe
                save_curves([lifespan_curves[mouse_name]], labels=[mouse_name], dimension=dim, filename=filename, 
                        title=f"Lifespan Curve Lobe {lobe}", limit = [0,global_last_deaths[dim]], dpi=300)



Saved plot to 2DComplementGraphs/Pressure3/postcaval/Lifespan_0_Hyper_m1053007_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Lifespan_0_Hyper_m2053007_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Lifespan_0_Hyper_m1053107_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Lifespan_0_Hyper_m2053107_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Lifespan_0_Hyper_m1060107_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Lifespan_0_Control_m1060407_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Lifespan_0_Control_m2060407_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Lifespan_0_Control_m3060407_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Lifespan_0_Control_m1060507_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Lifespan_0_Control_m2060507_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Lifespan_0_Contro

## Norm Lifespan

In [None]:

keys = list(norm_lifespan_curves.keys())
graph = 'Norm_Lifespan'

for dim in range(3):
        for i in range(13):
                mouse_name = keys[i]
                if i< 5:
                        filename = 'lobeTDAGraphs/Pressure'+str(pressure)+'/'+lobe+'/'+graph+'_'+str(dim)+'_Hyper_'+mouse_name+'_P'+str(pressure)+'_'+lobe
                else:
                        filename = 'lobeTDAGraphs/Pressure'+str(pressure)+'/'+lobe+'/'+graph+'_'+str(dim)+'_Control_'+mouse_name+'_P'+str(pressure)+'_'+lobe
                save_curves([norm_lifespan_curves[mouse_name]], labels=[mouse_name], dimension=dim, filename=filename, 
                        title=f"Norm Lifespan Curve Lobe {lobe}", limit = [0,global_last_deaths[dim]], dpi=300)



Saved plot to 2DComplementGraphs/Pressure3/postcaval/Norm_Lifespan_0_Hyper_m1053007_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Norm_Lifespan_0_Hyper_m2053007_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Norm_Lifespan_0_Hyper_m1053107_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Norm_Lifespan_0_Hyper_m2053107_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Norm_Lifespan_0_Hyper_m1060107_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Norm_Lifespan_0_Control_m1060407_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Norm_Lifespan_0_Control_m2060407_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Norm_Lifespan_0_Control_m3060407_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Norm_Lifespan_0_Control_m1060507_P3_postcaval
Saved plot to 2DComplementGraphs/Pressure3/postcaval/Norm_Lifespan_0_Control_m2060507_P3_postcaval
Saved plot to 2DComp

## Persistence Landscapes

In [None]:
keys = list(landscape_curves.keys())
graph = 'Persistence_Landscape'
for i in range(13):
    mouse_name = keys[i]
    for dim in [0, 1, 2]:
        if i< 5:
            filename = '2DComplementGraphs/Pressure'+str(pressure)+'/'+lobe+'/'+graph+'_'+str(dim)+'_Hyper_'+mouse_name+'_P'+str(pressure)+'_'+lobe
        else:
                filename = '2DComplementGraphs/Pressure'+str(pressure)+'/'+lobe+'/'+graph+'_'+str(dim)+'_Control_'+mouse_name+'_P'+str(pressure)+'_'+lobe
        t_vals, landscapes = landscape_curves[mouse_name][dim]
        save_landscape(
            t_vals,
            landscapes,
            dim=dim,
            filename=filename,
            title=f'Persistence Landscape {name} Dimension {dim} Lobe {lobe}',
            limit=[0, global_last_deaths[dim]],
            dpi=150
        )

Saved landscape plot to 2DComplementGraphs/Pressure3/postcaval/Persistence_Landscape_0_Hyper_m1053007_P3_postcaval
Saved landscape plot to 2DComplementGraphs/Pressure3/postcaval/Persistence_Landscape_1_Hyper_m1053007_P3_postcaval
Saved landscape plot to 2DComplementGraphs/Pressure3/postcaval/Persistence_Landscape_2_Hyper_m1053007_P3_postcaval
Saved landscape plot to 2DComplementGraphs/Pressure3/postcaval/Persistence_Landscape_0_Hyper_m2053007_P3_postcaval
Saved landscape plot to 2DComplementGraphs/Pressure3/postcaval/Persistence_Landscape_1_Hyper_m2053007_P3_postcaval
Saved landscape plot to 2DComplementGraphs/Pressure3/postcaval/Persistence_Landscape_2_Hyper_m2053007_P3_postcaval
Saved landscape plot to 2DComplementGraphs/Pressure3/postcaval/Persistence_Landscape_0_Hyper_m1053107_P3_postcaval
Saved landscape plot to 2DComplementGraphs/Pressure3/postcaval/Persistence_Landscape_1_Hyper_m1053107_P3_postcaval
Saved landscape plot to 2DComplementGraphs/Pressure3/postcaval/Persistence_Lands