# Import

In [1]:
import os
import re
import gc
import sys

from loguru import logger
import numpy as np
import random

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

# %matplotlib qt
%matplotlib qt

# Detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Input Layer

## Definition

In [12]:
ONLY_SEQ_FLAG = True    
READ_RAW_FLAG = False
freq_list = np.linspace(0,5000-1,101,dtype=int, endpoint=True)

In [13]:
def SearchELE(rootPath, ele_pattern = re.compile(r"(.+?)_归档")):
    '''==================================================
        Search all electrode directories in the rootPath
        Parameter: 
            rootPath: current search path
            ele_pattern: electrode dir name patten
        Returen:
            ele_list: list of electrode directories
        ==================================================
    '''
    ele_list = []
    for i in os.listdir(rootPath):
        _path = os.path.join(rootPath, i)
        if os.path.isdir(_path):
            match_ele = ele_pattern.match(i)
            if match_ele:
                ele_list.append([_path, match_ele.group(1)])
            else:
                ele_list.extend(SearchELE(_path, ele_pattern))

    return ele_list


## ARchive Old

In [14]:
Blacklist = [
    '01067093',     # Not look like EIS
    '01067094',     # Connection Error
    '02017385',     # Connection Error
    '05127177',     # Open to Short
    '06047729',     # Open to Short
    '06047730',     # Open to Short
    '06047731',     # Open to Short
    '09207024',     # Connection Error
    '10017038',     # Connection Error
    '10037050',     # Connection Error
    '10047056',     # Connection Error
    '10057069',     # Connection Error
    '10057083',     # Always Open
    '10057084',     # Chaos
    '10057087',     # Connection Error
    '22017367',     # Connection Error
    '22017371',     # Chaos
]

GrayList = [
    '10037051',     # Connection Error
    '10037052',     # Connection Error
    '10057071',     # Connection Error
    '10067077',     # Wired Shape like connection error
    '10150201',     # Wired Shape
    '10150202',     # Wired Shape
    '10150203',     # Wired Shape
    '20037515',     # Wired Shape
    '20037516',     # Wired Shape
    '20037517',     # Wired Shape
    '22037378',     # Connection Error
    '22037380',     # Connection Error
    '22047376',     # Connection Error

]

In [15]:
if READ_RAW_FLAG:
    rootPath = "D:/Baihm/EISNN/Archive/"
    ele_list = SearchELE(rootPath, re.compile(r"(.+?)_归档"))
    n_ele = len(ele_list)
    logger.info(f"Search in {rootPath} and find {n_ele:03d} electrodes")


In [16]:
if READ_RAW_FLAG:
    DATASET_SUFFIX = "Outlier_Ver03"

    vitro0_start_list = []
    vitro0_start_id_list = []
    vitro0_data_list = []
    vitro0_id_list = []

    n_avaliable = 0

    for i in range(n_ele):
    # for i in range(3):
        fd_pt = os.path.join(ele_list[i][0], DATASET_SUFFIX, f"{ele_list[i][1]}_{DATASET_SUFFIX}.pt")
        if not os.path.exists(fd_pt):
            logger.warning(f"{fd_pt} does not exist")
            continue
        data_pt = torch.load(fd_pt, weights_only=False)
        _meta_group = data_pt["meta_group"]
        _data_group = data_pt["data_group"]

        n_day       = _meta_group["n_day"]
        n_ch        = _meta_group["n_ch"]
        n_valid_ch  = len(_data_group["Channels"])


        logger.info(f"ELE [{i}/{n_ele}]: {ele_list[i][0]}")

        n_avaliable = n_avaliable + 1



        # Iteration by channel
        for j in _data_group['Channels']:
            _ch_data = _data_group[j]["chData"]

            if ONLY_SEQ_FLAG:
                eis_seq = _data_group[j]["eis_seq"]
                _ch_data = _ch_data[eis_seq,:,:]


            _ch_data_log = np.log(_ch_data[:,1,:] + 1j*_ch_data[:,2,:])
            _ch_data[:,1,:] = np.real(_ch_data_log)
            _ch_data[:,2,:] = np.imag(_ch_data_log)
            if _ch_data.shape[2] == 5000:
                _ch_data = np.hstack((_ch_data[:,1,freq_list],_ch_data[:,2,freq_list]))
            else:
                _ch_data = np.hstack((_ch_data[:,1,:],_ch_data[:,2,:]))
            vitro0_data_list.append(_ch_data)
            vitro0_start_list.append(_ch_data[0,:])


            _ch_id = j

            _id = [i, _ch_id] * np.shape(_ch_data)[0]
            _id = np.array(_id).reshape(-1,2)
            _eis_cluster = _data_group[j]['eis_cluster']
            _id = np.hstack((_id, _eis_cluster.reshape(-1,1)))
            
            vitro0_id_list.append(_id)
            vitro0_start_id_list.append(_id[0,:])





    vitro0_data_list = np.vstack(vitro0_data_list)
    vitro0_id_list = np.vstack(vitro0_id_list)
    vitro0_start_list = np.vstack(vitro0_start_list)
    vitro0_start_id_list = np.vstack(vitro0_start_id_list)

    vitro0_ele_list = [i[1] for i in ele_list]

    logger.info(f"Total {vitro0_data_list.shape[0]} data points from {n_avaliable} electrodes")

    del data_pt, _meta_group, _data_group, _ch_data



