In [1]:
import h5py
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import os
import gc


In [21]:
upper_tri_indices = np.triu_indices(5, k=1)
upper_tri_indices

(array([0, 0, 0, 0, 1, 1, 1, 2, 2, 3]), array([1, 2, 3, 4, 2, 3, 4, 3, 4, 4]))

In [22]:
for x,y in zip(upper_tri_indices[0], upper_tri_indices[1]):
    print(x+1,y+1)


1 2
1 3
1 4
1 5
2 3
2 4
2 5
3 4
3 5
4 5


In [2]:
# new format check
import dill as pickle
first_epoch = "/media/dan/Data/data/connectivity/six_run/011_epoch_0000000000-0000001024/calc.pkl"
with open(first_epoch, 'rb') as f:
    tmp = pickle.load(f)
    tmp = tmp['pdist_euclidean'].values


In [3]:
tmp[:5, :5]

array([[        nan, 26.59590548, 36.04899322, 31.85905297, 31.71056188],
       [26.59590548,         nan, 34.50576593, 34.14190158, 35.79508584],
       [36.04899322, 34.50576593,         nan, 21.13273723, 29.81815887],
       [31.85905297, 34.14190158, 21.13273723,         nan, 16.61480794],
       [31.71056188, 35.79508584, 29.81815887, 16.61480794,         nan]])

In [18]:
with h5py.File("/media/dan/Data/outputs/ubiquitous-spork/pyspi_combined_patient_hdf5s/011_20250414.h5", 'r') as f:
    test = f['metadata/adjacency_matrices']['pdist_euclidean'][()]
    soz = f['metadata/patient_info/soz'][()]
    ilae = f['metadata/patient_info/ilae'][()]

In [19]:
ilae

1

In [5]:
test[:5, :5, 0]

array([[        nan, 26.59590548, 36.04899322, 31.85905297, 31.71056188],
       [26.59590548,         nan, 34.50576593, 34.14190158, 35.79508584],
       [36.04899322, 34.50576593,         nan, 21.13273723, 29.81815887],
       [31.85905297, 34.14190158, 21.13273723,         nan, 16.61480794],
       [31.71056188, 35.79508584, 29.81815887, 16.61480794,         nan]])

In [6]:
soz

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
       0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=uint8)

In [7]:
groups = ['non_non', 'non_soz', 'soz_non', 'soz_soz']

In [8]:
def create_legend_plot(output_path=None, orientation='horizontal'):
    """
    Create a legend plot with specified orientation using patches.
    
    Parameters:
    -----------
    output_path : str, optional
        Path to save the legend plot. If None, the plot is displayed.
    orientation : str, optional
        'horizontal' or 'vertical' for the legend layout.
    """
    # Set figure size based on orientation
    if orientation == 'horizontal':
        plt.figure(figsize=(3, 0.5))
        ncol = 3
    else:  # vertical
        plt.figure(figsize=(1, 1.5))
        ncol = 1
    
    # Create a figure and axis
    ax = plt.gca()
    
    # Create patches for the legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='black', label='Non-EZ', alpha=0.6),
        Patch(facecolor='#D41159', label='EZ', alpha=0.6),
        Patch(facecolor='#1A85FF', label='Non->EZ', alpha=0.6)
    ]
    
    # Create legend with patches
    ax.legend(handles=legend_elements,
             loc='center',
             ncol=ncol,
             frameon=False)
    
    # Remove axes and background
    plt.axis('off')
    
    # Adjust layout to show only the legend
    plt.tight_layout()
    
    # Save the plot if output path is provided
    if output_path:
        # Add orientation to filename if not already present
        if not output_path.endswith('.png'):
            output_path += '.png'
        base_path = output_path.rsplit('.', 1)[0]
        output_path = f"{base_path}_{orientation}.png"
        
        plt.savefig(output_path, dpi=300, bbox_inches='tight', pad_inches=0)
        plt.close()
    else:
        plt.show()


create_legend_plot('/media/dan/Data/git/ubiquitous-spork/plots_for_seminar/legend.png', orientation='horizontal')
create_legend_plot('/media/dan/Data/git/ubiquitous-spork/plots_for_seminar/legend.png', orientation='vertical')

In [9]:
def combine_data(mix_data, non_data, soz_data):
    # Create labels for each group
    mix_labels = ['mix'] * len(mix_data)
    non_labels = ['non'] * len(non_data)
    soz_labels = ['soz'] * len(soz_data)
    
    # Combine all data and labels
    all_data = np.concatenate([mix_data, non_data, soz_data])
    all_labels = mix_labels + non_labels + soz_labels
    
    # Create DataFrame
    df = pd.DataFrame({
        'Measure Value': all_data,
        'group': all_labels
    })
    df = df.replace([np.inf, -np.inf], np.nan).dropna()
    return df

