In [None]:
!pip install elephant



In [None]:
import pandas as pd
import numpy as np
import json
import os
import torch

from google.colab import drive
import pickle

import plotly.express as plx
import matplotlib.pyplot as plt


import seaborn as sns

import gc

import scipy.signal as signal
from scipy.signal import welch, find_peaks

from neo import SpikeTrain
from elephant.conversion import BinnedSpikeTrain
import elephant.spike_train_correlation as elstc
import elephant.statistics as elstat
import quantities as pq
from elephant.statistics import isi

In [None]:
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Loading training dataset 

**LOADING DATASETS**

In [None]:
import pickle

# List of PCB indices
pcbs = [9, 14, 15, 16, 17, 18, 42, 48, 49]

# List to store DataFrames - for each PCB
df_list = []
# path to where you saved the feature dataset
path = "/content/drive/MyDrive/trainning_dataset"
# Loop through the PCB indices and load the corresponding pickle files
for pcb in pcbs:
    try:
        # Load the pickle file into a DataFrame
        file_path = os.path.join(path, f"pcb{pcb}.pkl")
        with open(file_path, "rb") as file:
            final_features = pickle.load(file)
            if pcb == 42 or pcb == 48 or pcb == 49:
                final_features = final_features.set_index(
                    ["day", "time_offset", "strain"]
                )
            df_list.append(final_features)
            print(f"Loaded PCB {pcb} successfully.")
    except FileNotFoundError:
        print(f"File for PCB {pcb} not found.")
    except Exception as e:
        print(f"Error loading PCB {pcb}: {e}")

# df_list contains all the features extracted for each pcb

Loaded PCB 9 successfully.
Loaded PCB 14 successfully.
Loaded PCB 15 successfully.
Loaded PCB 16 successfully.
Loaded PCB 17 successfully.
Loaded PCB 18 successfully.
Loaded PCB 42 successfully.
Loaded PCB 48 successfully.
Loaded PCB 49 successfully.


In [None]:
def element_wise_average(column):
    """
    Parameters:
        column: for matching rows the column where we need to calculate the element-wise avergae
    Returns: element-wise average
    """
    stacked = np.stack(column.to_numpy())
    # Calculate the mean along rows, not taking into count the nan values
    return np.nanmean(stacked, axis=0).tolist()

In [None]:
def row_element_wise_average(row, columns):
    """
    Parameters:
        row: the row which we are calculating the average for
        columns: the columns from which we need to calculate the average
    Returns: average across the columns in the row
    """
    stacked = np.stack([row[col] for col in columns])
    # Compute the element-wise mean
    return np.mean(stacked, axis=0).tolist()

In [1]:
def has_list_with_nan(element):
    """
    Parameters:
        element: element from dataframe, list or float
    Returns: Bool, check if list contains nan values
    """
    if isinstance(element, list):
        return any(
            pd.isna(x) for x in element
        )  # Check if any element in the list is NaN
    return False

**The section below processes the features from each PCB, getting rid of all the NaN values**

1. If there is a combination of (day,time_offset, strain) that matches the one of the element with NaN then we calculate the average across PCBs

2. If there are no other combinations of (day, time_offset,strain) that match from other PCBs then we calculate the average across the same PCB but amongs different MEA with data, because they are lists of features, we do element wise average across other MEAs

In [None]:
match_columns = ["day", "time_offset", "strain"]

