# Import

In [258]:
import os
import re
import gc
import sys

from loguru import logger
import numpy as np
import random

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

# %matplotlib qt
%matplotlib qt

# Detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Input Layer

In [259]:
def SearchELE(rootPath, ele_pattern = re.compile(r"(.+?)_归档")):
    '''==================================================
        Search all electrode directories in the rootPath
        Parameter: 
            rootPath: current search path
            ele_pattern: electrode dir name patten
        Returen:
            ele_list: list of electrode directories
        ==================================================
    '''
    ele_list = []
    for i in os.listdir(rootPath):
        match_ele = ele_pattern.match(i)
        if match_ele:
            ele_list.append([os.path.join(rootPath, i),match_ele.group(1)])
    return ele_list



In [None]:
rootPath = "D:/Baihm/EISNN/Archive/"
ele_list = SearchELE(rootPath)
n_ele = len(ele_list)
logger.info(f"Search in {rootPath} and find {n_ele:03d} electrodes")


[32m2025-04-16 19:39:46.508[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mSearch in D:/Baihm/EISNN/Archive/ and find 218 electrodes[0m


In [261]:
Blacklist = [
    '01067093',     # Not look like EIS
    '01067094',     # Connection Error
    '02017385',     # Connection Error
    '05127177',     # Open to Short
    '06047729',     # Open to Short
    '06047730',     # Open to Short
    '06047731',     # Open to Short
    '09207024',     # Connection Error
    '10017038',     # Connection Error
    '10037050',     # Connection Error
    '10047056',     # Connection Error
    '10057069',     # Connection Error
    '10057083',     # Always Open
    '10057084',     # Chaos
    '10057087',     # Connection Error
    '22017367',     # Connection Error
    '22017371',     # Chaos
]

GrayList = [
    '10037051',     # Connection Error
    '10037052',     # Connection Error
    '10057071',     # Connection Error
    '10067077',     # Wired Shape like connection error
    '10150201',     # Wired Shape
    '10150202',     # Wired Shape
    '10150203',     # Wired Shape
    '20037515',     # Wired Shape
    '20037516',     # Wired Shape
    '20037517',     # Wired Shape
    '22037378',     # Connection Error
    '22037380',     # Connection Error
    '22047376',     # Connection Error

]

In [262]:

MODEL_SUFFIX = "Matern12_Ver01"

all_data_list = []
all_id_list = []

_ch_pattern = re.compile(r"ch_(\d{3})")

for i in range(n_ele):
# for i in range(3):
    if ele_list[i][1] in Blacklist:
        continue

    fd_pt = os.path.join(ele_list[i][0], MODEL_SUFFIX, f"{ele_list[i][1]}_{MODEL_SUFFIX}.pt")
    if not os.path.exists(fd_pt):
        # logger.warning(f"{fd_pt} does not exist")
        continue
    data_pt = torch.load(fd_pt, weights_only=False)
    _meta_group = data_pt["meta_group"]
    _data_group = data_pt["data_group"]

    n_day       = _meta_group["n_day"]
    n_ch        = _meta_group["n_ch"]
    n_valid_ch  = len(_data_group["Channels"])

    # ignore abnormal ele
    if n_ch != 128 or n_valid_ch != n_ch:
        if n_day < 5 or n_valid_ch <= 100:
            continue

    logger.info(f"ELE [{i}/{n_ele}]: {ele_list[i][0]}")


    ele_data_list = []
    ele_id_list = []
    # Iteration by channel
    for j in _data_group['Channels']:
        _ch_data = _data_group[j]["y_eval"]
        ele_data_list.append(_ch_data)

        _ch_id = _ch_pattern.match(j)
        _ch_id = int(_ch_id.group(1))

        _id = [i, _ch_id] * np.shape(_ch_data)[0]
        _id = np.array(_id).reshape(-1,2)
        ele_id_list.append(_id)
        
    
    all_data_list.append(ele_data_list)
    all_id_list.append(ele_id_list)
    
    # ele_data_list = np.vstack(ele_data_list)
    # all_data_list.append(ele_data_list)

# all_data_list = np.vstack(all_data_list)
# all_id_list = np.vstack(all_id_list)


del data_pt, _meta_group, _data_group, _ch_data
gc.collect()



[32m2025-04-16 19:39:46.641[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mELE [0/218]: D:/Baihm/EISNN/Archive/01037160_归档[0m
[32m2025-04-16 19:39:46.692[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mELE [1/218]: D:/Baihm/EISNN/Archive/01037161_归档[0m
[32m2025-04-16 19:39:46.756[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mELE [2/218]: D:/Baihm/EISNN/Archive/01037162_归档[0m
[32m2025-04-16 19:39:46.808[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mELE [5/218]: D:/Baihm/EISNN/Archive/01067095_归档[0m
[32m2025-04-16 19:39:46.878[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mELE [9/218]: D:/Baihm/EISNN/Archive/02027373_归档[0m
[32m2025-04-16 19:39:46.901[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mELE [10/218]: D:/Baihm/EISNN/Archive/02027390_归档[0m
[32m2025-04-16 19:39:46.944[0m | [1m

847

# Helper

In [263]:
def load_all2seq(data_list, id_list = None):
    seq_data_list    = []
    seq_id_list     = []
    for i in range(len(data_list)):
        for j in range(len(data_list[i])):
            seq_data_list.append(data_list[i][j])
            if id_list is not None:
                seq_id_list.append(id_list[i][j])
    return seq_data_list, seq_id_list

def load_all2ch(data_list, id_list = None):
    ch_data_list, ch_id_list = load_all2seq(data_list, id_list)
    ch_data_list = np.vstack(ch_data_list)
    if id_list is not None:
        ch_id_list = np.vstack(ch_id_list)
    return ch_data_list, ch_id_list

In [264]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# EISVAE_CNN

## Model

In [None]:
class EISDataset_CNN(Dataset):
    def __init__(self, data_list, id_list = None):
        # data_list: n x m x k x l x 2 list
        # n: number of electrodes
        # m: number of channels
        # k: number of timestamps
        # l: number of freq as dimensions
        # 2: real and imaginary parts after logrithm

        _data, _id  = load_all2ch(data_list, id_list)
        _data = [torch.tensor(x, dtype=torch.float32) for x in _data]

        self.data = _data
        self.id = _id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Return [2,101] for Conv1D
        return self.data[idx].permute(1,0)  # [2,101] [in_ch, in_dim]
    
class EISVAE_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Layer Parameters
        self.in_ch  = 2
        self.in_dim = 101
        self.hid_ch = 16
        self.z_dim  = 32
        self.k      = 5

        # Encoder
        self.conv1  = nn.Conv1d(self.in_ch, self.hid_ch, kernel_size=self.k, stride=1, padding=self.k//2)
        self.bn1   = nn.BatchNorm1d(self.hid_ch)
        # self.ln1    = nn.LayerNorm(self.hid_ch)

        # self.pool1  = nn.AdaptiveAvgPool1d(1)
        self.pool1  = nn.AdaptiveMaxPool1d(1)

        self.fc_mu  = nn.Linear(self.hid_ch, self.z_dim)
        self.fc_lv  = nn.Linear(self.hid_ch, self.z_dim)

        # Decoder
        self.fc_dec = nn.Linear(self.z_dim, self.hid_ch * self.in_dim)
        self.deconv = nn.Conv1d(self.hid_ch, self.in_ch, kernel_size=self.k, padding=self.k//2)

    def encode(self, x):
        h = F.relu(self.bn1(self.conv1(x)))  # [B,hid_ch,in_dim]
        h = self.pool1(h).squeeze(-1)         # [B,hid_ch]

        # h = self.conv1(x)                   # [B,hid_ch,in_dim]
        # h = h.permute(0, 2, 1)              # [B,in_dim,hid_ch]
        # h = self.ln1(h)                     # [B,in_dim,hid_ch]
        # h = F.relu(h)                       # [B,in_dim,hid_ch]
        # h = h.permute(0, 2, 1)              # [B,hid_ch,in_dim]

        # h = self.pool1(h).squeeze(-1)        # [B,hid_ch]

        return self.fc_mu(h), self.fc_lv(h) # [B,z_dim], [B,z_dim]

    def reparam(self, mu, logvar):
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.fc_dec(z)                  # [B,hid_ch]
        h = h.view(-1, self.conv1.out_channels, self.in_dim)  # [B,hid_ch,in_dim]
        return self.deconv(h)               # [B,in_ch,in_dim]

    def forward(self, x):
        mu, lv = self.encode(x)
        z = self.reparam(mu, lv)
        x_rec = self.decode(z)
        return x_rec, mu, lv
        

## Tran & Eval

In [209]:
def vae_loss(x_rec, x, mu, logvar, kld_weight=1e-3):
    rec = F.mse_loss(x_rec, x)
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return rec + kld_weight * kld

# ===== 训练函数 =====
def train_EISVAECNN(model, train_list, val_list, num_epochs=20, batch_size=64, lr=1e-3):
    train_ds = EISDataset_CNN(train_list)
    val_ds   = EISDataset_CNN(val_list)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_loss_recorder = []
    eval_loss_recorder = []

    for epoch in range(1, num_epochs+1):
        model.train()
        train_loss = 0
        for x in train_loader:
            x = x.to(device)

            x_rec, mu, lv = model(x)
            loss = vae_loss(x_rec, x, mu, lv)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * x.size(0)

        # 验证
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x in val_loader:
                x = x.to(device)

                x_rec, mu, lv = model(x)
                loss = vae_loss(x_rec, x, mu, lv)
                # mu, lv = model.encode(x)
                # x_rec = model.decode(mu)
                # loss = vae_loss(x_rec, x, mu, lv)
                
                val_loss += loss.item() * x.size(0)

        train_loss /= len(train_ds)
        val_loss   /= len(val_ds)

        train_loss_recorder.append(train_loss)
        eval_loss_recorder.append(val_loss)
        print(f"Epoch {epoch}/{num_epochs}  Train: {train_loss:.6f}  Val: {val_loss:.6f}")

    return model, train_loss_recorder, eval_loss_recorder


# ===== 可视化重建 =====
def visualize_EISVAECNN(model, data_list, num=5):
    ds = EISDataset_CNN(data_list)
    loader = DataLoader(ds, batch_size=num, shuffle=True)
    x = next(iter(loader)).to(device)   # [num,2,101]
    model.eval()
    with torch.no_grad():
        x_rec, mu, lv = model(x)

    x = x.cpu().numpy()
    x_rec = x_rec.cpu().numpy()

    for i in range(num):
        plt.figure(figsize=(6,3))
        # 实部
        plt.subplot(1,2,1)
        plt.plot(x[i,0], label="orig")
        plt.plot(x_rec[i,0], '--', label="rec")
        plt.title(f"Sample {i} Real")
        plt.legend()
        # 虚部
        plt.subplot(1,2,2)
        plt.plot(x[i,1], label="orig")
        plt.plot(x_rec[i,1], '--', label="rec")
        plt.title(f"Sample {i} Imag")
        plt.legend()
        plt.tight_layout()
        plt.show()



## Running

In [210]:
vae_cnn = EISVAE_CNN().to(device)
print(count_parameters(vae_cnn))

54786


In [211]:
num_epochs = 50
batch_size=128
lr=1e-3
random_state = None

train_list, val_list = train_test_split(all_data_list, test_size=0.2, random_state=random_state)
vae_cnn, train_loss, eval_loss = train_EISVAECNN(vae_cnn, train_list, val_list, num_epochs=num_epochs, batch_size=batch_size, lr=lr)

Epoch 1/50  Train: 0.934713  Val: 0.077396
Epoch 2/50  Train: 0.064415  Val: 0.043070
Epoch 3/50  Train: 0.046682  Val: 0.034819
Epoch 4/50  Train: 0.041548  Val: 0.034531
Epoch 5/50  Train: 0.038677  Val: 0.026297
Epoch 6/50  Train: 0.037090  Val: 0.025684
Epoch 7/50  Train: 0.034826  Val: 0.029447
Epoch 8/50  Train: 0.033338  Val: 0.022685
Epoch 9/50  Train: 0.032089  Val: 0.031085
Epoch 10/50  Train: 0.029957  Val: 0.034249
Epoch 11/50  Train: 0.028945  Val: 0.025520
Epoch 12/50  Train: 0.028416  Val: 0.022151
Epoch 13/50  Train: 0.027731  Val: 0.021714
Epoch 14/50  Train: 0.026826  Val: 0.020643
Epoch 15/50  Train: 0.026423  Val: 0.020434
Epoch 16/50  Train: 0.025923  Val: 0.019325
Epoch 17/50  Train: 0.025127  Val: 0.018264
Epoch 18/50  Train: 0.025183  Val: 0.021697
Epoch 19/50  Train: 0.024773  Val: 0.020805
Epoch 20/50  Train: 0.024102  Val: 0.019732
Epoch 21/50  Train: 0.023932  Val: 0.018427
Epoch 22/50  Train: 0.023751  Val: 0.025145
Epoch 23/50  Train: 0.023142  Val: 0.0198

### Plot Loss

In [217]:
plt.figure()
# plt.semilogy(train_loss, label="train")
# plt.semilogy(eval_loss, label="eval")
plt.semilogy(train_loss, label="train")
plt.semilogy(eval_loss, label="eval")
plt.legend()
plt.show()


### Plot Eval

In [213]:
visualize_EISVAECNN(vae_cnn, val_list)

## Analysis

In [238]:
def VAE_latent(model, dataset, batch_size=64):
    # x: [B,2,101]
    ds = EISDataset_CNN(dataset)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True)

    _len_data = ds.__len__()
    _poi = 0

    latent_space_inst = []

    model.eval()
    with torch.no_grad():
        for x in loader:
            x = x.to(device)
            mu, lv = model.encode(x)
            latent_space_inst.append(mu.cpu().numpy())

            _poi = _poi + x.size(0)
            if _poi % 1000 == 0:
                logger.info(f"[{_poi}]/[{_len_data}]")

    latent_space_inst = np.concatenate(latent_space_inst, axis=0)  # [B,z_dim]

    if latent_space_inst.shape[1] > 2:
        _pca_inst = PCA(n_components=2)
        latent_dd = _pca_inst.fit_transform(latent_space_inst)
    else:
        latent_dd = latent_space_inst

    return latent_dd


In [239]:
latent_expr = VAE_latent(vae_cnn, all_data_list)

[32m2025-04-15 17:56:07.477[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[8000]/[333535][0m
[32m2025-04-15 17:56:07.569[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[16000]/[333535][0m
[32m2025-04-15 17:56:07.635[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[24000]/[333535][0m
[32m2025-04-15 17:56:07.707[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[32000]/[333535][0m
[32m2025-04-15 17:56:07.775[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[40000]/[333535][0m
[32m2025-04-15 17:56:07.842[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[48000]/[333535][0m
[32m2025-04-15 17:56:07.905[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[56000]/[333535][0m
[32m2025-04-15 17:56:07.966[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_lat

In [240]:

plt.figure(figsize=(9, 9))
plt.scatter(latent_expr[:, 0], -latent_expr[:, 1], alpha=0.5, s = 0.001)

plt.gca().set_aspect('equal', adjustable='box')
plt.title("Latent Space")
plt.xlabel("Latent Dimension 1")
plt.ylabel("Latent Dimension 2")
# plt.grid()
plt.show()

# EISVAE_DNN

In [218]:
a,_ = load_all2ch(all_data_list, all_id_list)
b = np.concatenate((a[:,:,0],a[:,:,1]), axis=1)
b.shape

(333535, 202)

## Model

In [219]:
class EISDataset_DNN(Dataset):
    def __init__(self, data_list, id_list = None):
        # data_list: n x m x k x l x 2 list
        # n: number of electrodes
        # m: number of channels
        # k: number of timestamps
        # l: number of freq as dimensions
        # 2: real and imaginary parts after logrithm

        _data, _id  = load_all2ch(data_list, id_list)
        _data       = np.concatenate((_data[:,:,0],_data[:,:,1]), axis=1)  # [n,2,101] -> [n,202]
        # _id         = np.concatenate((_id[:,:,0],_id[:,:,1]), axis=1)  # [n,2,101] -> [n,202]
        _data       = [torch.tensor(x, dtype=torch.float32) for x in _data]

        self.data   = _data
        self.id     = _id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]  # [202] [in_ch x in_dim]
    
class EISVAE_DNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Layer Parameters
        self.input_dim=202
        self.hidden_dims=[512,256,128,64]
        self.z_dim  = 2

        input_dim = self.input_dim
        hidden_dims = self.hidden_dims
        z_dim = self.z_dim

# --- Encoder ---
        self.enc_fc1 = nn.Linear(input_dim, hidden_dims[0])  # 160 → 512
        self.enc_fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])  # 512 → 256
        self.enc_fc3 = nn.Linear(hidden_dims[1], hidden_dims[2])  # 256 → 128
        self.enc_fc4 = nn.Linear(hidden_dims[2], hidden_dims[3])  # 128 → 64
        # mean & logvar
        self.fc_mu    = nn.Linear(hidden_dims[3], z_dim)     # 64 → 2
        self.fc_logvar= nn.Linear(hidden_dims[3], z_dim)     # 64 → 2

        # --- Decoder ---
        self.dec_fc1 = nn.Linear(z_dim,        hidden_dims[3])  # 2 → 64
        self.dec_fc2 = nn.Linear(hidden_dims[3], hidden_dims[2])# 64 → 128
        self.dec_fc3 = nn.Linear(hidden_dims[2], hidden_dims[1])# 128 → 256
        self.dec_fc4 = nn.Linear(hidden_dims[1], hidden_dims[0])# 256 → 512
        self.dec_fc5 = nn.Linear(hidden_dims[0], input_dim)     # 512 → 160

    def encode(self, x):
        h = F.relu(self.enc_fc1(x))
        h = F.relu(self.enc_fc2(h))
        h = F.relu(self.enc_fc3(h))
        h = F.relu(self.enc_fc4(h))
        mu     = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.dec_fc1(z))
        h = F.relu(self.dec_fc2(h))
        h = F.relu(self.dec_fc3(h))
        h = F.relu(self.dec_fc4(h))
        x_rec = self.dec_fc5(h)
        return x_rec             # [B,in_ch,in_dim]

    def forward(self, x):
        mu, lv = self.encode(x)
        z = self.reparam(mu, lv)
        x_rec = self.decode(z)
        return x_rec, mu, lv
        

## Train & Eval

In [220]:

def vae_loss(x_rec, x, mu, logvar, kld_weight=1e-3):
    """
    重建误差 + KL 散度
    """
    rec_loss = F.mse_loss(x_rec, x, reduction='mean')
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return rec_loss + kld_weight * kld

def train_EISVAEDNN(model, train_list, val_list, num_epochs=20, batch_size=64, lr=1e-3):
    train_ds = EISDataset_DNN(train_list)
    val_ds   = EISDataset_DNN(val_list)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_loss_recorder = []
    eval_loss_recorder = []

    for epoch in range(1, num_epochs+1):
        model.train()
        train_loss = 0
        for x in train_loader:
            x = x.to(device)
            x_rec, mu, lv = model(x)
            loss = vae_loss(x_rec, x, mu, lv)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * x.size(0)

        # 验证
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x in val_loader:
                x = x.to(device)
                x_rec, mu, lv = model(x)
                loss = vae_loss(x_rec, x, mu, lv)
                val_loss += loss.item() * x.size(0)
        train_loss /= len(train_ds)
        val_loss   /= len(val_ds)
        print(f"Epoch {epoch}/{num_epochs}  Train: {train_loss:.4f}  Val: {val_loss:.4f}")


        train_loss_recorder.append(train_loss)
        eval_loss_recorder.append(val_loss)

    return model, train_loss_recorder, eval_loss_recorder


# ===== 可视化重建 =====
def visualize_EISVAEDNN(model, data_list, num=5):
    ds = EISDataset_DNN(data_list)
    loader = DataLoader(ds, batch_size=num, shuffle=True)
    x = next(iter(loader)).to(device)   # [num,2,101]
    model.eval()
    with torch.no_grad():
        x_rec, mu, lv = model(x)

    x = x.cpu().numpy()
    x_rec = x_rec.cpu().numpy()


    for i in range(num):
        plt.figure(figsize=(6,3))
        # 实部
        plt.subplot(1,2,1)
        plt.plot(x[i,:101], label="orig")
        plt.plot(x_rec[i,:101], '--', label="rec")
        plt.title(f"Sample {i} Real")
        plt.legend()
        # 虚部
        plt.subplot(1,2,2)
        plt.plot(x[i,101:], label="orig")
        plt.plot(x_rec[i,101:], '--', label="rec")
        plt.title(f"Sample {i} Imag")
        plt.legend()
        plt.tight_layout()
        plt.show()




## Running

In [221]:
vae_dnn = EISVAE_DNN().to(device)
print(count_parameters(vae_dnn))

553422


In [222]:
num_epochs = 50
batch_size=128
lr=1e-3
random_state = None

train_list, val_list = train_test_split(all_data_list, test_size=0.2, random_state=random_state)
vae_dnn, train_loss, eval_loss = train_EISVAEDNN(vae_dnn, train_list, val_list, num_epochs=num_epochs, batch_size=batch_size, lr=lr)

Epoch 1/50  Train: 0.5635  Val: 0.0800
Epoch 2/50  Train: 0.0654  Val: 0.0504
Epoch 3/50  Train: 0.0431  Val: 0.0436
Epoch 4/50  Train: 0.0363  Val: 0.0290
Epoch 5/50  Train: 0.0322  Val: 0.0465
Epoch 6/50  Train: 0.0284  Val: 0.0277
Epoch 7/50  Train: 0.0270  Val: 0.0237
Epoch 8/50  Train: 0.0251  Val: 0.0235
Epoch 9/50  Train: 0.0236  Val: 0.0222
Epoch 10/50  Train: 0.0239  Val: 0.0209
Epoch 11/50  Train: 0.0311  Val: 0.0198
Epoch 12/50  Train: 0.0222  Val: 0.0220
Epoch 13/50  Train: 0.0222  Val: 0.0219
Epoch 14/50  Train: 0.0324  Val: 0.0233
Epoch 15/50  Train: 0.0239  Val: 0.0218
Epoch 16/50  Train: 0.0230  Val: 0.0226
Epoch 17/50  Train: 0.0218  Val: 0.0216
Epoch 18/50  Train: 0.0214  Val: 0.0233
Epoch 19/50  Train: 0.0220  Val: 0.0265
Epoch 20/50  Train: 0.0209  Val: 0.0252
Epoch 21/50  Train: 0.0259  Val: 0.0271
Epoch 22/50  Train: 0.0209  Val: 0.0212
Epoch 23/50  Train: 0.0211  Val: 0.0208
Epoch 24/50  Train: 0.0209  Val: 0.0358
Epoch 25/50  Train: 0.0214  Val: 0.0211
Epoch 26/

### Plot

In [223]:
plt.figure()
# plt.semilogy(train_loss, label="train")
# plt.semilogy(eval_loss, label="eval")
plt.semilogy(train_loss, label="train")
plt.semilogy(eval_loss, label="eval")
plt.legend()
plt.show()


In [233]:
visualize_EISVAEDNN(vae_dnn, val_list)

## Analysis

In [231]:
def VAE_latent(model, dataset, batch_size=64):
    # x: [B,2,101]
    ds = EISDataset_DNN(dataset)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True)

    _len_data = ds.__len__()
    _poi = 0

    latent_space_inst = []

    model.eval()
    with torch.no_grad():
        for x in loader:
            x = x.to(device)
            mu, lv = model.encode(x)
            latent_space_inst.append(mu.cpu().numpy())

            _poi = _poi + x.size(0)
            if _poi % 1000 == 0:
                logger.info(f"[{_poi}]/[{_len_data}]")

    latent_space_inst = np.concatenate(latent_space_inst, axis=0)  # [B,z_dim]

    if latent_space_inst.shape[1] > 2:
        _pca_inst = PCA(n_components=2)
        latent_dd = _pca_inst.fit_transform(latent_space_inst)
    else:
        latent_dd = latent_space_inst

    return latent_dd


In [234]:
latent_expr = VAE_latent(vae_dnn, all_data_list)

[32m2025-04-15 17:33:34.252[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[8000]/[333535][0m
[32m2025-04-15 17:33:34.316[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[16000]/[333535][0m
[32m2025-04-15 17:33:34.388[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[24000]/[333535][0m
[32m2025-04-15 17:33:34.447[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[32000]/[333535][0m
[32m2025-04-15 17:33:34.504[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[40000]/[333535][0m
[32m2025-04-15 17:33:34.555[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[48000]/[333535][0m
[32m2025-04-15 17:33:34.612[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[56000]/[333535][0m
[32m2025-04-15 17:33:34.662[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_lat

In [237]:

plt.figure(figsize=(9, 9))
plt.scatter(latent_expr[:, 0], -latent_expr[:, 1], alpha=0.5, s = 0.001)

# plt.gca().set_aspect('equal', adjustable='box')
plt.title("Latent Space")
plt.xlabel("Latent Dimension 1")
plt.ylabel("Latent Dimension 2")
# plt.grid()
plt.show()

# EISVAE_CNN_Ver02

## Model

In [None]:
class EISDataset_CNN(Dataset):
    def __init__(self, data_list, id_list = None):
        # data_list: n x m x k x l x 2 list
        # n: number of electrodes
        # m: number of channels
        # k: number of timestamps
        # l: number of freq as dimensions
        # 2: real and imaginary parts after logrithm

        _data, _id  = load_all2ch(data_list, id_list)
        _data = [torch.tensor(x, dtype=torch.float32) for x in _data]

        self.data = _data
        self.id = _id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Return [2,101] for Conv1D
        return self.data[idx].permute(1,0)  # [2,101] [in_ch, in_dim]
   
class Curve2VecEncoder(nn.Module):
    def __init__(self, in_ch, in_dim, hid_ch, 
                 z_dim, kernel_size):
        super().__init__()

        _layers = []
        poi_ch = in_ch
        for _ch in hid_ch:
            _layers.append(nn.Conv1d(poi_ch, _ch, kernel_size=kernel_size, padding=kernel_size//2))
            _layers.append(nn.BatchNorm1d(_ch))
            _layers.append(nn.ReLU())
            poi_ch = _ch

        if len(hid_ch) > 0:
            self.conv = nn.Sequential(*_layers)
            self.pool = nn.AdaptiveAvgPool1d(1)
        else:
            self.conv = nn.Identity()
            self.pool = nn.Flatten(start_dim=1)

        self.fc_mu = nn.Linear(poi_ch, z_dim)
        self.fc_lv = nn.Linear(poi_ch, z_dim)

    def forward(self, x):
        h = self.conv(x)                # [B,ch,in_dim]
        h = self.pool(h).squeeze(-1)    # [B,ch]
        return self.fc_mu(h), self.fc_lv(h) 


class Curve2VecDecoder(nn.Module):
    def __init__(self, out_ch, out_dim, hid_ch, 
                 z_dim, kernel_size):
        super().__init__()
        self.hid_ch = hid_ch
        self.out_dim = out_dim


        self.hid_ch.append(out_ch)
        poi_ch = self.hid_ch[0]

        self.fc_expand = nn.Linear(z_dim, poi_ch * out_dim)

        _layers = []
        for _ch in self.hid_ch[1:]:
            _layers.append(nn.ConvTranspose1d(poi_ch, _ch, kernel_size=kernel_size, padding=kernel_size//2))
            # _layers.append(nn.BatchNorm1d(_ch))
            # _layers.append(nn.ReLU())
            poi_ch = _ch
        
        self.deconv = nn.Sequential(*_layers)


    def forward(self, z):
        h = self.fc_expand(z)           # [B,ch*in_dim]
        h = h.view(-1, self.hid_ch[0], self.out_dim)
        h = self.deconv(h)               # [B,in_ch,in_dim]
        return h                        # [B,in_ch,in_dim]

class Curve2VecVAE(nn.Module):
    def __init__(self, in_ch=2, in_dim=101, 
                 enc_hid_ch = [16,32],
                 dec_hid_ch = [16],
                 z_dim = 16, kernel_size = 5):
        super().__init__()
        self.encoder = Curve2VecEncoder(in_ch, in_dim, enc_hid_ch, z_dim, kernel_size)
        self.decoder = Curve2VecDecoder(in_ch, in_dim, dec_hid_ch, z_dim, kernel_size)

    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, lv = self.encoder(x)
        z = self.reparam(mu, lv)
        x_rec = self.decoder(z)
        return x_rec, mu, lv 



## Tran & Eval

In [309]:
def vae_loss(x_rec, x, mu, logvar, kld_weight=1e-3):
    rec = F.mse_loss(x_rec, x)
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return rec + kld_weight * kld

# ===== 训练函数 =====
def train_EISVAECNN(model, train_list, val_list, num_epochs=20, batch_size=64, lr=1e-3):
    train_ds = EISDataset_CNN(train_list)
    val_ds   = EISDataset_CNN(val_list)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_loss_recorder = []
    eval_loss_recorder = []

    for epoch in range(1, num_epochs+1):
        model.train()
        train_loss = 0
        for x in train_loader:
            x = x.to(device)

            x_rec, mu, lv = model(x)
            loss = vae_loss(x_rec, x, mu, lv)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * x.size(0)

        # 验证
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x in val_loader:
                x = x.to(device)

                x_rec, mu, lv = model(x)
                loss = vae_loss(x_rec, x, mu, lv)
                # mu, lv = model.encode(x)
                # x_rec = model.decode(mu)
                # loss = vae_loss(x_rec, x, mu, lv)
                
                val_loss += loss.item() * x.size(0)

        train_loss /= len(train_ds)
        val_loss   /= len(val_ds)

        train_loss_recorder.append(train_loss)
        eval_loss_recorder.append(val_loss)
        print(f"Epoch {epoch}/{num_epochs}  Train: {train_loss:.6f}  Val: {val_loss:.6f}")

    return model, train_loss_recorder, eval_loss_recorder


# ===== 可视化重建 =====
def visualize_EISVAECNN(model, data_list, num=5):
    ds = EISDataset_CNN(data_list)
    loader = DataLoader(ds, batch_size=num, shuffle=True)
    x = next(iter(loader)).to(device)   # [num,2,101]
    model.eval()
    with torch.no_grad():
        x_rec, mu, lv = model(x)

    x = x.cpu().numpy()
    x_rec = x_rec.cpu().numpy()

    for i in range(num):
        plt.figure(figsize=(6,3))
        # 实部
        plt.subplot(1,2,1)
        plt.plot(x[i,0], label="orig")
        plt.plot(x_rec[i,0], '--', label="rec")
        plt.title(f"Sample {i} Real")
        plt.legend()
        # 虚部
        plt.subplot(1,2,2)
        plt.plot(x[i,1], label="orig")
        plt.plot(x_rec[i,1], '--', label="rec")
        plt.title(f"Sample {i} Imag")
        plt.legend()
        plt.tight_layout()
        plt.show()



## Running

In [310]:
vae_cnn = Curve2VecVAE().to(device)
print(count_parameters(vae_cnn))

114274


In [311]:
num_epochs = 20
batch_size=128
lr=1e-3
random_state = None

train_list, val_list = train_test_split(all_data_list, test_size=0.2, random_state=random_state)
len(val_list)

25

In [312]:
vae_cnn, train_loss, eval_loss = train_EISVAECNN(vae_cnn, train_list, val_list, num_epochs=num_epochs, batch_size=batch_size, lr=lr)

Epoch 1/20  Train: 0.485089  Val: 0.048830
Epoch 2/20  Train: 0.042080  Val: 0.039679
Epoch 3/20  Train: 0.035384  Val: 0.044900
Epoch 4/20  Train: 0.029256  Val: 0.028124
Epoch 5/20  Train: 0.025810  Val: 0.026033
Epoch 6/20  Train: 0.023547  Val: 0.070596
Epoch 7/20  Train: 0.021427  Val: 0.023325
Epoch 8/20  Train: 0.019186  Val: 0.022017
Epoch 9/20  Train: 0.017826  Val: 0.026279
Epoch 10/20  Train: 0.016813  Val: 0.026739
Epoch 11/20  Train: 0.016373  Val: 0.022048
Epoch 12/20  Train: 0.015846  Val: 0.020739
Epoch 13/20  Train: 0.015374  Val: 0.018574
Epoch 14/20  Train: 0.014692  Val: 0.022532
Epoch 15/20  Train: 0.014402  Val: 0.020956
Epoch 16/20  Train: 0.013980  Val: 0.026401
Epoch 17/20  Train: 0.013617  Val: 0.015903
Epoch 18/20  Train: 0.013200  Val: 0.014884
Epoch 19/20  Train: 0.012993  Val: 0.013660
Epoch 20/20  Train: 0.012549  Val: 0.018641


### Plot Loss

In [313]:
plt.figure()
# plt.semilogy(train_loss, label="train")
# plt.semilogy(eval_loss, label="eval")
plt.semilogy(train_loss, label="train")
plt.semilogy(eval_loss, label="eval")
plt.legend()
plt.show()


### Plot Eval

In [314]:
visualize_EISVAECNN(vae_cnn, val_list)

## Analysis

In [315]:
def VAE_latent(model, dataset, batch_size=64):
    # x: [B,2,101]
    ds = EISDataset_CNN(dataset)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True)

    _len_data = ds.__len__()
    _poi = 0

    latent_space_inst = []

    model.eval()
    with torch.no_grad():
        for x in loader:
            x = x.to(device)
            mu, lv = model.encoder(x)
            latent_space_inst.append(mu.cpu().numpy())

            _poi = _poi + x.size(0)
            if _poi % 1000 == 0:
                logger.info(f"[{_poi}]/[{_len_data}]")

    latent_space_inst = np.concatenate(latent_space_inst, axis=0)  # [B,z_dim]

    if latent_space_inst.shape[1] > 2:
        _pca_inst = PCA(n_components=2)
        latent_dd = _pca_inst.fit_transform(latent_space_inst)
    else:
        latent_dd = latent_space_inst

    return latent_dd


In [316]:
latent_expr = VAE_latent(vae_cnn, all_data_list)

[32m2025-04-17 10:40:04.765[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[8000]/[333535][0m
[32m2025-04-17 10:40:04.870[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[16000]/[333535][0m
[32m2025-04-17 10:40:04.954[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[24000]/[333535][0m
[32m2025-04-17 10:40:05.036[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[32000]/[333535][0m
[32m2025-04-17 10:40:05.113[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[40000]/[333535][0m
[32m2025-04-17 10:40:05.184[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[48000]/[333535][0m
[32m2025-04-17 10:40:05.253[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m20[0m - [1m[56000]/[333535][0m
[32m2025-04-17 10:40:05.328[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_lat

In [317]:

plt.figure(figsize=(9, 9))
plt.scatter(latent_expr[:, 0], -latent_expr[:, 1], alpha=0.5, s = 0.001)

plt.gca().set_aspect('equal', adjustable='box')
plt.title("Latent Space")
plt.xlabel("Latent Dimension 1")
plt.ylabel("Latent Dimension 2")
# plt.grid()
plt.show()

# Reference

## CNN + DNN

In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class Encoder(nn.Module):
#     def __init__(self, in_ch, in_dim, conv_channels, dense_dims, z_dim, kernel_size=3):
#         super().__init__()
#         layers = []
#         prev_ch = in_ch
#         for ch in conv_channels:
#             layers.append(nn.Conv1d(prev_ch, ch, kernel_size=kernel_size, padding=kernel_size//2))
#             layers.append(nn.BatchNorm1d(ch))
#             layers.append(nn.ReLU())
#             prev_ch = ch
#         self.conv = nn.Sequential(*layers)
#         self.pool = nn.AdaptiveAvgPool1d(1)

#         prev_dim = conv_channels[-1]
#         dense_layers = []
#         for dim in dense_dims:
#             dense_layers.append(nn.Linear(prev_dim, dim))
#             dense_layers.append(nn.ReLU())
#             prev_dim = dim
#         self.dense = nn.Sequential(*dense_layers)

#         self.fc_mu = nn.Linear(prev_dim, z_dim)
#         self.fc_lv = nn.Linear(prev_dim, z_dim)

#     def forward(self, x):
#         h = self.conv(x)              # [B, ch, in_dim]
#         h = self.pool(h).squeeze(-1)  # [B, ch]
#         h = self.dense(h)             # [B, dense_dim]
#         return self.fc_mu(h), self.fc_lv(h)


# class Decoder(nn.Module):
#     def __init__(self, z_dim, out_ch, out_dim, dense_dims, conv_channels, kernel_size=3):
#         super().__init__()
#         prev_dim = z_dim
#         dense_layers = []
#         for dim in dense_dims:
#             dense_layers.append(nn.Linear(prev_dim, dim))
#             dense_layers.append(nn.ReLU())
#             prev_dim = dim
#         self.dense = nn.Sequential(*dense_layers)

#         self.fc_expand = nn.Linear(prev_dim, conv_channels[0] * out_dim)

#         layers = []
#         prev_ch = conv_channels[0]
#         for ch in conv_channels[1:]:
#             layers.append(nn.Conv1d(prev_ch, ch, kernel_size=kernel_size, padding=kernel_size//2))
#             layers.append(nn.BatchNorm1d(ch))
#             layers.append(nn.ReLU())
#             prev_ch = ch
#         layers.append(nn.Conv1d(prev_ch, out_ch, kernel_size=kernel_size, padding=kernel_size//2))
#         self.conv = nn.Sequential(*layers)

#         self.out_dim = out_dim
#         self.conv_channels = conv_channels

#     def forward(self, z):
#         h = self.dense(z)
#         h = self.fc_expand(h)                            # [B, ch * dim]
#         h = h.view(-1, self.conv_channels[0], self.out_dim)
#         return self.conv(h)


# class EISVAE(nn.Module):
#     def __init__(self, in_ch=2, in_dim=101,
#                  enc_hid_ch=[16, 32], enc_hid_dim=[64], 
#                  dec_hid_dim=[64], dec_hid_ch=[32, 16],
#                  z_dim=32, kernel_size=3):
#         super().__init__()
#         self.encoder = Encoder(in_ch, in_dim, enc_hid_ch, enc_hid_dim, z_dim, kernel_size)
#         self.decoder = Decoder(z_dim, in_ch, in_dim, dec_hid_dim, dec_hid_ch, kernel_size)

#     def reparam(self, mu, logvar):
#         std = torch.exp(0.5 * logvar)
#         eps = torch.randn_like(std)
#         return mu + eps * std

#     def forward(self, x):
#         mu, lv = self.encoder(x)
#         z = self.reparam(mu, lv)
#         x_rec = self.decoder(z)
#         return x_rec, mu, lv
