In [15]:
import sys
sys.path.append('/Users/tereza/nishant/atlas/atlas_work_terez/atlas_harmonization/external_dependencies/CNT_research_tools/python')

import os
import pandas as pd
import numpy as np

from CNTtools import settings
from CNTtools import iEEGPreprocess  # Direct import from CNTtools

# To access iEEGData class
from CNTtools import iEEGData  # If this is a direct import
# OR
from CNTtools.iEEGPreprocess import iEEGData  # If iEEGData is in the iEEGPreprocess module

In [14]:
import CNTtools
print(dir(CNTtools))

['__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'iEEGData', 'iEEGPreprocess', 'settings', 'tools']


In [18]:
# In CNTtools/tools.py

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def login_config():
    # Prompt user for login info and create a config file.
    username = input("Enter username: ")
    # Code to write configuration files would go here.
    return username

def get_ieeg_data(username, password, filename, start, stop, select_elecs, ignore_elecs):
    # Connect to ieeg.org and download data.
    # For demonstration, generate dummy data.
    fs = 1000.0
    n_samples = int((stop - start) * fs)
    n_channels = 16  # dummy number
    data = np.random.randn(n_samples, n_channels)
    ch_names = np.array([f"Ch{i}" for i in range(n_channels)])
    return data, fs, ch_names

def clean_labels(ch_names):
    # Standardize channel labels (e.g., strip whitespace, convert to upper case).
    return np.array([name.strip().upper() for name in ch_names])

def find_non_ieeg(ch_names):
    # Identify non-iEEG channels. Dummy: mark none as non-iEEG.
    return np.array([False] * len(ch_names))

def identify_bad_chs(data, fs):
    # Identify bad channels based on a simple criterion (e.g., low variance).
    bad = np.std(data, axis=0) < 0.5
    reject_details = {i: "Low variance" for i, b in enumerate(bad) if b}
    return bad, reject_details

def bandpass_filter(data, fs, low_freq, high_freq):
    # Apply a bandpass filter. This dummy returns the data unchanged.
    return data

def notch_filter(data, fs, notch_freq):
    # Apply a notch filter. This dummy returns the data unchanged.
    return data

def car(data, ch_names):
    # Common Average Reference: subtract the average across channels.
    avg = np.mean(data, axis=1, keepdims=True)
    new_data = data - avg
    ref_chnames = np.array(ch_names)  # unchanged in this dummy
    return new_data, ref_chnames

def bipolar(data, ch_names):
    # Bipolar re-referencing: compute differences between adjacent channels.
    new_data = np.diff(data, axis=1)
    ref_chnames = np.array([f"{ch_names[i]}-{ch_names[i+1]}" for i in range(len(ch_names)-1)])
    return new_data, ref_chnames

def get_elec_locs(filename, ch_names, loc_file):
    # Load electrode locations from a CSV file.
    locs_df = pd.read_csv(loc_file)
    # Match electrode names; this is a dummy implementation.
    locs = {name: locs_df[locs_df['electrodeName'] == name].iloc[0].to_dict() for name in ch_names}
    return locs

def laplacian(data, ch_names, locs, radius):
    # Apply Laplacian referencing. Dummy: return data unchanged.
    ref_chnames = np.array(ch_names)
    return data, ref_chnames

def pre_whiten(data):
    # Pre-whitening: remove autocorrelation. Dummy implementation.
    return data

def bandpower(data, fs, band, window, relative):
    # Compute band power in a frequency band. Dummy: return a random value.
    return np.random.rand()

def line_length(data):
    # Compute the line length (sum of absolute differences) for each channel.
    return np.sum(np.abs(np.diff(data, axis=0)), axis=0)

def pearson(data, fs, win, win_size):
    # Compute Pearson correlation across channels.
    return np.corrcoef(data.T)

def squared_pearson(data, fs, win, win_size):
    corr = np.corrcoef(data.T)
    return corr ** 2

def cross_correlation(data, fs, win, win_size):
    # Compute cross-correlation between channels.
    nchs = data.shape[1]
    corr = np.zeros((nchs, nchs))
    for i in range(nchs):
        for j in range(nchs):
            corr[i, j] = np.correlate(data[:, i], data[:, j])[0]
    return corr, None

def coherence(data, fs, win, win_size, segment, overlap):
    # Compute coherence. Dummy: return a random matrix.
    nchs = data.shape[1]
    return np.random.rand(nchs, nchs)

def plv(data, fs, win, win_size):
    # Compute Phase Locking Value. Dummy: return a random matrix.
    nchs = data.shape[1]
    return np.random.rand(nchs, nchs)

def relative_entropy(data, fs, win, win_size):
    # Compute relative entropy. Dummy: return a random matrix.
    nchs = data.shape[1]
    return np.random.rand(nchs, nchs)

def plot_ieeg_data(data, ch_names, t):
    # Plot the iEEG data.
    fig, ax = plt.subplots()
    for i in range(data.shape[1]):
        ax.plot(t, data[:, i] + i * 5)  # Offset each channel for clarity.
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Amplitude")
    ax.set_title("iEEG Data Plot")
    return fig

In [9]:
def get_clean_hup_file_paths(base_path):
    subjects_dict = {}
    subject_dirs = sorted([d for d in os.listdir(base_path) 
                           if os.path.isdir(os.path.join(base_path, d)) and d.startswith("sub-")])
    for subject in subject_dirs:
        subject_path = os.path.join(base_path, subject)
        pkl_files = sorted([f for f in os.listdir(subject_path)
                            if f.startswith("interictal_eeg_bipolar_clean_") and f.endswith('.pkl')])
        if len(pkl_files) != 20:
            print(f"WARNING: {subject} has {len(pkl_files)} files instead of 20.")
        subjects_dict[subject] = {}
        for idx, filename in enumerate(pkl_files, start=1):
            file_path = os.path.join(subject_path, filename)
            subjects_dict[subject][idx] = file_path
    return subjects_dict

def load_epoch(file_path):
    obj = pd.read_pickle(file_path)
    if isinstance(obj, dict):
        df = obj.get('metadata').copy()
        data_obj = obj.get('data')
        if isinstance(data_obj, pd.DataFrame):
            if data_obj.shape[0] > data_obj.shape[1]:
                data_obj = data_obj.T
            df['data'] = data_obj.apply(lambda row: row.values, axis=1)
        else:
            df['data'] = [np.asarray(x) for x in data_obj]
        return df
    return obj

base_path = "/Users/tereza/nishant/atlas/atlas_work_terez/atlas_harmonization/Data/hup/derivatives/clean"

# Get file paths for each subject.
subject_files = get_clean_hup_file_paths(base_path)

# Dictionary to store the electrode counts.
subject_electrode_counts = {}

# For each subject, load the first epoch and count electrodes.
for subject, epochs in subject_files.items():
    if 1 in epochs:
        df = load_epoch(epochs[1])
        # Assuming each row in df corresponds to one electrode.
        electrode_count = df.shape[0]
        subject_electrode_counts[subject] = electrode_count
        print(f"{subject}: {electrode_count} electrodes")
    else:
        print(f"{subject}: No epoch 1 found.")

results_df = pd.DataFrame.from_dict(subject_electrode_counts, orient='index', columns=['Electrode Count'])
print(results_df)

sub-RID0031: 81 electrodes
sub-RID0032: 29 electrodes
sub-RID0033: 90 electrodes
sub-RID0050: 32 electrodes
sub-RID0051: 61 electrodes
sub-RID0064: 102 electrodes
sub-RID0089: 104 electrodes
sub-RID0101: 45 electrodes
sub-RID0117: 95 electrodes
sub-RID0143: 88 electrodes
sub-RID0167: 69 electrodes
sub-RID0175: 8 electrodes
sub-RID0179: 82 electrodes
sub-RID0238: 78 electrodes
sub-RID0301: 67 electrodes
sub-RID0320: 58 electrodes
sub-RID0381: 75 electrodes
sub-RID0405: 102 electrodes
sub-RID0424: 122 electrodes
sub-RID0508: 47 electrodes
sub-RID0562: 31 electrodes
sub-RID0589: 83 electrodes
sub-RID0658: 126 electrodes
             Electrode Count
sub-RID0031               81
sub-RID0032               29
sub-RID0033               90
sub-RID0050               32
sub-RID0051               61
sub-RID0064              102
sub-RID0089              104
sub-RID0101               45
sub-RID0117               95
sub-RID0143               88
sub-RID0167               69
sub-RID0175                

In [17]:
# Functions to get file paths and load an epoch (as provided)
def get_clean_hup_file_paths(base_path):
    subjects_dict = {}
    subject_dirs = sorted([d for d in os.listdir(base_path) 
                           if os.path.isdir(os.path.join(base_path, d)) and d.startswith("sub-")])
    for subject in subject_dirs:
        subject_path = os.path.join(base_path, subject)
        pkl_files = sorted([f for f in os.listdir(subject_path)
                            if f.startswith("interictal_eeg_bipolar_clean_") and f.endswith('.pkl')])
        if len(pkl_files) != 20:
            print(f"WARNING: {subject} has {len(pkl_files)} files instead of 20.")
        subjects_dict[subject] = {}
        for idx, filename in enumerate(pkl_files, start=1):
            file_path = os.path.join(subject_path, filename)
            subjects_dict[subject][idx] = file_path
    return subjects_dict

def load_epoch(file_path):
    obj = pd.read_pickle(file_path)
    if isinstance(obj, dict):
        df = obj.get('metadata').copy()
        data_obj = obj.get('data')
        if isinstance(data_obj, pd.DataFrame):
            if data_obj.shape[0] > data_obj.shape[1]:
                data_obj = data_obj.T
            df['data'] = data_obj.apply(lambda row: row.values, axis=1)
        else:
            df['data'] = [np.asarray(x) for x in data_obj]
        return df
    return obj

base_path = "/Users/tereza/nishant/atlas/atlas_work_terez/atlas_harmonization/Data/hup/derivatives/clean"
file_paths = get_clean_hup_file_paths(base_path)
subject = list(file_paths.keys())[0]           # choose first subject
epoch_key = list(file_paths[subject].keys())[0]  # choose first epoch
epoch_file = file_paths[subject][epoch_key]

# Load the epoch data
df = load_epoch(epoch_file)

# Stack each electrode's time series (each row in df['data']) into a 2D array:
# data shape: (time, channels)
data = np.column_stack(df['data'].values)
ch_names = df.index.values  # assuming the DataFrame index has electrode labels

# Define sampling frequency and time bounds
fs = 1000.0  # Hz (adjust if needed)
start = 0
stop = data.shape[0] / fs  # duration in seconds

# Create an iEEGData instance; note that filename is arbitrary here
ieeg = iEEGData(filename="example_epoch", start=start, stop=stop, data=data, fs=fs, ch_names=ch_names)

# Compute cross-correlation (windowed with 2-sec window)
ieeg.cross_corr(win=True, win_size=2)

# Compute coherence (using a 2-sec window, 1-sec segments with 0.5-sec overlap)
ieeg.coherence(win=True, win_size=2, segment=1, overlap=0.5)

# Compute Pearson correlation (windowed)
ieeg.pearson(win=True, win_size=2)

# Compute line length
ll = ieeg.line_length()

# Print out some results
print("Cross-correlation shape:", ieeg.conn["cross_corr"].shape)
print("Coherence shape:", ieeg.conn["coh"].shape)
print("Pearson correlation shape:", ieeg.conn["pearson"].shape)
print("Line length:", ll)


Cross-correlation shape: (81, 81)
Coherence shape: (81, 81, 7)
Pearson correlation shape: (81, 81)
Line length: [ 4.82027033  2.38358265  1.75585806  2.4995317   3.16329884  5.91096975
 12.50842765  8.87184142  4.28047809  2.5909629   2.32105354  2.09084481
  2.26818624  2.38404131  2.87973314  2.39244793  4.46663646 12.46826165
  6.17056843  2.75519093  1.76225699  1.46379304  1.83400456  4.45243274
  4.32839807  2.12048065  7.56739053  9.12891943  5.96646244  3.05332233
  1.38949696 10.92783461 10.43728509  8.86625762  5.40695825  3.25816594
  3.22013808  2.7233554   4.30113789  3.07454238  5.92775183 14.6232645
 14.35938695  4.13933805  3.29334034  5.96492859  3.74507173  2.70169612
  9.10757212  4.78196428  5.9187558   3.42845945  4.76673177  6.96830275
  7.74207147  7.99881656 17.71962455 18.45744645 17.56824361  3.82787304
  5.77206987  8.54549252  4.99019639  3.42664813  4.14815821  4.26452196
  4.25385373  4.9443768   3.13908282  2.75740188  2.7844904   2.97315544
  3.37875696 

In [None]:
data = np.column_stack(df['data'].values)