# Loop through each DataFrame and fill NaNs
for idx, df in enumerate(df_list):
    print(f"Processing PCB {idx+1}...")
    nan_rows = df[df.isna().any(axis=1)]
    # look for columns with NaN values (either float NaN or any NaN inside a list)
    columns_with_nan_in_lists = df.applymap(has_list_with_nan).any()
    columns_with_nan = columns_with_nan_in_lists[
        columns_with_nan_in_lists
    ].index.tolist()

    for row_index, row in nan_rows.iterrows():
        nan_columns = row[row.isna()].index.tolist()
        nan_columns.extend(columns_with_nan)
        first = True
        for ind, other_df in enumerate(df_list):
            # looking for PCBs with the same (day,time_offset,strain)
            if ind != idx:

                mask = (
                    (other_df.index.get_level_values("day") == row_index[0])
                    & (other_df.index.get_level_values("time_offset") == row_index[1])
                    & (other_df.index.get_level_values("strain") == row_index[2])
                )
                if first:
                    first = False
                    matching_rows = other_df.loc[mask]
                    matching_rows = matching_rows[nan_columns]
                else:
                    current = other_df.loc[mask]
                    current = current[nan_columns]
                    matching_rows = pd.concat([matching_rows, current], axis=0)
        # case 2
        if matching_rows.empty:
            # get the avergae of the surrounding MEA for MEA replacement and PSD for PSD replacement and Interactions for interactions
            for col in nan_columns:
                if col.startswith("MEA"):
                    # mea
                    other_mea = [
                        x
                        for x in df.columns
                        if x.startswith("MEA") and x not in nan_columns
                    ]
                    avg_value = row_element_wise_average(row, other_mea)
                    df.at[row_index, col] = avg_value
                elif col.startswith("PSD"):
                    # psd
                    other_psd = [
                        x
                        for x in df.columns
                        if x.startswith("PSD") and x not in nan_columns
                    ]
                    avg_value = row_element_wise_average(row, other_psd)
                    df.at[row_index, col] = avg_value
                else:
                    # interaction
                    other_int = [
                        x
                        for x in df.columns
                        if x.startswith("Interaction") and x not in nan_columns
                    ]
                    avg_value = row_element_wise_average(row, other_int)
                    df.at[row_index, col] = avg_value
        # case 1
        else:
            averages = {
                col: element_wise_average(matching_rows[col])
                for col in matching_rows.columns
            }
            for col in nan_columns:
                if not np.any(np.isnan(averages[col])):
                    df.at[row_index, col] = averages[col]
                else:
                    # Example Interactions where all interactions are nan
                    other_int = [
                        x
                        for x in df.columns
                        if x.startswith("Interaction") and x not in nan_columns
                    ]
                    avg_value = row_element_wise_average(row, other_int)
                    df.at[row_index, col] = avg_value

print("NaN values processed successfully.")

Processing PCB 1...
Processing PCB 2...
Processing PCB 3...
Processing PCB 4...
Processing PCB 5...
Processing PCB 6...
Processing PCB 7...
Processing PCB 8...


  columns_with_nan_in_lists = df.applymap(has_list_with_nan).any()
  columns_with_nan_in_lists = df.applymap(has_list_with_nan).any()
  columns_with_nan_in_lists = df.applymap(has_list_with_nan).any()
  columns_with_nan_in_lists = df.applymap(has_list_with_nan).any()
  columns_with_nan_in_lists = df.applymap(has_list_with_nan).any()
  columns_with_nan_in_lists = df.applymap(has_list_with_nan).any()
  columns_with_nan_in_lists = df.applymap(has_list_with_nan).any()
  columns_with_nan_in_lists = df.applymap(has_list_with_nan).any()
  return np.nanmean(stacked, axis=0).tolist()
  return np.nanmean(stacked, axis=0).tolist()
  return np.nanmean(stacked, axis=0).tolist()
  return np.nanmean(stacked, axis=0).tolist()
  return np.nanmean(stacked, axis=0).tolist()
  return np.nanmean(stacked, axis=0).tolist()
  columns_with_nan_in_lists = df.applymap(has_list_with_nan).any()


Processing PCB 9...
NaN values processed successfully.