# Example usage:
# data = {'non_non': array1, 'non_soz': array2, 'soz_non': array3, 'soz_soz': array4}
# df = combine_data(data)


def extract_class_connections(adjacency_matrices, soz):
    """
    Extract connections from upper triangle of adjacency matrices based on SOZ mask classes.
    
    Parameters:
    -----------
    adjacency_matrices : np.ndarray
        Shape (n, n, time) array of adjacency matrices over time
    soz : np.ndarray
        Binary array indicating SOZ channels
        
    Returns:
    --------
    non_connections : np.ndarray
        Flattened array of non-SOZ to non-SOZ connections over time
    mix_connections : np.ndarray
        Flattened array of mixed (SOZ to non-SOZ) connections over time
    soz_connections : np.ndarray
        Flattened array of SOZ to SOZ connections over time
    """
    # Create mask (0: non-non, 1: mix, 2: soz-soz)
    mask = soz.reshape(1, -1) + soz.reshape(-1, 1)
    
    # Get upper triangle indices
    upper_tri_indices = np.triu_indices(mask.shape[0], k=1)
    
    # Get the mask values for upper triangle
    mask_values = mask[upper_tri_indices]
    
    # Extract values over time for upper triangle
    connections = adjacency_matrices[upper_tri_indices[0], upper_tri_indices[1], :]
    
    # Separate based on mask classes
    non_connections = connections[mask_values == 0].flatten()
    mix_connections = connections[mask_values == 1].flatten()
    soz_connections = connections[mask_values == 2].flatten()
    
    return non_connections, mix_connections, soz_connections

In [12]:
upper_tri_indices = np.triu_indices(118, k=1)
upper_tri_indices[0].shape

(6903,)

In [13]:
118**2

13924

In [14]:
path = "/media/dan/Data/outputs/ubiquitous-spork/pyspi_combined_patient_hdf5s"

full_metrics = []
files_to_process = []
for patient in list(sorted(os.listdir(path))):
    if not patient.endswith('.h5'):
        continue
    in_path = os.path.join(path, patient)
    # Open the HDF5 file
    skip = False
    with h5py.File(in_path, 'r') as f:
        keys = list(f['metadata/adjacency_matrices'].keys())
        soz = f['metadata/patient_info/soz'][()]
        if sum(soz) == 0:
            skip = True
    if skip:
        continue
    files_to_process.append(in_path)
    full_metrics.extend(keys)
full_metrics = list(set(full_metrics))


for metric in full_metrics:
    soz_nums = np.array([])
    mix_nums = np.array([])
    non_nums = np.array([])
    if metric != "bary-sq_euclidean_max":
        continue

    for in_path in tqdm(files_to_process):
        if in_path != "/media/dan/Data/outputs/ubiquitous-spork/pyspi_combined_patient_hdf5s/001_20250414.h5":
            continue
        with h5py.File(in_path, 'r') as f:
            data = f['metadata/adjacency_matrices'][metric][()]
            soz = f['metadata/patient_info/soz'][()]
        out = extract_class_connections(data, soz)
        break
        non_nums = np.concatenate([non_nums, out[0]])
        mix_nums = np.concatenate([mix_nums, out[1]])
        soz_nums = np.concatenate([soz_nums, out[2]])

    # remove nans and infs
    non_nums = non_nums[~np.isnan(non_nums)]
    mix_nums = mix_nums[~np.isnan(mix_nums)]
    soz_nums = soz_nums[~np.isnan(soz_nums)]
    non_nums = non_nums[non_nums != np.inf]
    mix_nums = mix_nums[mix_nums != np.inf]
    soz_nums = soz_nums[soz_nums != np.inf]
    break




  0%|          | 0/69 [00:00<?, ?it/s]

In [17]:
out[1].shape

(473193,)

In [None]:
combined = combine_data(mix_data=mix_nums, non_data=non_nums, soz_data=soz_nums)
# subset = combined.sample(10000)
sns.kdeplot(data=combined, x="Measure Value", hue="group", common_grid=True, common_norm=False, cut=0,
            palette={'non': 'black', 'soz': '#D41159', 'mix': '#1A85FF'}, fill=False, alpha=0.6, legend=False)
plt.title(metric, fontsize=10)
output_dir = "/media/dan/Data/git/ubiquitous-spork/plots_for_seminar/mean_columns"
plt.savefig(os.path.join(output_dir, f"{metric}.png"), dpi=300)
plt.close()