# Import

In [21]:
import os
import re
import gc
import sys

from loguru import logger
import numpy as np
import random

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

# %matplotlib qt
%matplotlib qt

# Detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Input Layer

## Definition

In [22]:
ONLY_SEQ_FLAG = True    
READ_RAW_FLAG = True
freq_list = np.linspace(0,5000-1,101,dtype=int, endpoint=True)

In [23]:
def SearchELE(rootPath, ele_pattern = re.compile(r"(.+?)_归档")):
    '''==================================================
        Search all electrode directories in the rootPath
        Parameter: 
            rootPath: current search path
            ele_pattern: electrode dir name patten
        Returen:
            ele_list: list of electrode directories
        ==================================================
    '''
    ele_list = []
    for i in os.listdir(rootPath):
        _path = os.path.join(rootPath, i)
        if os.path.isdir(_path):
            match_ele = ele_pattern.match(i)
            if match_ele:
                ele_list.append([_path, match_ele.group(1)])
            else:
                ele_list.extend(SearchELE(_path, ele_pattern))

    return ele_list


## ARchive Old

In [24]:
Blacklist = [
    '01067093',     # Not look like EIS
    '01067094',     # Connection Error
    '02017385',     # Connection Error
    '05127177',     # Open to Short
    '06047729',     # Open to Short
    '06047730',     # Open to Short
    '06047731',     # Open to Short
    '09207024',     # Connection Error
    '10017038',     # Connection Error
    '10037050',     # Connection Error
    '10047056',     # Connection Error
    '10057069',     # Connection Error
    '10057083',     # Always Open
    '10057084',     # Chaos
    '10057087',     # Connection Error
    '22017367',     # Connection Error
    '22017371',     # Chaos
]

GrayList = [
    '10037051',     # Connection Error
    '10037052',     # Connection Error
    '10057071',     # Connection Error
    '10067077',     # Wired Shape like connection error
    '10150201',     # Wired Shape
    '10150202',     # Wired Shape
    '10150203',     # Wired Shape
    '20037515',     # Wired Shape
    '20037516',     # Wired Shape
    '20037517',     # Wired Shape
    '22037378',     # Connection Error
    '22037380',     # Connection Error
    '22047376',     # Connection Error

]

In [25]:
if READ_RAW_FLAG:
    rootPath = "D:/Baihm/EISNN/Archive/"
    ele_list = SearchELE(rootPath, re.compile(r"(.+?)_归档"))
    n_ele = len(ele_list)
    logger.info(f"Search in {rootPath} and find {n_ele:03d} electrodes")