In [None]:
df_list[-1]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,MEA1,PSD1,MEA2,PSD2,MEA3,PSD3,MEA4,PSD4,Network,Interaction 01,Interaction 02,Interaction 03,Interaction 12,Interaction 13,Interaction 23
day,time_offset,strain,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
1,0min,0.0,"[1.30995920110016, 192.7524065314985, 1.267024...","[-1.1405462424308969e-39, 2.8566506315285286e-...","[1.1641854290494424, 12.384501945440306, 1.436...","[3.356178462780562e-39, -6.594476728901217e-35...","[0.302428180649822, 17.073545420441246, 1.2782...","[-2.8555206713371538e-39, 6.072248837700805e-3...","[1.6008194032508645, 11.141809618035008, 1.852...","[2.6121950687323693e-39, -5.13223902646633e-35...","[3.333333347654843e-05, 0, 0.0017183548768257453]","[0.022914814389475607, 1505.0, 432051.0]","[-0.0015074103301396787, 932.625, 18063.875]","[-0.0020173882932514923, 1953.0, 687992.0]","[-0.0006231851587171551, 3126.0, 104810.142857...","[-0.002914926263648261, 22088.0, 760878.0]","[-0.0006340271773205388, 83.57142857142857, 63..."
2,0min,0.0,"[10.918426887350787, 23.804986855757974, 2.899...","[-1.2048258781108663e-37, 2.497412429553849e-3...","[0.19540149742189908, 53.80203000224559, 0.850...","[4.000805583265674e-39, -7.673274280490842e-35...","[1.7744209145886662, 14.899361927220424, 1.744...","[-3.787907097885168e-39, 8.034385254476293e-35...","[1.8293606018686883, 33.5825755217952, 2.32177...","[3.5241983578533977e-39, -6.90493018400456e-35...","[3.3333333324965984e-05, 110.99088666666667, 0...","[-0.020436004275817717, 20805.0, 1466615.0]","[0.014077331815813422, 8371.75, 409758.0]","[-0.01948636279415361, 108864.0, 6512493.0]","[-0.008182767485627273, 4540.857142857143, 127...","[0.01136509231565518, 33345.0, 342971.0]","[0.0070614521812397335, 362.42857142857144, 17..."
4,0min,0.0,"[0.16259933021963616, 20.465872897887927, 0.96...","[-5.811042296430387e-39, 1.2097192028952048e-3...","[0.6643589542527538, 28.715984393287727, 1.116...","[-7.514373226963374e-39, 1.5598181663453119e-3...","[0.8316563478499999, 4.8098141230798905, 1.413...","[-2.888025901166683e-39, 6.098748785634295e-35...","[2.559489589031948, 17.202194785103583, 1.1389...","[-1.185124723590988e-39, 2.4994678202452747e-3...","[0, 0, 0]","[-0.005582650739614038, 0.0, 214048.0]","[0.06464519108008718, 2999.0, 199843.57142857142]","[0.006226830914323413, 44505.0, 702852.0]","[0.01205354903658526, 8627.714285714286, 15786...","[0.005586010577192652, 110079.0, 2053493.0]","[0.05845903874489627, 801.6666666666667, 9477...."
7,0min,0.0,"[0.3499000117936424, 88.00854966312755, 0.6601...","[1.8392113387261172e-39, -3.1525243683371016e-...","[1.3991704584553633, 14.889025272242911, 1.073...","[-9.75456956189694e-38, 2.0322376578694627e-33...","[6.02082909591794, 0.7640875969387015, 4.79230...","[9.990516023047597e-39, -1.974665413264555e-34...","[2.472927462111851, 3.021164735592773, 1.38302...","[2.7598109886680564e-39, -5.259750443388845e-3...","[3.333333337910257e-05, 49.21561111111104, 0.0...","[0.0013429033596390812, 42525.0, 709865.0]","[-0.010840342942434425, 101959.0, 10444062.0]","[0.021470444625041954, 54180.0, 1132248.0]","[0.020881620369460188, 259200.0, 4423058.0]","[-0.0038439779556556615, 96336.0, 3607968.0]","[0.00019678131185787406, 120846.33333333333, 5..."
9,0min,0.0,"[0.1729094594805906, 32.129878875564835, 1.241...","[-1.4418760541403777e-39, 3.6519011591023956e-...","[0.2821538707747874, 38.8650243897466, 0.97098...","[-4.896915173999666e-39, 1.0951206393014804e-3...","[4.4442384964303185, 4.375610116809677, 1.8031...","[8.315145405675651e-39, -1.6140552683861525e-3...","[3.350792009810203, 6.850001780488239, 1.34689...","[1.763926999394718e-39, -3.142663925250267e-35...","[3.3333333306018176e-05, 81.64580000000001, 0....","[-0.014264051323746215, 3283.0, 63129.0]","[0.01265715909071463, 112056.0, 12540492.0]","[0.01304791248827665, 6490.0, 806197.0]","[0.015911006533044943, 36478.0, 5752650.0]","[0.0015277733802312268, 0.0, 948694.0]","[0.006151312433403697, 212434.33333333334, 882..."
9,10min,0.0,"[1.806635358622668, 0.9898133635550284, 1.4715...","[-1.0883773526485608e-38, 2.320527333042545e-3...","[0.7590218970207435, 79.28780832879193, 0.8687...","[-5.3754748504380086e-39, 1.196064571634092e-3...","[3.759474078015259, 0.5975498004187171, 2.1065...","[6.353538568016288e-39, -1.2304774044646087e-3...","[5.407233208651413, 1.5583007467772894, 1.8254...","[-8.597568712546851e-39, 1.825201428571871e-34...","[3.3333333337597854e-05, 16.297147619047614, 0...","[0.005297968909133699, 9315.0, 1632607.0]","[-0.0163402545812258, 23722.0, 3212950.0]","[0.005869502224520241, 106056.0, 9543477.0]","[-0.011563679336314851, 990.0, 8982885.0]","[0.002275463921877388, 117952.0, 4156812.0]","[-0.0037351422908292304, 40802.333333333336, 5..."
9,1h,0.0,"[2.3762700002904023, 0.9743846926820419, 1.200...","[-2.927041326196673e-38, 6.1076210389472645e-3...","[1.7097631024416162, 22.252688731386858, 0.957...","[-7.515809249821597e-39, 1.6487863559146889e-3...","[6.211852879725068, 0.6666931554561157, 1.3574...","[8.099019260495287e-39, -1.560743833410846e-34...","[7.671810141948143, 0.9938154168425007, 1.7697...","[-1.0585021154452379e-38, 2.2416989878039803e-...","[3.33333333284276e-05, 7.17387532467533, 0.134...","[-0.00574472791854121, 1455.0, 3668318.0]","[-0.006792232102475748, 562176.0, 19880277.0]","[-0.0058394017368592396, 137696.0, 15017068.0]","[0.014448689938527794, 49275.0, 8887947.0]","[-0.0012649349019775448, 234094.0, 10006450.0]","[0.0027429869441061724, 216297.0, 14699787.0]"
9,2h,0.0,"[3.2481203059198647, 0.4827270171126714, 1.194...","[-1.590653710252958e-38, 3.358969150352845e-34...","[2.9361775259589615, 82.2269892319733, 0.96233...","[-1.0371109164244281e-38, 2.229548157891138e-3...","[5.779113843244452, 17.112356920195065, 1.0720...","[2.1803681004509772e-39, -3.4622013978813017e-...","[9.929513655737567, 0.4819130132644648, 2.0903...","[-1.1096073833009987e-38, 2.341215044352428e-3...","[3.3333333336470006e-05, 2.9718165811965775, 0...","[0.01872344670088715, 9198.0, 7016372.0]","[0.02402154288051678, 296770.0, 29104141.0]","[0.016797209501524395, 230584.0, 22353783.0]","[0.02793060792099104, 150220.0, 9295209.0]","[0.009998465559291233, 323948.0, 18618914.0]","[0.02425220809044273, 407748.0, 19507939.0]"
10,0min,0.0,"[0.7919468038570654, 12.554064983975929, 1.176...","[-1.5437716296736285e-39, 3.723196508352809e-3...","[0.8660508579201756, 35.91080316128227, 1.1653...","[-1.662862155524021e-39, 4.181521279916767e-35...","[8.247530623804634, 0.8604135635499476, 1.4325...","[3.976277779435765e-39, -7.243634354452856e-35...","[6.44017928398329, 0.8234169546789734, 1.73644...","[-5.067960771507789e-39, 1.093465544881222e-34...","[3.33333333074757e-05, 14.854453508771957, 0.0...","[-0.0004961659697460689, 33000.0, 786520.0]","[0.0038036109894972903, 148242.0, 36983461.0]","[0.05367348456351087, 236089.0, 4930254.0]","[-0.0137726505154558, 292123.0, 13277953.0]","[-0.0012049121680575202, 218762.0, 5069327.0]","[-0.008710066827262443, 173897.66666666666, 23..."
10,10min,7.5,"[0.18386397182035474, 34.014689030181756, 1.04...","[2.004476423392174e-39, -3.4609135380589816e-3...","[1.1903289711040055, 18.06018626113619, 1.2305...","[2.0579003802317676e-39, -3.3862099835991817e-...","[2.9134502370523747, 17.579942236177395, 1.372...","[-2.527645417273867e-41, 7.561509062015487e-36...","[7.366157768232764, 0.6649514172142311, 1.8422...","[-4.1382061661421575e-39, 9.115576240262809e-3...","[3.3333333345929514e-05, 12.123489855072453, 0...","[-0.0097839479760609, 4940.0, 283038.0]","[-0.006197693496714517, 150333.33333333334, 31...","[0.0024967280482304606, 36040.0, 1755019.0]","[-0.006197693496714517, 150333.33333333334, 31...","[-0.011305860562313112, 410020.0, 7271153.0]","[-0.006197693496714517, 150333.33333333334, 31..."


