In [3]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import pandas as pd
import pickle
from scipy import stats
from tqdm import tqdm

import h5py
import os
import dill


In [4]:
# Function to write results to HDF5
def save_to_hdf5(path, data, file_path):
    if len(data) == 0:
        return
    with h5py.File(file_path, "a") as hdf5_file:
        if path not in hdf5_file:
            hdf5_file.create_dataset(path, data=data, maxshape=(None,), compression="gzip")
        else:
            dataset = hdf5_file[path]
            dataset.resize(dataset.shape[0] + data.shape[0], axis=0)
            dataset[-data.shape[0]:] = data

def get_idxs(idxs, soz_idx):
    soz_soz = []
    soz_non = []
    non_soz = []
    non_non = []
    for x, y in zip(idxs[0], idxs[1]):
        if x in soz_idx and y in soz_idx:
            soz_soz.append((x, y))
        elif x in soz_idx or y in soz_idx:
            if x in soz_idx:
                soz_non.append((x, y))
            else:
                non_soz.append((x, y))
        else:
            non_non.append((x, y))
    return np.array(soz_soz), np.array(soz_non), np.array(non_soz), np.array(non_non)

def safe_slice_and_flatten(measure, idx_array):
    if len(idx_array) == 0:
        return np.array([])  # Return empty array if no indices
    return measure[:, idx_array[:, 0], idx_array[:, 1]].flatten()

output_path = "/media/dan/Data/data/calculations"
mapping_path = "/media/dan/Big/manuiscript_0001_hfo_rates/data/FULL_composite_patient_info.csv"  
ilae_path = "/media/dan/Big/manuiscript_0001_hfo_rates/ravi_hfo_numbers~N59+v03.csv"
calculation_path = "/media/dan/Big/network_mining/calculations/sixrun/calculations/six_run"


mappings = pd.read_csv(mapping_path)
ilae = pd.read_csv(ilae_path)
# for each patient in mappings, find the corresponding ilae number. The patient may not be in the ilae dataset but has a designation of seizureFree or not.
# if the patient is not in the ilae dataset, then use the seizureFree column to determine the ilae number where -1 is seizureFree and 100 is not seizureFree
ilae_numbers = {}
for p in mappings["pid"].unique():
    if p in ilae["patient"].values:
        ilae_numbers[p] = ilae[ilae["patient"] == p]["ilae"].values[0]
    else:
        if mappings[mappings["pid"] == p]["seizureFree"].values[0] == True:
            ilae_numbers[p] = -1
        else:
            ilae_numbers[p] = 100

# now we have a dictionary of ilae numbers for each patient. Fill in the mappings dataframe with these numbers which has multiple rows for each patient
ilae_list = []
for p in mappings["pid"]:
    ilae_list.append(ilae_numbers[p])
mappings["ilae"] = ilae_list


files = list(sorted(os.listdir(calculation_path)))
pids = list(sorted(set([int(f.split("_")[0]) for f in files])))

# read the first file from the first patient to get the calculation names. make sure file has "epoch" in the name
first_epoch_file = [f for f in files if "epoch" in f][0]

with open(os.path.join(calculation_path, first_epoch_file, 'calc.pkl'), "rb") as f:
    calc = dill.load(f)
columns = calc.columns.levels[0].unique().values    

for pid in tqdm(pids, desc="Patients", leave=True):
    pid_files = list(sorted([f for f in files if f.startswith(f"{pid:03}")]))
    chnames_idx = pid_files.index(f"{pid:03}_chnames.csv")
    chnames = pd.read_csv(os.path.join(calculation_path, pid_files[chnames_idx]))['0'].values
    pid_files.pop(chnames_idx)

    pid_mappings = mappings[mappings["pid"] == pid]
    pid_mappings = pid_mappings[pid_mappings["electrode"].isin(chnames)]
    pid_mappings = pid_mappings.set_index("electrode").reindex(chnames).reset_index()

    soz_idx = pid_mappings.index[pid_mappings["soz"] == 1].values
    ilae_group = pid_mappings["ilae"].iloc[0]

    if len(soz_idx) == 0:
        continue

    data = []
    skip = False
    for file in pid_files:
        with open(os.path.join(calculation_path, file, 'calc.pkl'), "rb") as f:
            try:
                data.append(dill.load(f))
            except:
                print(f"Error loading {file}")
                skip = True
                break
    if skip:
        continue
    

    for col in tqdm(data[0].columns.levels[0].unique(), desc=f"Columns for {pid}", leave=False):
        if "gc_gaussian" not in col:
            continue
        measure = []
        for r in data:
            full = r[col].values
            measure.append(full)
        measure = np.array(measure)
        break
    break

Patients:   0%|          | 0/72 [00:41<?, ?it/s]


In [5]:
col

'gc_gaussian_k-1_kt-1_l-1_lt-1'

In [6]:
measure.shape

(609, 118, 118)

In [8]:
# save to matlab
import scipy.io
scipy.io.savemat(f"{pid}_grangercause.mat", {"measure_directed": measure, "soz": pid_mappings["soz"]})