# Import

In [281]:
import os
import re
import gc
import sys

from loguru import logger
import numpy as np
import random

import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection


from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE, Isomap, LocallyLinearEmbedding, MDS

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# %matplotlib qt
%matplotlib qt

# Detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Input Layer

## Definition

In [282]:
READ_RAW_FLAG = False

## Load Data

In [48]:

if not READ_RAW_FLAG:
    # Should always be Seq because not all data is in manifold
    Data_Path = "D:/Baihm/EISNN/Feature/SEQData.npz"
    if os.path.exists(Data_Path):
        AllData = np.load(Data_Path)
        vitro0_data_list = AllData["vitro0_data_list"]
        vitro0_id_list = AllData["vitro0_id_list"]
        vitro0_start_list = AllData["vitro0_start_list"]
        vitro0_start_id_list = AllData["vitro0_start_id_list"]
        vitro0_ele_list = AllData["vitro0_ele_list"]
        
        vitro1_data_list = AllData["vitro1_data_list"]
        vitro1_id_list = AllData["vitro1_id_list"]
        vitro1_start_list = AllData["vitro1_start_list"]
        vitro1_start_id_list = AllData["vitro1_start_id_list"]
        vitro1_ele_list = AllData["vitro1_ele_list"]

        
        vivo0_data_list = AllData["vivo0_data_list"]
        vivo0_id_list = AllData["vivo0_id_list"]
        vivo0_start_list = AllData["vivo0_start_list"]
        vivo0_start_id_list = AllData["vivo0_start_id_list"]
        vivo0_ele_list = AllData["vivo0_ele_list"]

        logger.info(f"Vitro0:\t{vitro0_data_list.shape}\t{vitro0_start_list.shape}")
        logger.info(f"vitro1:\t{vitro1_data_list.shape}\t{vitro1_start_list.shape}")
        logger.info(f"Vivo0:\t{vivo0_data_list.shape}\t{vivo0_start_list.shape}")
        
    else:
        logger.warning(f"{Data_Path} does not exist")


    # Calibrate ID List for concated list

    vitro0_id_list[:,0] = vitro0_id_list[:,0]   + 0
    vitro1_id_list[:,0] = vitro1_id_list[:,0]   + vitro0_ele_list.shape[0]
    vivo0_id_list[:,0]  = vivo0_id_list[:,0]    + vitro0_ele_list.shape[0] + vitro1_ele_list.shape[0]

    vitro0_start_id_list[:,0] = vitro0_start_id_list[:,0]   + 0
    vitro1_start_id_list[:,0] = vitro1_start_id_list[:,0]   + vitro0_ele_list.shape[0]
    vivo0_start_id_list[:,0]  = vivo0_start_id_list[:,0]    + vitro0_ele_list.shape[0] + vitro1_ele_list.shape[0]