In [None]:
def concat_data(df):
    """
    Parameters:
      df: dataframe with features for one PCB
    Returns:
      df: dataframe with each feature in an individual column
    """
    data = []
    index = []

    for header, row in df.iterrows():
        day = header[0]
        time_offset = header[1]
        strain = header[2]
        pcb_row = df.loc[day, time_offset, strain]

        # Flatten the data in pcb_row
        new_row = []
        for col_name, value in pcb_row.items():
            value = np.array(value)
            if not np.isnan(value).any():
                if col_name.startswith("MEA"):
                    mea_len = len(value)
                elif col_name.startswith("PSD"):
                    psd_len = len(value)
                new_row.extend(list(value))
            else:
                if col_name.startswith("MEA"):
                    new_row.extend([np.nan] * mea_len)
                elif col_name.startswith("PSD"):
                    new_row.extend([np.nan] * psd_len)

        data.append(new_row)
        index.append((day, time_offset, strain))

    # Construct the new DataFrame
    new_df = pd.DataFrame(
        data,
        index=pd.MultiIndex.from_tuples(index, names=["day", "time_offset", "strain"]),
    )

    return new_df

In [None]:
df_strained_list = df_list[-3:]

In [None]:
store_path = "/content/drive/MyDrive/data"
for idx, df in enumerate(df_strained_list):
    df = concat_data(df)  # separating each feature into individual columns
    store_path = os.path.join(store_path, f"data_strained_{idx}.pkl")
    df.to_pickle(store_path)

                             0          1         2          3          4    \
day time_offset strain                                                        
1   0min        0.0     0.624106  71.917299  1.969052  11.848490  14.442471   
2   0min        0.0     1.377777  98.514155  0.711785  11.713569  18.086366   
5   0min        0.0     1.660347  37.993136  0.986114  10.524614  14.940747   
6   0min        0.0     1.453041  60.086386  0.796751   5.417581  16.669587   
    10min       5.0     1.304653  47.345269  1.380794   6.521497  16.628831   
    1h          5.0     4.495904   1.811941  0.880507  12.075826  19.114785   
    2h          5.0     4.376994  70.615723  0.868934   3.974497  17.674250   
    10min       7.5     1.289625  36.121969  1.837080   4.417029  17.694718   
    1h          7.5     4.460764  30.703112  0.889857   9.269768  15.778724   
    2h          7.5     3.105497   5.085435  0.818294  13.548748  15.821928   
    10min       10.0    3.081869  18.517430  1.44926

In [None]:
store_path = "/content/drive/MyDrive/data"
for idx, df in enumerate(df_list):
    df = concat_data(df)  # separating each feature into individual columns
    store_path = os.path.join(store_path, f"data_{idx}.pkl")
    df.to_pickle(store_path)  # save the fully processed datasets