# Import

In [2]:
import os
import re
import gc
import sys

from loguru import logger
import numpy as np
import random

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

%matplotlib qt

# Detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Input Layer

In [3]:
def SearchELE(rootPath, ele_pattern = re.compile(r"(.+?)_归档")):
    '''==================================================
        Search all electrode directories in the rootPath
        Parameter: 
            rootPath: current search path
            ele_pattern: electrode dir name patten
        Returen:
            ele_list: list of electrode directories
        ==================================================
    '''
    ele_list = []
    for i in os.listdir(rootPath):
        match_ele = ele_pattern.match(i)
        if match_ele:
            ele_list.append([os.path.join(rootPath, i),match_ele.group(1)])
    return ele_list

In [4]:
rootPath = "D:/Baihm/EISNN/Archive/"
ele_list = SearchELE(rootPath)
n_ele = len(ele_list)
logger.info(f"Search in {rootPath} and find {n_ele:03d} electrodes")

MODEL_SUFFIX = "Matern12_Ver01"

all_data_list = []
all_id_list = []
_ch_pattern = re.compile(r"ch_(\d{3})")

for i in range(n_ele):
# for i in range(3):
    fd_pt = os.path.join(ele_list[i][0], MODEL_SUFFIX, f"{ele_list[i][1]}_{MODEL_SUFFIX}.pt")
    if not os.path.exists(fd_pt):
        # logger.warning(f"{fd_pt} does not exist")
        continue
    data_pt = torch.load(fd_pt, weights_only=False)
    _meta_group = data_pt["meta_group"]
    _data_group = data_pt["data_group"]

    n_day       = _meta_group["n_day"]
    n_ch        = _meta_group["n_ch"]
    n_valid_ch  = len(_data_group["Channels"])

    # ignore abnormal ele
    if n_ch != 128 or n_valid_ch != n_ch:
        if n_day < 5 or n_valid_ch <= 100:
            continue

    logger.info(f"ELE [{i}/{n_ele}]: {ele_list[i][0]}")


    # Iteration by channel
    for j in _data_group['Channels']:
        _ch_data = _data_group[j]["y_eval"]
        # _ch_data_log = np.log(_ch_data[:,:,0] + 1j*_ch_data[:,:,1])
        # _ch_data[:,:,0] = np.real(_ch_data_log)
        # _ch_data[:,:,1] = np.imag(_ch_data_log)
        all_data_list.append(_ch_data)

        _ch_id = _ch_pattern.match(j)
        _ch_id = int(_ch_id.group(1))

        _id = [i, _ch_id] * np.shape(_ch_data)[0]
        _id = np.array(_id).reshape(-1,2)
        all_id_list.append(_id)

# all_data_list = np.vstack(all_data_list)
# all_id_list = np.vstack(all_id_list)


del data_pt, _meta_group, _data_group, _ch_data
gc.collect()



