This file can either compute the Wasserstein, Bottleneck or Landscape distances to the 
Wasserstein, Bottleneck or the Landscape barycenters of each class.
It can also compute the respective distances to the origin diagram.

In [16]:
import numpy as np
import pyedflib
import statistics
import plotly.graph_objects as go
import pandas as pd
from gtda.time_series import SingleTakensEmbedding
from gtda.homology import VietorisRipsPersistence
from gtda.diagrams import PersistenceEntropy, Amplitude, NumberOfPoints, ComplexPolynomial, PersistenceLandscape, HeatKernel, Silhouette, BettiCurve, PairwiseDistance, ForgetDimension
import random
from sklearn import preprocessing
import plotly.io as pio
from gtda.plotting import plot_diagram
import h5py



# Choose parameters

In [17]:
# choose individuum
subject = "m292"

In [18]:
# Set the distance metric here

metric = "wasserstein"
#metric = "landscape"
#metric = "bottleneck"

# Load Data

In [19]:
label_list = [0, 1, 2, 3, 4]

In [20]:
dataframes = {}

for label in label_list:
    filename = "Data/"+str(subject)+"/run0"+str(label)+"/Brain_Imaging_Data.h5"
    file = h5py.File(filename,'r')
    dataframes[label] = file['Data']

In [21]:
# Load persistence diagrams

persistence_diagrams = np.load('Embeddings_and_Persistence_Diagrams/'+str(subject)+'/Persistence_Diagrams.npy', \
                allow_pickle=True).item() # .item() to convert the dtype to dict again

In [22]:
extended_persistence_diagrams = np.load('Embeddings_and_Persistence_Diagrams/'+str(subject)+'/Extended_Persistence_Diagrams.npz', \
                allow_pickle=True)

# Computing the distance to the Wasserstein Barycenter

## Wasserstein Barycenter

The Wasserstein Barycenter is the most representative persistence diagram in a set of diagrams (of one class), so the one with the lowest overall (Wasserstein) distance to all other diagrams. Because it takes long to compute, we will for now only use a part of the data as training data. For now, these training samples can also be in the test set of the simple classifier in the end.

In [8]:
# We will look at 0-, 1- and 2-dimensional holes
homology_dimensions = [0, 1, 2]

# We will use a Vietoris Rips filtrations
persistence = VietorisRipsPersistence(
    homology_dimensions=homology_dimensions, n_jobs=10
)

### Only take random subset of persistence diagrams into account (for computational efficiency)

In [52]:
# TODO If I use the segment barycenters, use portion of segments here

def get_random_subsets(extended_persistence_diagrams, subset_ratio):
    """
    Selects a random subset of elements from each key in the provided dictionary of persistence diagrams.

    Parameters:
    - extended_persistence_diagrams (dict): A dictionary where each value is a list or array of persistence diagrams.
    - subset_ratio (float): The ratio of elements to select from each list. Default is 0.15 (i.e., 15%).

    Returns:
    - random_subsets (dict): A new dictionary containing the random subsets.
    """
    random_subsets = {}
    
    for key, value in extended_persistence_diagrams.items():
        subset_size = int(len(value) * subset_ratio)
        indices = np.arange(len(value))
        random_indices = random.sample(sorted(indices), subset_size)
        random_subsets[key] = value[random_indices]
    
    return random_subsets

subset_ratio = (len(extended_persistence_diagrams["Label_0"])-80*50)/len(extended_persistence_diagrams["Label_0"])

random_subsets = get_random_subsets(extended_persistence_diagrams, subset_ratio)

### Computing the Wasserstein Barycenter for all labels

Compute Wasserstein barycenter of each segment (but only using a portion of the diagrams in each segment for computational efficiency), and then compute the Wasserstein barycenter of all segment Wasserstein barycenters.

In [None]:
import random
import numpy as np
from gtda.diagrams import PairwiseDistance
from joblib import Parallel, delayed
import sys


def get_random_subsets(diagrams, subset_ratio):
    """
    Selects a random subset of elements from each key in the provided dictionary of persistence diagrams.

    Parameters:
    - extended_persistence_diagrams (dict): A dictionary where each value is a list or array of persistence diagrams.
    - subset_ratio (float): The ratio of elements to select from each list. Default is 0.15 (i.e., 15%).

    Returns:
    - random_subsets (dict): A new dictionary containing the random subsets.
    """
    random_subsets = {}
    
    subset_size = int(len(diagrams) * subset_ratio)
    indices = np.arange(len(diagrams))
    random_indices = random.sample(sorted(indices), subset_size)
    random_subsets = diagrams[random_indices]
    
    return random_subsets


