# Import

In [114]:
import os
import re
import gc
import sys

from loguru import logger
import numpy as np
import random

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

# %matplotlib qt
%matplotlib qt

# Detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Input Layer

In [115]:
def SearchELE(rootPath, ele_pattern = re.compile(r"(.+?)_归档")):
    '''==================================================
        Search all electrode directories in the rootPath
        Parameter: 
            rootPath: current search path
            ele_pattern: electrode dir name patten
        Returen:
            ele_list: list of electrode directories
        ==================================================
    '''
    ele_list = []
    for i in os.listdir(rootPath):
        match_ele = ele_pattern.match(i)
        if match_ele:
            ele_list.append([os.path.join(rootPath, i),match_ele.group(1)])
    return ele_list



In [116]:
rootPath = "D:/Baihm/EISNN/Archive/"
ele_list = SearchELE(rootPath)
n_ele = len(ele_list)
logger.info(f"Search in {rootPath} and find {n_ele:03d} electrodes")


[32m2025-04-25 09:45:00.072[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mSearch in D:/Baihm/EISNN/Archive/ and find 218 electrodes[0m


In [117]:
Blacklist = [
    '01067093',     # Not look like EIS
    '01067094',     # Connection Error
    '02017385',     # Connection Error
    '05127177',     # Open to Short
    '06047729',     # Open to Short
    '06047730',     # Open to Short
    '06047731',     # Open to Short
    '09207024',     # Connection Error
    '10017038',     # Connection Error
    '10037050',     # Connection Error
    '10047056',     # Connection Error
    '10057069',     # Connection Error
    '10057083',     # Always Open
    '10057084',     # Chaos
    '10057087',     # Connection Error
    '22017367',     # Connection Error
    '22017371',     # Chaos
]

GrayList = [
    '10037051',     # Connection Error
    '10037052',     # Connection Error
    '10057071',     # Connection Error
    '10067077',     # Wired Shape like connection error
    '10150201',     # Wired Shape
    '10150202',     # Wired Shape
    '10150203',     # Wired Shape
    '20037515',     # Wired Shape
    '20037516',     # Wired Shape
    '20037517',     # Wired Shape
    '22037378',     # Connection Error
    '22037380',     # Connection Error
    '22047376',     # Connection Error

]

In [118]:

MODEL_SUFFIX = "Matern12_Ver01"

all_data_list = []
all_id_list = []

_ch_pattern = re.compile(r"ch_(\d{3})")

for i in range(n_ele):
# for i in range(3):
    if ele_list[i][1] in Blacklist:
        continue

    fd_pt = os.path.join(ele_list[i][0], MODEL_SUFFIX, f"{ele_list[i][1]}_{MODEL_SUFFIX}.pt")
    if not os.path.exists(fd_pt):
        # logger.warning(f"{fd_pt} does not exist")
        continue
    data_pt = torch.load(fd_pt, weights_only=False)
    _meta_group = data_pt["meta_group"]
    _data_group = data_pt["data_group"]

    n_day       = _meta_group["n_day"]
    n_ch        = _meta_group["n_ch"]
    n_valid_ch  = len(_data_group["Channels"])

    # ignore abnormal ele
    if n_ch != 128 or n_valid_ch != n_ch:
        if n_day < 5 or n_valid_ch <= 100:
            continue

    logger.info(f"ELE [{i}/{n_ele}]: {ele_list[i][0]}")


    ele_data_list = []
    ele_id_list = []
    # Iteration by channel
    for j in _data_group['Channels']:
        _ch_data = _data_group[j]["y_eval"]
        ele_data_list.append(_ch_data)

        _ch_id = _ch_pattern.match(j)
        _ch_id = int(_ch_id.group(1))

        _id = [i, _ch_id] * np.shape(_ch_data)[0]
        _id = np.array(_id).reshape(-1,2)
        ele_id_list.append(_id)
        
    
    all_data_list.append(ele_data_list)
    all_id_list.append(ele_id_list)
    
    # ele_data_list = np.vstack(ele_data_list)
    # all_data_list.append(ele_data_list)

# all_data_list = np.vstack(all_data_list)
# all_id_list = np.vstack(all_id_list)


del data_pt, _meta_group, _data_group, _ch_data
gc.collect()



[32m2025-04-25 09:45:00.323[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mELE [0/218]: D:/Baihm/EISNN/Archive/01037160_归档[0m
[32m2025-04-25 09:45:00.519[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mELE [1/218]: D:/Baihm/EISNN/Archive/01037161_归档[0m
[32m2025-04-25 09:45:00.734[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mELE [2/218]: D:/Baihm/EISNN/Archive/01037162_归档[0m
[32m2025-04-25 09:45:00.913[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mELE [5/218]: D:/Baihm/EISNN/Archive/01067095_归档[0m
[32m2025-04-25 09:45:01.271[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mELE [9/218]: D:/Baihm/EISNN/Archive/02027373_归档[0m
[32m2025-04-25 09:45:01.378[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mELE [10/218]: D:/Baihm/EISNN/Archive/02027390_归档[0m
[32m2025-04-25 09:45:01.526[0m | [1m

240000

# Helper

In [119]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


In [120]:
def load_all2seq(data_list, id_list = None):
    seq_data_list    = []
    seq_id_list     = []
    for i in range(len(data_list)):
        for j in range(len(data_list[i])):
            seq_data_list.append(data_list[i][j])
            if id_list is not None:
                seq_id_list.append(id_list[i][j])
    return seq_data_list, seq_id_list

def load_all2ch(data_list, id_list = None):
    ch_data_list, ch_id_list = load_all2seq(data_list, id_list)
    ch_data_list = np.vstack(ch_data_list)
    if id_list is not None:
        ch_id_list = np.vstack(ch_id_list)
    return ch_data_list, ch_id_list

## Tran & Eval Function

In [121]:
def vae_loss(x_rec, x, mu, logvar, beta=1e-3):
    rec = F.mse_loss(x_rec, x)
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return rec + beta * kld

# ===== 训练函数 =====
def train_model(model, train_ds, val_ds, num_epochs=20, batch_size=64, lr=1e-3):
    # train_ds = EISDataset_CNN(train_list)
    # val_ds   = EISDataset_CNN(val_list)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_loss_recorder = []
    eval_loss_recorder = []

    for epoch in range(1, num_epochs+1):
        model.train()
        train_loss = 0
        for x in train_loader:
            x = x.to(device)

            x_rec, mu, lv = model(x)
            loss = vae_loss(x_rec, x, mu, lv)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * x.size(0)

        # 验证
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x in val_loader:
                x = x.to(device)

                x_rec, mu, lv = model(x)
                loss = vae_loss(x_rec, x, mu, lv)
                # mu, lv = model.encode(x)
                # x_rec = model.decode(mu)
                # loss = vae_loss(x_rec, x, mu, lv)
                
                val_loss += loss.item() * x.size(0)

        train_loss /= len(train_ds)
        val_loss   /= len(val_ds)

        train_loss_recorder.append(train_loss)
        eval_loss_recorder.append(val_loss)
        print(f"Epoch {epoch}/{num_epochs}  Train: {train_loss:.6f}  Val: {val_loss:.6f}")

    return model, train_loss_recorder, eval_loss_recorder


# ===== 可视化重建 =====
def visualize_EISVAECNN(model, ds, num=5):
    # ds = EISDataset_CNN(data_list)
    loader = DataLoader(ds, batch_size=num, shuffle=True)
    x = next(iter(loader)).to(device)   # [num,2,101]
    model.eval()
    with torch.no_grad():
        x_rec, mu, lv = model(x)

    x = x.cpu().numpy()
    x_rec = x_rec.cpu().numpy()

    for i in range(num):
        plt.figure(figsize=(6,3))
        # 实部
        plt.subplot(1,2,1)
        plt.plot(x[i,0], label="orig", alpha = 0.5)
        plt.plot(x_rec[i,0], '--', label="rec")
        plt.title(f"Sample {i} Real")
        plt.legend()
        # 虚部
        plt.subplot(1,2,2)
        plt.plot(x[i,1], label="orig", alpha = 0.5)
        plt.plot(x_rec[i,1], '--', label="rec")
        plt.title(f"Sample {i} Imag")
        plt.legend()
        plt.tight_layout()
        plt.show()



## Dataloader

In [122]:
class EISDataset_DNN(Dataset):
    def __init__(self, data_list, id_list = None):
        # data_list: n x m x k x l x 2 list
        # n: number of electrodes
        # m: number of channels
        # k: number of timestamps
        # l: number of freq as dimensions
        # 2: real and imaginary parts after logrithm

        _data, _id  = load_all2ch(data_list, id_list)
        _data = [torch.tensor(x, dtype=torch.float32) for x in _data]

        self.data = _data
        self.id = _id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Return [101,2] for Linear
        return self.data[idx]
   
class EISDataset_CNN(Dataset):
    def __init__(self, data_list, id_list = None):
        # data_list: n x m x k x l x 2 list
        # n: number of electrodes
        # m: number of channels
        # k: number of timestamps
        # l: number of freq as dimensions
        # 2: real and imaginary parts after logrithm

        _data, _id  = load_all2ch(data_list, id_list)
        _data = [torch.tensor(x, dtype=torch.float32) for x in _data]

        self.data = _data
        self.id = _id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Return [2,101] for Conv1D
        return self.data[idx].permute(1,0)  # [2,101] [in_ch, in_dim]

class EISDataset_SEQ(Dataset):
    def __init__(self, data_list, id_list = None):
        # data_list: n x m x k x l x 2 list
        # n: number of electrodes
        # m: number of channels
        # k: number of timestamps
        # l: number of freq as dimensions
        # 2: real and imaginary parts after logrithm

        _data, _id  = load_all2seq(data_list, id_list)
        _data = [torch.tensor(x, dtype=torch.float32) for x in _data]

        self.data = _data
        self.id = _id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx].permute(0,2,1)  # [k,2,101] [in_ch, in_dim]
        

## Plot Latent Space

In [161]:
def VAE_latent(model, ds, batch_size=64):
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False)

    _len_data = ds.__len__()
    _poi = 0

    latent_space_inst = []

    model.eval()
    with torch.no_grad():
        for x in loader:
            x = x.to(device)
            mu, lv = model.encoder(x)
            latent_space_inst.append(mu.cpu().numpy())

            _poi = _poi + x.size(0)
            if _poi % 1000 == 0:
                logger.info(f"[{_poi}]/[{_len_data}]")

    latent_space_inst = np.concatenate(latent_space_inst, axis=0)  # [B,z_dim]


    _pca_inst = PCA(n_components=latent_space_inst.shape[1])
    latent_dd = _pca_inst.fit_transform(latent_space_inst)
    
    
    explained = _pca_inst.explained_variance_ratio_
    eff_dim = (explained.cumsum() < 0.99).sum() + 1


    fig, axis = plt.subplots(2,1,
                gridspec_kw={'height_ratios': [4,1]},
                figsize=(9, 9))
    axis[0].scatter(latent_dd[:, 0], latent_dd[:, 1], alpha=0.5, s = 0.001)

    axis[0].set_aspect('equal', adjustable='box')
    axis[0].set_box_aspect(1)
    axis[0].set_title("Latent Space")
    
    axis[1].plot(_pca_inst.explained_variance_ratio_,
                 label = f"Valid Dimension = {eff_dim}")
    axis[1].legend()
    fig.show()



    return latent_dd, eff_dim


## Kernel Analysis

In [124]:
def get_all_conv1d_layers(model, layer_type = nn.Conv1d):
    conv_layers = []
    for name, module in model.named_modules():
        if isinstance(module, layer_type):
            conv_layers.append((name, module))
    return conv_layers


In [125]:
def visualize_conv1d_kernel_importance(conv_layer, layer_name=None, mode='l2'):
    """
    mode: 'l2' 或 'l1'
    """
    with torch.no_grad():
        weights = conv_layer.weight.cpu()  # [out_ch, in_ch, kernel_size]

        if mode == 'l2':
            importance = torch.norm(weights.view(weights.size(0), -1), p=2, dim=1)  # 每个 out_ch 的 L2 norm
        elif mode == 'l1':
            importance = torch.norm(weights.view(weights.size(0), -1), p=1, dim=1)
        else:
            raise ValueError("Only 'l2' and 'l1' are supported")

        importance = importance.numpy()

    # 可视化
    plt.figure(figsize=(10, 4))
    plt.bar(np.arange(len(importance)), importance)
    plt.xlabel('Kernel Index')
    plt.ylabel(f'{mode.upper()} Norm')
    plt.title(f'Kernel Importance in {layer_name or "Conv1d Layer"}')
    plt.grid(True)
    plt.tight_layout()
    plt.show()


def visualize_conv1d_kernels(conv, title='Conv1d Kernels'):
    weight = conv.weight.data.cpu()  # shape: [out_ch, in_ch, kernel_size]
    out_ch, in_ch, k_size = weight.shape

    # 每个输出通道我们可以画成一个子图，每行是一个输入通道的 kernel
    n_rows = np.floor(np.sqrt(out_ch)).astype(int)
    n_cols = np.ceil(out_ch / n_rows).astype(int)

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols*2, n_rows*2), sharex=True, sharey=True)
    axes = axes.flatten()

    for idx in range(out_ch):
        ax = axes[idx]
        kernel = weight[idx]  # shape: [in_ch, k_size]

        for ic in range(in_ch):
            ax.plot(kernel[ic].numpy(), label=f'InCh {ic}', alpha=0.7)

        ax.set_title(f'OutCh {idx}', fontsize=8)
        if idx % n_cols != 0:
            ax.set_yticklabels([])
        if idx < (n_rows - 1) * n_cols:
            ax.set_xticklabels([])

    # Remove unused subplots
    for i in range(out_ch, len(axes)):
        fig.delaxes(axes[i])

    fig.suptitle(title, fontsize=14)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()


In [126]:
def kernel_dimensionality(conv, plot=True):
    W = conv.weight.data.cpu().numpy()  # [out_ch, in_ch, k_size]
    W_flat = W.reshape(W.shape[0], -1)  # [out_ch, in_ch * k_size]

    pca = PCA(n_components=W_flat.shape[0])
    pca.fit(W_flat)
    explained = pca.explained_variance_ratio_

    if plot:
        import matplotlib.pyplot as plt
        plt.figure(figsize=(5,3))
        plt.plot(explained.cumsum(), marker='o')
        plt.xlabel('Number of Components')
        plt.ylabel('Explained Variance Ratio')
        plt.title('Cumulative PCA on Conv Kernels')
        plt.grid(True)
        plt.show()

    # Return effective dimension 
    # (minimum number of dimensions that explain 99% variance)
    
    eff_dim = (explained.cumsum() < 0.99).sum() + 1
    return eff_dim, explained

# Vectorization Model Design

## Ver01 - 3 x CNN

### Model

In [135]:

class Curve2VecEncoder_Ver01(nn.Module):
    def __init__(self, in_ch, in_dim, hid_ch, 
                 z_dim, kernel_size):
        super().__init__()


        _layers = []

        pre_ch = in_ch
        poi_ch = hid_ch
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))
        
        pre_ch = poi_ch
        poi_ch = poi_ch * 2
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))
        
        pre_ch = poi_ch
        poi_ch = poi_ch * 2
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))


        self.conv = nn.Sequential(*_layers)
        self.pool = nn.AdaptiveAvgPool1d(1)


        self.fc_mu = nn.Linear(poi_ch, z_dim)
        self.fc_lv = nn.Linear(poi_ch, z_dim)


    def forward(self, x):
        h = self.conv(x)                # [B,ch,in_dim]
        h = self.pool(h).squeeze(-1)    # [B,ch]
        return self.fc_mu(h), self.fc_lv(h) 


class Curve2VecDecoder_Ver01(nn.Module):
    def __init__(self, out_ch, out_dim, hid_ch, 
                 z_dim, kernel_size):
        super().__init__()
        self.hid_ch = hid_ch
        self.out_dim = out_dim


        self.fc_expand = nn.Linear(z_dim, hid_ch * out_dim)


        _layers = []
        _layers.append(nn.ReLU())

        pre_ch = hid_ch
        poi_ch = hid_ch//2
        _layers.append(nn.ConvTranspose1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))
        _layers.append(nn.ReLU())
        # _layers.append(nn.BatchNorm1d(poi_ch))
        
        # pre_ch = poi_ch
        # poi_ch = poi_ch//2
        # _layers.append(nn.ConvTranspose1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))
        # _layers.append(nn.ReLU())
        # # _layers.append(nn.BatchNorm1d(poi_ch))

        pre_ch = poi_ch
        poi_ch = out_ch
        _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))


        # pre_ch = hid_ch
        # poi_ch = out_ch
        # _layers.append(nn.Conv1d(pre_ch, poi_ch, kernel_size=kernel_size, padding=kernel_size//2))


        
        self.deconv = nn.Sequential(*_layers)


    def forward(self, z):
        h = self.fc_expand(z)           # [B,in_ch*in_dim]
        h = h.view(-1, self.hid_ch, self.out_dim)
        h = self.deconv(h)               # [B,in_ch,in_dim]
        return h                        # [B,in_ch,in_dim]

class Curve2VecVAE_Ver01(nn.Module):
    def __init__(self, in_ch=2, in_dim=101, 
                 enc_hid_ch = 16,
                 dec_hid_ch = 16,
                 z_dim = 16, kernel_size = 13):
        super().__init__()
        self.encoder = Curve2VecEncoder_Ver01(in_ch, in_dim, enc_hid_ch, z_dim, kernel_size)
        self.decoder = Curve2VecDecoder_Ver01(in_ch, in_dim, dec_hid_ch, z_dim, kernel_size)

    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, lv = self.encoder(x)
        z = self.reparam(mu, lv)
        x_rec = self.decoder(z)
        return x_rec, mu, lv 



### Running

In [136]:
model_ver01 = Curve2VecVAE_Ver01().to(device)
print(count_parameters(model_ver01))

65242


In [137]:
num_epochs = 150
batch_size=128
lr=1e-3
random_state = None

train_list, val_list = train_test_split(all_data_list, test_size=0.2, random_state=random_state)
len(val_list)

all_ds = EISDataset_CNN(all_data_list)
train_ds = EISDataset_CNN(train_list)
val_ds   = EISDataset_CNN(val_list)


In [138]:

model_ver01, train_loss, eval_loss = train_model(model_ver01, train_ds, val_ds, num_epochs=num_epochs, batch_size=batch_size, lr=lr)



Epoch 1/150  Train: 0.668775  Val: 0.065801
Epoch 2/150  Train: 0.044477  Val: 0.038572
Epoch 3/150  Train: 0.033252  Val: 0.023493
Epoch 4/150  Train: 0.025913  Val: 0.036400
Epoch 5/150  Train: 0.023307  Val: 0.018160
Epoch 6/150  Train: 0.021148  Val: 0.030491
Epoch 7/150  Train: 0.019845  Val: 0.014824
Epoch 8/150  Train: 0.018673  Val: 0.025329
Epoch 9/150  Train: 0.017164  Val: 0.013390
Epoch 10/150  Train: 0.015754  Val: 0.012688
Epoch 11/150  Train: 0.014678  Val: 0.012148
Epoch 12/150  Train: 0.013849  Val: 0.011507
Epoch 13/150  Train: 0.013463  Val: 0.011766
Epoch 14/150  Train: 0.012242  Val: 0.008790
Epoch 15/150  Train: 0.011915  Val: 0.010949
Epoch 16/150  Train: 0.011485  Val: 0.016782
Epoch 17/150  Train: 0.013469  Val: 0.009198
Epoch 18/150  Train: 0.010976  Val: 0.009268
Epoch 19/150  Train: 0.010512  Val: 0.007781
Epoch 20/150  Train: 0.010166  Val: 0.010468
Epoch 21/150  Train: 0.009866  Val: 0.007529
Epoch 22/150  Train: 0.009772  Val: 0.008094
Epoch 23/150  Train

### Plot Loss & Eval Samples

In [139]:
plt.figure()
# plt.semilogy(train_loss, label="train")
# plt.semilogy(eval_loss, label="eval")
plt.semilogy(train_loss, label="train")
plt.semilogy(eval_loss, label="eval")
plt.legend()
plt.show()


In [140]:
visualize_EISVAECNN(model_ver01, val_ds)

### Plot Latent Space

In [None]:
latent_expr, _ = VAE_latent(model_ver01, all_ds)

[32m2025-04-25 12:43:06.554[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[8000]/[333535][0m
[32m2025-04-25 12:43:06.640[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[16000]/[333535][0m
[32m2025-04-25 12:43:06.720[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[24000]/[333535][0m
[32m2025-04-25 12:43:06.799[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[32000]/[333535][0m
[32m2025-04-25 12:43:06.879[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[40000]/[333535][0m
[32m2025-04-25 12:43:06.942[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[48000]/[333535][0m
[32m2025-04-25 12:43:07.006[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m18[0m - [1m[56000]/[333535][0m
[32m2025-04-25 12:43:07.070[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_lat

: 

### Plot Kernel

In [142]:
conv_layers = get_all_conv1d_layers(model_ver01, layer_type=nn.Conv1d)

In [143]:
for _layer in conv_layers:
    layer_name, layer = _layer
    print(layer_name, layer)
    # visualize_conv1d_kernels(layer)
    # W = layer.weight.data.cpu().numpy()  # [out_ch, in_ch, k_size]
    # print(W.shape)
    kernel_dimensionality(layer)
    # visualize_conv1d_kernel_importance(layer, layer_name=layer_name, mode='l1')

encoder.conv.0 Conv1d(2, 16, kernel_size=(13,), stride=(1,))
encoder.conv.2 Conv1d(16, 32, kernel_size=(13,), stride=(1,))
encoder.conv.4 Conv1d(32, 64, kernel_size=(13,), stride=(1,))
decoder.deconv.3 Conv1d(8, 2, kernel_size=(13,), stride=(1,), padding=(6,))


### Save Model

In [147]:
if True:
    eis2vec_save_path = "D:/Baihm/EISNN/PredictionModel/model/Convx2_z_ConvTx1_Convx1.pt"
    torch.save(model_ver01.state_dict(), eis2vec_save_path)

# Prediction Model

## Ver01 - LSTM

In [None]:

# ── 1. Dataset & Collate ─────────────────────────────────────────────


def collate_seq(batch):
    """
    batch: list of [T_i, C, D]
    returns:
      seqs: [B, T_max, C, D],
      lengths: [B]
    """
    lengths = torch.tensor([s.shape[0] for s in batch], dtype=torch.long)
    B = len(batch)
    C, D = batch[0].shape[1:]
    T_max = lengths.max().item()
    seqs = torch.zeros(B, T_max, C, D)
    for i, s in enumerate(batch):
        seqs[i, :s.shape[0]] = s
    return seqs, lengths

# ── 2. LSTM 迭代预测器 ─────────────────────────────────────────────

class LatentForecaster(nn.Module):
    def __init__(self, z_dim, hidden_dim, num_layers=1):
        super().__init__()
        self.lstm = nn.LSTM(input_size=z_dim, hidden_size=hidden_dim,
                            num_layers=num_layers, batch_first=True)
        self.head = nn.Linear(hidden_dim, z_dim)

    def forward(self, z_seq, lengths):
        """
        z_seq: [B, T, z_dim], lengths: [B]
        returns:
          pred_seq: [B, T-1, z_dim]
        """
        # we predict one-step ahead for each t: input z[:, :-1], target z[:,1:]
        z_in = z_seq[:, :-1]      # [B, T-1, z_dim]
        packed = pack_padded_sequence(z_in, (lengths-1).cpu(), batch_first=True, enforce_sorted=False)
        packed_out, _ = self.lstm(packed)
        out, _ = pad_packed_sequence(packed_out, batch_first=True)  # [B, T-1, hidden]
        z_pred = self.head(out)   # [B, T-1, z_dim]
        return z_pred

    def predict(self, z_init, steps):
        """
        Autoregressive predict for inference.
        z_init: [B, L, z_dim]  (use last few steps, e.g. L=3)
        returns:
          preds: [B, steps, z_dim]
        """
        B, L, z_dim = z_init.shape
        device = z_init.device
        # run encoder LSTM over entire z_init
        out, (h, c) = self.lstm(z_init)      # h: [num_layers, B, hidden]
        inp = z_init[:, -1, :].unsqueeze(1)  # start with last latent
        preds = []
        for _ in range(steps):
            out_step, (h, c) = self.lstm(inp, (h, c))  # out_step: [B,1,hidden]
            z_next = self.head(out_step.squeeze(1))   # [B, z_dim]
            preds.append(z_next.unsqueeze(1))
            inp = z_next.unsqueeze(1)
        return torch.cat(preds, dim=1)  # [B, steps, z_dim]

# ── 3. 统一训练框架 ──────────────────────────────────────────────

def train_forecasting(eis2vec_model, seq_model,
                      seq_dataloader, num_epochs,
                      optimizer, criterion,
                      device='cuda',
                      freeze_eis2vec=False,
                      freeze_seq=False):
    """
    eis2vec_model: instance of Curve2VecVAE_Ver01
    seq_model: instance of LatentForecaster
    seq_dataloader: DataLoader over EISDataset_SEQ
    criterion: e.g. nn.MSELoss(reduction='none')  # we'll mask
    freeze flags: whether to freeze eis2vec or seq_model
    """
    eis2vec_model.to(device)
    seq_model.to(device)

    # 冻结模块
    def set_requires_grad(model, req):
        for p in model.parameters():
            p.requires_grad = req
    set_requires_grad(eis2vec_model, not freeze_eis2vec)
    set_requires_grad(seq_model, not freeze_seq)

    for epoch in range(1, num_epochs+1):
        eis2vec_model.eval()  # we only train seq_model by default, EIS2Vec frozen or not
        if not freeze_seq:
            seq_model.train()
        total_loss = 0.0
        total_count = 0

        for seqs, lengths in seq_dataloader:
            # seqs: [B, T, C, D]
            seqs = seqs.to(device)
            lengths = lengths.to(device)

            B, T, C, D = seqs.shape
            # 1) 把 EIS sequence 编码到 latent sequence
            #    flatten batch*time → [B*T, C, D]
            flat = seqs.view(-1, C, D)
            with torch.set_grad_enabled(not freeze_eis2vec):
                mu, lv = eis2vec_model.encoder(flat)
                # 可选: use mu 或 reparam
                z_flat = mu
            z_seq = z_flat.view(B, T, -1)  # [B, T, z_dim]

            # 2) 用 seq_model 预测 next-step latent
            z_pred = seq_model(z_seq, lengths)  # [B, T-1, z_dim]

            # 3) 计算 seq loss，只对有效位置算
            z_target = z_seq[:, 1:, :]  # [B, T-1, z_dim]
            mask = torch.arange(T-1, device=device)[None, :] < (lengths-1)[:, None]
            # criterion reduction='none'
            loss_mat = criterion(z_pred, z_target)  # [B, T-1, z_dim]
            loss_mat = loss_mat.mean(dim=-1)         # [B, T-1]
            loss = (loss_mat * mask.float()).sum() / mask.sum()

            # 4) 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * mask.sum().item()
            total_count += mask.sum().item()

        avg_loss = total_loss / total_count
        print(f"Epoch {epoch}/{num_epochs}, Forecast MSE: {avg_loss:.6f}")

    return eis2vec_model, seq_model


def evaluate_and_plot(eis2vec_model, seq_model, val_list,
                      init_len=3, device='cuda',
                      num_samples=3):
    """
    eis2vec_model: 已训练好并加载权重的 Curve2VecVAE_Ver01
    seq_model: 已训练好的 LatentForecaster
    val_list: list of torch.Tensor [T, C, D]
    init_len: 启动预测的前几步 (<= min T-1)
    """
    eis2vec_model.eval().to(device)
    seq_model.eval().to(device)
    mse = nn.MSELoss()

    total_loss = 0.
    total_count = 0

    # 随机选样本用于可视化
    vis_indices = random.sample(range(len(val_list)), min(num_samples, len(val_list)))

    # 存储可视化数据
    vis_data = []

    with torch.no_grad():
        for idx, seq in enumerate(val_list):
            seq = seq.to(device)  # [T, C, D]
            T, C, D = seq.shape
            # 1) 编码 entire sequence
            flat = seq.view(-1, C, D)               # [T, C, D]
            mu, lv = eis2vec_model.encoder(flat)
            z_seq = mu.view(1, T, -1)               # [1, T, z_dim]

            # 2) 预测后续 (T - init_len) 步
            z_init = z_seq[:, :init_len, :]         # [1, init_len, z_dim]
            pred_len = T - init_len
            z_pred = seq_model.predict(z_init, pred_len)  # [1, pred_len, z_dim]

            # 3) 解码成 EIS 曲线
            z_pred_flat = z_pred.view(-1, z_pred.size(-1))  # [pred_len, z_dim]
            x_pred_flat = eis2vec_model.decoder(z_pred_flat)  # [pred_len, C, D]
            x_pred = x_pred_flat.view(1, pred_len, C, D)      # [1, pred_len, C, D]

            # 4) 计算 loss（真实 seq[:, init_len:] vs x_pred）
            target = seq[init_len:]                          # [pred_len, C, D]
            loss = mse(x_pred_flat, target.view(-1, C, D))
            total_loss += loss.item() * pred_len
            total_count += pred_len

            # 收集可视化
            if idx in vis_indices:
                vis_data.append({
                    'real': target.cpu().numpy(),          # [pred_len, C, D]
                    'pred': x_pred.cpu().numpy(),          # [1, pred_len, C, D]
                    'idx': idx
                })

    avg_mse = total_loss / total_count
    print(f"Validation MSE over {len(val_list)} sequences: {avg_mse:.6f}")

    # —— 可视化 —— 
    for sample in vis_data:
        real = sample['real']   # [pred_len, C, D]
        pred = sample['pred'][0]# [pred_len, C, D]
        pred_len, C, D = real.shape

        fig, axes = plt.subplots(1, pred_len, figsize=(pred_len*3, 3),
                                 sharex=True, sharey=True)
        for t in range(pred_len):
            ax = axes[t]
            # 画实部（通道0）：
            ax.plot(real[t, 0], label='Real', linestyle='-')
            ax.plot(pred[t, 0], label='Pred', linestyle='--')
            ax.set_title(f"Step {init_len + t}")
            if t == 0:
                ax.set_ylabel("Amplitude")
            if t == pred_len - 1:
                ax.legend(loc='upper right', fontsize='small')
        plt.suptitle(f"Seq idx {sample['idx']} Prediction vs True (Real part)")
        plt.tight_layout(rect=[0,0,1,0.95])
        plt.show()

    return avg_mse



### Running

In [None]:

# ── 4. 用法示例 ────────────────────────────────────────────────

# 假设你已有：
#   data_list_seq: list of tensors [T_i, 2, 101]
#   eis2vec_model: 已训练好的 Curve2VecVAE_Ver01
# 构造 Dataset, DataLoader
dataset = EISDataset_SEQ(train_list)
loader  = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_seq)

# 构造预测器
z_dim = 16
seq_model = LatentForecaster(z_dim=z_dim, hidden_dim=64, num_layers=1)

# 优化器 & 损失
optimizer = torch.optim.Adam(
    list(model_ver01.encoder.parameters()) + list(seq_model.parameters()),
    lr=1e-3
)
# 如果冻结 EIS2Vec 就只传 seq_model.parameters()
# criterion MSELoss 'none' 模式
criterion = nn.MSELoss(reduction='none')
count_parameters(seq_model)

In [None]:
a = []
for i,j in enumerate(dataset):
    if i>10:
        break
    a.append(j)
    # print(i,j.shape)

In [None]:
packed = pack_padded_sequence(z_in, (lengths-1).cpu(), batch_first=True, enforce_sorted=False)
packed_out, _ = self.lstm(packed)
out, _ = pad_packed_sequence(packed_out, batch_first=True)  # [B, T-1, hidden]
        

In [None]:

# 训练，示例冻结 EIS2Vec、训练 Seq2Seq
train_forecasting(
    model_ver01,
    seq_model,
    loader,
    num_epochs=200,
    optimizer=optimizer,
    criterion=criterion,
    device='cuda',
    freeze_eis2vec=True,
    freeze_seq=False
)



In [None]:
val_seq_ds = EISDataset_SEQ(val_list)
avg_mse = evaluate_and_plot(
    eis2vec_model=model_ver01,
    seq_model=seq_model,
    val_list=val_seq_ds,
    init_len=3,
    device='cuda',
    num_samples=3
)

## Ver02 - TCN

### Model

In [159]:
# —— 辅助：去除多余时间步 —— 
class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super().__init__()
        self.chomp_size = chomp_size
    def forward(self, x):
        # x: [B, C, L + chomp_size]
        return x[:, :, :-self.chomp_size]  # 去掉最后 chomp_size 个时间步

# —— 因果扩张卷积块 —— 
class CausalConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, dilation, dropout):
        super().__init__()
        pad = (kernel_size - 1) * dilation
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, kernel_size,
                      padding=pad, dilation=dilation),
            Chomp1d(pad),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)  # 长度不变

# —— TCN Forecaster —— 
class TCNForecaster(nn.Module):
    def __init__(self, z_dim, num_channels, kernel_size=3, dropout=0.2):
        """
        z_dim: latent dimension
        num_channels: e.g. [32,32,32] 三层 TCN
        """
        super().__init__()
        layers = []
        for i, ch in enumerate(num_channels):
            in_ch = z_dim if i == 0 else num_channels[i-1]
            dilation = 2 ** i
            layers.append(
                CausalConvBlock(in_ch, ch, kernel_size, dilation, dropout)
            )
        # 最后一层 1x1 卷积映射回 z_dim
        layers.append(nn.Conv1d(num_channels[-1], z_dim, kernel_size=1))
        self.network = nn.Sequential(*layers)

    def forward(self, z_seq):
        # z_seq: [B, T, z_dim] → [B, z_dim, T]
        x = z_seq.transpose(1, 2)
        y = self.network(x)           # → [B, z_dim, T]
        return y.transpose(1, 2)      # → [B, T, z_dim]

# —— 冻结参数辅助 —— 
def set_requires_grad(model, req):
    for p in model.parameters():
        p.requires_grad = req

# —— 统一训练函数（保持和你之前一致） —— 
def train_tcn_forecasting(eis2vec_model,
                          tcn_model,
                          train_loader,
                          val_loader,
                          num_epochs=20,
                          lr=1e-3,
                          device='cuda',
                          freeze_eis2vec=False,
                          freeze_tcn=False):
    eis2vec_model.to(device).eval()
    tcn_model.to(device)
    criterion = nn.MSELoss(reduction='none')

    # 冻结模块
    set_requires_grad(eis2vec_model.encoder, not freeze_eis2vec)
    set_requires_grad(tcn_model,       not freeze_tcn)

    # 优化器只管可训练的参数
    params = list(tcn_model.parameters())
    if not freeze_eis2vec:
        params += list(eis2vec_model.encoder.parameters())
    optimizer = optim.Adam(params, lr=lr)

    train_losses, val_losses = [], []

    for epoch in range(1, num_epochs+1):
        # —— 训练 —— 
        tcn_model.train()
        tot_loss = cnt = 0
        for seqs, lengths in train_loader:
            seqs, lengths = seqs.to(device), lengths.to(device)
            B, T, C, D = seqs.shape

            # 1) EIS2Vec 编码
            flat = seqs.view(-1, C, D)
            with torch.set_grad_enabled(not freeze_eis2vec):
                mu, _ = eis2vec_model.encoder(flat)
            z_seq = mu.view(B, T, -1)  # [B, T, z_dim]

            # 2) TCN 预测 full one-step
            z_pred = tcn_model(z_seq)            # [B, T, z_dim]
            pred   = z_pred[:, :-1]              # [B, T-1, z_dim]
            target = z_seq  [:, 1: ]             # [B, T-1, z_dim]

            # 3) mask + loss
            mask = (torch.arange(T-1, device=device)[None, :] < (lengths-1)[:, None])
            loss_mat = criterion(pred, target).mean(dim=-1)
            loss = (loss_mat * mask.float()).sum() / mask.sum()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            tot_loss += loss.item() * mask.sum().item()
            cnt      += mask.sum().item()

        train_losses.append(tot_loss/cnt)

        # —— 验证 —— 
        tcn_model.eval()
        tot_loss = cnt = 0
        with torch.no_grad():
            for seqs, lengths in val_loader:
                seqs, lengths = seqs.to(device), lengths.to(device)
                B, T, C, D = seqs.shape

                flat = seqs.view(-1, C, D)
                mu, _ = eis2vec_model.encoder(flat)
                z_seq = mu.view(B, T, -1)

                z_pred = tcn_model(z_seq)
                pred   = z_pred[:, :-1]
                target = z_seq  [:, 1: ]

                mask = (torch.arange(T-1, device=device)[None, :] < (lengths-1)[:, None])
                loss_mat = criterion(pred, target).mean(dim=-1)
                loss = (loss_mat * mask.float()).sum() / mask.sum()

                tot_loss += loss.item() * mask.sum().item()
                cnt      += mask.sum().item()

        val_losses.append(tot_loss/cnt)
        print(f"Epoch {epoch}/{num_epochs}  "
              f"Train MSE: {train_losses[-1]:.6f}  "
              f"Val MSE:   {val_losses[-1]:.6f}")

    return train_losses, val_losses

def evaluate_tcn_and_plot(eis2vec_model, tcn_model, eval_list,
                          init_len=3, device='cuda',
                          num_samples=5):
    """
    Args:
      eis2vec_model: 训练好的 Curve2VecVAE_Ver01（只用 encoder+decoder）
      tcn_model:      训练好的 TCNForecaster
      eval_list:      list of torch.Tensor, each [T, C, D]
      init_len:       用前几步启动自回归预测
      num_samples:    随机可视化多少条序列
    """
    eis2vec_model.to(device).eval()
    tcn_model.to(device).eval()
    mse = nn.MSELoss()

    total_loss = 0.0
    total_count = 0

    # 随机选几条用于可视化
    vis_idxs = random.sample(range(len(eval_list)), 
                             min(num_samples, len(eval_list)))
    vis_data = []

    with torch.no_grad():
        for idx, seq in enumerate(eval_list):
            # seq: [T, C, D]
            seq = seq.to(device)
            T, C, D = seq.shape

            # 1) 编码整个序列到 latent
            flat = seq.view(-1, C, D)                   # [T, C, D]
            mu, _ = eis2vec_model.encoder(flat)          # [T, z_dim]
            z_seq = mu.view(1, T, -1)                    # [1, T, z_dim]

            # 2) 自回归预测后续 (T - init_len) 步
            pred_len = T - init_len
            curr = z_seq[:, :init_len, :].clone()        # [1, init_len, z_dim]
            preds = []
            for _ in range(pred_len):
                out = tcn_model(curr)                    # [1, curr_len, z_dim]
                z_next = out[:, -1, :]                   # [1, z_dim]
                preds.append(z_next)
                curr = torch.cat([curr, z_next.unsqueeze(1)], dim=1)

            z_pred = torch.cat(preds, dim=0)             # [pred_len, z_dim]

            # 3) 解码 latent → EIS 曲线
            x_pred_flat = eis2vec_model.decoder(z_pred)  # [pred_len, C, D]
            x_pred = x_pred_flat.view(pred_len, C, D)    # [pred_len, C, D]

            # 4) 计算 MSE（实部+虚部一起算）
            real = seq[init_len:]                        # [pred_len, C, D]
            loss = mse(x_pred, real)                     # 标量
            total_loss += loss.item() * pred_len
            total_count += pred_len

            # 5) 收集可视化数据（只看实部通道）
            if idx in vis_idxs:
                vis_data.append({
                    'real': real[:, 0, :].cpu().numpy(),   # [pred_len, D]
                    'pred': x_pred[:, 0, :].cpu().numpy(), # [pred_len, D]
                    'idx' : idx
                })

    avg_mse = total_loss / total_count
    print(f"Eval set average MSE: {avg_mse:.6f}")

        
    # —— 可视化 —— 
    for sample in vis_data:
        real = sample['real']   # [pred_len, D]
        pred = sample['pred']   # [pred_len, D]
        
        pred_len, D = real.shape

        # 每行最多显示 5 个图
        max_cols = 5
        ncols = int(min(pred_len, max_cols))
        nrows = int(np.ceil(pred_len / max_cols))

        fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 3, nrows * 3),
                                sharex=True, sharey=True)

        # 保证 axes 是一个平铺的数组
        axes = axes.flatten() if pred_len > 1 else [axes]

        for t in range(pred_len):
            ax = axes[t]
            ax.plot(real[t,:], label='Real', linestyle='-')
            ax.plot(pred[t,:], label='Pred', linestyle='--')
            ax.set_title(f"Step {init_len + t}")
            # ax.legend()

        # 隐藏多余的子图
        for i in range(pred_len, len(axes)):
            axes[i].axis('off')

        plt.suptitle(f"Seq idx {sample['idx']} Prediction vs True (Real part)")
        plt.tight_layout(rect=[0, 0, 1, 0.95])
        plt.show()

    # —— 可视化 —— 
    # for sample in vis_data:
    #     real = sample['real']   # [pred_len, D]
    #     pred = sample['pred']   # [pred_len, D]
        
        
    #     pred_len, D = real.shape

    #     fig, axes = plt.subplots(1, pred_len, figsize=(pred_len*3, 3),
    #                              sharex=True, sharey=True)
    #     for t in range(pred_len):
    #         ax = axes[t]
    #         # 画实部（通道0）：
    #         ax.plot(real[t,:], label='Real', linestyle='-')
    #         ax.plot(pred[t,:], label='Pred', linestyle='--')
    #         ax.set_title(f"Step {init_len + t}")

    #     plt.suptitle(f"Seq idx {sample['idx']} Prediction vs True (Real part)")
    #     plt.tight_layout(rect=[0,0,1,0.95])
    #     plt.show()

    # for item in vis_data:
    #     real = item['real']   # [pred_len, D]
    #     pred = item['pred']   # [pred_len, D]
    #     idx  = item['idx']
    #     pred_len, D = real.shape

    #     fig, axes = plt.subplots(1, 2, figsize=(10, 4), 
    #                              sharex=True, sharey=True)
    #     im0 = axes[0].imshow(real, aspect='auto', cmap='viridis')
    #     axes[0].set_title(f"Sample {idx} Real (init→T)")
    #     im1 = axes[1].imshow(pred, aspect='auto', cmap='viridis')
    #     axes[1].set_title(f"Sample {idx} Pred ({init_len}→T)")
    #     for ax in axes:
    #         ax.set_xlabel("Freq Index")
    #         ax.set_ylabel("Time Step")
    #     fig.colorbar(im1, ax=axes.ravel().tolist(), shrink=0.6)
    #     plt.tight_layout()
    #     plt.show()

    return avg_mse


### Running

In [150]:


# ===== Example Usage =====

# Assuming you have:
# - data_list_seq: list of torch.Tensor [T_i, C, D]
# - eis2vec_model: pretrained Curve2VecVAE_Ver01
# - definitions of EISSeqDataset & collate_seq from before

# Prepare DataLoaders
train_ds_tcn = EISDataset_SEQ(train_list)
val_ds_tcn   = EISDataset_SEQ(val_list)
train_loader_tcn  = DataLoader(train_ds_tcn, batch_size=16, shuffle=True,
                           collate_fn=collate_seq)
val_loader_tcn    = DataLoader(val_ds_tcn, batch_size=16, shuffle=False,
                           collate_fn=collate_seq)

# Instantiate TCN
z_dim = 16
tcn_model = TCNForecaster(z_dim=z_dim,
                          num_channels=[32, 32, 32],
                          kernel_size=3,
                          dropout=0.2)


In [151]:

# Train with freezing options
train_losses, val_losses = train_tcn_forecasting(
    eis2vec_model=model_ver01,
    tcn_model=tcn_model,
    train_loader=train_loader_tcn,
    val_loader=val_loader_tcn,
    num_epochs=100,
    lr=1e-3,
    device='cuda',
    freeze_eis2vec=True,  # freeze encoder
    freeze_tcn=False      # train TCN
)

# train_losses and val_losses now hold per-epoch MSE values for plotting


Epoch 1/100  Train MSE: 0.157689  Val MSE:   0.067138
Epoch 2/100  Train MSE: 0.099325  Val MSE:   0.056973
Epoch 3/100  Train MSE: 0.089993  Val MSE:   0.053014
Epoch 4/100  Train MSE: 0.085545  Val MSE:   0.052763
Epoch 5/100  Train MSE: 0.082646  Val MSE:   0.051614
Epoch 6/100  Train MSE: 0.080332  Val MSE:   0.050302
Epoch 7/100  Train MSE: 0.079096  Val MSE:   0.049890
Epoch 8/100  Train MSE: 0.078180  Val MSE:   0.049420
Epoch 9/100  Train MSE: 0.077682  Val MSE:   0.049870
Epoch 10/100  Train MSE: 0.077307  Val MSE:   0.047972
Epoch 11/100  Train MSE: 0.077051  Val MSE:   0.048089
Epoch 12/100  Train MSE: 0.076701  Val MSE:   0.049274
Epoch 13/100  Train MSE: 0.076719  Val MSE:   0.048564
Epoch 14/100  Train MSE: 0.076591  Val MSE:   0.048796
Epoch 15/100  Train MSE: 0.076276  Val MSE:   0.048145
Epoch 16/100  Train MSE: 0.076296  Val MSE:   0.047130
Epoch 17/100  Train MSE: 0.076220  Val MSE:   0.048154
Epoch 18/100  Train MSE: 0.076174  Val MSE:   0.048750
Epoch 19/100  Train

### Plot

In [160]:
val_seq_ds = EISDataset_SEQ(val_list)
avg_mse = evaluate_tcn_and_plot(
    eis2vec_model=model_ver01,
    tcn_model=tcn_model,
    eval_list=val_seq_ds,
    init_len=3,
    device='cuda',
    num_samples=5
)

Eval set average MSE: 0.430405
