In [1]:
import pandas as pd
import numpy as np
from scipy.signal import butter, filtfilt, hilbert
import json
from tqdm import tqdm
import h5py
import multiprocessing

In [2]:
def load_data(root, subject):
    data = pd.read_csv(f'{root}/{subject}/{subject}_data.csv', header=None)
    with open(f'{root}/{subject}/{subject}_param.json', 'r') as file:
        param = json.load(file)
    if len(param['channel_id_labels']) == data.shape[1]:    
        print(f'Loaded {data.shape[0]} datapoints and {len(param.keys())} parameters for {data.shape[1]} channels')
    else:
        print('Error loading data')
    return data, param

In [3]:
def extract_params(param, datashape):
    Nch = datashape[1]
    sample_rate = param['sample_rate']
    tperm = param['tperm']
    Nperm = param['Nperm']
    sbin = round(param['tbin'] * sample_rate)
    sstp = round(param['tstp'] * sample_rate)
    rT = np.arange(0, datashape[0]-sbin, sstp)
    Nt = len(rT)
    t = (rT + sbin / 2) / sample_rate
    rP = np.column_stack((param['fP'][:-1], param['fP'][1:]))
    rA = np.column_stack((param['fA'][:-1], param['fA'][1:]))
    NP = rP.shape[0]
    NA = rA.shape[0]
    edges = np.arange(-np.pi, np.pi+param['delta'], param['delta'])
    x = edges[:-1] + param['delta'] / 2
    Nx = len(x)
    return Nch, sample_rate, tperm, Nperm, sbin, sstp, rT, Nt, t, rP, rA, NP, NA, edges, x, Nx

In [5]:
def save_results(PACmi, root, subject):
    PACmi_transposed = np.transpose(PACmi, (4, 3, 2, 1, 0))
    filename = f"{root}/{subject}/{subject}_PACmi.h5"
    with h5py.File(filename, 'w') as f:
        # Save the array to the file
        f.create_dataset('PACmi', data=PACmi_transposed)

In [4]:
root = '/Users/tnl/matlab/sEEG_proc'
subject = 'HUP241_RID890'
data, param = load_data(root, subject)
Nch, sample_rate, tperm, Nperm, sbin, sstp, rT, Nt, t, rP, rA, NP, NA, edges, x, Nx = extract_params(param, data.shape)

Loaded 2914304 datapoints and 10 parameters for 36 channels
2914304
36


In [13]:
ich = 0
tmi = np.zeros((NP, NA, Nt))
tmip = np.zeros((NP, NA, Nt))
data_numpy = data[ich].to_numpy()
log_Nx = np.log(Nx)
iP = 0

bP, aP = butter(2, rP[iP, :] / (sample_rate / 2), btype='band')
P = np.angle(hilbert(filtfilt(bP, aP, data[ich].to_numpy())))
Pbin = np.transpose(np.vstack([np.digitize(P[t1:t1+sbin], edges) for t1 in rT]))
Pbin.shape

(245760, 44)

In [14]:
iA = 0
bA, aA = butter(2, rA[iA, :] / (sample_rate / 2), btype='band')
A = np.abs(hilbert(filtfilt(bA, aA, data_numpy)))
PAC = np.zeros((Nch, NP, NA, Nt, Nx))

In [22]:
for iT in range(Nt):
    cA = A[rT[iT]:rT[iT]+sbin]
    cAm = np.array([np.mean(cA[Pbin[:, iT] == jj+1]) for jj in range(Nx)])
    PAC[ich, iP, iA, iT, :] = cAm
    cAm /= np.sum(cAm)
    cAm[cAm == 0] = 1e-10
    tmi[iP, iA, iT] = (log_Nx + np.sum(cAm * np.log(cAm))) / log_Nx

In [18]:
MIperm = np.zeros((Nperm, Nt))
iperm = 0
tshift = tperm[0] + np.diff(tperm)[0] * np.random.rand()
nshift = round(tshift * sample_rate)
Ashift = np.roll(A, nshift)

In [24]:
def calculate_MI(Ashift, Pbin_slice, Nx, log_Nx):
    cAm = np.array([np.mean(Ashift[Pbin_slice == jj+1]) for jj in range(Nx)])
    cAm /= np.sum(cAm)
    cAm[cAm == 0] = 1e-10
    return (log_Nx + np.sum(cAm * np.log(cAm))) / log_Nx


In [25]:
for iT in range(Nt):
                t1 = rT[iT]
                t2 = rT[iT] + sbin
                MIperm[iperm, iT] = calculate_MI(Ashift[t1:t2], Pbin[:, iT], Nx, log_Nx)

In [27]:
MIpermA = np.zeros((Nperm, Nt))
for iT in range(Nt):
    t1 = rT[iT]
    t2 = rT[iT] + sbin
    cA = Ashift[t1:t2]
    cAm = np.zeros(Nx)
    for jj in range(Nx):
        cAm[jj] = np.mean(cA[Pbin[:, iT] == jj+1])
        cAm /= np.sum(cAm)
        cAm[cAm == 0] = 1e-10
        MIpermA[iperm, iT] = (log_Nx + np.sum(cAm * np.log(cAm))) / log_Nx

In [32]:
(MIperm == MIpermA).all()

False

In [47]:
iT = 0
t1 = rT[iT]
t2 = rT[iT] + sbin
cA = A[t1:t2]
cAm = np.zeros(Nx)

In [56]:
for jj in range(Nx):
    cAm[jj] = np.mean(cA[Pbin[:, iT] == jj])
cAm

array([        nan, 18.60148554, 18.57019077, 18.55491346, 18.55011831,
       18.52758127, 18.5256682 , 18.53360616, 18.6002194 , 18.66921464,
       18.74652309, 18.75860608, 18.71290433, 18.68013297, 18.65824344,
       18.6609928 ])

In [57]:
for jj in range(Nx):
    cAm[jj] = np.mean(cA[Pbin[:, iT] == jj+1])
cAm

array([18.60148554, 18.57019077, 18.55491346, 18.55011831, 18.52758127,
       18.5256682 , 18.53360616, 18.6002194 , 18.66921464, 18.74652309,
       18.75860608, 18.71290433, 18.68013297, 18.65824344, 18.6609928 ,
       18.64698748])