[32m2025-04-11 11:54:34.504[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mSearch in D:/Baihm/EISNN/Archive/ and find 218 electrodes[0m
[32m2025-04-11 11:54:34.557[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m31[0m - [1mELE [0/218]: D:/Baihm/EISNN/Archive/01037160_归档[0m
[32m2025-04-11 11:54:34.594[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m31[0m - [1mELE [1/218]: D:/Baihm/EISNN/Archive/01037161_归档[0m
[32m2025-04-11 11:54:34.646[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m31[0m - [1mELE [2/218]: D:/Baihm/EISNN/Archive/01037162_归档[0m
[32m2025-04-11 11:54:34.689[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m31[0m - [1mELE [3/218]: D:/Baihm/EISNN/Archive/01067093_归档[0m
[32m2025-04-11 11:54:34.747[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m31[0m - [1mELE [4/218]: D:/Baihm/EISNN/Archive/01067094_归档[0m
[32m2025-04-11 11:54:34.788[0

337

In [5]:
_sample_sum = 0
for i in range(len(all_data_list)):
    # print(np.shape(all_data_list[i]))
    _sample_sum = _sample_sum + np.shape(all_data_list[i])[0]
_sample_sum // 2

193431

# BiLSTM Seq2Seq

## Model

In [None]:
'''
Implementations of:
1. BiLSTM Seq2Seq model for spectral sequence prediction
2. Conditional Diffusion model for spectral sequence generation

Usage:
- Data: list of numpy arrays or torch tensors, each shape [m_i, 101, 2]
- Save this file and import classes/functions as needed.
'''

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

# Detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ===== Dataset and Collate Function =====
class SpectrumDataset(Dataset):
    def __init__(self, data_list):
        # data_list: list of arrays or tensors shape [m_i, 101, 2]
        self.data = [torch.tensor(x, dtype=torch.float32) for x in data_list]  # [m_i, 101, 2]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def collate_fn(batch):
    # batch: list of [m_i, 101, 2] tensors
    lengths = torch.tensor([seq.size(0) for seq in batch], dtype=torch.long)
    padded = pad_sequence(batch, batch_first=True)  # [batch, max_len, 101, 2]
    return padded.to(device), lengths.to(device)

# ===== BiLSTM Seq2Seq Model =====
class Encoder(nn.Module):
    def __init__(self, input_dim=202, hidden_dim=128, num_layers=1, bidirectional=True):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_directions = 2 if bidirectional else 1
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers,
                            batch_first=True, bidirectional=bidirectional)

    def forward(self, x, lengths):
        # x: [batch, seq_len, input_dim]
        packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, (h, c) = self.lstm(packed)
        # h: [num_layers * num_directions, batch, hidden_dim]
        h = h.view(self.num_layers, self.num_directions, x.size(0), self.hidden_dim)
        c = c.view(self.num_layers, self.num_directions, x.size(0), self.hidden_dim)
        h = torch.cat([h[-1, i] for i in range(self.num_directions)], dim=1)  # [batch, hidden_dim*num_directions]
        c = torch.cat([c[-1, i] for i in range(self.num_directions)], dim=1)
        return (h.unsqueeze(0), c.unsqueeze(0))

class Decoder(nn.Module):
    def __init__(self, input_dim=202, hidden_dim=128, output_dim=202, num_layers=1):
        super().__init__()
        self.hidden_dim = hidden_dim * 2
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_dim, self.hidden_dim, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(self.hidden_dim, output_dim)

    def forward(self, input_step, hidden):
        out, hidden = self.lstm(input_step, hidden)
        out = self.fc(out.squeeze(1))  # [batch, output_dim]
        # out = self.fc(out)  # [batch, seq_len, output_dim]
        return out, hidden



class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device, teacher_forcing_ratio=0.5):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.teacher_forcing_ratio = teacher_forcing_ratio

    def forward(self, src, src_lengths, trg=None, trg_len=None):
        batch_size = src.size(0)
        input_dim = src.size(2) * src.size(3)
        src = src.view(batch_size, src.size(1), -1)
        if trg is not None:
            trg = trg.view(batch_size, trg.size(1), -1)
        encoder_hidden = self.encoder(src, src_lengths)
        idx = (src_lengths - 1).view(-1, 1, 1).expand(-1, 1, input_dim)
        decoder_input = src.gather(1, idx).squeeze(1).unsqueeze(1)
        hidden = encoder_hidden
        outputs = []
        max_trg_len = trg_len if trg is None else trg.size(1)
        for t in range(max_trg_len):
            out, hidden = self.decoder(decoder_input, hidden)
            outputs.append(out.unsqueeze(1))
            if trg is not None and torch.rand(1).item() < self.teacher_forcing_ratio:
                decoder_input = trg[:, t].unsqueeze(1)
            else:
                decoder_input = out.unsqueeze(1)
        return torch.cat(outputs, dim=1)

# ===== Training BiLSTM =====
def train_bilstm(data_list, num_epochs=10, batch_size=32, lr=1e-3):
    dataset = SpectrumDataset(data_list)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    encoder = Encoder().to(device)
    decoder = Decoder().to(device)
    model = Seq2Seq(encoder, decoder, device).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    for epoch in range(1, num_epochs+1):
        model.train()
        epoch_loss = 0
        for src, lengths in dataloader:
            trg = src  # demo: target same as input
            output = model(src, lengths, trg=trg)
            loss = criterion(output, trg.view(trg.size(0), trg.size(1), -1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"[BiLSTM] Epoch {epoch}, Loss: {epoch_loss/len(dataloader):.4f}")
    return model


# ========= 保存模型函数 =========
def save_model(model, path="bilstm_model.pt"):
    torch.save(model.state_dict(), path)

# ========= 测试与评估函数 =========
def evaluate_bilstm(model, test_data, batch_size=16, plot=True, num_samples_to_plot=3):
    test_dataset = SpectrumDataset(test_data)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    criterion = nn.MSELoss()
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch_idx, (src, lengths) in enumerate(test_loader):
            trg = src  # 目标仍然是原始输入
            output = model(src, lengths, trg_len=trg.size(1))
            loss = criterion(output, trg.view(trg.size(0), trg.size(1), -1))
            total_loss += loss.item()

            print(f"{trg.size(1)}")

            # 可视化：只绘制前 num_samples_to_plot 个样本
            if plot and batch_idx == 0:
                for i in range(min(num_samples_to_plot, src.size(0))):
                    true_seq = trg[i].cpu().view(trg.size(1), 101, 2)
                    pred_seq = output[i].cpu().view(trg.size(1), 101, 2)

                    fig, axis = plt.subplots(1,2,figsize = (8,4))
                    cmap = plt.colormaps.get_cmap('rainbow_r')
                    for j in range(np.shape(true_seq)[0]):
                        axis[0].plot(true_seq[j, :, 0], color=cmap(j / np.shape(true_seq)[0]))
                        axis[1].plot(true_seq[j, :, 1], color=cmap(j / np.shape(true_seq)[0]))
                    for j in range(np.shape(pred_seq)[0]):
                        axis[0].plot(pred_seq[j, :, 0], color=cmap(j / np.shape(pred_seq)[0]))
                        axis[1].plot(pred_seq[j, :, 1], color=cmap(j / np.shape(pred_seq)[0]))
                    # fig.show()

                    # # 示例：第50个频点的实部
                    # freq_idx = 50
                    # plt.figure(figsize=(10, 4))
                    # plt.plot(true_seq[:, freq_idx, 0], label='True Real', color='blue')
                    # plt.plot(pred_seq[:, freq_idx, 0], label='Pred Real', color='red', linestyle='--')
                    # plt.title(f"Sample {i} - Frequency {freq_idx} Real Part")
                    # plt.xlabel("Time Step")
                    # plt.ylabel("Value")
                    # plt.legend()
                    # plt.grid(True)
                    # plt.tight_layout()
                    # plt.show()

    avg_loss = total_loss / len(test_loader)
    print(f"[BiLSTM Eval] Test Loss (MSE): {avg_loss:.4f}")
    return avg_loss

# ========= 示例运行流程 =========
def run_bilstm_pipeline(all_data_list, num_epochs=10):
    # 数据划分
    train_list, test_list = train_test_split(all_data_list, test_size=0.2, random_state=42)
    
    # 训练模型
    model = train_bilstm(train_list, num_epochs=num_epochs)

    # 测试与评估
    test_loss = evaluate_bilstm(model, test_list, plot=True)

    # 保存模型（可选）
    # save_model(model, "bilstm_spectrum.pt")

    return model


## Training

## Pipeline Dessemble

In [148]:

train_list, test_list = train_test_split(all_data_list, test_size=0.2, random_state=42)

In [149]:
num_epochs=1
batch_size=32
lr=1e-3

In [52]:
dataset = SpectrumDataset(train_list)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)


In [None]:
encoder = Encoder().to(device)
decoder = Decoder().to(device)

torch.Size([11, 101, 2])

In [None]:

model = Seq2Seq(encoder, decoder, device).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()


In [None]:

# for epoch in range(1, num_epochs+1):
#     model.train()
#     epoch_loss = 0
#     for src, lengths in dataloader:
#         trg = src  # demo: target same as input
#         output = model(src, lengths, trg=trg)
#         loss = criterion(output, trg.view(trg.size(0), trg.size(1), -1))
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         epoch_loss += loss.item()
#     print(f"[BiLSTM] Epoch {epoch}, Loss: {epoch_loss/len(dataloader):.4f}")
# return model

In [150]:
num_epochs = 50

model_inst = run_bilstm_pipeline(all_data_list, num_epochs=num_epochs)

[BiLSTM] Epoch 1, Loss: 2.6931
[BiLSTM] Epoch 2, Loss: 0.1850
[BiLSTM] Epoch 3, Loss: 0.1399
[BiLSTM] Epoch 4, Loss: 0.1269
[BiLSTM] Epoch 5, Loss: 0.0986
[BiLSTM] Epoch 6, Loss: 0.0897
[BiLSTM] Epoch 7, Loss: 0.0832
[BiLSTM] Epoch 8, Loss: 0.0799
[BiLSTM] Epoch 9, Loss: 0.0762
[BiLSTM] Epoch 10, Loss: 0.0720
[BiLSTM] Epoch 11, Loss: 0.0773
[BiLSTM] Epoch 12, Loss: 0.0656
[BiLSTM] Epoch 13, Loss: 0.0680
[BiLSTM] Epoch 14, Loss: 0.0867
[BiLSTM] Epoch 15, Loss: 0.0690
[BiLSTM] Epoch 16, Loss: 0.0592
[BiLSTM] Epoch 17, Loss: 0.0617
[BiLSTM] Epoch 18, Loss: 0.0649
[BiLSTM] Epoch 19, Loss: 0.0657
[BiLSTM] Epoch 20, Loss: 0.0605
[BiLSTM] Epoch 21, Loss: 0.0586
[BiLSTM] Epoch 22, Loss: 0.0585
[BiLSTM] Epoch 23, Loss: 0.0576
[BiLSTM] Epoch 24, Loss: 0.0588
[BiLSTM] Epoch 25, Loss: 0.0587
[BiLSTM] Epoch 26, Loss: 0.0597
[BiLSTM] Epoch 27, Loss: 0.0548
[BiLSTM] Epoch 28, Loss: 0.0581
[BiLSTM] Epoch 29, Loss: 0.0517
[BiLSTM] Epoch 30, Loss: 0.0529
[BiLSTM] Epoch 31, Loss: 0.0502
[BiLSTM] Epoch 32

In [None]:
train_list, test_list = train_test_split(all_data_list, test_size=0.2, random_state=42)

test_loss = evaluate_bilstm(model, test_list, plot=True)

  

# BiLSTM + CNN

In [80]:
# ===== 数据集与 CollateFn =====
class SpectrumDataset(Dataset):
    def __init__(self, data_list):
        # data_list: list of arrays or tensors shape [m_i, 101, 2]
        self.data = [torch.tensor(x, dtype=torch.float32) for x in data_list]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def collate_fn(batch):
    # batch: list of [m_i,101,2]
    lengths = torch.tensor([seq.size(0) for seq in batch], dtype=torch.long)
    padded = pad_sequence(batch, batch_first=True)  # [B, L_max, 101, 2]
    return padded.to(device), lengths.to(device)


# ===== 频谱卷积子网 =====
class SpecCNNEncoder(nn.Module):
    def __init__(self, in_ch=2, feat_ch=64, k=5):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, feat_ch, kernel_size=k, padding=k//2)
        self.bn   = nn.BatchNorm1d(feat_ch)

    def forward(self, x):
        # x: [B, 101, 2] → [B, 2, 101]
        x = x.permute(0, 2, 1)
        x = F.relu(self.bn(self.conv(x)))   # [B, feat_ch, 101]
        x = x.mean(-1)                      # [B, feat_ch]
        return x


class SpecCNNDecoder(nn.Module):
    def __init__(self, in_hidden_dim, feat_ch=64, out_ch=2, k=5):
        super().__init__()
        self.feat_ch = feat_ch
        self.fc     = nn.Linear(in_hidden_dim, feat_ch * 101)
        self.deconv = nn.Conv1d(feat_ch, out_ch, kernel_size=k, padding=k//2)

    def forward(self, h):
        # h: [B, hidden_dim]
        x = self.fc(h)                    # [B, feat_ch*101]
        x = x.view(-1, self.feat_ch, 101) # [B, feat_ch, 101]
        x = self.deconv(x)                # [B, out_ch, 101]
        return x.permute(0, 2, 1)         # [B, 101, out_ch]


# ===== BiLSTM Encoder/Decoder =====
class Encoder(nn.Module):
    def __init__(self, in_ch=2, feat_ch=64, hidden_dim=128,
                 num_layers=1, bidirectional=True, k=5):
        super().__init__()
        self.spec_encoder = SpecCNNEncoder(in_ch, feat_ch, k)
        self.hidden_dim   = hidden_dim
        self.num_layers   = num_layers
        self.num_directions = 2 if bidirectional else 1

        self.lstm = nn.LSTM(
            input_size=feat_ch,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional
        )

    def forward(self, x, lengths):
        """
        x: [B, seq_len, 101, 2]
        lengths: [B]
        return:
          hidden: (h, c) each [num_layers, B, hidden_dim * num_directions]
          feat_seq: [B, seq_len, feat_ch]
        """
        B, L, F, C = x.shape
        x_flat = x.view(B * L, F, C)               # [B*L,101,2]
        feat_flat = self.spec_encoder(x_flat)      # [B*L, feat_ch]
        feat_seq  = feat_flat.view(B, L, -1)       # [B, L, feat_ch]

        packed = pack_padded_sequence(
            feat_seq, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        packed_out, (h, c) = self.lstm(packed)
        # h: [num_layers * num_directions, B, hidden_dim]
        # 重塑并拼接双向
        h = h.view(self.num_layers, self.num_directions, B, self.hidden_dim)
        c = c.view(self.num_layers, self.num_directions, B, self.hidden_dim)
        h = torch.cat([h[-1, i] for i in range(self.num_directions)], dim=1)  # [B, hidden_dim*2]
        c = torch.cat([c[-1, i] for i in range(self.num_directions)], dim=1)

        # 为 Decoder 返回 feat_seq 以做 teacher forcing
        return (h.unsqueeze(0), c.unsqueeze(0)), feat_seq


class Decoder(nn.Module):
    def __init__(self, feat_ch=64, hidden_dim=128,
                 num_layers=1, bidirectional=True, k=5):
        super().__init__()
        self.hidden_dim = hidden_dim * (2 if bidirectional else 1)
        self.lstm = nn.LSTM(
            input_size=feat_ch,
            hidden_size=self.hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.spec_decoder = SpecCNNDecoder(
            in_hidden_dim=self.hidden_dim,
            feat_ch=feat_ch,
            out_ch=2,
            k=k
        )

    def forward(self, input_feat, hidden):
        """
        input_feat: [B, 1, feat_ch]
        hidden: (h, c)
        return:
          spec: [B, 101, 2]
          hidden: new hidden state
        """
        out, hidden = self.lstm(input_feat, hidden)  # out: [B,1,hidden_dim]
        h_step = out.squeeze(1)                      # [B, hidden_dim]
        spec   = self.spec_decoder(h_step)           # [B,101,2]
        return spec, hidden


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device, teacher_forcing_ratio=0.5):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device  = device
        self.teacher_forcing_ratio = teacher_forcing_ratio

    def forward(self, src, src_lengths, trg=None, trg_len=None):
        """
        src: [B, seq,101,2]
        trg: [B, seq,101,2] or None
        """
        B, L, F, C = src.shape
        max_len = trg_len if trg is None else trg.size(1)

        # Encoder
        (h, c), feat_seq = self.encoder(src, src_lengths)
        hidden = (h, c)

        # 初始 decoder 输入：每条序列最后一个真实帧的特征
        idx = (src_lengths - 1).unsqueeze(1).expand(-1, feat_seq.size(2))  # [B, feat_ch]
        idx = idx.long()
        decoder_input = feat_seq.gather(1, idx.unsqueeze(1)).float()      # [B,1,feat_ch]

        outputs = []
        for t in range(max_len):
            pred_spec, hidden = self.decoder(decoder_input, hidden)
            outputs.append(pred_spec.unsqueeze(1))  # [B,1,101,2]

            # 下步输入：teacher forcing 或 用当前输出
            if trg is not None and random.random() < self.teacher_forcing_ratio:
                next_frame = trg[:, t, :, :]  # [B,101,2]
            else:
                next_frame = pred_spec         # [B,101,2]

            # 计算下步特征
            feat_next = self.encoder.spec_encoder(next_frame)  # [B, feat_ch]
            decoder_input = feat_next.unsqueeze(1)             # [B,1,feat_ch]

        return torch.cat(outputs, dim=1)  # [B, seq,101,2]


# ===== 平滑正则 Loss =====
def spec_smooth_loss(y_pred, beta=1e-3):
    # 在频谱维度上鼓励相邻频点平滑
    diff = y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :]
    return beta * (diff ** 2).mean()


# ===== 训练函数 =====
def train_bilstm(data_list, num_epochs=10, batch_size=32, lr=1e-3):
    dataset    = SpectrumDataset(data_list)
    dataloader = DataLoader(dataset, batch_size=batch_size,
                            shuffle=True, collate_fn=collate_fn)

    encoder = Encoder().to(device)
    decoder = Decoder().to(device)
    model   = Seq2Seq(encoder, decoder, device).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0
        for src, lengths in dataloader:
            trg = src  # demo: 预测自己
            output = model(src, lengths, trg=trg)

            mse = criterion(output.view(output.size(0), -1),
                            trg.view(trg.size(0), -1))
            smooth = spec_smooth_loss(output)
            loss = mse + smooth

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"[BiLSTM] Epoch {epoch}/{num_epochs}  Loss: {total_loss/len(dataloader):.4f}")

    return model


# ===== 测试与评估 =====
def evaluate_bilstm(model, test_data, batch_size=16,
                    plot=True, num_samples_to_plot=3):
    test_dataset = SpectrumDataset(test_data)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size,
                              shuffle=False, collate_fn=collate_fn)
    criterion = nn.MSELoss()
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch_idx, (src, lengths) in enumerate(test_loader):
            trg    = src
            output = model(src, lengths, trg_len=trg.size(1))

            loss = criterion(output.view(output.size(0), -1),
                             trg.view(trg.size(0), -1))
            total_loss += loss.item()

            # 可视化前几条样本的实/虚部随频点变化
            if plot and batch_idx == 0:
                for i in range(min(num_samples_to_plot, src.size(0))):
                    true_seq = trg[i].cpu().numpy()   # [L,101,2]
                    pred_seq = output[i].cpu().numpy()

                    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
                    cmap = plt.get_cmap('rainbow_r')
                    for t in range(true_seq.shape[0]):
                        axes[0].plot(true_seq[t, :, 0], color=cmap(t / true_seq.shape[0]))
                        axes[1].plot(true_seq[t, :, 1], color=cmap(t / true_seq.shape[0]))
                    for t in range(pred_seq.shape[0]):
                        axes[0].plot(pred_seq[t, :, 0], '--', color=cmap(t / pred_seq.shape[0]))
                        axes[1].plot(pred_seq[t, :, 1], '--', color=cmap(t / pred_seq.shape[0]))
                    axes[0].set_title(f"Sample {i} Real Part")
                    axes[1].set_title(f"Sample {i} Imag Part")
                    plt.tight_layout()
                    plt.show()

    avg_loss = total_loss / len(test_loader)
    print(f"[BiLSTM Eval] Test MSE: {avg_loss:.4f}")
    return avg_loss


# ===== 保存模型 =====
def save_model(model, path="bilstm_spectrum.pt"):
    torch.save(model.state_dict(), path)


# ===== 一键运行流水线 =====
def run_bilstm_pipeline(all_data_list, num_epochs=10):
    train_list, test_list = train_test_split(all_data_list,
                                             test_size=0.2,
                                             random_state=42)
    model = train_bilstm(train_list, num_epochs=num_epochs)
    _ = evaluate_bilstm(model, test_list, plot=True)
    # save_model(model, "bilstm_spectrum.pt")
    return model


# 如果作为脚本运行，可在此处添加：
# if __name__ == "__main__":
#     import your_data_loader
#     data = your_data_loader.load()
#     run_bilstm_pipeline(data, num_epochs=20)


In [82]:
num_epochs = 10

model_inst = run_bilstm_pipeline(all_data_list, num_epochs=num_epochs)

[BiLSTM] Epoch 1/10  Loss: 0.0557
[BiLSTM] Epoch 2/10  Loss: 0.0023
[BiLSTM] Epoch 3/10  Loss: 0.0020
[BiLSTM] Epoch 4/10  Loss: 0.0017
[BiLSTM] Epoch 5/10  Loss: 0.0018
[BiLSTM] Epoch 6/10  Loss: 0.0015
[BiLSTM] Epoch 7/10  Loss: 0.0014
[BiLSTM] Epoch 8/10  Loss: 0.0012
[BiLSTM] Epoch 9/10  Loss: 0.0012
[BiLSTM] Epoch 10/10  Loss: 0.0010
[BiLSTM Eval] Test MSE: 0.0141


# Curve‑to‑Vec Autoencoderv

## Model

In [169]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ===== 数据集 =====
class CurveDataset(Dataset):
    def __init__(self, data_list):
        # data_list: list of [101,2] numpy or torch.Tensor
        self.data = [torch.tensor(x, dtype=torch.float32) for x in data_list]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # 返回 [2,101] 方便 Conv1d (in_ch=2)
        return self.data[idx].permute(1,0)  # [2,101]


# ===== Autoencoder =====
class CurveAutoencoder(nn.Module):
    def __init__(self, in_ch=2, hid_ch=32, z_dim=64, k=5):
        super().__init__()
        # Encoder: Conv1d -> ReLU -> Pool -> Flatten -> FC -> z
        self.conv1 = nn.Conv1d(in_ch, hid_ch, kernel_size=k, padding=k//2)
        self.bn1   = nn.BatchNorm1d(hid_ch)
        self.pool  = nn.AdaptiveAvgPool1d(1)  # 全局池化到长度1
        self.fc_enc= nn.Linear(hid_ch, z_dim)

        # Decoder: FC -> Unpool模拟 -> ConvTranspose1d -> ReLU
        self.fc_dec= nn.Linear(z_dim, hid_ch * 101)
        self.deconv= nn.Conv1d(hid_ch, in_ch, kernel_size=k, padding=k//2)

    def encode(self, x):
        # x: [B,2,101]
        h = F.relu(self.bn1(self.conv1(x)))  # [B,hid_ch,101]
        h = self.pool(h).squeeze(-1)         # [B,hid_ch]
        z = self.fc_enc(h)                   # [B,z_dim]
        return z

    def decode(self, z):
        # z: [B,z_dim]
        h = self.fc_dec(z)                   # [B,hid_ch*101]
        h = h.view(-1, self.conv1.out_channels, 101)  # [B,hid_ch,101]
        x_rec = self.deconv(h)               # [B,2,101]
        return x_rec

    def forward(self, x):
        z = self.encode(x)
        x_rec = self.decode(z)
        return x_rec, z


# ===== VAE =====
class CurveVAE(nn.Module):
    def __init__(self, in_ch=2, hid_ch=32, z_dim=64, k=5):
        super().__init__()
        # 与上面类似，但 encode 出 mu/logvar
        self.conv1 = nn.Conv1d(in_ch, hid_ch, kernel_size=k, padding=k//2)
        # self.bn1   = nn.BatchNorm1d(hid_ch)
        self.ln1   = nn.LayerNorm(hid_ch)
        self.pool  = nn.AdaptiveAvgPool1d(1)
        self.fc_mu = nn.Linear(hid_ch, z_dim)
        self.fc_lv = nn.Linear(hid_ch, z_dim)

        self.fc_dec= nn.Linear(z_dim, hid_ch * 101)
        self.deconv= nn.Conv1d(hid_ch, in_ch, kernel_size=k, padding=k//2)

    def encode(self, x):
        # h = F.relu(self.bn1(self.conv1(x)))  # [B,hid_ch,101]
        # h = self.pool(h).squeeze(-1)         # [B,hid_ch]

        h = self.conv1(x)                  # [B,hid_ch,101]
        h = h.permute(0, 2, 1)             # [B,101,hid_ch]
        h = self.ln1(h)                    # [B,101,hid_ch]
        h = F.relu(h)                  # [B,101,hid_ch]
        h = h.permute(0, 2, 1)             # [B,hid_ch,101]

        h = self.pool(h).squeeze(-1)         # [B,hid_ch]

        return self.fc_mu(h), self.fc_lv(h)  # mu, logvar

    def reparam(self, mu, logvar):
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.fc_dec(z)
        h = h.view(-1, self.conv1.out_channels, 101)
        return self.deconv(h)

    def forward(self, x):
        mu, lv = self.encode(x)
        z = self.reparam(mu, lv)
        x_rec = self.decode(z)
        return x_rec, mu, lv


# ===== Losses =====
def recon_loss(x_rec, x):
    # x_rec, x: [B,2,101]
    return F.mse_loss(x_rec, x)

def vae_loss(x_rec, x, mu, logvar, kld_weight=1e-3):
    rec = recon_loss(x_rec, x)
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return rec + kld_weight * kld


# ===== 训练函数 =====
def train_curve2vec(model, data_list, num_epochs=20, batch_size=64, lr=1e-3, is_vae=False):
    train_list, val_list = train_test_split(data_list, test_size=0.2, random_state=42)
    train_ds = CurveDataset(train_list)
    val_ds   = CurveDataset(val_list)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_loss_recorder = []
    eval_loss_recorder = []

    for epoch in range(1, num_epochs+1):
        model.train()
        train_loss = 0
        for x in train_loader:
            x = x.to(device)
            if is_vae:
                x_rec, mu, lv = model(x)
                loss = vae_loss(x_rec, x, mu, lv)
            else:
                x_rec, _ = model(x)
                loss = recon_loss(x_rec, x)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * x.size(0)

        # 验证
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x in val_loader:
                x = x.to(device)
                if is_vae:
                    # x_rec, mu, lv = model(x)
                    # loss = vae_loss(x_rec, x, mu, lv)
                    mu, lv = model.encode(x)
                    x_rec = model.decode(mu)
                    loss = vae_loss(x_rec, x, mu, lv)
                else:
                    x_rec, _ = model(x)
                    loss = recon_loss(x_rec, x)
                val_loss += loss.item() * x.size(0)
        train_loss /= len(train_ds)
        val_loss   /= len(val_ds)

        train_loss_recorder.append(train_loss)
        eval_loss_recorder.append(val_loss)
        print(f"Epoch {epoch}/{num_epochs}  Train: {train_loss:.4f}  Val: {val_loss:.4f}")

    return model, train_loss_recorder, eval_loss_recorder


# ===== 可视化重建 =====
def visualize_recon(model, data_list, num=5, is_vae=False):
    ds = CurveDataset(data_list)
    loader = DataLoader(ds, batch_size=num, shuffle=True)
    x = next(iter(loader)).to(device)   # [num,2,101]
    model.eval()
    with torch.no_grad():
        if is_vae:
            x_rec, mu, lv = model(x)
        else:
            x_rec, _ = model(x)

    x = x.cpu().numpy()
    x_rec = x_rec.cpu().numpy()

    for i in range(num):
        plt.figure(figsize=(6,3))
        # 实部
        plt.subplot(1,2,1)
        plt.plot(x[i,0], label="orig")
        plt.plot(x_rec[i,0], '--', label="rec")
        plt.title(f"Sample {i} Real")
        plt.legend()
        # 虚部
        plt.subplot(1,2,2)
        plt.plot(x[i,1], label="orig")
        plt.plot(x_rec[i,1], '--', label="rec")
        plt.title(f"Sample {i} Imag")
        plt.legend()
        plt.tight_layout()
        plt.show()



In [170]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


## Input Layer

In [166]:
all_curves_list = np.concatenate(all_data_list, axis=0)
np.shape(all_curves_list)


(386862, 101, 2)

In [171]:

# ===== 主流程示例 =====

# 1) Autoencoder
# ae = CurveAutoencoder().to(device)
# ae = train_curve2vec(ae, all_curves_list, is_vae=False)
# visualize_recon(ae, all_curves_list, is_vae=False)


vae = CurveVAE().to(device)
print(count_parameters(vae))


215042


In [172]:
vae, train_loss, eval_loss = train_curve2vec(vae, all_curves_list, num_epochs=50,
                is_vae=True)


Epoch 1/50  Train: 0.7187  Val: 0.0461
Epoch 2/50  Train: 0.0337  Val: 0.0243
Epoch 3/50  Train: 0.0238  Val: 0.0204
Epoch 4/50  Train: 0.0209  Val: 0.0186
Epoch 5/50  Train: 0.0196  Val: 0.0178
Epoch 6/50  Train: 0.0182  Val: 0.0152
Epoch 7/50  Train: 0.0174  Val: 0.0163
Epoch 8/50  Train: 0.0165  Val: 0.0159
Epoch 9/50  Train: 0.0160  Val: 0.0165
Epoch 10/50  Train: 0.0157  Val: 0.0180
Epoch 11/50  Train: 0.0152  Val: 0.0129
Epoch 12/50  Train: 0.0147  Val: 0.0205
Epoch 13/50  Train: 0.0143  Val: 0.0119
Epoch 14/50  Train: 0.0139  Val: 0.0159
Epoch 15/50  Train: 0.0136  Val: 0.0121
Epoch 16/50  Train: 0.0134  Val: 0.0121
Epoch 17/50  Train: 0.0130  Val: 0.0152
Epoch 18/50  Train: 0.0126  Val: 0.0103
Epoch 19/50  Train: 0.0125  Val: 0.0147
Epoch 20/50  Train: 0.0120  Val: 0.0101
Epoch 21/50  Train: 0.0118  Val: 0.0112
Epoch 22/50  Train: 0.0116  Val: 0.0121
Epoch 23/50  Train: 0.0115  Val: 0.0101
Epoch 24/50  Train: 0.0113  Val: 0.0106
Epoch 25/50  Train: 0.0111  Val: 0.0101
Epoch 26/

In [177]:
plt.figure()
plt.semilogy(train_loss, label="train")
plt.semilogy(eval_loss, label="eval")
plt.legend()


<matplotlib.legend.Legend at 0x19a958149d0>

In [190]:
train_list, val_list = train_test_split(all_curves_list, test_size=0.2, random_state=42)
    


In [195]:
visualize_recon(vae, val_list, is_vae=True)
# torch.save(vae.state_dict(), "curve2vec_vae.pt")


## Model Parameter

553422

## Visualize latent space

In [112]:
def VAE_latent(model, dataset, batch_size=64):
    # x: [B,2,101]
    ds = CurveDataset(dataset)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True)

    _len_data = ds.__len__()
    _poi = 0

    latent_space_inst = []

    model.eval()
    with torch.no_grad():
        for x in loader:
            x = x.to(device)
            mu, lv = model.encode(x)
            latent_space_inst.append(mu.cpu().numpy())

            _poi = _poi + x.size(0)
            logger.info(f"[{_poi}]/[{_len_data}]")

    latent_space_inst = np.concatenate(latent_space_inst, axis=0)  # [B,z_dim]

    if latent_space_inst.shape[1] > 2:
        _pca_inst = PCA(n_components=2)
        latent_dd = _pca_inst.fit_transform(latent_space_inst)
    else:
        latent_dd = latent_space_inst

    return latent_dd


In [113]:
latent_expr = VAE_latent(vae, all_curves_list)

[32m2025-04-10 23:19:46.061[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m19[0m - [1m[64]/[386862][0m
[32m2025-04-10 23:19:46.063[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m19[0m - [1m[128]/[386862][0m
[32m2025-04-10 23:19:46.065[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m19[0m - [1m[192]/[386862][0m
[32m2025-04-10 23:19:46.067[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m19[0m - [1m[256]/[386862][0m
[32m2025-04-10 23:19:46.068[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m19[0m - [1m[320]/[386862][0m
[32m2025-04-10 23:19:46.073[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m19[0m - [1m[384]/[386862][0m
[32m2025-04-10 23:19:46.075[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m19[0m - [1m[448]/[386862][0m
[32m2025-04-10 23:19:46.077[0m | [1mINFO    [0m | [36m__main__[0m:[36mVAE_latent[0m:[36m1

In [123]:

plt.figure(figsize=(9, 9))
plt.scatter(latent_expr[:, 0], -latent_expr[:, 1], alpha=0.5, s = 0.001)

plt.gca().set_aspect('equal', adjustable='box')
plt.title("Latent Space")
plt.xlabel("Latent Dimension 1")
plt.ylabel("Latent Dimension 2")
# plt.grid()
plt.show()

# Curve-to-Vec 2xConv+4xFC

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Curve2VecAE(nn.Module):
    def __init__(self, in_ch=2, hid_ch=32, z_dim=64, k=5):
        super().__init__()
        # ===== Encoder =====
        # 两层 Conv1d
        self.conv1 = nn.Conv1d(in_ch, hid_ch, kernel_size=k, padding=k//2)
        self.bn1   = nn.BatchNorm1d(hid_ch)
        self.conv2 = nn.Conv1d(hid_ch, hid_ch*2, kernel_size=k, padding=k//2)
        self.bn2   = nn.BatchNorm1d(hid_ch*2)
        # 全局池化到长度 1 → [B, 64]
        self.pool  = nn.AdaptiveAvgPool1d(1)

        # 四层全连接：64→128→64→32→z_dim
        self.fc1 = nn.Linear(hid_ch*2, 128)
        self.fc2 = nn.Linear(128,       64)
        self.fc3 = nn.Linear(64,        32)
        self.fc4 = nn.Linear(32,        z_dim)

        # ===== Decoder =====
        # 四层全连接：z_dim→32→64→128→(64*101)
        self.fc5 = nn.Linear(z_dim,     32)
        self.fc6 = nn.Linear(32,        64)
        self.fc7 = nn.Linear(64,       128)
        self.fc8 = nn.Linear(128, hid_ch*2 * 101)

        # 两层 Conv1d 重建：64→32→2
        self.deconv1 = nn.Conv1d(hid_ch*2, hid_ch,   kernel_size=k, padding=k//2)
        self.bn3     = nn.BatchNorm1d(hid_ch)
        self.deconv2 = nn.Conv1d(hid_ch,   in_ch,    kernel_size=k, padding=k//2)

    def encode(self, x):
        """
        x: [B, 2, 101]
        returns: z [B, z_dim]
        """
        h = F.relu(self.bn1(self.conv1(x)))   # [B,32,101]
        h = F.relu(self.bn2(self.conv2(h)))   # [B,64,101]
        h = self.pool(h).squeeze(-1)          # [B,64]
        h = F.relu(self.fc1(h))               # [B,128]
        h = F.relu(self.fc2(h))               # [B,64]
        h = F.relu(self.fc3(h))               # [B,32]
        z = self.fc4(h)                       # [B,z_dim]
        return z

    def decode(self, z):
        """
        z: [B, z_dim]
        returns: x_rec [B,2,101]
        """
        h = F.relu(self.fc5(z))               # [B,32]
        h = F.relu(self.fc6(h))               # [B,64]
        h = F.relu(self.fc7(h))               # [B,128]
        h = self.fc8(h)                       # [B,64*101]
        h = h.view(-1, 64, 101)               # [B,64,101]
        h = F.relu(self.bn3(self.deconv1(h))) # [B,32,101]
        x_rec = self.deconv2(h)               # [B,2,101]
        return x_rec

    def forward(self, x):
        z     = self.encode(x)
        x_rec = self.decode(z)
        return x_rec, z


# VAE in journal

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class CurveDataset(Dataset):
    def __init__(self, data_list):
        # data_list: list of [202] numpy or torch.Tensor
        self.data = [torch.tensor(x, dtype=torch.float32) for x in data_list]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

class VAE(nn.Module):
    def __init__(self, input_dim=202, hidden_dims=[512,256,128,64], z_dim=2):
        super().__init__()
        # --- Encoder ---
        self.enc_fc1 = nn.Linear(input_dim, hidden_dims[0])  # 160 → 512
        self.enc_fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])  # 512 → 256
        self.enc_fc3 = nn.Linear(hidden_dims[1], hidden_dims[2])  # 256 → 128
        self.enc_fc4 = nn.Linear(hidden_dims[2], hidden_dims[3])  # 128 → 64
        # mean & logvar
        self.fc_mu    = nn.Linear(hidden_dims[3], z_dim)     # 64 → 2
        self.fc_logvar= nn.Linear(hidden_dims[3], z_dim)     # 64 → 2

        # --- Decoder ---
        self.dec_fc1 = nn.Linear(z_dim,        hidden_dims[3])  # 2 → 64
        self.dec_fc2 = nn.Linear(hidden_dims[3], hidden_dims[2])# 64 → 128
        self.dec_fc3 = nn.Linear(hidden_dims[2], hidden_dims[1])# 128 → 256
        self.dec_fc4 = nn.Linear(hidden_dims[1], hidden_dims[0])# 256 → 512
        self.dec_fc5 = nn.Linear(hidden_dims[0], input_dim)     # 512 → 160

    def encode(self, x):
        """
        x: [B, 160]
        returns: mu [B,2], logvar [B,2]
        """
        h = F.relu(self.enc_fc1(x))
        h = F.relu(self.enc_fc2(h))
        h = F.relu(self.enc_fc3(h))
        h = F.relu(self.enc_fc4(h))
        mu     = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """
        z = mu + eps * sigma
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        """
        z: [B,2]
        returns: x_rec [B,160]
        """
        h = F.relu(self.dec_fc1(z))
        h = F.relu(self.dec_fc2(h))
        h = F.relu(self.dec_fc3(h))
        h = F.relu(self.dec_fc4(h))
        x_rec = self.dec_fc5(h)
        return x_rec

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_rec = self.decode(z)
        return x_rec, mu, logvar


def vae_loss(x_rec, x, mu, logvar, kld_weight=1e-3):
    """
    重建误差 + KL 散度
    """
    rec_loss = F.mse_loss(x_rec, x, reduction='mean')
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return rec_loss + kld_weight * kld

def train_curve2vec(model, data_list, num_epochs=20, batch_size=64, lr=1e-3):
    train_list, val_list = train_test_split(data_list, test_size=0.2)
    train_ds = CurveDataset(train_list)
    val_ds   = CurveDataset(val_list)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_loss_recorder = []
    eval_loss_recorder = []

    for epoch in range(1, num_epochs+1):
        model.train()
        train_loss = 0
        for x in train_loader:
            x = x.to(device)
            x_rec, mu, lv = model(x)
            loss = vae_loss(x_rec, x, mu, lv)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * x.size(0)

        # 验证
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x in val_loader:
                x = x.to(device)
                x_rec, mu, lv = model(x)
                loss = vae_loss(x_rec, x, mu, lv)
                val_loss += loss.item() * x.size(0)
        train_loss /= len(train_ds)
        val_loss   /= len(val_ds)
        print(f"Epoch {epoch}/{num_epochs}  Train: {train_loss:.4f}  Val: {val_loss:.4f}")


        train_loss_recorder.append(train_loss)
        eval_loss_recorder.append(val_loss)

    return model, train_loss_recorder, eval_loss_recorder


# ===== 可视化重建 =====
def visualize_recon(model, data_list, num=5):
    ds = CurveDataset(data_list)
    loader = DataLoader(ds, batch_size=num, shuffle=True)
    x = next(iter(loader)).to(device)   # [num,2,101]
    model.eval()
    with torch.no_grad():
        x_rec, mu, lv = model(x)

    x = x.cpu().numpy()
    x_rec = x_rec.cpu().numpy()


    for i in range(num):
        plt.figure(figsize=(6,3))
        # 实部
        plt.subplot(1,2,1)
        plt.plot(x[i,:101], label="orig")
        plt.plot(x_rec[i,:101], '--', label="rec")
        plt.title(f"Sample {i} Real")
        plt.legend()
        # 虚部
        plt.subplot(1,2,2)
        plt.plot(x[i,101:], label="orig")
        plt.plot(x_rec[i,101:], '--', label="rec")
        plt.title(f"Sample {i} Imag")
        plt.legend()
        plt.tight_layout()
        plt.show()




In [7]:
all_curves_list = np.concatenate(all_data_list, axis=0)

_lz = all_curves_list[:,:,0]
_phi = all_curves_list[:,:,1]

# _z = np.exp(_lz + 1j*_phi)
_z = _lz + 1j*_phi

all_curves_list = np.concatenate((np.real(_z), -np.imag(_z)), axis=1)
# all_curves_list = np.concatenate((all_curves_list[:,:,0], all_curves_list[:,:,1]), axis=1)
np.shape(all_curves_list)


(386862, 202)

In [8]:
plt.figure()
# plt.semilogy(all_curves_list[0,:101], label="real")
# plt.semilogy(all_curves_list[0,101:][::-1], label="imag")
plt.plot(all_curves_list[0,:101], label="real")
plt.plot(all_curves_list[0,101:][::-1], label="imag")

[<matplotlib.lines.Line2D at 0x2094e02fa10>]

In [12]:

vae = VAE().to(device)
vae, train_loss, eval_loss = train_curve2vec(vae, all_curves_list, num_epochs=100)


Epoch 1/100  Train: 0.2968  Val: 0.1103
Epoch 2/100  Train: 0.0594  Val: 0.0290
Epoch 3/100  Train: 0.0326  Val: 0.0355
Epoch 4/100  Train: 0.0288  Val: 0.0314
Epoch 5/100  Train: 0.0276  Val: 0.0266
Epoch 6/100  Train: 0.0267  Val: 0.0213
Epoch 7/100  Train: 0.0255  Val: 0.0242
Epoch 8/100  Train: 0.0254  Val: 0.0237
Epoch 9/100  Train: 0.0254  Val: 0.0225
Epoch 10/100  Train: 0.0238  Val: 0.0255
Epoch 11/100  Train: 0.0232  Val: 0.0281
Epoch 12/100  Train: 0.0226  Val: 0.0213
Epoch 13/100  Train: 0.0219  Val: 0.0279
Epoch 14/100  Train: 0.0216  Val: 0.0207
Epoch 15/100  Train: 0.0212  Val: 0.0216
Epoch 16/100  Train: 0.0211  Val: 0.0211
Epoch 17/100  Train: 0.0212  Val: 0.0205
Epoch 18/100  Train: 0.0222  Val: 0.0283
Epoch 19/100  Train: 0.0207  Val: 0.0195
Epoch 20/100  Train: 0.0219  Val: 0.0228
Epoch 21/100  Train: 0.0223  Val: 0.0207
Epoch 22/100  Train: 0.0212  Val: 0.0209
Epoch 23/100  Train: 0.0207  Val: 0.0176
Epoch 24/100  Train: 0.0215  Val: 0.0217
Epoch 25/100  Train: 0.02

In [13]:
plt.figure()
plt.semilogy(train_loss, label="train")
plt.semilogy(eval_loss, label="eval")
plt.legend()


<matplotlib.legend.Legend at 0x20a13b81750>

In [14]:
visualize_recon(vae, all_curves_list)