## Archive New

In [17]:
if READ_RAW_FLAG:
    rootPath = "D:/Baihm/EISNN/Archive_New/"
    ele_list = SearchELE(rootPath,re.compile(r"(.+?)_归档"))
    n_ele = len(ele_list)
    logger.info(f"Search in {rootPath} and find {n_ele:03d} electrodes")


In [18]:
if READ_RAW_FLAG:
    DATASET_SUFFIX = "Outlier_Ver02"

    vitro1_start_list = []
    vitro1_start_id_list = []
    vitro1_data_list = []
    vitro1_id_list = []

    n_avaliable = 0

    for i in range(n_ele):
    # for i in range(3):
        fd_pt = os.path.join(ele_list[i][0], DATASET_SUFFIX, f"{ele_list[i][1]}_{DATASET_SUFFIX}.pt")
        if not os.path.exists(fd_pt):
            logger.warning(f"{fd_pt} does not exist")
            continue
        data_pt = torch.load(fd_pt, weights_only=False)
        _meta_group = data_pt["meta_group"]
        _data_group = data_pt["data_group"]

        n_day       = _meta_group["n_day"]
        n_ch        = _meta_group["n_ch"]
        n_valid_ch  = len(_data_group["Channels"])


        logger.info(f"ELE [{i}/{n_ele}]: {ele_list[i][0]}")

        n_avaliable = n_avaliable + 1

        # Iteration by channel
        for j in _data_group['Channels']:
            _ch_data = _data_group[j]["chData"]

            if ONLY_SEQ_FLAG:
                eis_seq = _data_group[j]["eis_seq"]
                _ch_data = _ch_data[eis_seq,:,:]

            _ch_data_log = np.log(_ch_data[:,1,:] + 1j*_ch_data[:,2,:])
            _ch_data[:,1,:] = np.real(_ch_data_log)
            _ch_data[:,2,:] = np.imag(_ch_data_log)
            if _ch_data.shape[2] == 5000:
                _ch_data = np.hstack((_ch_data[:,1,freq_list],_ch_data[:,2,freq_list]))
            else:
                _ch_data = np.hstack((_ch_data[:,1,:],_ch_data[:,2,:]))
            vitro1_data_list.append(_ch_data)
            vitro1_start_list.append(_ch_data[0,:])


            _ch_id = j

            _id = [i, _ch_id] * np.shape(_ch_data)[0]
            _id = np.array(_id).reshape(-1,2)
            _eis_cluster = _data_group[j]['eis_cluster']
            _id = np.hstack((_id, _eis_cluster.reshape(-1,1)))
            
            vitro1_id_list.append(_id)
            vitro1_start_id_list.append(_id[0,:])

    vitro1_data_list = np.vstack(vitro1_data_list)
    vitro1_id_list = np.vstack(vitro1_id_list)
    vitro1_start_list = np.vstack(vitro1_start_list)
    vitro1_start_id_list = np.vstack(vitro1_start_id_list)

    vitro1_ele_list = [i[1] for i in ele_list]

    logger.info(f"Total {vitro1_data_list.shape[0]} data points from {n_avaliable} electrodes")

    del data_pt, _meta_group, _data_group, _ch_data



## In vivo

