# Import

In [61]:
import os
import re
import gc
import sys

from loguru import logger
import numpy as np
import random

import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection


from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

# %matplotlib qt
%matplotlib qt

# Detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Input Layer

In [81]:
# Input Layer
def SearchELE(rootPath, ele_pattern = re.compile(r"(.+?)_归档")):
    '''==================================================
        Search all electrode directories in the rootPath
        Parameter: 
            rootPath: current search path
            ele_pattern: electrode dir name patten
        Returen:
            ele_list: list of electrode directories
        ==================================================
    '''
    ele_list = []
    for i in os.listdir(rootPath):
        match_ele = ele_pattern.match(i)
        if match_ele:
            ele_list.append([os.path.join(rootPath, i),match_ele.group(1)])
    return ele_list



In [82]:

rootPath = "D:/Baihm/EISNN/Archive/"
ele_list = SearchELE(rootPath)
n_ele = len(ele_list)
logger.info(f"Search in {rootPath} and find {n_ele:03d} electrodes")


[32m2025-04-25 11:21:24.464[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mSearch in D:/Baihm/EISNN/Archive/ and find 218 electrodes[0m


In [83]:
Whitelist = [
    '06017758',
    '06017760',
    '01037162',
    '10080601',
    '22017368',
    '01067095',
    '02027373',
    '05087164',
]


Blacklist = [
    '01067093',     # Not look like EIS
    '01067094',     # Connection Error
    '02017385',     # Connection Error
    '05127177',     # Open to Short
    '06047729',     # Open to Short
    '06047730',     # Open to Short
    '06047731',     # Open to Short
    '09207024',     # Connection Error
    '10017038',     # Connection Error
    '10037050',     # Connection Error
    '10047056',     # Connection Error
    '10057069',     # Connection Error
    '10057083',     # Always Open
    '10057084',     # Chaos
    '10057087',     # Connection Error
    '22017367',     # Connection Error
    '22017371',     # Chaos
]

GrayList = [
    '10037051',     # Connection Error
    '10037052',     # Connection Error
    '10057071',     # Connection Error
    '10067077',     # Wired Shape like connection error
    '10150201',     # Wired Shape
    '10150202',     # Wired Shape
    '10150203',     # Wired Shape
    '20037515',     # Wired Shape
    '20037516',     # Wired Shape
    '20037517',     # Wired Shape
    '22037378',     # Connection Error
    '22037380',     # Connection Error
    '22047376',     # Connection Error

]


In [84]:

MODEL_SUFFIX = "Matern12_Ver01"

all_data_list = []
all_id_list = []

white_id_list = []

_ch_pattern = re.compile(r"ch_(\d{3})")

for i in range(n_ele):
# for i in range(3):
    if ele_list[i][1] in Blacklist:
        continue

    if ele_list[i][1] in Whitelist:
        white_id_list.append(int(i))

    fd_pt = os.path.join(ele_list[i][0], MODEL_SUFFIX, f"{ele_list[i][1]}_{MODEL_SUFFIX}.pt")
    if not os.path.exists(fd_pt):
        # logger.warning(f"{fd_pt} does not exist")
        continue
    data_pt = torch.load(fd_pt, weights_only=False)
    _meta_group = data_pt["meta_group"]
    _data_group = data_pt["data_group"]

    n_day       = _meta_group["n_day"]
    n_ch        = _meta_group["n_ch"]
    n_valid_ch  = len(_data_group["Channels"])

    # ignore abnormal ele
    if n_ch != 128 or n_valid_ch != n_ch:
        if n_day < 5 or n_valid_ch <= 100:
            continue

    logger.info(f"ELE [{i}/{n_ele}]: {ele_list[i][0]}")


    ele_data_list = []
    ele_id_list = []
    # Iteration by channel
    for j in _data_group['Channels']:
        _ch_data = _data_group[j]["y_eval"]
        ele_data_list.append(_ch_data)

        _ch_id = _ch_pattern.match(j)
        _ch_id = int(_ch_id.group(1))

        # _id = [int(ele_list[i][1]), _ch_id] * np.shape(_ch_data)[0]
        _id = [int(i), _ch_id] * np.shape(_ch_data)[0]
        _id = np.array(_id).reshape(-1,2)

        _cluster_id = _data_group[j]['eis_cluster_eval']
        _id = np.hstack((_id, _cluster_id.reshape(-1,1)))

        ele_id_list.append(_id)
        
    
    all_data_list.append(ele_data_list)
    all_id_list.append(ele_id_list)
    
    # ele_data_list = np.vstack(ele_data_list)
    # all_data_list.append(ele_data_list)

# all_data_list = np.vstack(all_data_list)
# all_id_list = np.vstack(all_id_list)


del data_pt, _meta_group, _data_group, _ch_data
gc.collect()



[32m2025-04-25 11:21:24.637[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m35[0m - [1mELE [0/218]: D:/Baihm/EISNN/Archive/01037160_归档[0m
[32m2025-04-25 11:21:24.676[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m35[0m - [1mELE [1/218]: D:/Baihm/EISNN/Archive/01037161_归档[0m
[32m2025-04-25 11:21:24.877[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m35[0m - [1mELE [2/218]: D:/Baihm/EISNN/Archive/01037162_归档[0m
[32m2025-04-25 11:21:24.919[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m35[0m - [1mELE [5/218]: D:/Baihm/EISNN/Archive/01067095_归档[0m
[32m2025-04-25 11:21:24.962[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m35[0m - [1mELE [9/218]: D:/Baihm/EISNN/Archive/02027373_归档[0m
[32m2025-04-25 11:21:24.978[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m35[0m - [1mELE [10/218]: D:/Baihm/EISNN/Archive/02027390_归档[0m
[32m2025-04-25 11:21:25.011[0m | [1m

216

# Helper

In [16]:
# Helper
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def load_all2seq(data_list, id_list = None):
    seq_data_list    = []
    seq_id_list     = []
    for i in range(len(data_list)):
        for j in range(len(data_list[i])):
            seq_data_list.append(data_list[i][j])
            if id_list is not None:
                seq_id_list.append(id_list[i][j])
    return seq_data_list, seq_id_list

def load_all2ch(data_list, id_list = None):
    ch_data_list, ch_id_list = load_all2seq(data_list, id_list)
    ch_data_list = np.vstack(ch_data_list)
    if id_list is not None:
        ch_id_list = np.vstack(ch_id_list)
    return ch_data_list, ch_id_list

## Plot Latent Space
def VAE_latent(model, ds, batch_size=64):
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False)

    _len_data = ds.__len__()
    _poi = 0

    latent_space_inst = []

    model.eval()
    with torch.no_grad():
        for x in loader:
            x = x.to(device)
            mu, lv = model.encoder(x)
            latent_space_inst.append(mu.cpu().numpy())

            _poi = _poi + x.size(0)
            if _poi % 1000 == 0:
                logger.info(f"[{_poi}]/[{_len_data}]")

    latent_space_inst = np.concatenate(latent_space_inst, axis=0)  # [B,z_dim]


    _pca_inst = PCA(n_components=latent_space_inst.shape[1])
    latent_dd = _pca_inst.fit_transform(latent_space_inst)
    
    
    explained = _pca_inst.explained_variance_ratio_
    eff_dim = (explained.cumsum() < 0.99).sum() + 1


    fig, axis = plt.subplots(2,1,
                gridspec_kw={'height_ratios': [4,1]},
                figsize=(9, 9))
    axis[0].scatter(latent_dd[:, 0], -latent_dd[:, 1], alpha=0.5, s = 0.001)

    axis[0].set_aspect('equal', adjustable='box')
    axis[0].set_box_aspect(1)
    axis[0].set_title("Latent Space")
    
    axis[1].plot(_pca_inst.explained_variance_ratio_,
                 label = f"Valid Dimension = {eff_dim}")
    axis[1].legend()
    fig.show()



    return latent_dd, eff_dim


# VAE

## Model Define

In [33]:
class EISDataset_Manifold(Dataset):
    def __init__(self, data_list, id_list = None):
        # data_list: n x m x k x l x 2 list
        # n: number of electrodes
        # m: number of channels
        # k: number of timestamps
        # l: number of freq as dimensions
        # 2: real and imaginary parts after logrithm
        _data = data_list
        _id = id_list
        _data = [torch.tensor(x, dtype=torch.float32) for x in _data]

        self.data = _data
        self.id = _id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Return [2,101] for Conv1D
        return self.data[idx].permute(1,0)  # [2,101] [in_ch, in_dim]

class Curve2VecEncoder_Ver01(nn.Module):
    def __init__(self, in_ch, in_dim, hid_ch, 
                 z_dim, kernel_size):
        super().__init__()


        _layers = []

        pre_ch = in_ch
        poi_ch = hid_ch
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))
        
        pre_ch = poi_ch
        poi_ch = poi_ch * 2
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))
        
        pre_ch = poi_ch
        poi_ch = poi_ch * 2
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))


        self.conv = nn.Sequential(*_layers)
        self.pool = nn.AdaptiveAvgPool1d(1)


        self.fc_mu = nn.Linear(poi_ch, z_dim)
        self.fc_lv = nn.Linear(poi_ch, z_dim)


    def forward(self, x):
        h = self.conv(x)                # [B,ch,in_dim]
        h = self.pool(h).squeeze(-1)    # [B,ch]
        return self.fc_mu(h), self.fc_lv(h) 


class Curve2VecDecoder_Ver01(nn.Module):
    def __init__(self, out_ch, out_dim, hid_ch, 
                 z_dim, kernel_size):
        super().__init__()
        self.hid_ch = hid_ch
        self.out_dim = out_dim


        self.fc_expand = nn.Linear(z_dim, hid_ch * out_dim)


        _layers = []
        _layers.append(nn.ReLU())

        pre_ch = hid_ch
        poi_ch = hid_ch//2
        _layers.append(nn.ConvTranspose1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))
        
        # pre_ch = poi_ch
        # poi_ch = poi_ch//2
        # _layers.append(nn.ConvTranspose1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))
        # _layers.append(nn.ReLU())
        # # _layers.append(nn.BatchNorm1d(poi_ch))

        pre_ch = poi_ch
        poi_ch = out_ch
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))


        # pre_ch = hid_ch
        # poi_ch = out_ch
        # _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))


        
        self.deconv = nn.Sequential(*_layers)


    def forward(self, z):
        h = self.fc_expand(z)           # [B,in_ch*in_dim]
        h = h.view(-1, self.hid_ch, self.out_dim)
        h = self.deconv(h)               # [B,in_ch,in_dim]
        return h                        # [B,in_ch,in_dim]

class Curve2VecVAE_Ver01(nn.Module):
    def __init__(self, in_ch=2, in_dim=101, 
                 enc_hid_ch = 16,
                 dec_hid_ch = 16,
                 z_dim = 16, kernel_size = 13):
        super().__init__()
        self.encoder = Curve2VecEncoder_Ver01(in_ch, in_dim, enc_hid_ch, z_dim, kernel_size)
        self.decoder = Curve2VecDecoder_Ver01(in_ch, in_dim, dec_hid_ch, z_dim, kernel_size)

    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, lv = self.encoder(x)
        z = self.reparam(mu, lv)
        x_rec = self.decoder(z)
        return x_rec, mu, lv 



## Load Model

In [13]:
eis2vec_save_path = "D:/Baihm/EISNN/PredictionModel/model/Convx2_z_ConvTx1_Convx1.pt"
vae_model_dick = torch.load(eis2vec_save_path)
vae_model = Curve2VecVAE_Ver01().to(device)
vae_model.load_state_dict(vae_model_dick)
vae_model.eval()

  vae_model_dick = torch.load(eis2vec_save_path)


Curve2VecVAE_Ver01(
  (encoder): Curve2VecEncoder_Ver01(
    (conv): Sequential(
      (0): Conv1d(2, 16, kernel_size=(13,), stride=(1,))
      (1): ReLU()
      (2): Conv1d(16, 32, kernel_size=(13,), stride=(1,))
      (3): ReLU()
      (4): Conv1d(32, 64, kernel_size=(13,), stride=(1,))
      (5): ReLU()
    )
    (pool): AdaptiveAvgPool1d(output_size=1)
    (fc_mu): Linear(in_features=64, out_features=16, bias=True)
    (fc_lv): Linear(in_features=64, out_features=16, bias=True)
  )
  (decoder): Curve2VecDecoder_Ver01(
    (fc_expand): Linear(in_features=16, out_features=1616, bias=True)
    (deconv): Sequential(
      (0): ReLU()
      (1): ConvTranspose1d(16, 8, kernel_size=(13,), stride=(1,), padding=(6,))
      (2): ReLU()
      (3): Conv1d(8, 2, kernel_size=(13,), stride=(1,), padding=(6,))
    )
  )
)

# Manifold

## Dimensionallity Reduction

In [86]:
def VAE_latent(model, ds, batch_size=64):
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False)

    _len_data = ds.__len__()
    _poi = 0

    latent_space_inst = []

    model.eval()
    with torch.no_grad():
        for x in loader:
            x = x.to(device)
            mu, lv = model.encoder(x)
            latent_space_inst.append(mu.cpu().numpy())

            _poi = _poi + x.size(0)
            if _poi % 1000 == 0:
                logger.info(f"[{_poi}]/[{_len_data}]")

    latent_space_inst = np.concatenate(latent_space_inst, axis=0)  # [B,z_dim]


    _pca_inst = PCA(n_components=latent_space_inst.shape[1])
    latent_dd = _pca_inst.fit_transform(latent_space_inst)

    if True:
        explained = _pca_inst.explained_variance_ratio_
        eff_dim = (explained.cumsum() < 0.99).sum() + 1


        fig, axis = plt.subplots(2,1,
                    gridspec_kw={'height_ratios': [4,1]},
                    figsize=(9, 9))
        axis[0].scatter(latent_dd[:, 0], -latent_dd[:, 1], alpha=0.5, s = 0.001)

        axis[0].set_aspect('equal', adjustable='box')
        axis[0].set_box_aspect(1)
        axis[0].set_title("Latent Space")
        
        # axis[1].plot(_pca_inst.explained_variance_ratio_,
        #             label = f"Valid Dimension = {eff_dim}")
        # axis[1].legend()
        fig.show()


    return latent_dd

In [87]:

ch_data_list, ch_id_list  = load_all2ch(all_data_list, all_id_list)
all_data_ds = EISDataset_Manifold(ch_data_list)

latent_dd = VAE_latent(vae_model, all_data_ds, batch_size=64)

[32m2025-04-25 11:21:56.865[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[8000]/[333535][0m
[32m2025-04-25 11:21:56.963[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[16000]/[333535][0m
[32m2025-04-25 11:21:57.050[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[24000]/[333535][0m
[32m2025-04-25 11:21:57.130[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[32000]/[333535][0m
[32m2025-04-25 11:21:57.198[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[40000]/[333535][0m
[32m2025-04-25 11:21:57.263[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[48000]/[333535][0m
[32m2025-04-25 11:21:57.326[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[56000]/[333535][0m
[32m2025-04-25 11:21:57.380[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_lat

## Plot Manifold

In [None]:

SAVE_FLAG = False
manifold_fig_save_path = "D:/Baihm/EISNN/PredictionModel/Manifold"


white_id_list


uq_id_list = np.unique(ch_id_list[:,0])
uq_id_max = np.max(uq_id_list)



cmap = plt.colormaps.get_cmap("rainbow_r")

# for i in range(len(uq_id_list)):
for i in range(0,2):
    # if uq_id_list[i] not in white_id_list:
    #     continue
    fig, axis = plt.subplots(1,1, figsize = (9,9))
    axis.scatter(latent_dd[:,0],latent_dd[:,1], color = 'lightgray', s=0.05)
    # plt.scatter(_pca_start[:,0],_pca_start[:,1],s=0.1)


    _ele_id = uq_id_list[i]

    ele_mask = ch_id_list[:,0] == _ele_id
    _ch_list = np.unique(ch_id_list[ele_mask,1])
    # for j in _ch_list:
    for j in _ch_list:
        _ch_mask = ch_id_list[:,:2] == [_ele_id,j]
        _ch_mask = _ch_mask[:,0] & _ch_mask[:,1]
        _ch_data = latent_dd[_ch_mask,:2]


        # _c = cmap(_ele_id / uq_id_max)
        # axis.plot(_ch_data[:,0],_ch_data[:,1], color = _c, alpha = 0.5)

        _cluster_list = np.unique(ch_id_list[_ch_mask,2])

        _seq_all_len = ch_id_list[_ch_mask,2].shape[0]
        _seg_poi = 0

        for k in _cluster_list:
            _cluster_mask = ch_id_list[:,:] == [_ele_id,j,k]
            _cluster_mask = _cluster_mask[:,0] & _cluster_mask[:,1] & _cluster_mask[:,2]
            _cluster_data = latent_dd[_cluster_mask,:2]

            _seg_data = _cluster_data.reshape(-1,1,2)
            _seg_data = np.concatenate([_seg_data[:-1], _seg_data[1:]], axis=1)

            _seg_len = _cluster_data.shape[0]
            
            color_range = np.linspace(_seg_poi/_seq_all_len, (_seg_poi+_seg_len)/_seq_all_len, _seg_len - 1)
            colors = cmap(color_range)

            _seg_poi = _seg_poi+_seg_len
            lc = LineCollection(_seg_data, colors=colors, linewidth=2)
            axis.add_collection(lc)

    axis.set_title(f"{ele_list[int(_ele_id)][1]}_Manifold")
    if SAVE_FLAG:
        _fig_name = f"{ele_list[int(_ele_id)][1]}_Manifold.png"
        _fig_save_path = os.path.join(manifold_fig_save_path, _fig_name)

        fig.savefig(_fig_save_path)
        plt.close(fig) 

        logger.info(f"{i}/{len(uq_id_list)} Saved")
    else:
        fig.show()



## Plot all manifold

In [98]:




fig, axis = plt.subplots(1,1, figsize = (16,9))
axis.scatter(latent_dd[:,0],latent_dd[:,1], color = 'lightgray', s=0.05)
# plt.scatter(_pca_start[:,0],_pca_start[:,1],s=0.1)



uq_id_list = np.unique(ch_id_list[:,0])
uq_id_max = np.max(uq_id_list)


cmap = plt.colormaps.get_cmap("rainbow_r")

for i in range(len(uq_id_list)):
# for i in range(0,6):
    _ele_id = uq_id_list[i]

    ele_mask = ch_id_list[:,0] == _ele_id
    _ch_list = np.unique(ch_id_list[ele_mask,1])


    for j in _ch_list:
        _ch_mask = ch_id_list[:,:2] == [_ele_id,j]
        _ch_mask = _ch_mask[:,0] & _ch_mask[:,1]
        _ch_data = latent_dd[_ch_mask,:2]

        # _c = cmap(_ele_id / uq_id_max)
        # axis.plot(_ch_data[:,0],_ch_data[:,1], color = _c, alpha = 0.5)

        _cluster_list = np.unique(ch_id_list[_ch_mask,2])

        _seq_all_len = ch_id_list[_ch_mask,2].shape[0]
        _seg_poi = 0

        for k in _cluster_list:
            _cluster_mask = ch_id_list[:,:] == [_ele_id,j,k]
            _cluster_mask = _cluster_mask[:,0] & _cluster_mask[:,1] & _cluster_mask[:,2]
            _cluster_data = latent_dd[_cluster_mask,:2]

            _seg_data = _cluster_data.reshape(-1,1,2)
            _seg_data = np.concatenate([_seg_data[:-1], _seg_data[1:]], axis=1)

            _seg_len = _cluster_data.shape[0]
            
            color_range = np.linspace(_seg_poi/_seq_all_len, (_seg_poi+_seg_len)/_seq_all_len, _seg_len - 1)
            colors = cmap(color_range)

            _seg_poi = _seg_poi+_seg_len
            lc = LineCollection(_seg_data, colors=colors, linewidth=1, alpha = 0.1)
            axis.add_collection(lc)

fig.show()