[32m2025-05-19 18:40:07.968[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1mSearch in D:/Baihm/EISNN/Archive/ and find 218 electrodes[0m


In [26]:
if READ_RAW_FLAG:
    DATASET_SUFFIX = "Outlier_Ver03"

    vitro0_start_list = []
    vitro0_start_id_list = []
    vitro0_data_list = []
    vitro0_id_list = []

    n_avaliable = 0

    for i in range(n_ele):
    # for i in range(3):
        fd_pt = os.path.join(ele_list[i][0], DATASET_SUFFIX, f"{ele_list[i][1]}_{DATASET_SUFFIX}.pt")
        if not os.path.exists(fd_pt):
            logger.warning(f"{fd_pt} does not exist")
            continue
        data_pt = torch.load(fd_pt, weights_only=False)
        _meta_group = data_pt["meta_group"]
        _data_group = data_pt["data_group"]

        n_day       = _meta_group["n_day"]
        n_ch        = _meta_group["n_ch"]
        n_valid_ch  = len(_data_group["Channels"])

        TimeSpan    = _meta_group["TimeSpan"]
        _x_date = np.array([(poi - TimeSpan[0]).days for poi in TimeSpan])


        logger.info(f"ELE [{i}/{n_ele}]: {ele_list[i][0]}")

        n_avaliable = n_avaliable + 1



        # Iteration by channel
        for j in _data_group['Channels']:
            _ch_data = _data_group[j]["chData"]
            _id_date = np.array(_x_date)

            if ONLY_SEQ_FLAG:
                eis_seq = _data_group[j]["eis_seq"]
                _ch_data = _ch_data[eis_seq,:,:]
                _id_date = _id_date[eis_seq]


            _ch_data_log = np.log(_ch_data[:,1,:] + 1j*_ch_data[:,2,:])
            _ch_data[:,1,:] = np.real(_ch_data_log)
            _ch_data[:,2,:] = np.imag(_ch_data_log)
            if _ch_data.shape[2] == 5000:
                _ch_data = np.hstack((_ch_data[:,1,freq_list],_ch_data[:,2,freq_list]))
            else:
                _ch_data = np.hstack((_ch_data[:,1,:],_ch_data[:,2,:]))
            vitro0_data_list.append(_ch_data)
            vitro0_start_list.append(_ch_data[0,:])


            _ch_id = j

            _id = [i, _ch_id] * np.shape(_ch_data)[0]
            _id = np.array(_id).reshape(-1,2)
            _eis_cluster = _data_group[j]['eis_cluster']
            _id = np.hstack((_id, _eis_cluster.reshape(-1,1)))
            _id = np.hstack((_id, _id_date.reshape(-1,1)))
            
            vitro0_id_list.append(_id)
            vitro0_start_id_list.append(_id[0,:])





    vitro0_data_list = np.vstack(vitro0_data_list)
    vitro0_id_list = np.vstack(vitro0_id_list)
    vitro0_start_list = np.vstack(vitro0_start_list)
    vitro0_start_id_list = np.vstack(vitro0_start_id_list)

    vitro0_ele_list = [i[1] for i in ele_list]

    logger.info(f"Total {vitro0_data_list.shape[0]} data points from {n_avaliable} electrodes")

    del data_pt, _meta_group, _data_group, _ch_data



[32m2025-05-19 18:40:08.084[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [0/218]: D:/Baihm/EISNN/Archive/01037160_归档[0m
[32m2025-05-19 18:40:08.169[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [1/218]: D:/Baihm/EISNN/Archive/01037161_归档[0m
[32m2025-05-19 18:40:08.258[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [2/218]: D:/Baihm/EISNN/Archive/01037162_归档[0m
[32m2025-05-19 18:40:08.339[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [3/218]: D:/Baihm/EISNN/Archive/01067093_归档[0m
[32m2025-05-19 18:40:08.404[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [4/218]: D:/Baihm/EISNN/Archive/01067094_归档[0m
[32m2025-05-19 18:40:08.468[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [5/218]: D:/Baihm/EISNN/Archive/01067095_归档[0m
[32m2025-05-19 18:40:08.500[0m | [1mI

## Archive New

In [27]:
if READ_RAW_FLAG:
    rootPath = "D:/Baihm/EISNN/Archive_New/"
    ele_list = SearchELE(rootPath,re.compile(r"(.+?)_归档"))
    n_ele = len(ele_list)
    logger.info(f"Search in {rootPath} and find {n_ele:03d} electrodes")


[32m2025-05-19 18:40:19.736[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1mSearch in D:/Baihm/EISNN/Archive_New/ and find 187 electrodes[0m


In [28]:
if READ_RAW_FLAG:
    DATASET_SUFFIX = "Outlier_Ver02"

    vitro1_start_list = []
    vitro1_start_id_list = []
    vitro1_data_list = []
    vitro1_id_list = []

    n_avaliable = 0

    for i in range(n_ele):
    # for i in range(3):
        fd_pt = os.path.join(ele_list[i][0], DATASET_SUFFIX, f"{ele_list[i][1]}_{DATASET_SUFFIX}.pt")
        if not os.path.exists(fd_pt):
            logger.warning(f"{fd_pt} does not exist")
            continue
        data_pt = torch.load(fd_pt, weights_only=False)
        _meta_group = data_pt["meta_group"]
        _data_group = data_pt["data_group"]

        n_day       = _meta_group["n_day"]
        n_ch        = _meta_group["n_ch"]
        n_valid_ch  = len(_data_group["Channels"])

        TimeSpan    = _meta_group["TimeSpan"]
        _x_date = np.array([(poi - TimeSpan[0]).days for poi in TimeSpan])


        logger.info(f"ELE [{i}/{n_ele}]: {ele_list[i][0]}")

        n_avaliable = n_avaliable + 1

        # Iteration by channel
        for j in _data_group['Channels']:
            _ch_data = _data_group[j]["chData"]
            _id_date = np.array(_x_date)

            if ONLY_SEQ_FLAG:
                eis_seq = _data_group[j]["eis_seq"]
                _ch_data = _ch_data[eis_seq,:,:]
                _id_date = _id_date[eis_seq]

            _ch_data_log = np.log(_ch_data[:,1,:] + 1j*_ch_data[:,2,:])
            _ch_data[:,1,:] = np.real(_ch_data_log)
            _ch_data[:,2,:] = np.imag(_ch_data_log)
            if _ch_data.shape[2] == 5000:
                _ch_data = np.hstack((_ch_data[:,1,freq_list],_ch_data[:,2,freq_list]))
            else:
                _ch_data = np.hstack((_ch_data[:,1,:],_ch_data[:,2,:]))
            vitro1_data_list.append(_ch_data)
            vitro1_start_list.append(_ch_data[0,:])


            _ch_id = j

            _id = [i, _ch_id] * np.shape(_ch_data)[0]
            _id = np.array(_id).reshape(-1,2)
            _eis_cluster = _data_group[j]['eis_cluster']
            _id = np.hstack((_id, _eis_cluster.reshape(-1,1)))
            _id = np.hstack((_id, _id_date.reshape(-1,1)))
            
            vitro1_id_list.append(_id)
            vitro1_start_id_list.append(_id[0,:])

    vitro1_data_list = np.vstack(vitro1_data_list)
    vitro1_id_list = np.vstack(vitro1_id_list)
    vitro1_start_list = np.vstack(vitro1_start_list)
    vitro1_start_id_list = np.vstack(vitro1_start_id_list)

    vitro1_ele_list = [i[1] for i in ele_list]

    logger.info(f"Total {vitro1_data_list.shape[0]} data points from {n_avaliable} electrodes")

    del data_pt, _meta_group, _data_group, _ch_data



[32m2025-05-19 18:40:19.873[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [0/187]: D:/Baihm/EISNN/Archive_New/2025\1月\02027452_归档[0m
[32m2025-05-19 18:40:19.982[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [1/187]: D:/Baihm/EISNN/Archive_New/2025\1月\02027453_归档[0m
[32m2025-05-19 18:40:20.076[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [6/187]: D:/Baihm/EISNN/Archive_New/2025\1月\11037287_归档[0m
[32m2025-05-19 18:40:20.203[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [10/187]: D:/Baihm/EISNN/Archive_New/2025\1月\16057219_归档[0m
[32m2025-05-19 18:40:20.316[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [11/187]: D:/Baihm/EISNN/Archive_New/2025\1月\16057220_归档[0m
[32m2025-05-19 18:40:20.409[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [12/187]: D:/Baihm/EISNN/A

## In vivo

In [29]:
if READ_RAW_FLAG:
    rootPath = "D:/Baihm/EISNN/Invivo/"
    ele_list = SearchELE(rootPath,re.compile(r"(.+?)_Ver02"))
    n_ele = len(ele_list)
    logger.info(f"Search in {rootPath} and find {n_ele:03d} electrodes")


[32m2025-05-19 18:40:29.421[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1mSearch in D:/Baihm/EISNN/Invivo/ and find 006 electrodes[0m


In [30]:
if READ_RAW_FLAG:
    DATASET_SUFFIX = "Outlier_Ver04"

    vivo0_start_list = []
    vivo0_start_id_list = []
    vivo0_data_list = []
    vivo0_id_list = []

    n_avaliable = 0

    for i in range(n_ele):
    # for i in range(1):
        fd_pt = os.path.join(ele_list[i][0], DATASET_SUFFIX, f"{ele_list[i][1]}_{DATASET_SUFFIX}.pt")
        if not os.path.exists(fd_pt):
            logger.warning(f"{fd_pt} does not exist")
            continue
        data_pt = torch.load(fd_pt, weights_only=False)
        _meta_group = data_pt["meta_group"]
        _data_group = data_pt["data_group"]

        n_day       = _meta_group["n_day"]
        n_ch        = _meta_group["n_ch"]
        n_valid_ch  = len(_data_group["Channels"])

        TimeSpan    = _meta_group["TimeSpan"]
        _x_date = np.array([(poi - TimeSpan[0]).days for poi in TimeSpan])


        logger.info(f"ELE [{i}/{n_ele}]: {ele_list[i][0]}")

        n_avaliable = n_avaliable + 1

        # Iteration by channel
        for j in _data_group['Channels']:
            _ch_data = _data_group[j]["chData"]
            _id_date = np.array(_x_date)

            if ONLY_SEQ_FLAG:
                eis_seq = _data_group[j]["eis_seq"]
                _ch_data = _ch_data[eis_seq,:,:]
                _id_date = _id_date[eis_seq]
                
            _ch_data_log = np.log(_ch_data[:,1,:] + 1j*_ch_data[:,2,:])
            _ch_data[:,1,:] = np.real(_ch_data_log)
            _ch_data[:,2,:] = np.imag(_ch_data_log)
            if _ch_data.shape[2] == 5000:
                _ch_data = np.hstack((_ch_data[:,1,freq_list],_ch_data[:,2,freq_list]))
            else:
                _ch_data = np.hstack((_ch_data[:,1,:],_ch_data[:,2,:]))
            vivo0_data_list.append(_ch_data)
            vivo0_start_list.append(_ch_data[0,:])


            _ch_id = j

            _id = [i, _ch_id] * np.shape(_ch_data)[0]
            _id = np.array(_id).reshape(-1,2)
            _eis_cluster = _data_group[j]['eis_cluster']
            _id = np.hstack((_id, _eis_cluster.reshape(-1,1)))
            _id = np.hstack((_id, _id_date.reshape(-1,1)))

            vivo0_id_list.append(_id)
            vivo0_start_id_list.append(_id[0,:])

    vivo0_data_list = np.vstack(vivo0_data_list)
    vivo0_id_list = np.vstack(vivo0_id_list)
    vivo0_start_list = np.vstack(vivo0_start_list)
    vivo0_start_id_list = np.vstack(vivo0_start_id_list)

    vivo0_ele_list = [i[1] for i in ele_list]

    logger.info(f"Total {vivo0_data_list.shape[0]} data points from {n_avaliable} electrodes")

    del data_pt, _meta_group, _data_group, _ch_data



[32m2025-05-19 18:40:30.918[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [0/6]: D:/Baihm/EISNN/Invivo/S5877_Ver02[0m
[32m2025-05-19 18:40:35.163[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [1/6]: D:/Baihm/EISNN/Invivo/S6005_Ver02[0m
[32m2025-05-19 18:40:39.720[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [2/6]: D:/Baihm/EISNN/Invivo/S6006_Ver02[0m
[32m2025-05-19 18:40:43.407[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [3/6]: D:/Baihm/EISNN/Invivo/S6072_Ver02[0m
[32m2025-05-19 18:40:47.271[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [4/6]: D:/Baihm/EISNN/Invivo/S6106_Ver02[0m
[32m2025-05-19 18:40:48.198[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mELE [5/6]: D:/Baihm/EISNN/Invivo/S6175_Ver02[0m
[32m2025-05-19 18:40:48.248[0m | [1mINFO    [0m | [36

## Data Summary

In [31]:
if not READ_RAW_FLAG:
    Data_Path = "D:/Baihm/EISNN/Feature/AllData.npz"
    # Data_Path = "D:/Baihm/EISNN/Feature/SEQData.npz"
    if os.path.exists(Data_Path):
        AllData = np.load(Data_Path)
        vitro0_data_list = AllData["vitro0_data_list"]
        vitro0_id_list = AllData["vitro0_id_list"]
        vitro0_start_list = AllData["vitro0_start_list"]
        vitro0_start_id_list = AllData["vitro0_start_id_list"]
        vitro0_ele_list = AllData["vitro0_ele_list"]
        
        vitro1_data_list = AllData["vitro1_data_list"]
        vitro1_id_list = AllData["vitro1_id_list"]
        vitro1_start_list = AllData["vitro1_start_list"]
        vitro1_start_id_list = AllData["vitro1_start_id_list"]
        vitro1_ele_list = AllData["vitro1_ele_list"]

        
        vivo0_data_list = AllData["vivo0_data_list"]
        vivo0_id_list = AllData["vivo0_id_list"]
        vivo0_start_list = AllData["vivo0_start_list"]
        vivo0_start_id_list = AllData["vivo0_start_id_list"]
        vivo0_ele_list = AllData["vivo0_ele_list"]

        logger.info(f"Vitro0:\t{vitro0_data_list.shape}\t{vitro0_start_list.shape}")
        logger.info(f"vitro1:\t{vitro1_data_list.shape}\t{vitro1_start_list.shape}")
        logger.info(f"Vivo0:\t{vivo0_data_list.shape}\t{vivo0_start_list.shape}")
        
    else:
        logger.warning(f"{Data_Path} does not exist")

In [33]:
all_data_list = np.vstack((vitro0_data_list, vitro1_data_list, vivo0_data_list))
all_id_list = np.vstack((vitro0_id_list, vitro1_id_list, vivo0_id_list))
all_start_list = np.vstack((vitro0_start_list, vitro1_start_list, vivo0_start_list))
all_start_id_list = np.vstack((vitro0_start_id_list, vitro1_start_id_list, vivo0_start_id_list))


In [34]:
logger.info(f"Vitro0:\t{vitro0_data_list.shape}\t{vitro0_start_list.shape}")
logger.info(f"vitro1:\t{vitro1_data_list.shape}\t{vitro1_start_list.shape}")
logger.info(f"Vivo0:\t{vivo0_data_list.shape}\t{vivo0_start_list.shape}")
logger.info(f"All:\t\t{all_data_list.shape}\t{all_start_list.shape}")

[32m2025-05-19 18:41:09.303[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mVitro0:	(98690, 202)	(12170, 202)[0m
[32m2025-05-19 18:41:09.303[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mvitro1:	(81674, 202)	(9708, 202)[0m
[32m2025-05-19 18:41:09.304[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mVivo0:	(9406, 202)	(719, 202)[0m
[32m2025-05-19 18:41:09.304[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mAll:		(189770, 202)	(22597, 202)[0m


In [None]:
# # AllData_Path = "D:/Baihm/EISNN/Feature/AllData.npz"
# SEQData_Path = "D:/Baihm/EISNN/Feature/SeqData.npz"
# np.savez_compressed(SEQData_Path, 
#                     vitro0_data_list=vitro0_data_list, vitro0_id_list=vitro0_id_list,
#                     vitro0_start_list=vitro0_start_list, vitro0_start_id_list=vitro0_start_id_list,
#                     vitro0_ele_list = vitro0_ele_list,

#                     vitro1_data_list=vitro1_data_list, vitro1_id_list=vitro1_id_list,
#                     vitro1_start_list=vitro1_start_list, vitro1_start_id_list=vitro1_start_id_list,
#                     vitro1_ele_list = vitro1_ele_list,

#                     vivo0_data_list=vivo0_data_list, vivo0_id_list=vivo0_id_list,
#                     vivo0_start_list=vivo0_start_list, vivo0_start_id_list=vivo0_start_id_list,
#                     vivo0_ele_list = vivo0_ele_list,)