In [19]:
if READ_RAW_FLAG:
    rootPath = "D:/Baihm/EISNN/Invivo/"
    ele_list = SearchELE(rootPath,re.compile(r"(.+?)_Ver02"))
    n_ele = len(ele_list)
    logger.info(f"Search in {rootPath} and find {n_ele:03d} electrodes")


In [20]:
if READ_RAW_FLAG:
    DATASET_SUFFIX = "Outlier_Ver04"

    vivo0_start_list = []
    vivo0_start_id_list = []
    vivo0_data_list = []
    vivo0_id_list = []

    n_avaliable = 0

    for i in range(n_ele):
    # for i in range(3):
        fd_pt = os.path.join(ele_list[i][0], DATASET_SUFFIX, f"{ele_list[i][1]}_{DATASET_SUFFIX}.pt")
        if not os.path.exists(fd_pt):
            logger.warning(f"{fd_pt} does not exist")
            continue
        data_pt = torch.load(fd_pt, weights_only=False)
        _meta_group = data_pt["meta_group"]
        _data_group = data_pt["data_group"]

        n_day       = _meta_group["n_day"]
        n_ch        = _meta_group["n_ch"]
        n_valid_ch  = len(_data_group["Channels"])


        logger.info(f"ELE [{i}/{n_ele}]: {ele_list[i][0]}")

        n_avaliable = n_avaliable + 1

        # Iteration by channel
        for j in _data_group['Channels']:
            _ch_data = _data_group[j]["chData"]

            if ONLY_SEQ_FLAG:
                eis_seq = _data_group[j]["eis_seq"]
                _ch_data = _ch_data[eis_seq,:,:]

            _ch_data_log = np.log(_ch_data[:,1,:] + 1j*_ch_data[:,2,:])
            _ch_data[:,1,:] = np.real(_ch_data_log)
            _ch_data[:,2,:] = np.imag(_ch_data_log)
            if _ch_data.shape[2] == 5000:
                _ch_data = np.hstack((_ch_data[:,1,freq_list],_ch_data[:,2,freq_list]))
            else:
                _ch_data = np.hstack((_ch_data[:,1,:],_ch_data[:,2,:]))
            vivo0_data_list.append(_ch_data)
            vivo0_start_list.append(_ch_data[0,:])


            _ch_id = j

            _id = [i, _ch_id] * np.shape(_ch_data)[0]
            _id = np.array(_id).reshape(-1,2)
            _eis_cluster = _data_group[j]['eis_cluster']
            _id = np.hstack((_id, _eis_cluster.reshape(-1,1)))
            
            vivo0_id_list.append(_id)
            vivo0_start_id_list.append(_id[0,:])

    vivo0_data_list = np.vstack(vivo0_data_list)
    vivo0_id_list = np.vstack(vivo0_id_list)
    vivo0_start_list = np.vstack(vivo0_start_list)
    vivo0_start_id_list = np.vstack(vivo0_start_id_list)

    vivo0_ele_list = [i[1] for i in ele_list]

    logger.info(f"Total {vivo0_data_list.shape[0]} data points from {n_avaliable} electrodes")

    del data_pt, _meta_group, _data_group, _ch_data



## Data Summary

In [21]:
if not READ_RAW_FLAG:
    # Data_Path = "D:/Baihm/EISNN/Feature/AllData.npz"
    Data_Path = "D:/Baihm/EISNN/Feature/SEQData.npz"
    if os.path.exists(Data_Path):
        AllData = np.load(Data_Path)
        vitro0_data_list = AllData["vitro0_data_list"]
        vitro0_id_list = AllData["vitro0_id_list"]
        vitro0_start_list = AllData["vitro0_start_list"]
        vitro0_start_id_list = AllData["vitro0_start_id_list"]
        vitro0_ele_list = AllData["vitro0_ele_list"]
        
        vitro1_data_list = AllData["vitro1_data_list"]
        vitro1_id_list = AllData["vitro1_id_list"]
        vitro1_start_list = AllData["vitro1_start_list"]
        vitro1_start_id_list = AllData["vitro1_start_id_list"]
        vitro1_ele_list = AllData["vitro1_ele_list"]

        
        vivo0_data_list = AllData["vivo0_data_list"]
        vivo0_id_list = AllData["vivo0_id_list"]
        vivo0_start_list = AllData["vivo0_start_list"]
        vivo0_start_id_list = AllData["vivo0_start_id_list"]
        vivo0_ele_list = AllData["vivo0_ele_list"]

        logger.info(f"Vitro0:\t{vitro0_data_list.shape}\t{vitro0_start_list.shape}")
        logger.info(f"vitro1:\t{vitro1_data_list.shape}\t{vitro1_start_list.shape}")
        logger.info(f"Vivo0:\t{vivo0_data_list.shape}\t{vivo0_start_list.shape}")
        
    else:
        logger.warning(f"{Data_Path} does not exist")