def custom_logger(idx, total_segments):
    sys.stdout.write(f"\rProcessing pair {idx + 1}/{total_segments}")
    sys.stdout.flush()

def compute_wasserstein_distances(diagrams):
    n = len(diagrams)
    total_pairs = (n * (n - 1)) // 2  # Total number of unique pairs
    distances = np.zeros((n, n))
    
    # Initialize PairwiseDistance with the Wasserstein metric
    pairwise_dist = PairwiseDistance(metric='wasserstein')
    
    def compute_pair(i, j, idx):
        # Log progress
        custom_logger(idx, total_pairs)
        # Compute the Wasserstein distance between diagrams[i] and diagrams[j]
        return pairwise_dist.fit_transform([diagrams[i], diagrams[j]])[0, 0]
    
    # Generate all unique pairs of indices (i, j)
    pairs = [(i, j) for i in range(n) for j in range(i + 1, n)]
    results = Parallel(n_jobs=-1)(delayed(compute_pair)(i, j, idx)
                                  for idx, (i, j) in enumerate(pairs))
    
    for idx, dist in enumerate(results):
        i, j = pairs[idx]
        distances[i, j] = dist
        distances[j, i] = dist
    
    return distances


def find_barycenters_for_each_segment(diagrams, dataframes, label, subsample_size, n_jobs=-1):
    # Initialize PairwiseDistance with the Wasserstein metric
    pairwise_dist = PairwiseDistance(metric='wasserstein')

    segment_length = 80 

    segment_barycenters = []

    for idx in range(int(len(diagrams) / segment_length)):

        print("Processing segment " + str(idx) + "/" + str(int(len(diagrams) / segment_length)))
        
        # Compute the pairwise Wasserstein distances
        segment_diagrams = diagrams[segment_length * idx:segment_length * (idx+1)]

        subsample_of_segment_diagrams = get_random_subsets(segment_diagrams, 0.2)        
        
        pairwise_wasserstein_distances_for_segment = compute_wasserstein_distances(subsample_of_segment_diagrams)

        # For each hole, calculate the sum of distances to all other holes
        sum_distances_for_segment = [sum(dist) for dist in pairwise_wasserstein_distances_for_segment]

        # Find the index of the Wasserstein barycenter
        most_representative_index_in_segment = np.argmin(sum_distances_for_segment)

        # Wasserstein Barycenter for our label
        most_representative_diagram_in_segment = persistence.fit_transform([dataframes[most_representative_index_in_segment]])

        segment_barycenters.append(most_representative_diagram_in_segment[0])

    return segment_barycenters


def find_barycenter(diagrams, dataframes, label, subsample_size, n_jobs=-1):

    segment_barycenters = find_barycenters_for_each_segment(diagrams, dataframes, label, subsample_size, n_jobs=-1)

    print(segment_barycenters)

    # Compute pairwise Wasserstein distances in parallel
    pairwise_wasserstein_distances = compute_wasserstein_distances(segment_barycenters)

    # For each diagram, calculate the sum of distances to all other diagrams
    sum_distances = [sum(dist) for dist in pairwise_wasserstein_distances]

    # Find the index of the Wasserstein barycenter
    most_representative_index = np.argmin(sum_distances)

    # Wasserstein Barycenter for our label
    most_representative_diagram = persistence.fit_transform([dataframes[most_representative_index]])

    fig = plot_diagram(most_representative_diagram[0])
    fig.show()
    pio.write_image(fig, 'Plots/BI_'+str(subject)+'_Label ' + str(label) + ' Most Representative Diagram (Extended Diagrams).png')

    return most_representative_diagram

representative_diagrams = {}  # barycenters for all labels

for label in label_list:
    print("Starting computation for label " +str(label))
    representative_diagrams["Label_"+str(label)] = find_barycenter(extended_persistence_diagrams["Label_"+str(label)], dataframes[label], label, subsample_size=0.3, n_jobs=-1)


Starting computation for label 0
Processing segment 0/75
Processing pair 120/120Processing segment 1/75
Processing pair 10/120

In [None]:
np.save('Embeddings_and_Persistence_Diagrams/'+str(subject)+'/Most_Representative_Diagrams_(Extended_Diagrams).npy', \
            np.array(representative_diagrams, dtype=object), allow_pickle=True)