[32m2025-05-19 18:54:23.728[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m25[0m - [1mVitro0:	(98690, 202)	(12170, 202)[0m
[32m2025-05-19 18:54:23.728[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mvitro1:	(81674, 202)	(9708, 202)[0m
[32m2025-05-19 18:54:23.729[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m27[0m - [1mVivo0:	(9406, 202)	(719, 202)[0m


## All Data

In [49]:
all_ele_list = np.concatenate([vitro0_ele_list,vitro1_ele_list,vivo0_ele_list], axis=0)

all_data_list = np.vstack((vitro0_data_list, vitro1_data_list, vivo0_data_list))
all_id_list = np.vstack((vitro0_id_list, vitro1_id_list, vivo0_id_list))
all_start_list = np.vstack((vitro0_start_list, vitro1_start_list, vivo0_start_list))
all_start_id_list = np.vstack((vitro0_start_id_list, vitro1_start_id_list, vivo0_start_id_list))
logger.info(f"All:\t{all_data_list.shape}\t{all_start_list.shape}")



[32m2025-05-19 18:54:23.849[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mAll:	(189770, 202)	(22597, 202)[0m


In [50]:
print(np.unique(vitro0_id_list[:,0]).shape, np.unique(vitro1_id_list[:,0]).shape, np.unique(vivo0_id_list[:,0]).shape, np.unique(all_id_list[:,0]).shape)
print(vitro0_ele_list.shape,vitro1_ele_list.shape,vivo0_ele_list.shape,all_ele_list.shape)

(153,) (128,) (6,) (287,)
(218,) (187,) (6,) (411,)


# VAE

## Data Loader

In [51]:
# Helper
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def load_all2ch(data_list, id_list = None):
    '''==================================================
        Load all data and split into 2 channels
        Parameter: 
            data_list: data list    n x 202
            id_list: id list        n x 2
        Returen:
            ch_data_list: channel data list     n x 101 x 2
            ch_id_list: channel id list         n x 2
        ==================================================
    '''
    ch_data_list = np.array([data_list[:,:101],data_list[:,101:]])
    ch_data_list = ch_data_list.transpose(1,2,0)

    ch_id_list = id_list

    return ch_data_list, ch_id_list

## Model Define

In [52]:
class EISDataset_Manifold(Dataset):
    def __init__(self, data_list, id_list = None):
        # data_list: n x m x k x l x 2 list
        # n: number of electrodes
        # m: number of channels
        # k: number of timestamps
        # l: number of freq as dimensions
        # 2: real and imaginary parts after logrithm
        _data = data_list
        _id = id_list
        _data = [torch.tensor(x, dtype=torch.float32) for x in _data]

        self.data = _data
        self.id = _id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Return [2,101] for Conv1D
        return self.data[idx].permute(1,0)  # [2,101] [in_ch, in_dim]

class Curve2VecEncoder_Ver01(nn.Module):
    def __init__(self, in_ch, in_dim, hid_ch, 
                 z_dim, kernel_size):
        super().__init__()


        _layers = []

        pre_ch = in_ch
        poi_ch = hid_ch
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))
        
        pre_ch = poi_ch
        poi_ch = poi_ch * 2
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))
        
        pre_ch = poi_ch
        poi_ch = poi_ch * 2
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))


        self.conv = nn.Sequential(*_layers)
        self.pool = nn.AdaptiveAvgPool1d(1)


        self.fc_mu = nn.Linear(poi_ch, z_dim)
        self.fc_lv = nn.Linear(poi_ch, z_dim)


    def forward(self, x):
        h = self.conv(x)                # [B,ch,in_dim]
        h = self.pool(h).squeeze(-1)    # [B,ch]
        return self.fc_mu(h), self.fc_lv(h) 


class Curve2VecDecoder_Ver01(nn.Module):
    def __init__(self, out_ch, out_dim, hid_ch, 
                 z_dim, kernel_size):
        super().__init__()
        self.hid_ch = hid_ch
        self.out_dim = out_dim


        self.fc_expand = nn.Linear(z_dim, hid_ch * out_dim)


        _layers = []
        _layers.append(nn.ReLU())

        pre_ch = hid_ch
        poi_ch = hid_ch//2
        _layers.append(nn.ConvTranspose1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))
        
        # pre_ch = poi_ch
        # poi_ch = poi_ch//2
        # _layers.append(nn.ConvTranspose1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))
        # _layers.append(nn.ReLU())
        # # _layers.append(nn.BatchNorm1d(poi_ch))

        pre_ch = poi_ch
        poi_ch = out_ch
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))


        # pre_ch = hid_ch
        # poi_ch = out_ch
        # _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))


        
        self.deconv = nn.Sequential(*_layers)


    def forward(self, z):
        h = self.fc_expand(z)           # [B,in_ch*in_dim]
        h = h.view(-1, self.hid_ch, self.out_dim)
        h = self.deconv(h)               # [B,in_ch,in_dim]
        return h                        # [B,in_ch,in_dim]

class Curve2VecVAE_Ver01(nn.Module):
    def __init__(self, in_ch=2, in_dim=101, 
                 enc_hid_ch = 16,
                 dec_hid_ch = 16,
                 z_dim = 16, kernel_size = 13):
        super().__init__()
        self.encoder = Curve2VecEncoder_Ver01(in_ch, in_dim, enc_hid_ch, z_dim, kernel_size)
        self.decoder = Curve2VecDecoder_Ver01(in_ch, in_dim, dec_hid_ch, z_dim, kernel_size)

    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, lv = self.encoder(x)
        z = self.reparam(mu, lv)
        x_rec = self.decoder(z)
        return x_rec, mu, lv 



## Load Model

In [54]:
eis2vec_save_path = "D:/Baihm/EISNN/Feature/SeqData_Convx2_z_ConvTx1_Convx1.pt"
vae_model_dick = torch.load(eis2vec_save_path)
vae_model = Curve2VecVAE_Ver01().to(device)
vae_model.load_state_dict(vae_model_dick)
vae_model.eval()

  vae_model_dick = torch.load(eis2vec_save_path)


Curve2VecVAE_Ver01(
  (encoder): Curve2VecEncoder_Ver01(
    (conv): Sequential(
      (0): Conv1d(2, 16, kernel_size=(13,), stride=(1,))
      (1): ReLU()
      (2): Conv1d(16, 32, kernel_size=(13,), stride=(1,))
      (3): ReLU()
      (4): Conv1d(32, 64, kernel_size=(13,), stride=(1,))
      (5): ReLU()
    )
    (pool): AdaptiveAvgPool1d(output_size=1)
    (fc_mu): Linear(in_features=64, out_features=16, bias=True)
    (fc_lv): Linear(in_features=64, out_features=16, bias=True)
  )
  (decoder): Curve2VecDecoder_Ver01(
    (fc_expand): Linear(in_features=16, out_features=1616, bias=True)
    (deconv): Sequential(
      (0): ReLU()
      (1): ConvTranspose1d(16, 8, kernel_size=(13,), stride=(1,), padding=(6,))
      (2): ReLU()
      (3): Conv1d(8, 2, kernel_size=(13,), stride=(1,), padding=(6,))
    )
  )
)

# Manifold

## Dimensionallity Reduction

### Definition

In [56]:
def VAE_latent(model, ds, batch_size=64):
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False)

    _len_data = ds.__len__()
    _poi = 0

    latent_space_inst = []

    model.eval()
    with torch.no_grad():
        for x in loader:
            x = x.to(device)
            mu, lv = model.encoder(x)
            latent_space_inst.append(mu.cpu().numpy())

            _poi = _poi + x.size(0)
            if _poi % 1000 == 0:
                logger.info(f"[{_poi}]/[{_len_data}]")

    latent_space_inst = np.concatenate(latent_space_inst, axis=0)  # [B,z_dim]

    return latent_space_inst

def VAE_PCA_Plot(_pca_inst, latent_dd, alpha = 0.5, s = 0.001):
    explained = _pca_inst.explained_variance_ratio_
    eff_dim = (explained.cumsum() < 0.90).sum() + 1


    fig, axis = plt.subplots(2,1,
                gridspec_kw={'height_ratios': [4,1]},
                figsize=(9, 9))
    axis[0].scatter(-latent_dd[:, 0], latent_dd[:, 1], alpha=alpha, s = s)

    axis[0].set_xlim(-2, 3)
    axis[0].set_ylim(-3, 3)
    axis[0].set_aspect('equal', adjustable='box')
    # axis[0].set_box_aspect(1)
    axis[0].set_title("Latent Space")

    axis[1].plot(_pca_inst.explained_variance_ratio_,
                label = f"Valid Dimension = {eff_dim}")
    axis[1].legend()
    fig.show()

    return latent_dd

### Run DR

In [146]:
seg0 = vitro0_ele_list.shape[0]
seg1 = vitro0_ele_list.shape[0] + vitro1_ele_list.shape[0]


# Single Dataset
# ch_data_list, ch_id_list  = load_all2ch(all_data_list[all_id_list[:,0]<seg0], all_id_list[all_id_list[:,0]<seg0])
# ch_data_list, ch_id_list  = load_all2ch(all_data_list[(all_id_list[:,0]>=seg0) & (all_id_list[:,0]<seg1)], all_id_list[(all_id_list[:,0]>=seg0) & (all_id_list[:,0]<seg1)])
# ch_data_list, ch_id_list  = load_all2ch(all_data_list[all_id_list[:,0]>=seg1], all_id_list[all_id_list[:,0]>=seg1])

# Couple Dataset
# ch_data_list, ch_id_list  = load_all2ch(all_data_list[all_id_list[:,0]<seg1], all_id_list[all_id_list[:,0]<seg1])
ch_data_list, ch_id_list  = load_all2ch(all_data_list[:], all_id_list[:])

all_data_ds = EISDataset_Manifold(ch_data_list)
latent_space_inst = VAE_latent(vae_model, all_data_ds, batch_size=64)




[32m2025-05-19 21:12:10.225[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[8000]/[189770][0m
[32m2025-05-19 21:12:10.354[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[16000]/[189770][0m
[32m2025-05-19 21:12:10.421[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[24000]/[189770][0m
[32m2025-05-19 21:12:10.488[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[32000]/[189770][0m
[32m2025-05-19 21:12:10.552[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[40000]/[189770][0m
[32m2025-05-19 21:12:10.614[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[48000]/[189770][0m
[32m2025-05-19 21:12:10.682[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[56000]/[189770][0m
[32m2025-05-19 21:12:10.748[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_lat

In [58]:
_pca_inst = PCA(n_components=latent_space_inst.shape[1])
latent_dd = _pca_inst.fit_transform(latent_space_inst)

## Plot Manifold

In [None]:
# latent_dd = _pca_inst.transform(latent_space_inst[all_id_list[:,0]<seg0])
# latent_dd = _pca_inst.transform(latent_space_inst[(all_id_list[:,0]>=seg0) & (all_id_list[:,0]<seg1)])
# latent_dd = _pca_inst.transform(latent_space_inst[all_id_list[:,0]>=seg1])

# latent_dd = _pca_inst.transform(latent_space_inst[all_id_list[:,0]<seg1])
latent_dd = _pca_inst.transform(latent_space_inst[:])

latent_dd = VAE_PCA_Plot(_pca_inst, latent_dd, alpha = 0.5, s = 0.001)


In [79]:

SAVE_FLAG = False
if SAVE_FLAG:
    manifold_fig_save_path = "D:/Baihm/EISNN/Feature/Manifold"
    if not os.path.exists(manifold_fig_save_path):
        os.makedirs(manifold_fig_save_path)

uq_id_list = np.unique(ch_id_list[:,0])
uq_id_max = np.max(uq_id_list)

cmap = plt.colormaps.get_cmap("rainbow_r")

# for i in range(len(uq_id_list)):
for i in range(0,2):
    # if uq_id_list[i] not in white_id_list:
    #     continue
    fig, axis = plt.subplots(1,1, figsize = (9,9))
    axis.scatter(latent_dd[:,0],latent_dd[:,1], color = 'lightgray', s=0.05)
    # plt.scatter(_pca_start[:,0],_pca_start[:,1],s=0.1)


    _ele_id = uq_id_list[i]

    ele_mask = ch_id_list[:,0] == _ele_id
    _ch_list = np.unique(ch_id_list[ele_mask,1])
    # for j in _ch_list:
    for j in _ch_list:
        _ch_mask = ch_id_list[:,:2] == [_ele_id,j]
        _ch_mask = _ch_mask[:,0] & _ch_mask[:,1]


        # _ch_data = latent_dd[_ch_mask,:2]
        # _c = cmap(_ele_id / uq_id_max)
        # axis.plot(_ch_data[:,0],_ch_data[:,1], color = _c, alpha = 0.5)

        _cluster_list = np.unique(ch_id_list[_ch_mask,2])

        _seq_all_len = ch_id_list[_ch_mask,2].shape[0]
        _seg_poi = 0

        for k in _cluster_list:
            _cluster_mask = ch_id_list[:,:3] == [_ele_id,j,k]
            _cluster_mask = _cluster_mask[:,0] & _cluster_mask[:,1] & _cluster_mask[:,2]
            _cluster_data = latent_dd[_cluster_mask,:2]

            _seg_data = _cluster_data.reshape(-1,1,2)
            _seg_data = np.concatenate([_seg_data[:-1], _seg_data[1:]], axis=1)

            _dx = np.abs(_seg_data[:,1,0] - _seg_data[:,0,0])
            _seg_data = _seg_data[_dx < 1,:,:]

            _seg_len = _cluster_data.shape[0]
            
            color_range = np.linspace(_seg_poi/_seq_all_len, (_seg_poi+_seg_len)/_seq_all_len, _seg_data.shape[0])
            colors = cmap(color_range)

            _seg_poi = _seg_poi+_seg_len
            lc = LineCollection(_seg_data, colors=colors, linewidth=2)
            axis.add_collection(lc)

    axis.set_title(f"{all_ele_list[int(_ele_id)]}_Manifold")
    if SAVE_FLAG:
        _fig_name = f"{all_ele_list[int(_ele_id)]}_Manifold.png"
        _fig_save_path = os.path.join(manifold_fig_save_path, _fig_name)

        fig.savefig(_fig_save_path)
        plt.close(fig) 

        logger.info(f"{i}/{len(uq_id_list)} Saved")
    else:
        fig.show()



[32m2025-05-19 15:11:21.615[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m0/287 Saved[0m
[32m2025-05-19 15:11:22.262[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m1/287 Saved[0m
[32m2025-05-19 15:11:22.950[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m2/287 Saved[0m
[32m2025-05-19 15:11:23.305[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m3/287 Saved[0m
[32m2025-05-19 15:11:23.630[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m4/287 Saved[0m
[32m2025-05-19 15:11:24.073[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m5/287 Saved[0m
[32m2025-05-19 15:11:24.236[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m6/287 Saved[0m
[32m2025-05-19 15:11:24.551[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m70[0m - [1m7/287 Saved[0m
[32m202

## Plot all manifold

### 1 vs 2

In [261]:
# latent_dd = _pca_inst.transform(latent_space_inst[all_id_list[:,0]<seg0])
# latent_dd = _pca_inst.transform(latent_space_inst[(all_id_list[:,0]>=seg0) & (all_id_list[:,0]<seg1)])
# latent_dd = _pca_inst.transform(latent_space_inst[all_id_list[:,0]>=seg1])

# latent_dd = _pca_inst.transform(latent_space_inst[all_id_list[:,0]<seg1])
latent_dd = _pca_inst.transform(latent_space_inst[:])


latent_dd = VAE_PCA_Plot(_pca_inst, latent_dd, alpha = 0.5, s = 0.001)


In [None]:




fig, axis = plt.subplots(1,1, figsize = (9,9))
axis.scatter(-latent_dd[:,0],latent_dd[:,1], color = 'lightgray', s=0.05)
# plt.scatter(_pca_start[:,0],_pca_start[:,1],s=0.1)



uq_id_list = np.unique(ch_id_list[:,0])
uq_id_max = np.max(uq_id_list)


cmap = plt.colormaps.get_cmap("rainbow_r")

for i in range(len(uq_id_list)):
# for i in range(0,6):
    # _ele_id = uq_id_list[i]

    # if _ele_id >= seg0: break
    # if  _ele_id < seg0 or _ele_id >= seg1: continue
    # if _ele_id < seg1: continue

    # if _ele_id >= seg1: break
    



    ele_mask = ch_id_list[:,0] == _ele_id
    _ch_list = np.unique(ch_id_list[ele_mask,1])


    for j in _ch_list:
        _ch_mask = ch_id_list[:,:2] == [_ele_id,j]
        _ch_mask = _ch_mask[:,0] & _ch_mask[:,1]
        # _ch_data = latent_dd[_ch_mask,:2]

        # _c = cmap(_ele_id / uq_id_max)
        # axis.plot(_ch_data[:,0],_ch_data[:,1], color = _c, alpha = 0.5)

        _cluster_list = np.unique(ch_id_list[_ch_mask,2])

        _seq_all_len = ch_id_list[_ch_mask,2].shape[0]
        _seg_poi = 0

        for k in _cluster_list:
            _cluster_mask = ch_id_list[:,:3] == [_ele_id,j,k]
            _cluster_mask = _cluster_mask[:,0] & _cluster_mask[:,1] & _cluster_mask[:,2]
            # _cluster_data = latent_dd[_cluster_mask,:2]
            _cluster_data = np.stack([-latent_dd[_cluster_mask,0],latent_dd[_cluster_mask,1]], axis=1)

            _seg_data = _cluster_data.reshape(-1,1,2)
            _seg_data = np.concatenate([_seg_data[:-1], _seg_data[1:]], axis=1)

            _dx = np.abs(_seg_data[:,1,0] - _seg_data[:,0,0])
            _seg_data = _seg_data[_dx < 1,:,:]

            _seg_len = _cluster_data.shape[0]
            
            color_range = np.linspace(_seg_poi/_seq_all_len, (_seg_poi+_seg_len)/_seq_all_len, _seg_data.shape[0])
            colors = cmap(color_range)

            _seg_poi = _seg_poi+_seg_len
            lc = LineCollection(_seg_data, colors=colors, linewidth=1, alpha = 0.05)
            axis.add_collection(lc)

    
axis.set_xlim(-2, 3)
axis.set_ylim(-3, 3)
axis.set_aspect('equal', adjustable='box')
# axis[0].set_box_aspect(1)
axis.set_title("Latent Space")

fig.show()



### 2 vs 3

In [280]:
latent_dd = _pca_inst.transform(latent_space_inst[:])
latent_dd.shape

latent_mask = latent_dd[:,0]>-0.5
latent_dd = latent_dd[latent_mask]


latent_dd_tmp = latent_dd.copy()
latent_dd_tmp[:,0] = latent_dd[:,1]
latent_dd_tmp[:,1] = latent_dd[:,2]
latent_dd = latent_dd_tmp 

latent_dd = VAE_PCA_Plot(_pca_inst, latent_dd, alpha = 0.5, s = 0.001)



In [279]:

fig, axis = plt.subplots(1,1, figsize = (9,9))
axis.scatter(-latent_dd[:,0],latent_dd[:,1], color = 'lightgray', s=0.05)
# plt.scatter(_pca_start[:,0],_pca_start[:,1],s=0.1)


_poi_id_list= ch_id_list[latent_mask,:]

uq_id_list = np.unique(_poi_id_list[:,0])
uq_id_max = np.max(uq_id_list)


cmap = plt.colormaps.get_cmap("rainbow_r")

for i in range(len(uq_id_list)):
# for i in range(0,6):
    # _ele_id = uq_id_list[i]

    # if _ele_id >= seg0: break
    # if  _ele_id < seg0 or _ele_id >= seg1: continue
    # if _ele_id < seg1: continue

    # if _ele_id >= seg1: break
    

    ele_mask = _poi_id_list[:,0] == _ele_id
    _ch_list = np.unique(_poi_id_list[ele_mask,1])


    for j in _ch_list:
        _ch_mask = _poi_id_list[:,:2] == [_ele_id,j]
        _ch_mask = _ch_mask[:,0] & _ch_mask[:,1]
        # _ch_data = latent_dd[_ch_mask,:2]

        # _c = cmap(_ele_id / uq_id_max)
        # axis.plot(_ch_data[:,0],_ch_data[:,1], color = _c, alpha = 0.5)

        _cluster_list = np.unique(_poi_id_list[_ch_mask,2])

        _seq_all_len = _poi_id_list[_ch_mask,2].shape[0]
        _seg_poi = 0

        for k in _cluster_list:
            _cluster_mask = _poi_id_list[:,:3] == [_ele_id,j,k]
            _cluster_mask = _cluster_mask[:,0] & _cluster_mask[:,1] & _cluster_mask[:,2]
            # _cluster_data = latent_dd[_cluster_mask,:2]
            _cluster_data = np.stack([-latent_dd[_cluster_mask,0],latent_dd[_cluster_mask,1]], axis=1)

            _seg_data = _cluster_data.reshape(-1,1,2)
            _seg_data = np.concatenate([_seg_data[:-1], _seg_data[1:]], axis=1)

            _dx = np.abs(_seg_data[:,1,0] - _seg_data[:,0,0])
            _seg_data = _seg_data[_dx < 1,:,:]

            _seg_len = _cluster_data.shape[0]
            
            color_range = np.linspace(_seg_poi/_seq_all_len, (_seg_poi+_seg_len)/_seq_all_len, _seg_data.shape[0])
            colors = cmap(color_range)

            _seg_poi = _seg_poi+_seg_len
            lc = LineCollection(_seg_data, colors=colors, linewidth=1, alpha = 0.05)
            axis.add_collection(lc)

    
axis.set_xlim(-2, 3)
axis.set_ylim(-3, 3)
axis.set_aspect('equal', adjustable='box')
# axis[0].set_box_aspect(1)
axis.set_title("Latent Space")

fig.show()



## Plot Start

In [243]:
ch_start_list, ch_start_id_list  = load_all2ch(all_start_list[:], all_start_id_list[:])
all_start_ds = EISDataset_Manifold(ch_start_list)
latent_space_start_inst = VAE_latent(vae_model, all_start_ds, batch_size=64)


[32m2025-05-19 22:02:51.873[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[8000]/[22597][0m
[32m2025-05-19 22:02:51.965[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[16000]/[22597][0m


### 2D - 1 vs 2

In [289]:
latent_dd = _pca_inst.transform(latent_space_inst[:])

# latent_dd_tmp = latent_dd.copy()
# latent_dd_tmp[:,1] = latent_dd[:,2]
# latent_dd = latent_dd_tmp 



latent_start_dd = _pca_inst.transform(latent_space_start_inst[all_start_id_list[:,0]<seg0])
# latent_start_dd = _pca_inst.transform(latent_space_start_inst[(all_start_id_list[:,0]>=seg0) & (all_start_id_list[:,0]<seg1)])
# latent_start_dd = _pca_inst.transform(latent_space_start_inst[all_start_id_list[:,0]>=seg1])

# latent_start_dd = _pca_inst.transform(latent_space_start_inst[all_start_id_list[:,0]<seg1])
# latent_start_dd = _pca_inst.transform(latent_space_start_inst[:])


# latent_start_dd_tmp = latent_start_dd.copy()
# latent_start_dd_tmp[:,1] = latent_start_dd[:,2]
# latent_start_dd = latent_start_dd_tmp 



In [None]:


fig, axis = plt.subplots(1,1, figsize = (9,9))
axis.scatter(-latent_dd[:,0],latent_dd[:,1], color = 'lightgray',s=0.005)
axis.scatter(-latent_start_dd[:,0],latent_start_dd[:,1], color = 'red',s=0.1)

axis.set_xlim(-2, 3)
axis.set_ylim(-3, 3)
axis.set_aspect('equal', adjustable='box')
# axis[0].set_box_aspect(1)
axis.set_title("Latent Space")

fig.show()



### 2D - 2 vs 3

In [293]:
latent_dd = _pca_inst.transform(latent_space_inst[:])

latent_mask = latent_dd[:,0]>-0.5
latent_dd = latent_dd[latent_mask]



latent_dd_tmp = latent_dd.copy()
latent_dd_tmp[:,0] = latent_dd[:,1]
latent_dd_tmp[:,1] = latent_dd[:,2]
latent_dd = latent_dd_tmp 

latent_start_dd = _pca_inst.transform(latent_space_start_inst[:])

latent_mask = latent_start_dd[:,0]>-0.5
latent_start_dd = latent_start_dd[latent_mask]



latent_dd_tmp = latent_start_dd.copy()
latent_dd_tmp[:,0] = latent_start_dd[:,1]
latent_dd_tmp[:,1] = latent_start_dd[:,2]
latent_start_dd = latent_dd_tmp 


latent_dd = VAE_PCA_Plot(_pca_inst, latent_dd, alpha = 0.5, s = 0.001)



In [296]:




fig, axis = plt.subplots(1,1, figsize = (9,9))
axis.scatter(-latent_dd[:,0],latent_dd[:,1], color = 'lightgray',s=0.005)
axis.scatter(-latent_start_dd[:,0],latent_start_dd[:,1], color = 'red',s=0.005)

axis.set_xlim(-2, 3)
axis.set_ylim(-3, 3)
axis.set_aspect('equal', adjustable='box')
# axis[0].set_box_aspect(1)
axis.set_title("Latent Space")

fig.show()



### 3D

In [290]:
latent_dd = _pca_inst.transform(latent_space_inst[:])
latent_start_dd = _pca_inst.transform(latent_space_start_inst[all_start_id_list[:,0]<seg0])



In [291]:
fig = plt.figure(figsize=(9, 9))
axis = fig.add_subplot(111, projection='3d')

# 3D scatter
axis.scatter(
    -latent_dd[:, 0],  # 注意保留你的负号翻转
    latent_dd[:, 1],
    latent_dd[:, 2],
    color='lightgray',
    s=0.05
)

axis.scatter(
    -latent_start_dd[:, 0],
    latent_start_dd[:, 1],
    latent_start_dd[:, 2],
    color='red',
    s=0.05
)

# 设置范围（你可以按实际数据修改）
axis.set_xlim(-2, 3)
axis.set_ylim(-3, 3)
axis.set_zlim(-3, 3)

axis.set_xlabel("Latent Dimension 1")
axis.set_ylabel("Latent Dimension 2")
axis.set_zlabel("Latent Dimension 3")

axis.set_title("Latent Space (3D)")
plt.tight_layout()
plt.show()

In [292]:
import plotly.graph_objects as go
# 降采样数据
# step = 100
# points = latent_dd[::step]
# starts = latent_start_dd[::step]
step = 100
points = latent_dd[:]
starts = latent_start_dd[:]

fig = go.Figure()

# 背景灰色点
fig.add_trace(go.Scatter3d(
    x=-points[:, 0], y=points[:, 1], z=points[:, 2],
    mode='markers',
    marker=dict(size=0.5, color='lightgray'),
    name='All points'
))

# 起始点红色
fig.add_trace(go.Scatter3d(
    x=-starts[:, 0], y=starts[:, 1], z=starts[:, 2],
    mode='markers',
    marker=dict(size=0.5, color='red'),
    name='Start points'
))

fig.update_layout(
    title="Latent Space (3D)",
    scene=dict(
        xaxis_title='Latent Dim 1',
        yaxis_title='Latent Dim 2',
        zaxis_title='Latent Dim 3'
    ),
    height=800,
)
fig.show(renderer="browser")



## Plot Velocity

### Data Selector

In [368]:
ch_start_list, ch_start_id_list  = load_all2ch(all_start_list[:], all_start_id_list[:])
all_start_ds = EISDataset_Manifold(ch_start_list)
latent_space_start_inst = VAE_latent(vae_model, all_start_ds, batch_size=64)


[32m2025-05-19 23:28:50.813[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[8000]/[22597][0m
[32m2025-05-19 23:28:50.914[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[16000]/[22597][0m


In [369]:
FLAG23 = True

In [370]:
latent_dd = _pca_inst.transform(latent_space_inst[:])

if FLAG23:
    latent_dd_tmp = latent_dd.copy()
    latent_dd_tmp[:,1] = latent_dd[:,2]
    latent_dd = latent_dd_tmp 

latent_dd = VAE_PCA_Plot(_pca_inst, latent_dd, alpha = 0.5, s = 0.001)


### Calculate Velocity

In [371]:

uq_id_list = np.unique(ch_id_list[:,0])
uq_id_max = np.max(uq_id_list)


cmap = plt.colormaps.get_cmap("rainbow_r")

manifold_vector_list    = []
manifold_time_list      = []

for i in range(len(uq_id_list)):
# for i in range(1):
    logger.info(f"[{i}/{len(uq_id_list)}]")
    _ele_id = uq_id_list[i]

    # if _ele_id >= seg0: break
    # if  _ele_id < seg0 or _ele_id >= seg1: continue
    if _ele_id < seg1: continue

    # if _ele_id >= seg1: break
    

    ele_mask = ch_id_list[:,0] == _ele_id
    _ch_list = np.unique(ch_id_list[ele_mask,1])


    for j in _ch_list:
        _ch_mask = ch_id_list[:,:2] == [_ele_id,j]
        _ch_mask = _ch_mask[:,0] & _ch_mask[:,1]
        # _ch_data = latent_dd[_ch_mask,:2]

        # _c = cmap(_ele_id / uq_id_max)
        # axis.plot(_ch_data[:,0],_ch_data[:,1], color = _c, alpha = 0.5)

        _cluster_list = np.unique(ch_id_list[_ch_mask,2])
        for k in _cluster_list:
            _cluster_mask = ch_id_list[:,:3] == [_ele_id,j,k]
            _cluster_mask = _cluster_mask[:,0] & _cluster_mask[:,1] & _cluster_mask[:,2]
            # _cluster_data = latent_dd[_cluster_mask,:2]
            _cluster_data = np.stack([-latent_dd[_cluster_mask,0],latent_dd[_cluster_mask,1]], axis=1)

            # Seg Data
            _seg_data = _cluster_data.reshape(-1,1,2)
            _seg_data = np.concatenate([_seg_data[:-1], _seg_data[1:]], axis=1)

            _dx = np.abs(_seg_data[:,1,0] - _seg_data[:,0,0])
            _seg_data = _seg_data[_dx < 1,:,:]


            # Seg Time
            _seg_time = ch_id_list[_cluster_mask,3]
            _seg_time = np.diff(_seg_time)
            _seg_time = _seg_time[_dx < 1]

            manifold_vector_list.append(_seg_data)
            manifold_time_list.append(_seg_time)

            
manifold_vector_list = np.concatenate(manifold_vector_list, axis=0)
manifold_time_list = np.concatenate(manifold_time_list, axis=0)



[32m2025-05-19 23:28:51.146[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1m[0/287][0m
[32m2025-05-19 23:28:51.147[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1m[1/287][0m
[32m2025-05-19 23:28:51.147[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1m[2/287][0m
[32m2025-05-19 23:28:51.147[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1m[3/287][0m
[32m2025-05-19 23:28:51.147[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1m[4/287][0m
[32m2025-05-19 23:28:51.148[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1m[5/287][0m
[32m2025-05-19 23:28:51.148[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1m[6/287][0m
[32m2025-05-19 23:28:51.148[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1m[7/287][0m
[32m2025-05-19 23:28:51.148[0m | [1mI

In [372]:
manifold_speed_list = (manifold_vector_list[:,1,:] - manifold_vector_list[:,0,:])
manifold_speed_list = manifold_speed_list/manifold_time_list[:,np.newaxis]
manifold_speed_list



array([[-0.005848  ,  0.01975725],
       [-0.0023479 , -0.01078043],
       [-0.00854843,  0.04511589],
       ...,
       [-0.00510339, -0.00956862],
       [ 0.00388857, -0.03006081],
       [-0.00377305,  0.08917225]], shape=(7447, 2))

### Plot Velocity Field

In [373]:
# latent_dd = _pca_inst.transform(latent_space_inst[all_id_list[:,0]<seg0])
# latent_dd = _pca_inst.transform(latent_space_inst[(all_id_list[:,0]>=seg0) & (all_id_list[:,0]<seg1)])
latent_dd = _pca_inst.transform(latent_space_inst[all_id_list[:,0]>=seg1])

# latent_dd = _pca_inst.transform(latent_space_inst[all_id_list[:,0]<seg1])
# latent_dd = _pca_inst.transform(latent_space_inst[:])


# latent_start_dd = _pca_inst.transform(latent_space_start_inst[all_start_id_list[:,0]<seg0])
# latent_start_dd = _pca_inst.transform(latent_space_start_inst[(all_start_id_list[:,0]>=seg0) & (all_start_id_list[:,0]<seg1)])
latent_start_dd = _pca_inst.transform(latent_space_start_inst[all_start_id_list[:,0]>=seg1])

# latent_start_dd = _pca_inst.transform(latent_space_start_inst[all_start_id_list[:,0]<seg1])
# latent_start_dd = _pca_inst.transform(latent_space_start_inst[:])



if FLAG23:
    latent_dd_tmp = latent_dd.copy()
    latent_dd_tmp[:,1] = latent_dd[:,2]
    latent_dd = latent_dd_tmp 

    
    latent_dd_tmp = latent_start_dd.copy()
    latent_dd_tmp[:,1] = latent_start_dd[:,2]
    latent_start_dd = latent_dd_tmp 


In [374]:
# 获取所有点的坐标
all_points = manifold_vector_list.reshape(-1, 2)
x_min, x_max = all_points[:, 0].min(), all_points[:, 0].max()
y_min, y_max = all_points[:, 1].min(), all_points[:, 1].max()

# 定义网格大小
grid_size = 50  # 可根据需要调整
x_bins = np.linspace(x_min, x_max, grid_size + 1)
y_bins = np.linspace(y_min, y_max, grid_size + 1)


# 获取起点坐标
start_points = manifold_vector_list[:, 0, :]

# 计算每个起点所在的网格索引
x_indices = np.digitize(start_points[:, 0], x_bins) - 1
y_indices = np.digitize(start_points[:, 1], y_bins) - 1

# 初始化速度场和计数器
velocity_field = np.zeros((grid_size, grid_size, 2))
count = np.zeros((grid_size, grid_size))

# 累加速度向量
for xi, yi, v in zip(x_indices, y_indices, manifold_speed_list):
    if 0 <= xi < grid_size and 0 <= yi < grid_size:
        velocity_field[yi, xi] += v
        count[yi, xi] += 1

# 计算平均速度
with np.errstate(divide='ignore', invalid='ignore'):
    average_velocity = np.divide(velocity_field, count[:, :, np.newaxis])
    average_velocity[np.isnan(average_velocity)] = 0  # 将 NaN 替换为 0

# 去除噪声样本导致的向量统计
threshold = 10  # 最小样本数量阈值
average_velocity[count < threshold] = 0  # 将低于阈值的单元速度设为零




In [375]:

x_centers = (x_bins[:-1] + x_bins[1:]) / 2
y_centers = (y_bins[:-1] + y_bins[1:]) / 2
alpha = np.clip(count / count.max(), 0.2, 1.0)
X, Y = np.meshgrid(x_centers, y_centers)

U = average_velocity[:, :, 0]
V = average_velocity[:, :, 1]
speed_mask = (U != 0) | (V != 0)

# Plot Speed Field
fig, axis = plt.subplots(1,1, figsize = (9,9))
# Plot PCA
axis.scatter(-latent_dd[:,0],latent_dd[:,1], color = 'lightgray', s=0.05)

# Plot Start Point
axis.scatter(-latent_start_dd[:,0],latent_start_dd[:,1], color = 'red',s=0.001)




# Plot Manifold
axis.quiver(X[speed_mask], Y[speed_mask], U[speed_mask], V[speed_mask], alpha = 0.7, scale=1, scale_units='xy', angles='xy')
# axis.quiver(X, Y, U, V, scale=1, alpha = alpha, scale_units='xy', angles='xy')



axis.set_xlim(-2, 3)
axis.set_ylim(-3, 3)
axis.set_aspect('equal', adjustable='box')
# axis[0].set_box_aspect(1)
axis.set_title("Latent Space")

axis.set_xlabel('Latent Dimension 1')
axis.set_ylabel('Latent Dimension 2')
axis.set_title('Velocity Field in Latent Space')
# plt.grid(True)
fig.show()