[32m2025-05-25 23:02:35.515[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m25[0m - [1mVitro0:	(98690, 202)	(12170, 202)[0m
[32m2025-05-25 23:02:35.515[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mvitro1:	(81674, 202)	(9708, 202)[0m
[32m2025-05-25 23:02:35.515[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m27[0m - [1mVivo0:	(9406, 202)	(719, 202)[0m


In [22]:
all_data_list = np.vstack((vitro0_data_list, vitro1_data_list, vivo0_data_list))
all_id_list = np.vstack((vitro0_id_list, vitro1_id_list, vivo0_id_list))
all_start_list = np.vstack((vitro0_start_list, vitro1_start_list, vivo0_start_list))
all_start_id_list = np.vstack((vitro0_start_id_list, vitro1_start_id_list, vivo0_start_id_list))


In [23]:
logger.info(f"Vitro0:\t{vitro0_data_list.shape}\t{vitro0_start_list.shape}")
logger.info(f"vitro1:\t{vitro1_data_list.shape}\t{vitro1_start_list.shape}")
logger.info(f"Vivo0:\t{vivo0_data_list.shape}\t{vivo0_start_list.shape}")
logger.info(f"All:\t\t{all_data_list.shape}\t{all_start_list.shape}")

[32m2025-05-25 23:02:35.597[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mVitro0:	(98690, 202)	(12170, 202)[0m
[32m2025-05-25 23:02:35.597[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mvitro1:	(81674, 202)	(9708, 202)[0m
[32m2025-05-25 23:02:35.597[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mVivo0:	(9406, 202)	(719, 202)[0m
[32m2025-05-25 23:02:35.598[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mAll:		(189770, 202)	(22597, 202)[0m


# Helper

In [25]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def load_all2ch(data_list, id_list = None):
    '''==================================================
        Load all data and split into 2 channels
        Parameter: 
            data_list: data list    n x 202
            id_list: id list        n x 2
        Returen:
            ch_data_list: channel data list     n x 101 x 2
            ch_id_list: channel id list         n x 2
        ==================================================
    '''
    ch_data_list = np.array([data_list[:,:101],data_list[:,101:]])
    ch_data_list = ch_data_list.transpose(1,2,0)

    ch_id_list = id_list

    return ch_data_list, ch_id_list

## Tran & Eval Function

In [26]:
def vae_loss(x_rec, x, mu, logvar, beta=1e-3):
    rec = F.mse_loss(x_rec, x)
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return rec + beta * kld

# ===== 训练函数 =====
def train_model(model, train_ds, val_ds, num_epochs=20, batch_size=64, lr=1e-3):
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_loss_recorder = []
    eval_loss_recorder = []

    for epoch in range(1, num_epochs+1):
        model.train()
        train_loss = 0
        for x in train_loader:
            x = x.to(device)

            x_rec, mu, lv = model(x)
            loss = vae_loss(x_rec, x, mu, lv)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * x.size(0)

        # 验证
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x in val_loader:
                x = x.to(device)

                x_rec, mu, lv = model(x)
                loss = vae_loss(x_rec, x, mu, lv)
                # mu, lv = model.encode(x)
                # x_rec = model.decode(mu)
                # loss = vae_loss(x_rec, x, mu, lv)
                
                val_loss += loss.item() * x.size(0)

        train_loss /= len(train_ds)
        val_loss   /= len(val_ds)

        train_loss_recorder.append(train_loss)
        eval_loss_recorder.append(val_loss)
        print(f"Epoch {epoch}/{num_epochs}  Train: {train_loss:.6f}  Val: {val_loss:.6f}")

    return model, train_loss_recorder, eval_loss_recorder


# ===== 可视化重建 =====
def visualize_EISVAECNN(model, ds, num=5):
    # ds = EISDataset_CNN(data_list)
    loader = DataLoader(ds, batch_size=num, shuffle=True)
    x = next(iter(loader)).to(device)   # [num,2,101]
    model.eval()
    with torch.no_grad():
        x_rec, mu, lv = model(x)

    x = x.cpu().numpy()
    x_rec = x_rec.cpu().numpy()

    for i in range(num):
        plt.figure(figsize=(6,3))
        # 实部
        plt.subplot(1,2,1)
        plt.plot(x[i,0], label="orig", alpha = 0.5)
        plt.plot(x_rec[i,0], '--', label="rec")
        plt.title(f"Sample {i} Real")
        plt.legend()
        # 虚部
        plt.subplot(1,2,2)
        plt.plot(x[i,1], label="orig", alpha = 0.5)
        plt.plot(x_rec[i,1], '--', label="rec")
        plt.title(f"Sample {i} Imag")
        plt.legend()
        plt.tight_layout()
        plt.show()



## Dataloader

In [27]:
class EISDataset_CNN(Dataset):
    def __init__(self, data_list, id_list = None):
        # data_list: n x m x k x l x 2 list
        # n: number of electrodes
        # m: number of channels
        # k: number of timestamps
        # l: number of freq as dimensions
        # 2: real and imaginary parts after logrithm

        _data, _id  = load_all2ch(data_list, id_list)
        _data = [torch.tensor(x, dtype=torch.float32) for x in _data]

        self.data = _data
        self.id = _id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Return [2,101] for Conv1D
        return self.data[idx].permute(1,0)  # [2,101] [in_ch, in_dim]

        

## Plot Latent Space

In [28]:
def VAE_latent(model, ds, batch_size=64):
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False)

    _len_data = ds.__len__()
    _poi = 0

    latent_space_inst = []

    model.eval()
    with torch.no_grad():
        for x in loader:
            x = x.to(device)
            mu, lv = model.encoder(x)
            latent_space_inst.append(mu.cpu().numpy())

            _poi = _poi + x.size(0)
            if _poi % 1000 == 0:
                logger.info(f"[{_poi}]/[{_len_data}]")

    latent_space_inst = np.concatenate(latent_space_inst, axis=0)  # [B,z_dim]


    _pca_inst = PCA(n_components=latent_space_inst.shape[1])
    latent_dd = _pca_inst.fit_transform(latent_space_inst)
    
    
    explained = _pca_inst.explained_variance_ratio_
    eff_dim = (explained.cumsum() < 0.99).sum() + 1


    fig, axis = plt.subplots(2,1,
                gridspec_kw={'height_ratios': [4,1]},
                figsize=(9, 9))
    axis[0].scatter(latent_dd[:, 0], latent_dd[:, 1], alpha=0.5, s = 0.001)

    axis[0].set_aspect('equal', adjustable='box')
    axis[0].set_box_aspect(1)
    axis[0].set_title("Latent Space")
    
    axis[1].plot(_pca_inst.explained_variance_ratio_,
                 label = f"Valid Dimension = {eff_dim}")
    axis[1].legend()
    fig.show()



    return latent_dd, eff_dim


## Kernel Analysis

In [29]:
def get_all_conv1d_layers(model, layer_type = nn.Conv1d):
    conv_layers = []
    for name, module in model.named_modules():
        if isinstance(module, layer_type):
            conv_layers.append((name, module))
    return conv_layers


In [30]:
def visualize_conv1d_kernel_importance(conv_layer, layer_name=None, mode='l2'):
    """
    mode: 'l2' 或 'l1'
    """
    with torch.no_grad():
        weights = conv_layer.weight.cpu()  # [out_ch, in_ch, kernel_size]

        if mode == 'l2':
            importance = torch.norm(weights.view(weights.size(0), -1), p=2, dim=1)  # 每个 out_ch 的 L2 norm
        elif mode == 'l1':
            importance = torch.norm(weights.view(weights.size(0), -1), p=1, dim=1)
        else:
            raise ValueError("Only 'l2' and 'l1' are supported")

        importance = importance.numpy()

    # 可视化
    plt.figure(figsize=(10, 4))
    plt.bar(np.arange(len(importance)), importance)
    plt.xlabel('Kernel Index')
    plt.ylabel(f'{mode.upper()} Norm')
    plt.title(f'Kernel Importance in {layer_name or "Conv1d Layer"}')
    plt.grid(True)
    plt.tight_layout()
    plt.show()


def visualize_conv1d_kernels(conv, title='Conv1d Kernels'):
    weight = conv.weight.data.cpu()  # shape: [out_ch, in_ch, kernel_size]
    out_ch, in_ch, k_size = weight.shape

    # 每个输出通道我们可以画成一个子图，每行是一个输入通道的 kernel
    n_rows = np.floor(np.sqrt(out_ch)).astype(int)
    n_cols = np.ceil(out_ch / n_rows).astype(int)

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols*2, n_rows*2), sharex=True, sharey=True)
    axes = axes.flatten()

    for idx in range(out_ch):
        ax = axes[idx]
        kernel = weight[idx]  # shape: [in_ch, k_size]

        for ic in range(in_ch):
            ax.plot(kernel[ic].numpy(), label=f'InCh {ic}', alpha=0.7)

        ax.set_title(f'OutCh {idx}', fontsize=8)
        if idx % n_cols != 0:
            ax.set_yticklabels([])
        if idx < (n_rows - 1) * n_cols:
            ax.set_xticklabels([])

    # Remove unused subplots
    for i in range(out_ch, len(axes)):
        fig.delaxes(axes[i])

    fig.suptitle(title, fontsize=14)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()


In [31]:
def kernel_dimensionality(conv, plot=True):
    W = conv.weight.data.cpu().numpy()  # [out_ch, in_ch, k_size]
    W_flat = W.reshape(W.shape[0], -1)  # [out_ch, in_ch * k_size]

    pca = PCA(n_components=W_flat.shape[0])
    pca.fit(W_flat)
    explained = pca.explained_variance_ratio_

    if plot:
        import matplotlib.pyplot as plt
        plt.figure(figsize=(5,3))
        plt.plot(explained.cumsum(), marker='o')
        plt.xlabel('Number of Components')
        plt.ylabel('Explained Variance Ratio')
        plt.title('Cumulative PCA on Conv Kernels')
        plt.grid(True)
        plt.show()

    # Return effective dimension 
    # (minimum number of dimensions that explain 99% variance)
    
    eff_dim = (explained.cumsum() < 0.99).sum() + 1
    return eff_dim, explained

# Vectorization Model Design

## Ver01 - 3 x CNN

### Model

In [32]:

class Curve2VecEncoder_Ver01(nn.Module):
    def __init__(self, in_ch, in_dim, hid_ch, 
                 z_dim, kernel_size):
        super().__init__()


        _layers = []

        pre_ch = in_ch
        poi_ch = hid_ch
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))
        
        pre_ch = poi_ch
        poi_ch = poi_ch * 2
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))
        
        pre_ch = poi_ch
        poi_ch = poi_ch * 2
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))


        self.conv = nn.Sequential(*_layers)
        self.pool = nn.AdaptiveAvgPool1d(1)


        self.fc_mu = nn.Linear(poi_ch, z_dim)
        self.fc_lv = nn.Linear(poi_ch, z_dim)


    def forward(self, x):
        h = self.conv(x)                # [B,ch,in_dim]
        h = self.pool(h).squeeze(-1)    # [B,ch]
        return self.fc_mu(h), self.fc_lv(h) 


class Curve2VecDecoder_Ver01(nn.Module):
    def __init__(self, out_ch, out_dim, hid_ch, 
                 z_dim, kernel_size):
        super().__init__()
        self.hid_ch = hid_ch
        self.out_dim = out_dim


        self.fc_expand = nn.Linear(z_dim, hid_ch * out_dim)


        _layers = []
        _layers.append(nn.ReLU())

        pre_ch = hid_ch
        poi_ch = hid_ch//2
        _layers.append(nn.ConvTranspose1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))
        
        # pre_ch = poi_ch
        # poi_ch = poi_ch//2
        # _layers.append(nn.ConvTranspose1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))
        # _layers.append(nn.ReLU())
        # # _layers.append(nn.BatchNorm1d(poi_ch))

        pre_ch = poi_ch
        poi_ch = out_ch
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))


        # pre_ch = hid_ch
        # poi_ch = out_ch
        # _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))


        
        self.deconv = nn.Sequential(*_layers)


    def forward(self, z):
        h = self.fc_expand(z)           # [B,in_ch*in_dim]
        h = h.view(-1, self.hid_ch, self.out_dim)
        h = self.deconv(h)               # [B,in_ch,in_dim]
        return h                        # [B,in_ch,in_dim]

class Curve2VecVAE_Ver01(nn.Module):
    def __init__(self, in_ch=2, in_dim=101, 
                 enc_hid_ch = 16,
                 dec_hid_ch = 16,
                 z_dim = 16, kernel_size = 13):
        super().__init__()
        self.encoder = Curve2VecEncoder_Ver01(in_ch, in_dim, enc_hid_ch, z_dim, kernel_size)
        self.decoder = Curve2VecDecoder_Ver01(in_ch, in_dim, dec_hid_ch, z_dim, kernel_size)

    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, lv = self.encoder(x)
        z = self.reparam(mu, lv)
        x_rec = self.decoder(z)
        return x_rec, mu, lv 



### Running

In [33]:
model_ver01 = Curve2VecVAE_Ver01().to(device)
print(count_parameters(model_ver01))

65242


In [34]:
num_epochs = 200
batch_size=128
lr=1e-3
random_state = None

train_list, val_list = train_test_split(all_data_list, test_size=0.2, random_state=random_state)
len(val_list)

all_ds = EISDataset_CNN(all_data_list)
train_ds = EISDataset_CNN(train_list)
val_ds   = EISDataset_CNN(val_list)


In [35]:

model_ver01, train_loss, eval_loss = train_model(model_ver01, train_ds, val_ds, num_epochs=num_epochs, batch_size=batch_size, lr=lr)



Epoch 1/200  Train: 2.440765  Val: 0.282994
Epoch 2/200  Train: 0.244326  Val: 0.203530
Epoch 3/200  Train: 0.162992  Val: 0.126925
Epoch 4/200  Train: 0.109234  Val: 0.094743
Epoch 5/200  Train: 0.075916  Val: 0.057839
Epoch 6/200  Train: 0.058495  Val: 0.058915
Epoch 7/200  Train: 0.055690  Val: 0.067451
Epoch 8/200  Train: 0.046415  Val: 0.037308
Epoch 9/200  Train: 0.037960  Val: 0.027565
Epoch 10/200  Train: 0.029589  Val: 0.032086
Epoch 11/200  Train: 0.025367  Val: 0.030340
Epoch 12/200  Train: 0.022151  Val: 0.027276
Epoch 13/200  Train: 0.020039  Val: 0.022645
Epoch 14/200  Train: 0.019006  Val: 0.016399
Epoch 15/200  Train: 0.018001  Val: 0.017722
Epoch 16/200  Train: 0.017381  Val: 0.027798
Epoch 17/200  Train: 0.016484  Val: 0.015474
Epoch 18/200  Train: 0.015688  Val: 0.018571
Epoch 19/200  Train: 0.015306  Val: 0.015113
Epoch 20/200  Train: 0.014591  Val: 0.012225
Epoch 21/200  Train: 0.014228  Val: 0.014339
Epoch 22/200  Train: 0.013377  Val: 0.083636
Epoch 23/200  Train

In [None]:
# if False:
#     # eis2vec_save_path = "D:\Baihm\EISNN\Feature\SeqData_Convx2_z_ConvTx1_Convx1.pt"
#     torch.save(model_ver01.state_dict(), eis2vec_save_path)

### Plot Loss & Eval Samples

In [36]:
plt.figure()
# plt.semilogy(train_loss, label="train")
# plt.semilogy(eval_loss, label="eval")
plt.semilogy(train_loss, label="train")
plt.semilogy(eval_loss, label="eval")
plt.legend()
plt.show()


In [None]:
visualize_EISVAECNN(model_ver01, val_ds)

### Plot Latent Space

In [None]:
latent_expr, _ = VAE_latent(model_ver01, all_ds)

### Plot Kernel

In [None]:
conv_layers = get_all_conv1d_layers(model_ver01, layer_type=nn.Conv1d)

In [None]:
for _layer in conv_layers:
    layer_name, layer = _layer
    print(layer_name, layer)
    # visualize_conv1d_kernels(layer)
    # W = layer.weight.data.cpu().numpy()  # [out_ch, in_ch, k_size]
    # print(W.shape)
    kernel_dimensionality(layer)
    # visualize_conv1d_kernel_importance(layer, layer_name=layer_name, mode='l1')

### Save Model

In [None]:
if False:
    eis2vec_save_path = "D:/Baihm/EISNN/PredictionModel/model/Convx2_z_ConvTx1_Convx1.pt"
    torch.save(model_ver01.state_dict(), eis2vec_save_path)