# HW4P2: Automatic Speech Recognition with an Encoder-Decoder Transformer

# Schedule:
- Checkpoint Submission (DUE 21 November 2025 @ 11:59PM EST)
- Kaggle Submission (DUE 5 December 2025 @ 11:59PM EST | Slack Deadline is 11 December 2025 @ 11:59PM EST)
- Code Submission (DUE 7 December 2025 @ 11:59PM EST OR Day-of Slack submission)


## Requirement Acknowledgement
Setting the below flag to True indicates full understanding and acceptance of the following:
1. Slack days may ONLY be used on P2 FINAL (not checkpoint) submission. I.e. you may use slack days to submit final P2 kaggle scores (such as this one) later on the **SLACK KAGGLE COMPETITION** at the expense of your Slack days.
2. The final autolab **code submission is due 48 hours after** the conclusion of the Kaggle Deadline (or, the same day as your final kaggle submission).
3. We will require your kaggle username here, and then we will pull your official PRIVATE kaggle leaderboard score. This submission may result in slight variance in scores/code, but we will check for acceptable discrepancies. Any discrepancies related to modifying the submission code (at the bottom of the notebook) will result in an AIV.
4. You are NOT allowed to use any code that will pre-load models (such as those from Hugging Face, etc.).
   You MAY use models described by papers or articles, but you MUST implement them yourself through fundamental PyTorch operations (i.e. Linear, Conv2d, etc.).
5. You are NOT allowed to use any external data/datasets at ANY point of this assignment.
6. You may work with teammates to run ablations/experiments, BUT you must submit your OWN code and your OWN results.
7. Failure to comply with the prior rules will be considered an Academic Integrity Violation (AIV).
8. Late submissions MUST be submitted through the Slack Kaggle (see writeup for details). Any submissions made to the regular Kaggle after the original deadline will NOT be considered, no matter how many slack days remain for the student.

In [None]:
ACKNOWLEDGED = True #TODO: Only set Acknowledged to True if you have read the above acknowlegements and agree to ALL of them.

# Setup
-  Follow the setup instructions based on your preferred environment!

## Local

One of our key goals in designing this assignment is to allow you to complete most of the preliminary implementation work locally.  
We highly recommend that you **pass all tests locally** using the provided `hw4_data_subset` before moving to a GPU runtime.  
To do this, simply:

### Create a new conda environment
```bash
# Be sure to deactivate any active environments first
conda create -n hw4 python=3.12.4
```

### Activate the conda environment
```bash
conda activate hw4
```

### Install the dependencies using the provided `requirements.txt`
```bash
pip install --no-cache-dir --ignore-installed -r requirements.txt
```

### Ensure that your notebook is in the same working directory as the `Handout`
This can be achieved by:
1. Physically moving the notebook into the handout directory.
2. Changing the notebook’s current working directory to the handout directory using the os.chdir() function.

### Open the notebook and select the newly created environment from the kernel selector.

If everything was done correctly, You should see atleast the following files in your current working directory after running `!ls`:
```
.
├── README.md
├── requirements.txt
├── hw4lib/
├── mytorch/
├── tests/
└── hw4_data_subset/
```

## Colab

### Step 1: Get your handout
- See writeup for recommended approaches.

In [None]:
# Example: My preferred approach
import os
# Settings -> Developer Settings -> Personal Access Tokens -> Token (classic)
os.environ['GITHUB_TOKEN'] = "your_github_token_here"

GITHUB_USERNAME = "your_github_username_here"
REPO_NAME       = "your_github_repo_name_here"
TOKEN = os.environ.get("GITHUB_TOKEN")
repo_url        = f"https://{TOKEN}@github.com/{GITHUB_USERNAME}/{REPO_NAME}.git"
!git clone {repo_url}

In [None]:
# To pull latest changes (Must be in the repo dir, use pwd/ls to verify)
!cd {REPO_NAME} && git pull

### Step 2: Install Dependencies
- `NOTE`: Your runtime will be restarted to ensure all dependencies are updated.
- `NOTE`: You will see a runtime crashed message, this was intentionally done. Simply move on to the next cell.

In [None]:
!pwd
%cd /content/

In [None]:
%pip install --no-deps -r IDL-HW4/requirements.txt
import os
os.kill(os.getpid(), 9) # NOTE: This will restart the your colab Python runtime (required)!

In [None]:
!pip install transformers -U

# Start

In [1]:
# Uncomment this if you want to use Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
import os
import zipfile
from tqdm import tqdm

# 1. 在 Google Drive 里的 zip 文件路径（你给的路径）
gdrive_zip_path = "/content/gdrive/MyDrive/11785Project/ProjFreq.zip"

# 2. 拷贝到本地 /content 下的目标路径
local_zip_path = "/content/ProjFreq.zip"

# 3. 解压目标目录
extract_dir = "/content/ProjFreq"

# ========== Step 1: 带进度条的复制 ==========
print("Step 1: 复制 zip 文件到本地 /content ...")

if not os.path.exists(gdrive_zip_path):
    raise FileNotFoundError(f"找不到文件: {gdrive_zip_path}")

os.makedirs("/content", exist_ok=True)

file_size = os.path.getsize(gdrive_zip_path)
chunk_size = 1024 * 1024  # 1 MB 一块

with open(gdrive_zip_path, "rb") as src, open(local_zip_path, "wb") as dst, \
     tqdm(total=file_size, unit="B", unit_scale=True, unit_divisor=1024, desc="Copying") as pbar:
    while True:
        chunk = src.read(chunk_size)
        if not chunk:
            break
        dst.write(chunk)
        pbar.update(len(chunk))

print(f"✅ 复制完成: {local_zip_path}\n")

# ========== Step 2: 带进度条的解压 ==========
print("Step 2: 解压到本地目录:", extract_dir)
os.makedirs(extract_dir, exist_ok=True)

with zipfile.ZipFile(local_zip_path, 'r') as zf:
    members = zf.infolist()
    for member in tqdm(members, desc="Unzipping", unit="file"):
        zf.extract(member, extract_dir)

print("\n✅ 解压完成！")
print("你现在可以从这里读取数据：", extract_dir)


Step 1: 复制 zip 文件到本地 /content ...


Copying: 100%|██████████| 6.01G/6.01G [02:29<00:00, 43.1MB/s]


✅ 复制完成: /content/ProjFreq.zip

Step 2: 解压到本地目录: /content/ProjFreq


Unzipping: 100%|██████████| 2647185/2647185 [05:42<00:00, 7738.41file/s]


✅ 解压完成！
你现在可以从这里读取数据： /content/ProjFreq





In [3]:
import os
from pathlib import Path
import pandas as pd

# 你的有效数据目录（不是 __MACOSX 那个）
DATA_ROOT = Path("/content/ProjFreq/ProjFreq")

print("==== 有效数据父目录结构 ====\n")
for root, dirs, files in os.walk(DATA_ROOT):
    level = root.replace(str(DATA_ROOT), "").count(os.sep)
    indent = " " * 2 * level
    print(f"{indent}{os.path.basename(root)}/")
    subindent = " " * 2 * (level + 1)
    for f in sorted(files)[:5]:
        print(f"{subindent}{f}")
print("\n======================================\n")

def show_first_real_csv(subdir_name: str):
    subdir = DATA_ROOT / subdir_name
    if not subdir.exists():
        print(f"[{subdir_name}] 不存在\n")
        return

    # 过滤掉以 ._ 开头的 Mac 资源文件，只要真正的 csv
    files = sorted([
        p for p in subdir.iterdir()
        if p.is_file()
        and p.suffix.lower() == ".csv"
        and not p.name.startswith("._")
    ])
    if not files:
        print(f"[{subdir_name}] 目录下没有真正的 csv 文件\n")
        return

    first_file = files[0]
    print(f"--- 目录 [{subdir_name}] 中第一个文件: {first_file} ---")
    try:
        df_head = pd.read_csv(first_file, nrows=5)
        print(df_head)
    except Exception as e:
        print(f"读取 {first_file} 出错: {e}")
    print("\n")

print("==== 看一下 trainFre / valFre / testFre 的第一个样本文件格式 ====\n")
show_first_real_csv("trainFre")
show_first_real_csv("valFre")
show_first_real_csv("testFre")
print("==== 结束 ====")


==== 有效数据父目录结构 ====

ProjFreq/
  .DS_Store
  testFre/
    3019644_0001_seg0001_freq.csv
    3019644_0001_seg0003_freq.csv
    3019644_0001_seg0006_freq.csv
    3019644_0001_seg0010_freq.csv
    3019644_0001_seg0014_freq.csv
  valFre/
    3019644_0001_seg0002_freq.csv
    3019644_0001_seg0005_freq.csv
    3019644_0001_seg0008_freq.csv
    3019644_0001_seg0011_freq.csv
    3019644_0001_seg0012_freq.csv
  trainFre/
    3019644_0001_seg0000_freq.csv
    3019644_0001_seg0004_freq.csv
    3019644_0001_seg0007_freq.csv
    3019644_0001_seg0009_freq.csv
    3019644_0001_seg0013_freq.csv


==== 看一下 trainFre / valFre / testFre 的第一个样本文件格式 ====

--- 目录 [trainFre] 中第一个文件: /content/ProjFreq/ProjFreq/trainFre/3019644_0001_seg0000_freq.csv ---
    freq_hz    ECG_real  ECG_imag    PPG_real   PPG_imag    ABP_real  \
0  0.000000  193.816400  0.000000  355.615840   0.000000  90617.2660   
1  0.166667    6.917574  4.894201   -8.122787 -23.515665  -2955.0700   
2  0.333333  -12.258882 -2.112530   19.733555 

# model1 ECG PPG 一起推测ABP

In [None]:
import os
from pathlib import Path
import glob
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

DATA_ROOT = Path("/content/ProjFreq/ProjFreq")

# 1. 简单看一下每个 split 有多少个样本
def count_files(split_name: str):
    folder = DATA_ROOT / f"{split_name}Fre"
    files = sorted(glob.glob(str(folder / "*.csv")))
    print(f"{split_name}Fre: {len(files)} files")
    return files

train_files = count_files("train")
val_files   = count_files("val")
test_files  = count_files("test")


# 2. 定义频域 Dataset：输入 ECG/PPG 频谱，输出 ABP 频谱
class FreqBPDataset(Dataset):
    """
    每个样本：
      X: [num_freq, 4]  -> (ECG_real, ECG_imag, PPG_real, PPG_imag)
      y: [num_freq, 2]  -> (ABP_real, ABP_imag)

    你后面可以：
      - flatten 成一维向量喂 MLP
      - 或者当成一维“频谱序列”喂 1D-CNN / Transformer
    """
    def __init__(self, split: str, data_root: Path = DATA_ROOT):
        assert split in ["train", "val", "test"]
        folder = data_root / f"{split}Fre"
        self.files = sorted(
            p for p in folder.iterdir()
            if p.is_file() and p.suffix.lower() == ".csv"
        )
        if not self.files:
            raise RuntimeError(f"No csv files found in {folder}")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        csv_path = self.files[idx]
        df = pd.read_csv(csv_path)

        # 频率列可以保留也可以丢弃，看你后面要不要用
        # freq = df["freq_hz"].values.astype("float32")

        # 输入特征：ECG/PPG 的实部+虚部
        x = df[["ECG_real", "ECG_imag", "PPG_real", "PPG_imag"]].values.astype("float32")

        # 标签：ABP 的实部+虚部（频域建模）
        y = df[["ABP_real", "ABP_imag"]].values.astype("float32")

        # 转成 tensor
        x = torch.from_numpy(x)    # shape: [num_freq, 4]
        y = torch.from_numpy(y)    # shape: [num_freq, 2]

        return x, y, str(csv_path)  # 把路径也返回，方便 debug

# 3. 简单测试一下 Dataset
train_dataset = FreqBPDataset("train")
val_dataset   = FreqBPDataset("val")
test_dataset  = FreqBPDataset("test")

print("\nSample counts:")
print("  train:", len(train_dataset))
print("  val  :", len(val_dataset))
print("  test :", len(test_dataset))

# 看一个样本的形状
x0, y0, path0 = train_dataset[0]
print(f"\nFirst train sample: {path0}")
print("  X shape:", x0.shape)   # [num_freq, 4]
print("  y shape:", y0.shape)   # [num_freq, 2]

import torch
from torch.utils.data import DataLoader

# --------- 自定义 collate_fn：对变长频谱做 padding ----------
def freqbp_collate_fn(batch):
    """
    batch 是一个 list，里面每个元素是 Dataset.__getitem__ 的返回：
        x: [L_i, 4]
        y: [L_i, 2]
        path: str
    我们返回：
        batch_x: [B, L_max, 4]
        batch_y: [B, L_max, 2]
        mask:    [B, L_max]  True 表示有效位置，False 表示 padding
        paths:   list[str]
    """
    xs, ys, paths = zip(*batch)  # xs, ys 是长度为 B 的 list，每个元素是 [L_i, D]

    lengths = [x.shape[0] for x in xs]
    B = len(xs)
    L_max = max(lengths)

    # 准备 padding 容器
    batch_x = torch.zeros(B, L_max, xs[0].shape[1], dtype=xs[0].dtype)
    batch_y = torch.zeros(B, L_max, ys[0].shape[1], dtype=ys[0].dtype)
    mask = torch.zeros(B, L_max, dtype=torch.bool)

    for i, (x, y) in enumerate(zip(xs, ys)):
        L = x.shape[0]
        batch_x[i, :L, :] = x
        batch_y[i, :L, :] = y
        mask[i, :L] = True   # 前 L 个位置是有效的

    return batch_x, batch_y, mask, paths

# --------- 用新的 collate_fn 创建 DataLoader ----------
train_dataset = FreqBPDataset("train")
val_dataset   = FreqBPDataset("val")
test_dataset  = FreqBPDataset("test")

train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0,           # 先设 0，稳定了再考虑开多进程
    collate_fn=freqbp_collate_fn,
)

# 试取一个 batch 看看形状
batch_x, batch_y, batch_mask, batch_paths = next(iter(train_loader))
print("batch_x shape:", batch_x.shape)   # [B, L_max, 4]
print("batch_y shape:", batch_y.shape)   # [B, L_max, 2]
print("mask shape   :", batch_mask.shape)
print("paths[0]     :", batch_paths[0])
print("有效长度示例:", batch_mask[0].sum().item())


trainFre: 441206 files
valFre: 441185 files
testFre: 441197 files

Sample counts:
  train: 441206
  val  : 441185
  test : 441197

First train sample: /content/ProjFreq/ProjFreq/trainFre/3019644_0001_seg0000_freq.csv
  X shape: torch.Size([121, 4])
  y shape: torch.Size([121, 2])
batch_x shape: torch.Size([16, 121, 4])
batch_y shape: torch.Size([16, 121, 2])
mask shape   : torch.Size([16, 121])
paths[0]     : /content/ProjFreq/ProjFreq/trainFre/3644535_0010_seg2619_freq.csv
有效长度示例: 120


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm   # 用 notebook 版的 tqdm

# 复用之前的 FreqBPDataset 和 freqbp_collate_fn

class FreqBPDatasetSubset(FreqBPDataset):
    """在原来的 Dataset 上加一个 max_files，只用前 max_files 个文件调试。"""
    def __init__(self, split: str, data_root=DATA_ROOT, max_files: int = None):
        super().__init__(split, data_root)
        if max_files is not None and max_files > 0:
            self.files = self.files[:max_files]
        print(f"[{split}] Using {len(self.files)} files for this run.")

# 先用小一点的数据量调试，比如每个 split 2000 个样本
train_dataset = FreqBPDatasetSubset("train", max_files=2000)
val_dataset   = FreqBPDatasetSubset("val",   max_files=500)  # 验证更少点也行

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


class FreqBPModel(nn.Module):
    def __init__(self, in_channels=4, hidden_dim=64, num_layers=3, out_channels=2):
        super().__init__()
        self.input_proj = nn.Linear(in_channels, hidden_dim)

        convs = []
        for i in range(num_layers):
            convs.append(
                nn.Conv1d(
                    in_channels=hidden_dim,
                    out_channels=hidden_dim,
                    kernel_size=3,
                    padding=1,
                )
            )
            convs.append(nn.ReLU())
        self.conv_net = nn.Sequential(*convs)

        self.output_proj = nn.Linear(hidden_dim, out_channels)

    def forward(self, x, mask=None):
        # x: [B, L, 4]
        x = self.input_proj(x)     # [B, L, H]
        x = x.transpose(1, 2)      # [B, H, L]
        x = self.conv_net(x)       # [B, H, L]
        x = x.transpose(1, 2)      # [B, L, H]
        out = self.output_proj(x)  # [B, L, 2]
        return out


def masked_mae(pred, target, mask):
    mask_expanded = mask.unsqueeze(-1)  # [B, L, 1]
    diff = (pred - target).abs() * mask_expanded
    valid_count = mask_expanded.sum()
    mae = diff.sum() / (valid_count + 1e-8)
    return mae


train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=0,
    collate_fn=freqbp_collate_fn,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=0,
    collate_fn=freqbp_collate_fn,
)

model = FreqBPModel(in_channels=4, hidden_dim=64, num_layers=3, out_channels=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


def train_one_epoch(model, loader, optimizer, device, epoch):
    model.train()
    total_loss = 0.0
    count = 0

    # 给 DataLoader 加 tqdm，看 batch 进度
    pbar = tqdm(loader, desc=f"Train epoch {epoch}", leave=False)
    for batch_idx, (batch_x, batch_y, batch_mask, batch_paths) in enumerate(pbar):
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        batch_mask = batch_mask.to(device)

        optimizer.zero_grad()
        pred = model(batch_x, mask=batch_mask)
        loss = masked_mae(pred, batch_y, batch_mask)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        count += 1
        pbar.set_postfix({"loss": loss.item()})

    return total_loss / max(count, 1)


@torch.no_grad()
def eval_one_epoch(model, loader, device, epoch):
    model.eval()
    total_loss = 0.0
    count = 0

    pbar = tqdm(loader, desc=f"Val epoch {epoch}", leave=False)
    for batch_idx, (batch_x, batch_y, batch_mask, batch_paths) in enumerate(pbar):
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        batch_mask = batch_mask.to(device)

        pred = model(batch_x, mask=batch_mask)
        loss = masked_mae(pred, batch_y, batch_mask)

        total_loss += loss.item()
        count += 1
        pbar.set_postfix({"loss": loss.item()})

    return total_loss / max(count, 1)


# 先小跑几轮，确认整个 pipeline 是通的
num_epochs = 2
for epoch in range(1, num_epochs + 1):
    print(f"\n===== Epoch {epoch} =====")
    train_loss = train_one_epoch(model, train_loader, optimizer, device, epoch)
    val_loss = eval_one_epoch(model, val_loader, device, epoch)
    print(f"Epoch {epoch}: train MAE = {train_loss:.4f}, val MAE = {val_loss:.4f}")


[train] Using 2000 files for this run.
[val] Using 500 files for this run.
Using device: cuda

===== Epoch 1 =====


Train epoch 1:   0%|          | 0/32 [00:00<?, ?it/s]

Val epoch 1:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 1: train MAE = 1050.9876, val MAE = 1108.9745

===== Epoch 2 =====


Train epoch 2:   0%|          | 0/32 [00:00<?, ?it/s]

Val epoch 2:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 2: train MAE = 908.3920, val MAE = 641.0414


In [None]:
import os
from pathlib import Path
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

# ====================== 路径设置 ======================
DATA_ROOT = Path("/content/ProjFreq/ProjFreq")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ====================== Dataset：加上归一化 & 返回 freq_hz ======================
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from pathlib import Path
from tqdm.auto import tqdm

DATA_ROOT = Path("/content/ProjFreq/ProjFreq")  # 确保和你原来的一致

class FreqBPDataset(Dataset):
    """
    预加载版 Dataset：
      - 在 __init__ 里把所有 CSV 一次性读进内存
      - 做 per-file 标准化（ECG/PPG）
      - 存成 tensor 列表，后续 __getitem__ 不再读文件

    每个样本：
      x_norm: [L, 4] -> (ECG_real, ECG_imag, PPG_real, PPG_imag)，已标准化
      y:      [L, 2] -> (ABP_real, ABP_imag)，原始尺度（频域 MAE 用）
      freqs:  [L]    -> 频率轴
      path:   str    -> 文件路径字符串
    """
    def __init__(self, split: str, data_root: Path = DATA_ROOT):
        assert split in ["train", "val", "test"]
        folder = data_root / f"{split}Fre"
        files = sorted(
            p for p in folder.iterdir()
            if p.is_file() and p.suffix.lower() == ".csv"
        )
        if not files:
            raise RuntimeError(f"No csv files found in {folder}")

        self.split = split
        self.paths = []
        self.x_list = []
        self.y_list = []
        self.freqs_list = []

        print(f"[{split}] Found {len(files)} files. Loading into RAM...")

        # 预加载所有样本到内存
        for csv_path in tqdm(files, desc=f"Loading {split}", ncols=100):
            df = pd.read_csv(csv_path)

            freqs = df["freq_hz"].values.astype("float32")

            # 输入特征：ECG/PPG 实部 + 虚部
            x = df[["ECG_real", "ECG_imag", "PPG_real", "PPG_imag"]].values.astype("float32")
            # per-file 统计归一化
            mean = x.mean(axis=0, keepdims=True)
            std = x.std(axis=0, keepdims=True)
            x_norm = (x - mean) / (std + 1e-6)

            # 标签：ABP 频谱（保持原始尺度）
            y = df[["ABP_real", "ABP_imag"]].values.astype("float32")

            # 转成 tensor 并缓存
            self.x_list.append(torch.from_numpy(x_norm))   # [L, 4]
            self.y_list.append(torch.from_numpy(y))        # [L, 2]
            self.freqs_list.append(torch.from_numpy(freqs))# [L]
            self.paths.append(csv_path.as_posix())

        print(f"[{split}] Finished loading {len(self.paths)} samples into RAM.")

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        # 不做 clone，一般训练不会对样本 in-place 修改，避免额外开销
        x = self.x_list[idx]
        y = self.y_list[idx]
        freqs = self.freqs_list[idx]
        path = self.paths[idx]
        return x, y, freqs, path


# ====================== collate_fn：对 X, y, freqs 做 padding ======================
def freqbp_collate_fn(batch):
    """
    batch 中每个元素： (x[L_i,4], y[L_i,2], freqs[L_i], path)
    返回：
      batch_x:     [B, L_max, 4]
      batch_y:     [B, L_max, 2]
      batch_freqs: [B, L_max]
      mask:        [B, L_max]  True 表示有效
      paths:       list[str]
    """
    xs, ys, freqs_list, paths = zip(*batch)
    lengths = [x.shape[0] for x in xs]
    B = len(xs)
    L_max = max(lengths)

    D_x = xs[0].shape[1]
    D_y = ys[0].shape[1]

    batch_x = torch.zeros(B, L_max, D_x, dtype=xs[0].dtype)
    batch_y = torch.zeros(B, L_max, D_y, dtype=ys[0].dtype)
    batch_freqs = torch.zeros(B, L_max, dtype=freqs_list[0].dtype)
    mask = torch.zeros(B, L_max, dtype=torch.bool)

    for i, (x, y, freqs) in enumerate(zip(xs, ys, freqs_list)):
        L = x.shape[0]
        batch_x[i, :L, :] = x
        batch_y[i, :L, :] = y
        batch_freqs[i, :L] = freqs
        mask[i, :L] = True

    return batch_x, batch_y, batch_freqs, mask, paths

# ====================== 模型 ======================
class FreqBPModel(nn.Module):
    """
    输入: x: [B, L, 4]（已归一化）
    输出: y_hat: [B, L, 2]（预测 ABP_real, ABP_imag）
    """
    def __init__(self, in_channels=4, hidden_dim=64, num_layers=3, out_channels=2):
        super().__init__()
        self.input_proj = nn.Linear(in_channels, hidden_dim)

        convs = []
        for i in range(num_layers):
            convs.append(
                nn.Conv1d(
                    in_channels=hidden_dim,
                    out_channels=hidden_dim,
                    kernel_size=3,
                    padding=1,   # 保持 L 不变
                )
            )
            convs.append(nn.ReLU())
        self.conv_net = nn.Sequential(*convs)

        self.output_proj = nn.Linear(hidden_dim, out_channels)

    def forward(self, x, mask=None):
        # x: [B, L, 4]
        x = self.input_proj(x)     # [B, L, H]
        x = x.transpose(1, 2)      # [B, H, L]
        x = self.conv_net(x)       # [B, H, L]
        x = x.transpose(1, 2)      # [B, L, H]
        out = self.output_proj(x)  # [B, L, 2]
        return out


class BigFreqBPModel(nn.Module):
    """
    更大一点的频域 1D-CNN 模型：
      输入:  x [B, L, 4]   (ECG_real, ECG_imag, PPG_real, PPG_imag)
      输出:  y [B, L, 2]   (ABP_real, ABP_imag)

    结构：
      - Linear:   4  -> 256  (频点特征升维)
      - 8 个 conv block，每个：
          Conv1d(256 -> 256, kernel_size=5, padding=2)
          BatchNorm1d(256)
          ReLU
          Dropout(0.1)
      - Linear:   256 -> 2

    参数量大约 2.6M（可以打印确认）
    """
    def __init__(self,
                 in_channels: int = 4,
                 hidden_dim: int = 256,
                 num_blocks: int = 8,
                 out_channels: int = 2,
                 kernel_size: int = 5,
                 dropout: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim

        # 1) 先把每个频点的 4 维特征升维到 256
        self.input_proj = nn.Linear(in_channels, hidden_dim)

        # 2) 沿频率轴堆 8 个 Conv1d block
        padding = kernel_size // 2  # 保持长度不变
        blocks = []
        for i in range(num_blocks):
            blocks.append(
                nn.Conv1d(
                    in_channels=hidden_dim,
                    out_channels=hidden_dim,
                    kernel_size=kernel_size,
                    padding=padding,
                )
            )
            blocks.append(nn.BatchNorm1d(hidden_dim))
            blocks.append(nn.ReLU())
            blocks.append(nn.Dropout(dropout))
        self.conv_net = nn.Sequential(*blocks)

        # 3) 输出投影到 ABP 频谱 (real, imag)
        self.output_proj = nn.Linear(hidden_dim, out_channels)

    def forward(self, x, mask=None):
        # x: [B, L, 4]
        B, L, C = x.shape

        # [B, L, 4] -> [B, L, 256]
        x = self.input_proj(x)

        # Conv1d 期望 [B, C, L]，所以换一下顺序
        x = x.transpose(1, 2)      # [B, 256, L]
        x = self.conv_net(x)       # [B, 256, L]
        x = x.transpose(1, 2)      # [B, L, 256]

        # [B, L, 256] -> [B, L, 2]
        out = self.output_proj(x)
        return out


# ====================== 损失：频域 masked MAE ======================
def masked_mae(pred, target, mask):
    """
    pred:   [B, L, 2]
    target: [B, L, 2]
    mask:   [B, L]
    """
    mask_expanded = mask.unsqueeze(-1)  # [B, L, 1]
    diff = (pred - target).abs() * mask_expanded
    valid_count = mask_expanded.sum()
    mae = diff.sum() / (valid_count + 1e-8)
    return mae

# ====================== 时域 MAE：iFFT 重建 ======================
def batch_time_mae(freqs, pred_y, true_y, mask, dt=0.008):
    """
    freqs:   [B, L]     频率轴（只到 20 Hz 的截断部分）
    pred_y:  [B, L, 2]  (real, imag)
    true_y:  [B, L, 2]
    mask:    [B, L]     True 为有效频点
    dt:      采样间隔（根据原始数据，0.008 秒）

    思路：
      - 对每个样本 b：
         * 取有效的频率 + 对应 real/imag
         * 根据 Δf 和 dt 估计 N（原始时域长度）
         * 构造 full_fft_pred / full_fft_true，填充到 k 索引处，其余高频为 0
         * irfft 得到时域波形
      - 对所有样本的时域点做 MAE
    """
    freqs = freqs.cpu().numpy()
    pred_y = pred_y.cpu().numpy()
    true_y = true_y.cpu().numpy()
    mask = mask.cpu().numpy()

    B, L = freqs.shape
    total_abs_err = 0.0
    total_count = 0

    for b in range(B):
        valid_idx = mask[b]  # [L]
        if not valid_idx.any():
            continue

        f = freqs[b, valid_idx]                   # [L_valid]
        pred_real = pred_y[b, valid_idx, 0]       # [L_valid]
        pred_imag = pred_y[b, valid_idx, 1]
        true_real = true_y[b, valid_idx, 0]
        true_imag = true_y[b, valid_idx, 1]

        if len(f) < 2:
            continue

        df = f[1] - f[0]                          # 频率分辨率
        # Δf = 1 / (N * dt)  =>  N = 1 / (dt * Δf)
        N = int(round(1.0 / (dt * df)))
        n_fft = N // 2 + 1

        full_pred_fft = np.zeros(n_fft, dtype=np.complex64)
        full_true_fft = np.zeros(n_fft, dtype=np.complex64)

        # 这些截断频率对应的索引 k = round(f / df)
        k_indices = np.round(f / df).astype(int)
        k_indices = np.clip(k_indices, 0, n_fft - 1)

        full_pred_fft[k_indices] = pred_real + 1j * pred_imag
        full_true_fft[k_indices] = true_real + 1j * true_imag

        pred_time = np.fft.irfft(full_pred_fft, n=N)
        true_time = np.fft.irfft(full_true_fft, n=N)

        diff = np.abs(pred_time - true_time)
        total_abs_err += diff.sum()
        total_count += N

    if total_count == 0:
        return 0.0
    return total_abs_err / total_count

# ====================== 构建 DataLoader ======================
train_dataset = FreqBPDataset("train")
val_dataset   = FreqBPDataset("val")

train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=0,
    collate_fn=freqbp_collate_fn,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=0,
    collate_fn=freqbp_collate_fn,
)

# ====================== 初始化模型 & 优化器 ======================
#model = FreqBPModel(in_channels=4, hidden_dim=64, num_layers=3, out_channels=2).to(device)
model = BigFreqBPModel(
    in_channels=4,
    hidden_dim=256,   # 可以改大/改小
    num_blocks=8,     # 卷积 block 数量
    out_channels=2,
    kernel_size=5,
    dropout=0.1
).to(device)

# ===== 打印模型参数量 =====
def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable, total - trainable

total, trainable, non_trainable = count_parameters(model)
print("===== Model Parameter Count =====")
print(f"Total params        : {total:,}")
print(f"Trainable params    : {trainable:,}")
print(f"Non-trainable params: {non_trainable:,}")
print("=================================")

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# ====================== 训练 & 验证 ======================
def train_one_epoch(model, loader, optimizer, device, epoch):
    model.train()
    total_loss = 0.0
    count = 0
    pbar = tqdm(loader, desc=f"Train epoch {epoch}", leave=False)
    for batch_x, batch_y, batch_freqs, batch_mask, batch_paths in pbar:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        batch_mask = batch_mask.to(device)

        optimizer.zero_grad()
        pred = model(batch_x, mask=batch_mask)
        loss = masked_mae(pred, batch_y, batch_mask)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        count += 1
        pbar.set_postfix({"freq_MAE": loss.item()})

    return total_loss / max(count, 1)


@torch.no_grad()
def eval_one_epoch(model, loader, device, epoch, max_time_batches=200):
    model.eval()
    total_freq_loss = 0.0
    count = 0

    total_time_mae = 0.0
    time_batches = 0

    pbar = tqdm(loader, desc=f"Val epoch {epoch}", leave=False)
    for batch_idx, (batch_x, batch_y, batch_freqs, batch_mask, batch_paths) in enumerate(pbar):
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        batch_mask = batch_mask.to(device)
        batch_freqs = batch_freqs.to(device)

        pred = model(batch_x, mask=batch_mask)
        freq_loss = masked_mae(pred, batch_y, batch_mask)

        total_freq_loss += freq_loss.item()
        count += 1

        # 只在前 max_time_batches 个 batch 上计算时域 MAE，避免太耗时
        if batch_idx < max_time_batches:
            time_mae_batch = batch_time_mae(batch_freqs, pred, batch_y, batch_mask, dt=0.008)
            total_time_mae += time_mae_batch
            time_batches += 1

        pbar.set_postfix({"freq_MAE": freq_loss.item()})

    avg_freq_mae = total_freq_loss / max(count, 1)
    avg_time_mae = total_time_mae / max(time_batches, 1) if time_batches > 0 else 0.0
    return avg_freq_mae, avg_time_mae

# ====================== 跑几个 epoch，打印频域 + 时域 MAE ======================
num_epochs = 20  # 你可以自己改大，比如 5 或 10
for epoch in range(1, num_epochs + 1):
    print(f"\n===== Epoch {epoch} =====")
    train_freq_mae = train_one_epoch(model, train_loader, optimizer, device, epoch)
    val_freq_mae, val_time_mae = eval_one_epoch(model, val_loader, device, epoch, max_time_batches=200)

    print(
        f"Epoch {epoch}: "
        f"train_freq_MAE = {train_freq_mae:.4f}, "
        f"val_freq_MAE = {val_freq_mae:.4f}, "
        f"val_time_MAE = {val_time_mae:.4f}"
    )


Using device: cuda
[train] Found 441206 files. Loading into RAM...


Loading train:   0%|                                                     | 0/441206 [00:00<?, ?it/s]

[train] Finished loading 441206 samples into RAM.
[val] Found 441185 files. Loading into RAM...


Loading val:   0%|                                                       | 0/441185 [00:00<?, ?it/s]

[val] Finished loading 441185 samples into RAM.
===== Model Parameter Count =====
Total params        : 2,629,378
Trainable params    : 2,629,378
Non-trainable params: 0

===== Epoch 1 =====


Train epoch 1:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 1:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 1: train_freq_MAE = 600.8208, val_freq_MAE = 356.7817, val_time_MAE = 11.7897

===== Epoch 2 =====


Train epoch 2:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 2:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 2: train_freq_MAE = 368.8335, val_freq_MAE = 335.6141, val_time_MAE = 10.0327

===== Epoch 3 =====


Train epoch 3:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 3:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 3: train_freq_MAE = 352.9663, val_freq_MAE = 325.1351, val_time_MAE = 10.0075

===== Epoch 4 =====


Train epoch 4:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 4:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 4: train_freq_MAE = 343.8804, val_freq_MAE = 315.8017, val_time_MAE = 8.7173

===== Epoch 5 =====


Train epoch 5:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 5:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 5: train_freq_MAE = 337.7901, val_freq_MAE = 312.3817, val_time_MAE = 8.9072

===== Epoch 6 =====


Train epoch 6:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 6:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 6: train_freq_MAE = 333.1203, val_freq_MAE = 306.8314, val_time_MAE = 8.7606

===== Epoch 7 =====


Train epoch 7:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 7:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 7: train_freq_MAE = 329.1629, val_freq_MAE = 303.8865, val_time_MAE = 8.6737

===== Epoch 8 =====


Train epoch 8:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 8:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 8: train_freq_MAE = 326.1599, val_freq_MAE = 299.8609, val_time_MAE = 8.0805

===== Epoch 9 =====


Train epoch 9:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 9:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 9: train_freq_MAE = 323.4499, val_freq_MAE = 297.9818, val_time_MAE = 8.2544

===== Epoch 10 =====


Train epoch 10:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 10:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 10: train_freq_MAE = 320.9752, val_freq_MAE = 296.5232, val_time_MAE = 8.1958

===== Epoch 11 =====


Train epoch 11:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 11:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 11: train_freq_MAE = 318.9637, val_freq_MAE = 293.2947, val_time_MAE = 8.3464

===== Epoch 12 =====


Train epoch 12:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 12:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 12: train_freq_MAE = 317.2409, val_freq_MAE = 293.6512, val_time_MAE = 7.9169

===== Epoch 13 =====


Train epoch 13:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 13:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 13: train_freq_MAE = 315.5618, val_freq_MAE = 291.5097, val_time_MAE = 8.0723

===== Epoch 14 =====


Train epoch 14:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 14:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 14: train_freq_MAE = 314.2550, val_freq_MAE = 289.7225, val_time_MAE = 7.8884

===== Epoch 15 =====


Train epoch 15:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 15:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 15: train_freq_MAE = 312.8352, val_freq_MAE = 288.1395, val_time_MAE = 7.7059

===== Epoch 16 =====


Train epoch 16:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 16:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 16: train_freq_MAE = 311.6268, val_freq_MAE = 286.8479, val_time_MAE = 7.5720

===== Epoch 17 =====


Train epoch 17:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 17:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 17: train_freq_MAE = 310.5819, val_freq_MAE = 285.3398, val_time_MAE = 7.7064

===== Epoch 18 =====


Train epoch 18:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 18:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 18: train_freq_MAE = 309.4820, val_freq_MAE = 284.7138, val_time_MAE = 7.6368

===== Epoch 19 =====


Train epoch 19:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 19:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 19: train_freq_MAE = 308.6093, val_freq_MAE = 284.6449, val_time_MAE = 7.7685

===== Epoch 20 =====


Train epoch 20:   0%|          | 0/6894 [00:00<?, ?it/s]

Val epoch 20:   0%|          | 0/6894 [00:00<?, ?it/s]

Epoch 20: train_freq_MAE = 307.6938, val_freq_MAE = 282.5302, val_time_MAE = 7.6637


# Basic Facilities

In [4]:
# ==== Block 1: 基础设置 ====

import os
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# 数据根目录
DATA_ROOT = Path("/content/ProjFreq/ProjFreq")

# 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [5]:
# ==== Block 2: Dataset + DataLoader ====

class FreqPPGECGABPDatasetCached(Dataset):
    """
    频域数据集（一次性读入内存）：
      - 每个样本来自一个 *_freq.csv
      - 对 ECG / PPG 做 per-file 标准化
      - ABP 频谱保持原始尺度
    """

    def __init__(self, split: str, data_root: Path = DATA_ROOT, max_files: int | None = None):
        """
        split: "train" / "val" / "test"
        max_files: 方便调试时限制文件个数，None 表示全量
        """
        assert split in ["train", "val", "test"]
        folder = data_root / f"{split}Fre"
        files = sorted(
            p for p in folder.iterdir()
            if p.is_file() and p.suffix.lower() == ".csv"
        )
        if not files:
            raise RuntimeError(f"No csv files found in {folder}")

        if max_files is not None:
            files = files[:max_files]

        self.paths = []
        self.ppg_list = []
        self.ecg_list = []
        self.abp_list = []
        self.freqs_list = []

        print(f"[{split}] Found {len(files)} files. Loading into RAM...")
        for csv_path in tqdm(files, desc=f"Loading {split}", ncols=100):
            df = pd.read_csv(csv_path)

            # freq 轴
            freqs = df["freq_hz"].values.astype("float32")  # [L]

            # [L,2]
            ecg_spec = df[["ECG_real", "ECG_imag"]].values.astype("float32")
            ppg_spec = df[["PPG_real", "PPG_imag"]].values.astype("float32")
            abp_spec = df[["ABP_real", "ABP_imag"]].values.astype("float32")

            # 对 ECG / PPG 做 per-file z-score 标准化
            # 这样 diffusion & BiLSTM 都在 ~N(0,1) 的尺度上工作，不用硬编码 scale
            def zscore(x: np.ndarray):
                m = x.mean(axis=0, keepdims=True)
                s = x.std(axis=0, keepdims=True)
                return (x - m) / (s + 1e-6)

            ecg_norm = zscore(ecg_spec)   # [L,2]
            ppg_norm = zscore(ppg_spec)   # [L,2]

            # 转成 tensor 并转置为 [2,L]
            ecg_norm = torch.from_numpy(ecg_norm).T.contiguous()
            ppg_norm = torch.from_numpy(ppg_norm).T.contiguous()
            abp_spec = torch.from_numpy(abp_spec).T.contiguous()
            freqs = torch.from_numpy(freqs)

            self.ecg_list.append(ecg_norm)    # [2,L]
            self.ppg_list.append(ppg_norm)    # [2,L]
            self.abp_list.append(abp_spec)    # [2,L]
            self.freqs_list.append(freqs)     # [L]
            self.paths.append(csv_path.as_posix())

        print(f"[{split}] Finished loading {len(self.paths)} samples into RAM.")

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        return (
            self.ppg_list[idx],   # [2,L]
            self.ecg_list[idx],   # [2,L]
            self.abp_list[idx],   # [2,L]
            self.freqs_list[idx], # [L]
            self.paths[idx],
        )


def freq_collate_fn(batch):
    """
    把不同 L 的样本 pad 成同一长度 L_max.

    batch: list of (ppg[2,L_i], ecg[2,L_i], abp[2,L_i], freqs[L_i], path)
    返回：
      ppg_batch   : [B, 2, L_max]
      ecg_batch   : [B, 2, L_max]
      abp_batch   : [B, 2, L_max]
      freqs_batch : [B, L_max]
      mask        : [B, L_max]   True = 有效位置
      paths       : list[str]
    """
    ppg_list, ecg_list, abp_list, freqs_list, paths = zip(*batch)
    B = len(batch)
    lengths = [f.size(0) for f in freqs_list]
    L_max = max(lengths)

    ppg_batch = torch.zeros(B, 2, L_max, dtype=ppg_list[0].dtype)
    ecg_batch = torch.zeros(B, 2, L_max, dtype=ecg_list[0].dtype)
    abp_batch = torch.zeros(B, 2, L_max, dtype=abp_list[0].dtype)
    freqs_batch = torch.zeros(B, L_max, dtype=freqs_list[0].dtype)
    mask = torch.zeros(B, L_max, dtype=torch.bool)

    for i, (ppg, ecg, abp, freqs) in enumerate(zip(ppg_list, ecg_list, abp_list, freqs_list)):
        L = freqs.size(0)
        ppg_batch[i, :, :L] = ppg
        ecg_batch[i, :, :L] = ecg
        abp_batch[i, :, :L] = abp
        freqs_batch[i, :L] = freqs
        mask[i, :L] = True

    return ppg_batch, ecg_batch, abp_batch, freqs_batch, mask, paths


In [6]:
# ==== Block 2.1: 构建 DataLoader ====

# 为了先调通代码，可以先用少量文件，比如 2000 / 500
MAX_TRAIN_FILES = 400000  # 比如设为 2000 调试，None 表示全量
MAX_VAL_FILES   = 100000  # 比如设为 500  调试

train_dataset = FreqPPGECGABPDatasetCached("train", max_files=MAX_TRAIN_FILES)
val_dataset   = FreqPPGECGABPDatasetCached("val",   max_files=MAX_VAL_FILES)

BATCH_SIZE = 256

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    collate_fn=freq_collate_fn,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=freq_collate_fn,
)

print("Train samples:", len(train_dataset))
print("Val samples  :", len(val_dataset))

# 看一眼一个 batch 的 shape
ppg_b, ecg_b, abp_b, freqs_b, mask_b, paths_b = next(iter(train_loader))
print("ppg_batch:", ppg_b.shape)   # [B,2,L_max]
print("ecg_batch:", ecg_b.shape)
print("abp_batch:", abp_b.shape)
print("freqs_batch:", freqs_b.shape)
print("mask:", mask_b.shape)
print("example path:", paths_b[0])


[train] Found 400000 files. Loading into RAM...


Loading train:   0%|                                                     | 0/400000 [00:00<?, ?it/s]

[train] Finished loading 400000 samples into RAM.
[val] Found 100000 files. Loading into RAM...


Loading val:   0%|                                                       | 0/100000 [00:00<?, ?it/s]

[val] Finished loading 100000 samples into RAM.
Train samples: 400000
Val samples  : 100000
ppg_batch: torch.Size([256, 2, 121])
ecg_batch: torch.Size([256, 2, 121])
abp_batch: torch.Size([256, 2, 121])
freqs_batch: torch.Size([256, 121])
mask: torch.Size([256, 121])
example path: /content/ProjFreq/ProjFreq/trainFre/3553816_0020_seg6337_freq.csv


In [7]:
# ==== Block 3: 工具函数 ====

def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable, total - trainable


def masked_l1_loss(pred, target, mask):
    """
    pred, target: [B, C, L]
    mask       : [B, L]
    """
    mask = mask.unsqueeze(1)  # [B,1,L]
    diff = (pred - target).abs() * mask
    denom = mask.sum()
    return diff.sum() / (denom + 1e-8)


@torch.no_grad()
def batch_time_mae_from_spec(freqs, pred_spec, true_spec, mask, dt=0.008):
    """
    把 ABP 频谱 (real, imag) 还原成时域波形，然后计算 MAE.

    freqs     : [B, L]
    pred_spec : [B, 2, L]  (real, imag)
    true_spec : [B, 2, L]
    mask      : [B, L]     True=有效频点
    dt        : 采样间隔（原始数据是 0.008 s）
    """
    freqs = freqs.cpu().numpy()
    pred = pred_spec.cpu().numpy()
    true = true_spec.cpu().numpy()
    mask = mask.cpu().numpy()

    B, _, L = pred.shape
    total_err = 0.0
    total_count = 0

    for b in range(B):
        valid = mask[b]
        if not valid.any():
            continue

        f = freqs[b, valid]          # [L_valid]
        pred_r = pred[b, 0, valid]
        pred_i = pred[b, 1, valid]
        true_r = true[b, 0, valid]
        true_i = true[b, 1, valid]

        if len(f) < 2:
            continue

        df = f[1] - f[0]             # 频率分辨率
        N = int(round(1.0 / (dt * df)))   # Δf = 1 / (N * dt)
        n_fft = N // 2 + 1

        full_pred = np.zeros(n_fft, dtype=np.complex64)
        full_true = np.zeros(n_fft, dtype=np.complex64)

        k_idx = np.round(f / df).astype(int)
        k_idx = np.clip(k_idx, 0, n_fft - 1)

        full_pred[k_idx] = pred_r + 1j * pred_i
        full_true[k_idx] = true_r + 1j * true_i

        pred_t = np.fft.irfft(full_pred, n=N)
        true_t = np.fft.irfft(full_true, n=N)

        diff = np.abs(pred_t - true_t)
        total_err += diff.sum()
        total_count += N

    if total_count == 0:
        return 0.0
    return total_err / total_count


# Model 2 一起训练

In [None]:
# ==== Block 4: 模型定义（大号版） ====

class ResConvBlock1D(nn.Module):
    """
    1D 残差卷积块: [B, C, L] -> [B, C, L]
    比原来的 ConvBlock1D 更深一点，并且有 skip connection，表达力更强。
    """
    def __init__(self, ch, k=3, p=1, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=k, padding=p),
            nn.BatchNorm1d(ch),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv1d(ch, ch, kernel_size=k, padding=p),
            nn.BatchNorm1d(ch),
            nn.SiLU(),
        )

    def forward(self, x):
        return x + self.net(x)


class FreqConditionalEpsNet1D(nn.Module):
    """
    更大号的频域 eps 模型（不下采样，不上采样）：
      - base_ch 默认 512
      - 深度：1 个输入卷积 + 6 个残差块
      输入:
        ecg_t : [B, 2, L]
        ppg   : [B, 2, L]
        t     : [B]
      输出:
        eps   : [B, 2, L]
    """
    def __init__(self, base_ch=512, time_emb_dim=512, dropout=0.1):
        super().__init__()
        self.time_emb_dim = time_emb_dim
        self.base_ch = base_ch

        # 时间步嵌入: sinusoidal -> MLP -> base_ch 维
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, base_ch),
            nn.ReLU(),
        )
        self.to_t = nn.Linear(base_ch, base_ch)

        in_ch = 4  # 2(ECG) + 2(PPG)

        # 输入卷积：4 -> base_ch
        self.conv_in = nn.Sequential(
            nn.Conv1d(in_ch, base_ch, kernel_size=3, padding=1),
            nn.BatchNorm1d(base_ch),
            nn.SiLU(),
        )

        # 6 个残差卷积块，全是 same-length
        self.blocks = nn.ModuleList([
            ResConvBlock1D(base_ch, k=3, p=1, dropout=dropout)
            for _ in range(6)
        ])

        # 输出 eps_pred: 2 通道（real, imag）
        self.out_conv = nn.Conv1d(base_ch, 2, kernel_size=1)

    def sinusoidal_embedding(self, t, dim):
        device = t.device
        half_dim = dim // 2
        emb_factor = torch.exp(
            torch.arange(half_dim, device=device)
            * (-torch.log(torch.tensor(10000.0)) / half_dim)
        )
        emb = t[:, None].float() * emb_factor[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if dim % 2 == 1:
            emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
        return emb

    def forward(self, ecg_t, ppg, t):
        """
        ecg_t: [B, 2, L]
        ppg  : [B, 2, L]
        t    : [B]
        """
        # 拼接 ECG_t 和 PPG => [B, 4, L]
        x = torch.cat([ecg_t, ppg], dim=1)  # [B,4,L]

        # 时间 embedding
        t_sin  = self.sinusoidal_embedding(t, self.time_emb_dim)  # [B, D]
        t_base = self.time_mlp(t_sin)                             # [B, base_ch]
        t_feat = self.to_t(t_base)[:, :, None]                    # [B, base_ch, 1]

        # 输入卷积
        x = self.conv_in(x)  # [B, base_ch, L]

        # 6 个 ResBlock，每层都加一次时间条件
        for block in self.blocks:
            x = block(x + t_feat)

        # 输出 eps_pred [B, 2, L]
        eps = self.out_conv(x)
        return eps


class FreqPPG2ECGDiffusion(nn.Module):
    """
    频域 DDPM：在 ECG 频谱 (2ch) 上做扩散，PPG 频谱 (2ch) 作为条件。
    """
    def __init__(self, eps_model: FreqConditionalEpsNet1D, timesteps: int = 50):
        super().__init__()
        self.eps_model = eps_model
        self.T = timesteps

        beta_start, beta_end = 1e-4, 0.02
        betas = torch.linspace(beta_start, beta_end, timesteps)
        alphas = 1.0 - betas
        alpha_bars = torch.cumprod(alphas, dim=0)

        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alpha_bars', alpha_bars)

    def q_sample(self, x0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x0)
        alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1)
        return torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * noise

    def forward(self, ecg_spec, ppg_spec):
        """
        训练 diffusion:
        ecg_spec: [B,2,L]
        ppg_spec: [B,2,L]
        返回 diff_loss
        """
        B = ecg_spec.size(0)
        device = ecg_spec.device
        t = torch.randint(0, self.T, (B,), device=device)

        noise = torch.randn_like(ecg_spec)
        x_t = self.q_sample(ecg_spec, t, noise)

        eps_pred = self.eps_model(x_t, ppg_spec, t)
        diff_loss = F.mse_loss(eps_pred, noise)
        return diff_loss

    @torch.no_grad()
    def p_sample(self, x_t, ppg_spec, t):
        beta_t = self.betas[t].view(-1, 1, 1)
        alpha_t = self.alphas[t].view(-1, 1, 1)
        alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1)

        eps_theta = self.eps_model(x_t, ppg_spec, t)
        mean = (1 / torch.sqrt(alpha_t)) * (
            x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * eps_theta
        )

        if t[0] == 0:
            return mean
        else:
            noise = torch.randn_like(x_t)
            return mean + torch.sqrt(beta_t) * noise

    @torch.no_grad()
    def sample_ecg_spec(self, ppg_spec):
        """
        ppg_spec: [B,2,L]
        返回: 生成的 ECG 频谱 [B,2,L]
        """
        B, C, L = ppg_spec.shape
        device = ppg_spec.device
        x_t = torch.randn((B, 2, L), device=device)

        for step in reversed(range(self.T)):
            t = torch.full((B,), step, device=device, dtype=torch.long)
            x_t = self.p_sample(x_t, ppg_spec, t)

        return x_t


class FreqBPBiLSTM(nn.Module):
    """
    放大版频域 BP 估计器:
      输入:  PPG_spec (2ch) + ECG_spec (2ch) -> [B, L, 4]
      输出:  ABP_spec (2ch) -> [B, 2, L]
      结构:
        - BiLSTM hidden_dim=512
        - num_layers=4
    """
    def __init__(self, hidden_dim=512, num_layers=4, dropout=0.3):
        super().__init__()
        in_dim = 4

        self.lstm = nn.LSTM(
            input_size=in_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0.0,
        )
        # 可加一个中间层再输出，增大容量
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, 2),
        )

    def forward(self, ppg_spec, ecg_spec):
        # [B,2,L] -> [B,L,2]
        ppg = ppg_spec.transpose(1, 2)
        ecg = ecg_spec.transpose(1, 2)
        x = torch.cat([ppg, ecg], dim=-1)  # [B,L,4]

        out, _ = self.lstm(x)             # [B,L,2H]
        abp_hat = self.fc(out)            # [B,L,2]
        return abp_hat.transpose(1, 2)    # [B,2,L]


class FreqPPGGModel(nn.Module):
    def __init__(self,
                 timesteps=50,
                 base_ch=512,
                 time_emb_dim=512,
                 lstm_hidden=512,
                 lstm_layers=4,
                 lambda_bp=1.0):
        super().__init__()
        eps_net = FreqConditionalEpsNet1D(
            base_ch=base_ch,
            time_emb_dim=time_emb_dim,
        )
        self.diffusion = FreqPPG2ECGDiffusion(eps_net, timesteps=timesteps)
        self.bp_estimator = FreqBPBiLSTM(
            hidden_dim=lstm_hidden,
            num_layers=lstm_layers,
        )
        self.lambda_bp = lambda_bp

    def forward(self, ppg_spec, ecg_spec, abp_spec):
        diff_loss = self.diffusion(ecg_spec, ppg_spec)
        abp_hat = self.bp_estimator(ppg_spec, ecg_spec)
        # 注意这里的 mask 在 train_one_epoch 里会重新算，这里先用全 True
        bp_loss = masked_l1_loss(
            abp_hat,
            abp_spec,
            mask=torch.ones_like(abp_spec[:, 0, :]).bool()
        )
        total_loss = diff_loss + self.lambda_bp * bp_loss
        return total_loss, diff_loss, bp_loss, abp_hat


In [None]:
# ==== Block 5: 初始化模型 & 优化器 ====

model = FreqPPGGModel(
    timesteps=50,       # 先保持 50，不然 diffusion 太慢
    base_ch=512,
    time_emb_dim=512,
    lstm_hidden=512,
    lstm_layers=4,
    lambda_bp=1.0,
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)

total, trainable, nontrain = count_parameters(model)
print("===== Model Parameter Count =====")
print(f"Total params        : {total:,}")
print(f"Trainable params    : {trainable:,}")
print(f"Non-trainable params: {nontrain:,}")
print("=================================")


===== Model Parameter Count =====
Total params        : 32,061,956
Trainable params    : 32,061,956
Non-trainable params: 0


In [None]:
# ==== Block 6: 训练 & 验证 ====

from math import inf

def train_one_epoch(model, loader, optimizer, device, epoch):
    model.train()
    total_loss = 0.0
    total_diff = 0.0
    total_bp = 0.0
    steps = 0

    pbar = tqdm(loader, desc=f"Train epoch {epoch}", ncols=120)
    for ppg_b, ecg_b, abp_b, freqs_b, mask_b, paths_b in pbar:
        ppg_b = ppg_b.to(device)   # [B,2,L]
        ecg_b = ecg_b.to(device)
        abp_b = abp_b.to(device)
        mask_b = mask_b.to(device)

        optimizer.zero_grad()
        # 在 bp_loss 里我们自己传 mask
        total_l, diff_l, bp_l, abp_hat = model(ppg_b, ecg_b, abp_b)

        # 用实际 mask 更新 bp_loss（上面简化为全 True，这里再修正）
        bp_l = masked_l1_loss(abp_hat, abp_b, mask_b)
        total_l = diff_l + model.lambda_bp * bp_l

        total_l.backward()
        optimizer.step()

        total_loss += total_l.item()
        total_diff += diff_l.item()
        total_bp   += bp_l.item()
        steps += 1

        pbar.set_postfix({
            "loss": f"{total_loss/steps:.3f}",
            "diff": f"{total_diff/steps:.3f}",
            "bp":   f"{total_bp/steps:.3f}",
        })

    return total_loss / max(1, steps), total_diff / max(1, steps), total_bp / max(1, steps)


@torch.no_grad()
def eval_one_epoch(model, loader, device, epoch,
                   max_gen_batches: int = 10):
    """
    验证：
      - diff_loss_val: 用真实 ECG_norm 计算的 diffusion loss
      - bp_loss_teacher: 用真实 ECG_norm → ABP 的频域 MAE
      - bp_loss_gen_freq: 用生成 ECG 频谱 → ABP 的频域 MAE（前 max_gen_batches 个 batch）
      - bp_loss_gen_time: 用生成 ECG 频谱 → ABP 时域 MAE
    """
    model.eval()
    total_diff = 0.0
    total_bp_teacher = 0.0
    steps = 0

    total_bp_gen_freq = 0.0
    total_bp_gen_time = 0.0
    gen_steps = 0

    pbar = tqdm(loader, desc=f"Val epoch {epoch}", ncols=120)
    for batch_idx, (ppg_b, ecg_b, abp_b, freqs_b, mask_b, paths_b) in enumerate(pbar):
        ppg_b = ppg_b.to(device)
        ecg_b = ecg_b.to(device)
        abp_b = abp_b.to(device)
        freqs_b = freqs_b.to(device)
        mask_b = mask_b.to(device)

        # 1) diffusion loss + teacher forcing BP
        diff_l = model.diffusion(ecg_b, ppg_b)
        abp_hat_teacher = model.bp_estimator(ppg_b, ecg_b)
        bp_teacher = masked_l1_loss(abp_hat_teacher, abp_b, mask_b)

        total_diff += diff_l.item()
        total_bp_teacher += bp_teacher.item()
        steps += 1

        # 2) 用生成 ECG 评估（只在前 max_gen_batches 个 batch 上做，避免太慢）
        if batch_idx < max_gen_batches:
            ecg_gen = model.diffusion.sample_ecg_spec(ppg_b)  # [B,2,L]
            abp_hat_gen = model.bp_estimator(ppg_b, ecg_gen)

            bp_gen_freq = masked_l1_loss(abp_hat_gen, abp_b, mask_b)
            bp_gen_time = batch_time_mae_from_spec(freqs_b, abp_hat_gen, abp_b, mask_b)

            total_bp_gen_freq += bp_gen_freq
            total_bp_gen_time += bp_gen_time
            gen_steps += 1

        pbar.set_postfix({
            "diff": f"{total_diff/steps:.3f}",
            "bp_t": f"{total_bp_teacher/steps:.3f}",
            "bp_g_f": f"{(total_bp_gen_freq/max(1,gen_steps)):.3f}",
        })

    avg_diff = total_diff / max(1, steps)
    avg_bp_teacher = total_bp_teacher / max(1, steps)
    avg_bp_gen_freq = total_bp_gen_freq / max(1, gen_steps) if gen_steps > 0 else 0.0
    avg_bp_gen_time = total_bp_gen_time / max(1, gen_steps) if gen_steps > 0 else 0.0

    return avg_diff, avg_bp_teacher, avg_bp_gen_freq, avg_bp_gen_time


In [None]:
# ==== Block 7: 主训练循环 ====

NUM_EPOCHS = 80   # 先小一点试，确认能跑
BEST_VAL = inf

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"\n===== Epoch {epoch} =====")
    train_loss, train_diff, train_bp = train_one_epoch(model, train_loader, optimizer, device, epoch)

    val_diff, val_bp_teacher, val_bp_gen_freq, val_bp_gen_time = eval_one_epoch(
        model, val_loader, device, epoch,
        max_gen_batches=10,   # 只对前 10 个 batch 做生成 ECG 的评估
    )

    print(
        f"Epoch {epoch} summary:\n"
        f"  Train - loss: {train_loss:.4f}, diff: {train_diff:.4f}, bp: {train_bp:.4f}\n"
        f"  Val   - diff: {val_diff:.4f}, "
        f"bp_teacher_freq: {val_bp_teacher:.4f}, "
        f"bp_gen_freq: {val_bp_gen_freq:.4f}, "
        f"bp_gen_time: {val_bp_gen_time:.4f}"
    )

    # 简单保存最好模型（按生成版时域 MAE）
    if val_bp_gen_time < BEST_VAL:
        BEST_VAL = val_bp_gen_time
        torch.save(model.state_dict(), "best_freq_ppgg_model.pt")
        print(f"  ↳ Saved best model (val_bp_gen_time = {BEST_VAL:.4f})")



===== Epoch 1 =====


Train epoch 1:   0%|                                                                           | 0/6250 [00:00…

Val epoch 1:   0%|                                                                              | 0/782 [00:00…

Epoch 1 summary:
  Train - loss: 588.4030, diff: 0.4579, bp: 587.9452
  Val   - diff: 0.3297, bp_teacher_freq: 267.7180, bp_gen_freq: 469.1256, bp_gen_time: 24.1713
  ↳ Saved best model (val_bp_gen_time = 24.1713)

===== Epoch 2 =====


Train epoch 2:   0%|                                                                           | 0/6250 [00:00…

Val epoch 2:   0%|                                                                              | 0/782 [00:00…

Epoch 2 summary:
  Train - loss: 311.7880, diff: 0.3944, bp: 311.3937
  Val   - diff: 0.3110, bp_teacher_freq: 213.7534, bp_gen_freq: 436.1187, bp_gen_time: 17.8576
  ↳ Saved best model (val_bp_gen_time = 17.8576)

===== Epoch 3 =====


Train epoch 3:   0%|                                                                           | 0/6250 [00:00…

Val epoch 3:   0%|                                                                              | 0/782 [00:00…

Epoch 3 summary:
  Train - loss: 276.5565, diff: 0.3763, bp: 276.1802
  Val   - diff: 0.3019, bp_teacher_freq: 188.8838, bp_gen_freq: 474.9619, bp_gen_time: 21.0113

===== Epoch 4 =====


Train epoch 4:   0%|                                                                           | 0/6250 [00:00…

Val epoch 4:   0%|                                                                              | 0/782 [00:00…

Epoch 4 summary:
  Train - loss: 256.6006, diff: 0.3674, bp: 256.2332
  Val   - diff: 0.2978, bp_teacher_freq: 177.8533, bp_gen_freq: 478.5466, bp_gen_time: 19.3416

===== Epoch 5 =====


Train epoch 5:   0%|                                                                           | 0/6250 [00:00…

Val epoch 5:   0%|                                                                              | 0/782 [00:00…

Epoch 5 summary:
  Train - loss: 242.2896, diff: 0.3606, bp: 241.9290
  Val   - diff: 0.2932, bp_teacher_freq: 169.5335, bp_gen_freq: 489.8767, bp_gen_time: 20.7537

===== Epoch 6 =====


Train epoch 6:   0%|                                                                           | 0/6250 [00:00…

Val epoch 6:   0%|                                                                              | 0/782 [00:00…

Epoch 6 summary:
  Train - loss: 232.2383, diff: 0.3561, bp: 231.8821
  Val   - diff: 0.2896, bp_teacher_freq: 164.6244, bp_gen_freq: 497.5477, bp_gen_time: 20.6190

===== Epoch 7 =====


Train epoch 7:   0%|                                                                           | 0/6250 [00:00…

Val epoch 7:   0%|                                                                              | 0/782 [00:00…

Epoch 7 summary:
  Train - loss: 224.4222, diff: 0.3522, bp: 224.0700
  Val   - diff: 0.2899, bp_teacher_freq: 159.7113, bp_gen_freq: 483.8061, bp_gen_time: 19.6502

===== Epoch 8 =====


Train epoch 8:   0%|                                                                           | 0/6250 [00:00…

Val epoch 8:   0%|                                                                              | 0/782 [00:00…

Epoch 8 summary:
  Train - loss: 217.9471, diff: 0.3478, bp: 217.5993
  Val   - diff: 0.2841, bp_teacher_freq: 156.7506, bp_gen_freq: 445.8974, bp_gen_time: 17.9446

===== Epoch 9 =====


Train epoch 9:   0%|                                                                           | 0/6250 [00:00…

Val epoch 9:   0%|                                                                              | 0/782 [00:00…

Epoch 9 summary:
  Train - loss: 212.6841, diff: 0.3447, bp: 212.3394
  Val   - diff: 0.2831, bp_teacher_freq: 153.0419, bp_gen_freq: 454.5662, bp_gen_time: 18.1611

===== Epoch 10 =====


Train epoch 10:   0%|                                                                          | 0/6250 [00:00…

Val epoch 10:   0%|                                                                             | 0/782 [00:00…

Epoch 10 summary:
  Train - loss: 207.9213, diff: 0.3425, bp: 207.5788
  Val   - diff: 0.2789, bp_teacher_freq: 149.4441, bp_gen_freq: 450.0887, bp_gen_time: 17.8258
  ↳ Saved best model (val_bp_gen_time = 17.8258)

===== Epoch 11 =====


Train epoch 11:   0%|                                                                          | 0/6250 [00:00…

Val epoch 11:   0%|                                                                             | 0/782 [00:00…

Epoch 11 summary:
  Train - loss: 203.6651, diff: 0.3397, bp: 203.3254
  Val   - diff: 0.2797, bp_teacher_freq: 147.7624, bp_gen_freq: 480.8780, bp_gen_time: 19.5027

===== Epoch 12 =====


Train epoch 12:   0%|                                                                          | 0/6250 [00:00…

Val epoch 12:   0%|                                                                             | 0/782 [00:00…

Epoch 12 summary:
  Train - loss: 199.9334, diff: 0.3379, bp: 199.5954
  Val   - diff: 0.2771, bp_teacher_freq: 146.8239, bp_gen_freq: 438.8532, bp_gen_time: 17.0847
  ↳ Saved best model (val_bp_gen_time = 17.0847)

===== Epoch 13 =====


Train epoch 13:   0%|                                                                          | 0/6250 [00:00…

Val epoch 13:   0%|                                                                             | 0/782 [00:00…

Epoch 13 summary:
  Train - loss: 196.8941, diff: 0.3354, bp: 196.5587
  Val   - diff: 0.2750, bp_teacher_freq: 144.3054, bp_gen_freq: 443.6897, bp_gen_time: 17.7677

===== Epoch 14 =====


Train epoch 14:   0%|                                                                          | 0/6250 [00:00…

Val epoch 14:   0%|                                                                             | 0/782 [00:00…

Epoch 14 summary:
  Train - loss: 193.6283, diff: 0.3332, bp: 193.2951
  Val   - diff: 0.2766, bp_teacher_freq: 142.5295, bp_gen_freq: 468.1628, bp_gen_time: 18.7813

===== Epoch 15 =====


Train epoch 15:   0%|                                                                          | 0/6250 [00:00…

Val epoch 15:   0%|                                                                             | 0/782 [00:00…

Epoch 15 summary:
  Train - loss: 191.2145, diff: 0.3318, bp: 190.8827
  Val   - diff: 0.2738, bp_teacher_freq: 141.1816, bp_gen_freq: 435.1964, bp_gen_time: 16.6942
  ↳ Saved best model (val_bp_gen_time = 16.6942)

===== Epoch 16 =====


Train epoch 16:   0%|                                                                          | 0/6250 [00:00…

Val epoch 16:   0%|                                                                             | 0/782 [00:00…

Epoch 16 summary:
  Train - loss: 188.7021, diff: 0.3301, bp: 188.3720
  Val   - diff: 0.2722, bp_teacher_freq: 141.1099, bp_gen_freq: 439.0503, bp_gen_time: 16.8691

===== Epoch 17 =====


Train epoch 17:   0%|                                                                          | 0/6250 [00:00…

Val epoch 17:   0%|                                                                             | 0/782 [00:00…

Epoch 17 summary:
  Train - loss: 186.9143, diff: 0.3286, bp: 186.5857
  Val   - diff: 0.2706, bp_teacher_freq: 139.8598, bp_gen_freq: 427.4621, bp_gen_time: 15.6375
  ↳ Saved best model (val_bp_gen_time = 15.6375)

===== Epoch 18 =====


Train epoch 18:   0%|                                                                          | 0/6250 [00:00…

Val epoch 18:   0%|                                                                             | 0/782 [00:00…

Epoch 18 summary:
  Train - loss: 184.2186, diff: 0.3276, bp: 183.8911
  Val   - diff: 0.2694, bp_teacher_freq: 139.5728, bp_gen_freq: 467.3811, bp_gen_time: 18.9157

===== Epoch 19 =====


Train epoch 19:   0%|                                                                          | 0/6250 [00:00…

Val epoch 19:   0%|                                                                             | 0/782 [00:00…

Epoch 19 summary:
  Train - loss: 182.1991, diff: 0.3258, bp: 181.8733
  Val   - diff: 0.2703, bp_teacher_freq: 138.3432, bp_gen_freq: 430.5440, bp_gen_time: 17.0828

===== Epoch 20 =====


Train epoch 20:   0%|                                                                          | 0/6250 [00:00…

Val epoch 20:   0%|                                                                             | 0/782 [00:00…

Epoch 20 summary:
  Train - loss: 180.2115, diff: 0.3247, bp: 179.8868
  Val   - diff: 0.2693, bp_teacher_freq: 138.3903, bp_gen_freq: 433.8893, bp_gen_time: 17.1758

===== Epoch 21 =====


Train epoch 21:   0%|                                                                          | 0/6250 [00:00…

Val epoch 21:   0%|                                                                             | 0/782 [00:00…

Epoch 21 summary:
  Train - loss: 178.7730, diff: 0.3231, bp: 178.4499
  Val   - diff: 0.2682, bp_teacher_freq: 137.0628, bp_gen_freq: 466.7571, bp_gen_time: 18.3382

===== Epoch 22 =====


Train epoch 22:   0%|                                                                          | 0/6250 [00:00…

Val epoch 22:   0%|                                                                             | 0/782 [00:00…

Epoch 22 summary:
  Train - loss: 176.4404, diff: 0.3226, bp: 176.1178
  Val   - diff: 0.2683, bp_teacher_freq: 136.2470, bp_gen_freq: 426.1801, bp_gen_time: 16.9885

===== Epoch 23 =====


Train epoch 23:   0%|                                                                          | 0/6250 [00:00…

Val epoch 23:   0%|                                                                             | 0/782 [00:00…

Epoch 23 summary:
  Train - loss: 174.7701, diff: 0.3218, bp: 174.4483
  Val   - diff: 0.2664, bp_teacher_freq: 137.6161, bp_gen_freq: 447.6800, bp_gen_time: 18.0665

===== Epoch 24 =====


Train epoch 24:   0%|                                                                          | 0/6250 [00:00…

Val epoch 24:   0%|                                                                             | 0/782 [00:00…

Epoch 24 summary:
  Train - loss: 173.1532, diff: 0.3213, bp: 172.8319
  Val   - diff: 0.2652, bp_teacher_freq: 138.8592, bp_gen_freq: 422.6772, bp_gen_time: 16.2833

===== Epoch 25 =====


Train epoch 25:   0%|                                                                          | 0/6250 [00:00…

Val epoch 25:   0%|                                                                             | 0/782 [00:00…

Epoch 25 summary:
  Train - loss: 171.5774, diff: 0.3199, bp: 171.2575
  Val   - diff: 0.2648, bp_teacher_freq: 136.4398, bp_gen_freq: 416.4814, bp_gen_time: 16.3833

===== Epoch 26 =====


Train epoch 26:   0%|                                                                          | 0/6250 [00:00…

Val epoch 26:   0%|                                                                             | 0/782 [00:00…

Epoch 26 summary:
  Train - loss: 169.8297, diff: 0.3189, bp: 169.5109
  Val   - diff: 0.2654, bp_teacher_freq: 135.5659, bp_gen_freq: 441.7944, bp_gen_time: 17.6878

===== Epoch 27 =====


Train epoch 27:   0%|                                                                          | 0/6250 [00:00…

Val epoch 27:   0%|                                                                             | 0/782 [00:00…

Epoch 27 summary:
  Train - loss: 168.6554, diff: 0.3189, bp: 168.3365
  Val   - diff: 0.2639, bp_teacher_freq: 135.2358, bp_gen_freq: 434.5617, bp_gen_time: 17.1582

===== Epoch 28 =====


Train epoch 28:   0%|                                                                          | 0/6250 [00:00…

Val epoch 28:   0%|                                                                             | 0/782 [00:00…

Epoch 28 summary:
  Train - loss: 167.4015, diff: 0.3184, bp: 167.0832
  Val   - diff: 0.2647, bp_teacher_freq: 133.5261, bp_gen_freq: 451.7281, bp_gen_time: 17.2389

===== Epoch 29 =====


Train epoch 29:   0%|                                                                          | 0/6250 [00:00…

Val epoch 29:   0%|                                                                             | 0/782 [00:00…

Epoch 29 summary:
  Train - loss: 166.2620, diff: 0.3170, bp: 165.9450
  Val   - diff: 0.2626, bp_teacher_freq: 134.2345, bp_gen_freq: 407.1714, bp_gen_time: 16.3986

===== Epoch 30 =====


Train epoch 30:   0%|                                                                          | 0/6250 [00:00…

Val epoch 30:   0%|                                                                             | 0/782 [00:00…

Epoch 30 summary:
  Train - loss: 164.4302, diff: 0.3165, bp: 164.1137
  Val   - diff: 0.2613, bp_teacher_freq: 135.1024, bp_gen_freq: 427.8509, bp_gen_time: 17.2851

===== Epoch 31 =====


Train epoch 31:   0%|                                                                          | 0/6250 [00:00…

Val epoch 31:   0%|                                                                             | 0/782 [00:00…

Epoch 31 summary:
  Train - loss: 163.6104, diff: 0.3156, bp: 163.2948
  Val   - diff: 0.2621, bp_teacher_freq: 131.8356, bp_gen_freq: 419.1346, bp_gen_time: 17.1551

===== Epoch 32 =====


Train epoch 32:   0%|                                                                          | 0/6250 [00:00…

Val epoch 32:   0%|                                                                             | 0/782 [00:00…

Epoch 32 summary:
  Train - loss: 162.0126, diff: 0.3153, bp: 161.6973
  Val   - diff: 0.2631, bp_teacher_freq: 131.9107, bp_gen_freq: 447.8568, bp_gen_time: 17.7344

===== Epoch 33 =====


Train epoch 33:   0%|                                                                          | 0/6250 [00:00…

Val epoch 33:   0%|                                                                             | 0/782 [00:00…

Epoch 33 summary:
  Train - loss: 161.5811, diff: 0.3152, bp: 161.2659
  Val   - diff: 0.2608, bp_teacher_freq: 131.7235, bp_gen_freq: 430.3392, bp_gen_time: 16.8923

===== Epoch 34 =====


Train epoch 34:   0%|                                                                          | 0/6250 [00:00…

Val epoch 34:   0%|                                                                             | 0/782 [00:00…

Epoch 34 summary:
  Train - loss: 159.5491, diff: 0.3145, bp: 159.2346
  Val   - diff: 0.2615, bp_teacher_freq: 131.4423, bp_gen_freq: 433.4615, bp_gen_time: 17.1825

===== Epoch 35 =====


Train epoch 35:   0%|                                                                          | 0/6250 [00:00…

Val epoch 35:   0%|                                                                             | 0/782 [00:00…

Epoch 35 summary:
  Train - loss: 158.6711, diff: 0.3138, bp: 158.3573
  Val   - diff: 0.2599, bp_teacher_freq: 133.3288, bp_gen_freq: 448.5592, bp_gen_time: 17.4060

===== Epoch 36 =====


Train epoch 36:   0%|                                                                          | 0/6250 [00:00…

Val epoch 36:   0%|                                                                             | 0/782 [00:00…

Epoch 36 summary:
  Train - loss: 157.7260, diff: 0.3133, bp: 157.4127
  Val   - diff: 0.2611, bp_teacher_freq: 130.6096, bp_gen_freq: 421.2996, bp_gen_time: 16.8481

===== Epoch 37 =====


Train epoch 37:   0%|                                                                          | 0/6250 [00:00…

Val epoch 37:   0%|                                                                             | 0/782 [00:00…

Epoch 37 summary:
  Train - loss: 156.0537, diff: 0.3129, bp: 155.7408
  Val   - diff: 0.2593, bp_teacher_freq: 130.3454, bp_gen_freq: 428.7973, bp_gen_time: 17.5981

===== Epoch 38 =====


Train epoch 38:   0%|                                                                          | 0/6250 [00:00…

Val epoch 38:   0%|                                                                             | 0/782 [00:00…

Epoch 38 summary:
  Train - loss: 154.9267, diff: 0.3126, bp: 154.6141
  Val   - diff: 0.2607, bp_teacher_freq: 129.3422, bp_gen_freq: 434.3642, bp_gen_time: 17.5394

===== Epoch 39 =====


Train epoch 39:   0%|                                                                          | 0/6250 [00:00…

Val epoch 39:   0%|                                                                             | 0/782 [00:00…

Epoch 39 summary:
  Train - loss: 154.3995, diff: 0.3123, bp: 154.0871
  Val   - diff: 0.2609, bp_teacher_freq: 129.6406, bp_gen_freq: 450.0760, bp_gen_time: 17.9559

===== Epoch 40 =====


Train epoch 40:   0%|                                                                          | 0/6250 [00:00…

Val epoch 40:   0%|                                                                             | 0/782 [00:00…

Epoch 40 summary:
  Train - loss: 153.4382, diff: 0.3117, bp: 153.1265
  Val   - diff: 0.2600, bp_teacher_freq: 129.8998, bp_gen_freq: 429.3064, bp_gen_time: 16.9671

===== Epoch 41 =====


Train epoch 41:   0%|                                                                          | 0/6250 [00:00…

Val epoch 41:   0%|                                                                             | 0/782 [00:00…

Epoch 41 summary:
  Train - loss: 153.0337, diff: 0.3111, bp: 152.7225
  Val   - diff: 0.2596, bp_teacher_freq: 128.9692, bp_gen_freq: 457.0310, bp_gen_time: 17.6391

===== Epoch 42 =====


Train epoch 42:   0%|                                                                          | 0/6250 [00:00…

Val epoch 42:   0%|                                                                             | 0/782 [00:00…

Epoch 42 summary:
  Train - loss: 151.3587, diff: 0.3108, bp: 151.0479
  Val   - diff: 0.2608, bp_teacher_freq: 128.2627, bp_gen_freq: 436.4220, bp_gen_time: 16.6273

===== Epoch 43 =====


Train epoch 43:   0%|                                                                          | 0/6250 [00:00…

Val epoch 43:   0%|                                                                             | 0/782 [00:00…

Epoch 43 summary:
  Train - loss: 150.7196, diff: 0.3107, bp: 150.4089
  Val   - diff: 0.2590, bp_teacher_freq: 128.0516, bp_gen_freq: 442.6404, bp_gen_time: 18.4406

===== Epoch 44 =====


Train epoch 44:   0%|                                                                          | 0/6250 [00:00…

Val epoch 44:   0%|                                                                             | 0/782 [00:00…

Epoch 44 summary:
  Train - loss: 149.6150, diff: 0.3104, bp: 149.3045
  Val   - diff: 0.2584, bp_teacher_freq: 127.2828, bp_gen_freq: 437.2430, bp_gen_time: 17.4797

===== Epoch 45 =====


Train epoch 45:   0%|                                                                          | 0/6250 [00:00…

Val epoch 45:   0%|                                                                             | 0/782 [00:00…

Epoch 45 summary:
  Train - loss: 148.9716, diff: 0.3096, bp: 148.6620
  Val   - diff: 0.2586, bp_teacher_freq: 127.7467, bp_gen_freq: 435.8069, bp_gen_time: 17.9200

===== Epoch 46 =====


Train epoch 46:   0%|                                                                          | 0/6250 [00:00…

# model 3

In [8]:
import torch
import torch.nn.functional as F



import numpy as np
import torch

def abp_spec_to_sbp_dbp(
    abp_spec: torch.Tensor,
    freqs: torch.Tensor,
    mask: torch.Tensor,
    dt: float = 0.008,
    q_high: float = 0.95,
    q_low: float = 0.05,
) -> torch.Tensor:
    """
    严格版：用 freqs + mask 还原时域 ABP，再用“分位数”算 SBP/DBP

    参数：
      abp_spec : [B, 2, L]  频域 ABP (real, imag)，已经截到 0~20 Hz
      freqs    : [B, L]     对应的频率轴，来源是 np.fft.rfftfreq(...)[mask]
      mask     : [B, L]     有效频点的 mask（True/1 = 有效）
      dt       : 采样间隔，MIMIC 这边是 0.008s (125 Hz)
      q_high   : 近似 SBP 的分位数，比如 0.95
      q_low    : 近似 DBP 的分位数，比如 0.05

    返回：
      sbp_dbp  : [B, 2]，第 0 列 SBP，第 1 列 DBP
    """
    device = abp_spec.device

    # 统一搬到 CPU+numpy 做频域复原，反正只是算 label，不需要梯度
    abp_spec_np = abp_spec.detach().cpu().numpy()   # [B, 2, L]
    freqs_np    = freqs.detach().cpu().numpy()      # [B, L]
    mask_np     = mask.detach().cpu().numpy().astype(bool)  # [B, L]

    B, _, L = abp_spec_np.shape

    sbp_list = []
    dbp_list = []

    for b in range(B):
        f_valid = freqs_np[b][mask_np[b]]   # [L_valid]
        if f_valid.shape[0] < 2:
            # 太少的有效频点就先用 NaN 占位，你也可以改成 0 或者直接 continue
            sbp_list.append(np.nan)
            dbp_list.append(np.nan)
            continue

        real = abp_spec_np[b, 0, mask_np[b]]   # [L_valid]
        imag = abp_spec_np[b, 1, mask_np[b]]   # [L_valid]

        # 频率分辨率 Δf
        df = f_valid[1] - f_valid[0]

        # 利用 Δt * Δf = 1 / N 反推出原始 N
        N = int(round(1.0 / (dt * df)))   # 这里的 N 就是当初 rFFT 的时域长度
        n_fft = N // 2 + 1

        # 重建完整 one-sided 频谱（0 ~ Nyquist）
        full = np.zeros(n_fft, dtype=np.complex64)

        # 当前保留的频率 f_valid 对应的 bin 索引 k
        k_idx = np.round(f_valid / df).astype(int)
        k_idx = np.clip(k_idx, 0, n_fft - 1)

        full[k_idx] = real + 1j * imag

        # 严格版 irfft：用原始 N 还原低通后的 ABP 波形
        abp_time = np.fft.irfft(full, n=N)   # [N]

        # 用分位数估 SBP / DBP，而不是用 max/min（对噪声不敏感）
        sbp = np.quantile(abp_time, q_high)
        dbp = np.quantile(abp_time, q_low)

        sbp_list.append(sbp)
        dbp_list.append(dbp)

    sbp_arr = np.asarray(sbp_list, dtype=np.float32)   # [B]
    dbp_arr = np.asarray(dbp_list, dtype=np.float32)   # [B]

    sbp_dbp = np.stack([sbp_arr, dbp_arr], axis=-1)    # [B, 2]
    sbp_dbp = torch.from_numpy(sbp_dbp).to(device)     # 转回 GPU

    return sbp_dbp   # [B, 2]




class ResBlock1D(nn.Module):
    """
    简单的 1D ResNet Block:
      输入 / 输出: [B, C, L]，C 不变
      结构: Conv1d -> BN -> ReLU -> Conv1d -> BN -> Dropout -> 残差加和 -> ReLU
    """
    def __init__(self, channels: int, kernel_size: int = 5, dropout: float = 0.1):
        super().__init__()
        padding = kernel_size // 2

        self.conv1 = nn.Conv1d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=kernel_size,
            padding=padding,
        )
        self.bn1 = nn.BatchNorm1d(channels)

        self.conv2 = nn.Conv1d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=kernel_size,
            padding=padding,
        )
        self.bn2 = nn.BatchNorm1d(channels)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        x: [B, C, L]
        """
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.dropout(out)

        out = out + identity     # 残差连接
        out = F.relu(out)
        return out               # [B, C, L]

class Freq2SBPDBPModel(nn.Module):
    """
    用频域特征 (PPG/ECG/PPG+ECG) 直接回归 SBP / DBP 两个标量

    输入: x [B, L, C_in]   (C_in=2: PPG 或 ECG;  C_in=4: PPG+ECG)
    输出: y [B, 2]         (SBP, DBP)
    """
    def __init__(self,
                 in_channels: int = 4,
                 hidden_dim: int = 256,
                 num_blocks: int = 8,
                 kernel_size: int = 5,
                 dropout: float = 0.1,
                 print_num_params: bool = True):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.in_channels = in_channels

        # 1) 频点特征升维: [B, L, C_in] -> [B, L, hidden_dim]
        self.input_proj = nn.Linear(in_channels, hidden_dim)

        # 2) ResNet 风格的 Conv1d blocks (在 [B, hidden_dim, L] 上操作)
        blocks = []
        for i in range(num_blocks):
            blocks.append(ResBlock1D(hidden_dim, kernel_size=kernel_size, dropout=dropout))
        self.conv_net = nn.Sequential(*blocks)

        # 3) Head: 频率轴上的全局 pooling 之后，用一个小 MLP 回归 SBP/DBP
        #    [B, hidden_dim] -> [B, hidden_dim] -> [B, 2]
        self.head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 2),
        )

        # 4) 可选：在初始化时打印模型参数量
        if print_num_params:
            num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
            print(f"[Freq2SBPDBPModel] Trainable params: {num_params} "
                  f"({num_params/1e6:.3f} M)")

    def forward(self, x, mask=None):
        """
        x:    [B, L, C_in]
        mask: [B, L]，可选；1 表示有效频点，0 表示无效
        """
        B, L, C = x.shape
        assert C == self.in_channels, f"expect {self.in_channels} channels, got {C}"

        # [B, L, C_in] -> [B, L, hidden_dim]
        x = self.input_proj(x)

        # Conv1d 期望 [B, C, L]
        x = x.transpose(1, 2)          # [B, hidden_dim, L]
        x = self.conv_net(x)           # [B, hidden_dim, L]

        # 频率轴上做 masked mean pooling
        x = x.transpose(1, 2)          # [B, L, hidden_dim]

        if mask is not None:
            m = mask.unsqueeze(-1).float()          # [B, L, 1]
            x = x * m                               # 无效频点置 0
            denom = m.sum(dim=1).clamp(min=1.0)     # [B,1] 避免除 0
            x = x.sum(dim=1) / denom                # [B, hidden_dim]
        else:
            x = x.mean(dim=1)                       # [B, hidden_dim]

        # 映射到 SBP / DBP
        out = self.head(x)                          # [B, 2]
        return out



def sbp_dbp_loss(pred_bp: torch.Tensor, true_bp: torch.Tensor):
    """
    pred_bp: [B, 2] -> (SBP_pred, DBP_pred)
    true_bp: [B, 2] -> (SBP_true, DBP_true)
    """
    sbp_pred = pred_bp[:, 0]
    dbp_pred = pred_bp[:, 1]
    sbp_true = true_bp[:, 0]
    dbp_true = true_bp[:, 1]

    sbp_mae = F.l1_loss(sbp_pred, sbp_true)
    dbp_mae = F.l1_loss(dbp_pred, dbp_true)
    loss = sbp_mae + dbp_mae   # 也可以改成加权和

    return loss, sbp_mae, dbp_mae


from tqdm import tqdm

def train_one_epoch_bpindex(model, loader, optimizer, device, epoch,
                            input_mode: str = "ppg"):
    """
    直接训练 SBP/DBP baseline，不改 Dataset 结构：
    loader 仍然产出: ppg_b, ecg_b, abp_b, freqs_b, mask_b, paths_b
    """
    model.train()
    total_loss = 0.0
    total_sbp = 0.0
    total_dbp = 0.0
    steps = 0

    pbar = tqdm(loader, desc=f"[BPIndex-{input_mode}] Train epoch {epoch}", ncols=120)
    for ppg_b, ecg_b, abp_b, freqs_b, mask_b, paths_b in pbar:
        ppg_b = ppg_b.to(device)      # [B, 2, L]
        ecg_b = ecg_b.to(device)      # [B, 2, L]
        abp_b = abp_b.to(device)      # [B, 2, L]
        mask_b = mask_b.to(device)    # [B, L]

        # 1) 从 ABP 频谱算出当前 batch 的 SBP/DBP 真值
        with torch.no_grad():
            sbp_dbp_true = abp_spec_to_sbp_dbp(
                abp_b,         # [B, 2, L]
                freqs_b,       # [B, L]
                mask_b,        # [B, L]
                dt=0.008,      # 和你预处理时的一致
                q_high=0.95,
                q_low=0.05,
            )

        # 2) 选择输入频谱
        if input_mode == "ppg":
            x_spec = ppg_b                # [B, 2, L]
        elif input_mode == "ecg":
            x_spec = ecg_b                # [B, 2, L]
        elif input_mode == "ppg_ecg":
            x_spec = torch.cat([ppg_b, ecg_b], dim=1)  # [B, 4, L]
        else:
            raise ValueError(f"Unknown input_mode: {input_mode}")

        # [B, C, L] -> [B, L, C]
        x = x_spec.transpose(1, 2)

        optimizer.zero_grad()
        pred_bp = model(x, mask_b)        # [B, 2] -> (SBP_pred, DBP_pred)

        loss, sbp_mae, dbp_mae = sbp_dbp_loss(pred_bp, sbp_dbp_true)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_sbp += sbp_mae.item()
        total_dbp += dbp_mae.item()
        steps += 1

        pbar.set_postfix({
            "loss": f"{total_loss/steps:.3f}",
            "sbp":  f"{total_sbp/steps:.3f}",
            "dbp":  f"{total_dbp/steps:.3f}",
        })

    return (total_loss / max(1, steps),
            total_sbp / max(1, steps),
            total_dbp / max(1, steps))


@torch.no_grad()
def eval_one_epoch_bpindex(model, loader, device, epoch,
                           input_mode: str = "ppg"):
    model.eval()
    total_loss = 0.0
    total_sbp = 0.0
    total_dbp = 0.0
    steps = 0

    pbar = tqdm(loader, desc=f"[BPIndex-{input_mode}] Val epoch {epoch}", ncols=120)
    for ppg_b, ecg_b, abp_b, freqs_b, mask_b, paths_b in pbar:
        ppg_b = ppg_b.to(device)
        ecg_b = ecg_b.to(device)
        abp_b = abp_b.to(device)
        mask_b = mask_b.to(device)

        sbp_dbp_true = sbp_dbp_true = abp_spec_to_sbp_dbp(
                abp_b,         # [B, 2, L]
                freqs_b,       # [B, L]
                mask_b,        # [B, L]
                dt=0.008,      # 和你预处理时的一致
                q_high=0.95,
                q_low=0.05,
            )

        if input_mode == "ppg":
            x_spec = ppg_b
        elif input_mode == "ecg":
            x_spec = ecg_b
        elif input_mode == "ppg_ecg":
            x_spec = torch.cat([ppg_b, ecg_b], dim=1)
        else:
            raise ValueError(f"Unknown input_mode: {input_mode}")

        x = x_spec.transpose(1, 2)          # [B, L, C]
        pred_bp = model(x, mask_b)          # [B, 2]

        loss, sbp_mae, dbp_mae = sbp_dbp_loss(pred_bp, sbp_dbp_true)

        total_loss += loss.item()
        total_sbp += sbp_mae.item()
        total_dbp += dbp_mae.item()
        steps += 1

        pbar.set_postfix({
            "loss": f"{total_loss/steps:.3f}",
            "sbp":  f"{total_sbp/steps:.3f}",
            "dbp":  f"{total_dbp/steps:.3f}",
        })

    return (total_loss / max(1, steps),
            total_sbp / max(1, steps),
            total_dbp / max(1, steps))



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS_BPINDEX = 50

from torch.optim.lr_scheduler import ReduceLROnPlateau

# ---------- Baseline 1: PPG_freq -> SBP/DBP ----------
bpidx_ppg = Freq2SBPDBPModel(
    in_channels=2,   # PPG_real, PPG_imag
    hidden_dim=512,
    num_blocks=8,
    kernel_size=3,
    dropout=0.1,
).to(device)

optimizer_ppg = torch.optim.AdamW(
    bpidx_ppg.parameters(),
    lr=5e-4,
    weight_decay=1e-2
)

# Reduce LR on Plateau for PPG
scheduler_ppg = ReduceLROnPlateau(
    optimizer_ppg,
    mode="min",      # 监控的 metric 越小越好
    factor=0.5,      # 每次降低为原来的 0.5
    patience=4,      # 连续 5 个 epoch 没提升就降 LR
    min_lr=1e-6
)

BEST_PPG = float("inf")

for epoch in range(1, EPOCHS_BPINDEX + 1):
    current_lr = optimizer_ppg.param_groups[0]["lr"]
    print(f"\n===== BPIndex PPG - Epoch {epoch}, lr={current_lr:.6g} =====")

    train_loss, train_sbp, train_dbp = train_one_epoch_bpindex(
        bpidx_ppg, train_loader, optimizer_ppg, device, epoch,
        input_mode="ppg"
    )
    val_loss, val_sbp, val_dbp = eval_one_epoch_bpindex(
        bpidx_ppg, val_loader, device, epoch,
        input_mode="ppg"
    )
    print(f"PPG BPIndex Epoch {epoch}: "
          f"train_loss={train_loss:.4f}, train_sbp_mae={train_sbp:.4f}, train_dbp_mae={train_dbp:.4f}, "
          f"val_loss={val_loss:.4f}, val_sbp_mae={val_sbp:.4f}, val_dbp_mae={val_dbp:.4f}")

    # 先根据 val_loss 更新 scheduler
    scheduler_ppg.step(val_loss)

    # 再根据 val_loss 判断是否保存 best model
    if val_loss < BEST_PPG:
        BEST_PPG = val_loss
        torch.save(bpidx_ppg.state_dict(), "bpindex_ppg_cnn.pt")
        print(f"  ↳ Saved best PPG BPIndex model (val_loss = {BEST_PPG:.4f})")


# ---------- Baseline 2: ECG_freq -> SBP/DBP ----------
bpidx_ecg = Freq2SBPDBPModel(
    in_channels=2,   # ECG_real, ECG_imag
    hidden_dim=512,
    num_blocks=8,
    kernel_size=3,
    dropout=0.1,
).to(device)

optimizer_ecg = torch.optim.AdamW(
    bpidx_ecg.parameters(),
    lr=5e-4,
    weight_decay=1e-2
)

scheduler_ecg = ReduceLROnPlateau(
    optimizer_ecg,
    mode="min",
    factor=0.5,
    patience=4,
    min_lr=1e-6
)

BEST_ECG = float("inf")

for epoch in range(1, EPOCHS_BPINDEX + 1):
    current_lr = optimizer_ecg.param_groups[0]["lr"]
    print(f"\n===== BPIndex ECG - Epoch {epoch}, lr={current_lr:.6g} =====")

    train_loss, train_sbp, train_dbp = train_one_epoch_bpindex(
        bpidx_ecg, train_loader, optimizer_ecg, device, epoch,
        input_mode="ecg"
    )
    val_loss, val_sbp, val_dbp = eval_one_epoch_bpindex(
        bpidx_ecg, val_loader, device, epoch,
        input_mode="ecg"
    )
    print(f"ECG BPIndex Epoch {epoch}: "
          f"train_loss={train_loss:.4f}, train_sbp_mae={train_sbp:.4f}, train_dbp_mae={train_dbp:.4f}, "
          f"val_loss={val_loss:.4f}, val_sbp_mae={val_sbp:.4f}, val_dbp_mae={val_dbp:.4f}")

    scheduler_ecg.step(val_loss)

    if val_loss < BEST_ECG:
        BEST_ECG = val_loss
        torch.save(bpidx_ecg.state_dict(), "bpindex_ecg_cnn.pt")
        print(f"  ↳ Saved best ECG BPIndex model (val_loss = {BEST_ECG:.4f})")


# ---------- Baseline 3: PPG+ECG_freq -> SBP/DBP ----------
bpidx_ppg_ecg = Freq2SBPDBPModel(
    in_channels=4,   # PPG_real, PPG_imag, ECG_real, ECG_imag
    hidden_dim=512,
    num_blocks=8,
    kernel_size=3,
    dropout=0.1,
).to(device)

optimizer_ppg_ecg = torch.optim.AdamW(
    bpidx_ppg_ecg.parameters(),
    lr=5e-4,
    weight_decay=1e-2
)

scheduler_ppg_ecg = ReduceLROnPlateau(
    optimizer_ppg_ecg,
    mode="min",
    factor=0.5,
    patience=4,
    min_lr=1e-6
)

BEST_PPG_ECG = float("inf")

for epoch in range(1, EPOCHS_BPINDEX + 1):
    current_lr = optimizer_ppg_ecg.param_groups[0]["lr"]
    print(f"\n===== BPIndex PPG+ECG - Epoch {epoch}, lr={current_lr:.6g} =====")

    train_loss, train_sbp, train_dbp = train_one_epoch_bpindex(
        bpidx_ppg_ecg, train_loader, optimizer_ppg_ecg, device, epoch,
        input_mode="ppg_ecg"
    )
    val_loss, val_sbp, val_dbp = eval_one_epoch_bpindex(
        bpidx_ppg_ecg, val_loader, device, epoch,
        input_mode="ppg_ecg"
    )
    print(f"PPG+ECG BPIndex Epoch {epoch}: "
          f"train_loss={train_loss:.4f}, train_sbp_mae={train_sbp:.4f}, train_dbp_mae={train_dbp:.4f}, "
          f"val_loss={val_loss:.4f}, val_sbp_mae={val_sbp:.4f}, val_dbp_mae={val_dbp:.4f}")

    scheduler_ppg_ecg.step(val_loss)

    if val_loss < BEST_PPG_ECG:
        BEST_PPG_ECG = val_loss
        torch.save(bpidx_ppg_ecg.state_dict(), "bpindex_ppg_ecg_cnn.pt")
        print(f"  ↳ Saved best PPG+ECG BPIndex model (val_loss = {BEST_PPG_ECG:.4f})")



[Freq2SBPDBPModel] Trainable params: 12872706 (12.873 M)

===== BPIndex PPG - Epoch 1, lr=0.0005 =====


[BPIndex-ppg] Train epoch 1: 100%|█████████████| 1563/1563 [03:21<00:00,  7.75it/s, loss=35.315, sbp=19.904, dbp=15.411]
[BPIndex-ppg] Val epoch 1: 100%|█████████████████| 391/391 [00:36<00:00, 10.82it/s, loss=41.519, sbp=20.533, dbp=20.986]


PPG BPIndex Epoch 1: train_loss=35.3147, train_sbp_mae=19.9041, train_dbp_mae=15.4106, val_loss=41.5186, val_sbp_mae=20.5328, val_dbp_mae=20.9858
  ↳ Saved best PPG BPIndex model (val_loss = 41.5186)

===== BPIndex PPG - Epoch 2, lr=0.0005 =====


[BPIndex-ppg] Train epoch 2: 100%|█████████████| 1563/1563 [03:19<00:00,  7.85it/s, loss=31.452, sbp=17.502, dbp=13.950]
[BPIndex-ppg] Val epoch 2: 100%|█████████████████| 391/391 [00:36<00:00, 10.86it/s, loss=42.701, sbp=19.287, dbp=23.415]


PPG BPIndex Epoch 2: train_loss=31.4519, train_sbp_mae=17.5016, train_dbp_mae=13.9503, val_loss=42.7015, val_sbp_mae=19.2866, val_dbp_mae=23.4149

===== BPIndex PPG - Epoch 3, lr=0.0005 =====


[BPIndex-ppg] Train epoch 3: 100%|█████████████| 1563/1563 [03:19<00:00,  7.84it/s, loss=29.991, sbp=16.572, dbp=13.419]
[BPIndex-ppg] Val epoch 3: 100%|█████████████████| 391/391 [00:36<00:00, 10.86it/s, loss=38.002, sbp=18.484, dbp=19.518]


PPG BPIndex Epoch 3: train_loss=29.9910, train_sbp_mae=16.5722, train_dbp_mae=13.4188, val_loss=38.0025, val_sbp_mae=18.4845, val_dbp_mae=19.5180
  ↳ Saved best PPG BPIndex model (val_loss = 38.0025)

===== BPIndex PPG - Epoch 4, lr=0.0005 =====


[BPIndex-ppg] Train epoch 4: 100%|█████████████| 1563/1563 [03:19<00:00,  7.82it/s, loss=29.095, sbp=16.020, dbp=13.075]
[BPIndex-ppg] Val epoch 4: 100%|█████████████████| 391/391 [00:36<00:00, 10.85it/s, loss=38.233, sbp=18.432, dbp=19.801]


PPG BPIndex Epoch 4: train_loss=29.0954, train_sbp_mae=16.0201, train_dbp_mae=13.0752, val_loss=38.2331, val_sbp_mae=18.4320, val_dbp_mae=19.8011

===== BPIndex PPG - Epoch 5, lr=0.0005 =====


[BPIndex-ppg] Train epoch 5: 100%|█████████████| 1563/1563 [03:19<00:00,  7.82it/s, loss=28.501, sbp=15.659, dbp=12.842]
[BPIndex-ppg] Val epoch 5: 100%|█████████████████| 391/391 [00:36<00:00, 10.84it/s, loss=39.591, sbp=19.490, dbp=20.101]


PPG BPIndex Epoch 5: train_loss=28.5010, train_sbp_mae=15.6588, train_dbp_mae=12.8421, val_loss=39.5914, val_sbp_mae=19.4900, val_dbp_mae=20.1014

===== BPIndex PPG - Epoch 6, lr=0.0005 =====


[BPIndex-ppg] Train epoch 6: 100%|█████████████| 1563/1563 [03:19<00:00,  7.82it/s, loss=28.053, sbp=15.395, dbp=12.658]
[BPIndex-ppg] Val epoch 6: 100%|█████████████████| 391/391 [00:36<00:00, 10.85it/s, loss=37.512, sbp=17.834, dbp=19.678]


PPG BPIndex Epoch 6: train_loss=28.0535, train_sbp_mae=15.3953, train_dbp_mae=12.6582, val_loss=37.5120, val_sbp_mae=17.8344, val_dbp_mae=19.6776
  ↳ Saved best PPG BPIndex model (val_loss = 37.5120)

===== BPIndex PPG - Epoch 7, lr=0.0005 =====


[BPIndex-ppg] Train epoch 7: 100%|█████████████| 1563/1563 [03:19<00:00,  7.83it/s, loss=27.650, sbp=15.155, dbp=12.495]
[BPIndex-ppg] Val epoch 7: 100%|█████████████████| 391/391 [00:36<00:00, 10.85it/s, loss=39.096, sbp=18.853, dbp=20.244]


PPG BPIndex Epoch 7: train_loss=27.6500, train_sbp_mae=15.1555, train_dbp_mae=12.4945, val_loss=39.0963, val_sbp_mae=18.8526, val_dbp_mae=20.2437

===== BPIndex PPG - Epoch 8, lr=0.0005 =====


[BPIndex-ppg] Train epoch 8: 100%|█████████████| 1563/1563 [03:19<00:00,  7.82it/s, loss=27.302, sbp=14.955, dbp=12.347]
[BPIndex-ppg] Val epoch 8: 100%|█████████████████| 391/391 [00:36<00:00, 10.86it/s, loss=36.599, sbp=17.688, dbp=18.911]


PPG BPIndex Epoch 8: train_loss=27.3020, train_sbp_mae=14.9552, train_dbp_mae=12.3468, val_loss=36.5989, val_sbp_mae=17.6882, val_dbp_mae=18.9108
  ↳ Saved best PPG BPIndex model (val_loss = 36.5989)

===== BPIndex PPG - Epoch 9, lr=0.0005 =====


[BPIndex-ppg] Train epoch 9: 100%|█████████████| 1563/1563 [03:19<00:00,  7.83it/s, loss=27.016, sbp=14.788, dbp=12.227]
[BPIndex-ppg] Val epoch 9: 100%|█████████████████| 391/391 [00:36<00:00, 10.84it/s, loss=37.245, sbp=17.932, dbp=19.312]


PPG BPIndex Epoch 9: train_loss=27.0158, train_sbp_mae=14.7884, train_dbp_mae=12.2274, val_loss=37.2448, val_sbp_mae=17.9324, val_dbp_mae=19.3124

===== BPIndex PPG - Epoch 10, lr=0.0005 =====


[BPIndex-ppg] Train epoch 10: 100%|████████████| 1563/1563 [03:19<00:00,  7.85it/s, loss=26.701, sbp=14.609, dbp=12.092]
[BPIndex-ppg] Val epoch 10: 100%|████████████████| 391/391 [00:35<00:00, 10.88it/s, loss=36.247, sbp=17.483, dbp=18.765]


PPG BPIndex Epoch 10: train_loss=26.7009, train_sbp_mae=14.6089, train_dbp_mae=12.0920, val_loss=36.2474, val_sbp_mae=17.4826, val_dbp_mae=18.7648
  ↳ Saved best PPG BPIndex model (val_loss = 36.2474)

===== BPIndex PPG - Epoch 11, lr=0.0005 =====


[BPIndex-ppg] Train epoch 11: 100%|████████████| 1563/1563 [03:20<00:00,  7.81it/s, loss=26.495, sbp=14.485, dbp=12.010]
[BPIndex-ppg] Val epoch 11: 100%|████████████████| 391/391 [00:36<00:00, 10.81it/s, loss=35.931, sbp=17.123, dbp=18.808]


PPG BPIndex Epoch 11: train_loss=26.4950, train_sbp_mae=14.4850, train_dbp_mae=12.0100, val_loss=35.9310, val_sbp_mae=17.1226, val_dbp_mae=18.8085
  ↳ Saved best PPG BPIndex model (val_loss = 35.9310)

===== BPIndex PPG - Epoch 12, lr=0.0005 =====


[BPIndex-ppg] Train epoch 12: 100%|████████████| 1563/1563 [03:19<00:00,  7.82it/s, loss=26.300, sbp=14.370, dbp=11.930]
[BPIndex-ppg] Val epoch 12: 100%|████████████████| 391/391 [00:35<00:00, 10.86it/s, loss=35.531, sbp=17.041, dbp=18.490]


PPG BPIndex Epoch 12: train_loss=26.2998, train_sbp_mae=14.3699, train_dbp_mae=11.9300, val_loss=35.5312, val_sbp_mae=17.0411, val_dbp_mae=18.4901
  ↳ Saved best PPG BPIndex model (val_loss = 35.5312)

===== BPIndex PPG - Epoch 13, lr=0.0005 =====


[BPIndex-ppg] Train epoch 13: 100%|████████████| 1563/1563 [03:20<00:00,  7.81it/s, loss=26.082, sbp=14.234, dbp=11.847]
[BPIndex-ppg] Val epoch 13: 100%|████████████████| 391/391 [00:35<00:00, 10.86it/s, loss=35.775, sbp=17.165, dbp=18.610]


PPG BPIndex Epoch 13: train_loss=26.0816, train_sbp_mae=14.2343, train_dbp_mae=11.8473, val_loss=35.7751, val_sbp_mae=17.1653, val_dbp_mae=18.6098

===== BPIndex PPG - Epoch 14, lr=0.0005 =====


[BPIndex-ppg] Train epoch 14: 100%|████████████| 1563/1563 [03:19<00:00,  7.83it/s, loss=25.881, sbp=14.121, dbp=11.760]
[BPIndex-ppg] Val epoch 14: 100%|████████████████| 391/391 [00:36<00:00, 10.86it/s, loss=38.414, sbp=18.156, dbp=20.258]


PPG BPIndex Epoch 14: train_loss=25.8806, train_sbp_mae=14.1209, train_dbp_mae=11.7597, val_loss=38.4145, val_sbp_mae=18.1565, val_dbp_mae=20.2580

===== BPIndex PPG - Epoch 15, lr=0.0005 =====


[BPIndex-ppg] Train epoch 15: 100%|████████████| 1563/1563 [03:19<00:00,  7.83it/s, loss=25.752, sbp=14.044, dbp=11.708]
[BPIndex-ppg] Val epoch 15: 100%|████████████████| 391/391 [00:36<00:00, 10.85it/s, loss=35.628, sbp=16.485, dbp=19.143]


PPG BPIndex Epoch 15: train_loss=25.7518, train_sbp_mae=14.0436, train_dbp_mae=11.7082, val_loss=35.6282, val_sbp_mae=16.4854, val_dbp_mae=19.1427

===== BPIndex PPG - Epoch 16, lr=0.0005 =====


[BPIndex-ppg] Train epoch 16: 100%|████████████| 1563/1563 [03:19<00:00,  7.82it/s, loss=25.557, sbp=13.931, dbp=11.626]
[BPIndex-ppg] Val epoch 16: 100%|████████████████| 391/391 [00:35<00:00, 10.87it/s, loss=38.590, sbp=19.485, dbp=19.105]


PPG BPIndex Epoch 16: train_loss=25.5570, train_sbp_mae=13.9309, train_dbp_mae=11.6261, val_loss=38.5903, val_sbp_mae=19.4848, val_dbp_mae=19.1055

===== BPIndex PPG - Epoch 17, lr=0.0005 =====


[BPIndex-ppg] Train epoch 17: 100%|████████████| 1563/1563 [03:19<00:00,  7.83it/s, loss=25.436, sbp=13.860, dbp=11.576]
[BPIndex-ppg] Val epoch 17: 100%|████████████████| 391/391 [00:36<00:00, 10.85it/s, loss=35.059, sbp=16.809, dbp=18.250]


PPG BPIndex Epoch 17: train_loss=25.4360, train_sbp_mae=13.8602, train_dbp_mae=11.5759, val_loss=35.0589, val_sbp_mae=16.8094, val_dbp_mae=18.2495
  ↳ Saved best PPG BPIndex model (val_loss = 35.0589)

===== BPIndex PPG - Epoch 18, lr=0.0005 =====


[BPIndex-ppg] Train epoch 18: 100%|████████████| 1563/1563 [03:20<00:00,  7.81it/s, loss=25.256, sbp=13.753, dbp=11.503]
[BPIndex-ppg] Val epoch 18: 100%|████████████████| 391/391 [00:36<00:00, 10.86it/s, loss=35.396, sbp=17.021, dbp=18.375]


PPG BPIndex Epoch 18: train_loss=25.2562, train_sbp_mae=13.7531, train_dbp_mae=11.5031, val_loss=35.3955, val_sbp_mae=17.0208, val_dbp_mae=18.3747

===== BPIndex PPG - Epoch 19, lr=0.0005 =====


[BPIndex-ppg] Train epoch 19: 100%|████████████| 1563/1563 [03:19<00:00,  7.82it/s, loss=25.107, sbp=13.656, dbp=11.451]
[BPIndex-ppg] Val epoch 19: 100%|████████████████| 391/391 [00:36<00:00, 10.84it/s, loss=35.328, sbp=16.636, dbp=18.691]


PPG BPIndex Epoch 19: train_loss=25.1072, train_sbp_mae=13.6563, train_dbp_mae=11.4509, val_loss=35.3277, val_sbp_mae=16.6365, val_dbp_mae=18.6912

===== BPIndex PPG - Epoch 20, lr=0.0005 =====


[BPIndex-ppg] Train epoch 20: 100%|████████████| 1563/1563 [03:19<00:00,  7.83it/s, loss=24.980, sbp=13.574, dbp=11.406]
[BPIndex-ppg] Val epoch 20: 100%|████████████████| 391/391 [00:35<00:00, 10.86it/s, loss=34.377, sbp=16.383, dbp=17.994]


PPG BPIndex Epoch 20: train_loss=24.9798, train_sbp_mae=13.5736, train_dbp_mae=11.4062, val_loss=34.3773, val_sbp_mae=16.3828, val_dbp_mae=17.9945
  ↳ Saved best PPG BPIndex model (val_loss = 34.3773)

===== BPIndex PPG - Epoch 21, lr=0.0005 =====


[BPIndex-ppg] Train epoch 21: 100%|████████████| 1563/1563 [03:19<00:00,  7.83it/s, loss=24.851, sbp=13.509, dbp=11.343]
[BPIndex-ppg] Val epoch 21: 100%|████████████████| 391/391 [00:36<00:00, 10.86it/s, loss=34.833, sbp=16.626, dbp=18.206]


PPG BPIndex Epoch 21: train_loss=24.8511, train_sbp_mae=13.5086, train_dbp_mae=11.3426, val_loss=34.8328, val_sbp_mae=16.6264, val_dbp_mae=18.2064

===== BPIndex PPG - Epoch 22, lr=0.0005 =====


[BPIndex-ppg] Train epoch 22: 100%|████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=24.714, sbp=13.414, dbp=11.300]
[BPIndex-ppg] Val epoch 22: 100%|████████████████| 391/391 [00:36<00:00, 10.83it/s, loss=34.831, sbp=16.609, dbp=18.221]


PPG BPIndex Epoch 22: train_loss=24.7142, train_sbp_mae=13.4144, train_dbp_mae=11.2997, val_loss=34.8310, val_sbp_mae=16.6095, val_dbp_mae=18.2215

===== BPIndex PPG - Epoch 23, lr=0.0005 =====


[BPIndex-ppg] Train epoch 23: 100%|████████████| 1563/1563 [03:20<00:00,  7.81it/s, loss=24.573, sbp=13.334, dbp=11.239]
[BPIndex-ppg] Val epoch 23: 100%|████████████████| 391/391 [00:36<00:00, 10.80it/s, loss=34.944, sbp=16.661, dbp=18.283]


PPG BPIndex Epoch 23: train_loss=24.5732, train_sbp_mae=13.3344, train_dbp_mae=11.2389, val_loss=34.9438, val_sbp_mae=16.6612, val_dbp_mae=18.2826

===== BPIndex PPG - Epoch 24, lr=0.0005 =====


[BPIndex-ppg] Train epoch 24: 100%|████████████| 1563/1563 [03:20<00:00,  7.81it/s, loss=24.466, sbp=13.274, dbp=11.192]
[BPIndex-ppg] Val epoch 24: 100%|████████████████| 391/391 [00:36<00:00, 10.82it/s, loss=34.128, sbp=16.345, dbp=17.783]


PPG BPIndex Epoch 24: train_loss=24.4660, train_sbp_mae=13.2735, train_dbp_mae=11.1925, val_loss=34.1277, val_sbp_mae=16.3447, val_dbp_mae=17.7830
  ↳ Saved best PPG BPIndex model (val_loss = 34.1277)

===== BPIndex PPG - Epoch 25, lr=0.0005 =====


[BPIndex-ppg] Train epoch 25: 100%|████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=24.331, sbp=13.188, dbp=11.143]
[BPIndex-ppg] Val epoch 25: 100%|████████████████| 391/391 [00:36<00:00, 10.79it/s, loss=34.663, sbp=16.559, dbp=18.104]


PPG BPIndex Epoch 25: train_loss=24.3314, train_sbp_mae=13.1880, train_dbp_mae=11.1434, val_loss=34.6630, val_sbp_mae=16.5593, val_dbp_mae=18.1037

===== BPIndex PPG - Epoch 26, lr=0.0005 =====


[BPIndex-ppg] Train epoch 26: 100%|████████████| 1563/1563 [03:20<00:00,  7.80it/s, loss=24.177, sbp=13.095, dbp=11.082]
[BPIndex-ppg] Val epoch 26: 100%|████████████████| 391/391 [00:36<00:00, 10.79it/s, loss=34.893, sbp=16.677, dbp=18.216]


PPG BPIndex Epoch 26: train_loss=24.1765, train_sbp_mae=13.0949, train_dbp_mae=11.0816, val_loss=34.8935, val_sbp_mae=16.6771, val_dbp_mae=18.2164

===== BPIndex PPG - Epoch 27, lr=0.0005 =====


[BPIndex-ppg] Train epoch 27: 100%|████████████| 1563/1563 [03:20<00:00,  7.81it/s, loss=24.071, sbp=13.029, dbp=11.042]
[BPIndex-ppg] Val epoch 27: 100%|████████████████| 391/391 [00:36<00:00, 10.82it/s, loss=36.932, sbp=18.137, dbp=18.795]


PPG BPIndex Epoch 27: train_loss=24.0710, train_sbp_mae=13.0295, train_dbp_mae=11.0416, val_loss=36.9316, val_sbp_mae=18.1371, val_dbp_mae=18.7945

===== BPIndex PPG - Epoch 28, lr=0.0005 =====


[BPIndex-ppg] Train epoch 28: 100%|████████████| 1563/1563 [03:20<00:00,  7.81it/s, loss=23.954, sbp=12.957, dbp=10.996]
[BPIndex-ppg] Val epoch 28: 100%|████████████████| 391/391 [00:36<00:00, 10.81it/s, loss=33.882, sbp=16.153, dbp=17.729]


PPG BPIndex Epoch 28: train_loss=23.9538, train_sbp_mae=12.9574, train_dbp_mae=10.9964, val_loss=33.8818, val_sbp_mae=16.1526, val_dbp_mae=17.7291
  ↳ Saved best PPG BPIndex model (val_loss = 33.8818)

===== BPIndex PPG - Epoch 29, lr=0.0005 =====


[BPIndex-ppg] Train epoch 29: 100%|████████████| 1563/1563 [03:20<00:00,  7.81it/s, loss=23.806, sbp=12.862, dbp=10.944]
[BPIndex-ppg] Val epoch 29: 100%|████████████████| 391/391 [00:36<00:00, 10.78it/s, loss=33.985, sbp=16.162, dbp=17.824]


PPG BPIndex Epoch 29: train_loss=23.8064, train_sbp_mae=12.8623, train_dbp_mae=10.9440, val_loss=33.9855, val_sbp_mae=16.1620, val_dbp_mae=17.8235

===== BPIndex PPG - Epoch 30, lr=0.0005 =====


[BPIndex-ppg] Train epoch 30: 100%|████████████| 1563/1563 [03:20<00:00,  7.81it/s, loss=23.657, sbp=12.774, dbp=10.883]
[BPIndex-ppg] Val epoch 30: 100%|████████████████| 391/391 [00:36<00:00, 10.83it/s, loss=34.418, sbp=16.574, dbp=17.844]


PPG BPIndex Epoch 30: train_loss=23.6572, train_sbp_mae=12.7741, train_dbp_mae=10.8831, val_loss=34.4176, val_sbp_mae=16.5739, val_dbp_mae=17.8437

===== BPIndex PPG - Epoch 31, lr=0.0005 =====


[BPIndex-ppg] Train epoch 31: 100%|████████████| 1563/1563 [03:20<00:00,  7.81it/s, loss=23.532, sbp=12.698, dbp=10.834]
[BPIndex-ppg] Val epoch 31: 100%|████████████████| 391/391 [00:36<00:00, 10.82it/s, loss=34.474, sbp=16.180, dbp=18.293]


PPG BPIndex Epoch 31: train_loss=23.5319, train_sbp_mae=12.6976, train_dbp_mae=10.8343, val_loss=34.4737, val_sbp_mae=16.1803, val_dbp_mae=18.2934

===== BPIndex PPG - Epoch 32, lr=0.0005 =====


[BPIndex-ppg] Train epoch 32: 100%|████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=23.379, sbp=12.605, dbp=10.774]
[BPIndex-ppg] Val epoch 32: 100%|████████████████| 391/391 [00:36<00:00, 10.83it/s, loss=34.775, sbp=16.509, dbp=18.266]


PPG BPIndex Epoch 32: train_loss=23.3792, train_sbp_mae=12.6050, train_dbp_mae=10.7742, val_loss=34.7749, val_sbp_mae=16.5092, val_dbp_mae=18.2658

===== BPIndex PPG - Epoch 33, lr=0.0005 =====


[BPIndex-ppg] Train epoch 33: 100%|████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=23.246, sbp=12.527, dbp=10.720]
[BPIndex-ppg] Val epoch 33: 100%|████████████████| 391/391 [00:36<00:00, 10.80it/s, loss=33.794, sbp=16.090, dbp=17.704]


PPG BPIndex Epoch 33: train_loss=23.2463, train_sbp_mae=12.5268, train_dbp_mae=10.7195, val_loss=33.7935, val_sbp_mae=16.0896, val_dbp_mae=17.7040
  ↳ Saved best PPG BPIndex model (val_loss = 33.7935)

===== BPIndex PPG - Epoch 34, lr=0.0005 =====


[BPIndex-ppg] Train epoch 34: 100%|████████████| 1563/1563 [03:20<00:00,  7.80it/s, loss=23.091, sbp=12.433, dbp=10.658]
[BPIndex-ppg] Val epoch 34: 100%|████████████████| 391/391 [00:36<00:00, 10.85it/s, loss=33.839, sbp=16.197, dbp=17.642]


PPG BPIndex Epoch 34: train_loss=23.0907, train_sbp_mae=12.4329, train_dbp_mae=10.6578, val_loss=33.8388, val_sbp_mae=16.1967, val_dbp_mae=17.6421

===== BPIndex PPG - Epoch 35, lr=0.0005 =====


[BPIndex-ppg] Train epoch 35: 100%|████████████| 1563/1563 [03:20<00:00,  7.80it/s, loss=22.965, sbp=12.358, dbp=10.607]
[BPIndex-ppg] Val epoch 35: 100%|████████████████| 391/391 [00:36<00:00, 10.79it/s, loss=33.977, sbp=16.471, dbp=17.507]


PPG BPIndex Epoch 35: train_loss=22.9646, train_sbp_mae=12.3576, train_dbp_mae=10.6069, val_loss=33.9771, val_sbp_mae=16.4706, val_dbp_mae=17.5065

===== BPIndex PPG - Epoch 36, lr=0.0005 =====


[BPIndex-ppg] Train epoch 36: 100%|████████████| 1563/1563 [03:20<00:00,  7.80it/s, loss=22.837, sbp=12.284, dbp=10.552]
[BPIndex-ppg] Val epoch 36: 100%|████████████████| 391/391 [00:36<00:00, 10.83it/s, loss=33.050, sbp=15.636, dbp=17.414]


PPG BPIndex Epoch 36: train_loss=22.8366, train_sbp_mae=12.2842, train_dbp_mae=10.5523, val_loss=33.0498, val_sbp_mae=15.6357, val_dbp_mae=17.4141
  ↳ Saved best PPG BPIndex model (val_loss = 33.0498)

===== BPIndex PPG - Epoch 37, lr=0.0005 =====


[BPIndex-ppg] Train epoch 37: 100%|████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=22.668, sbp=12.181, dbp=10.487]
[BPIndex-ppg] Val epoch 37: 100%|████████████████| 391/391 [00:36<00:00, 10.81it/s, loss=34.306, sbp=16.706, dbp=17.600]


PPG BPIndex Epoch 37: train_loss=22.6684, train_sbp_mae=12.1811, train_dbp_mae=10.4873, val_loss=34.3056, val_sbp_mae=16.7057, val_dbp_mae=17.5999

===== BPIndex PPG - Epoch 38, lr=0.0005 =====


[BPIndex-ppg] Train epoch 38: 100%|████████████| 1563/1563 [03:20<00:00,  7.81it/s, loss=22.576, sbp=12.140, dbp=10.436]
[BPIndex-ppg] Val epoch 38: 100%|████████████████| 391/391 [00:36<00:00, 10.82it/s, loss=34.030, sbp=16.296, dbp=17.734]


PPG BPIndex Epoch 38: train_loss=22.5760, train_sbp_mae=12.1399, train_dbp_mae=10.4361, val_loss=34.0298, val_sbp_mae=16.2963, val_dbp_mae=17.7335

===== BPIndex PPG - Epoch 39, lr=0.0005 =====


[BPIndex-ppg] Train epoch 39: 100%|████████████| 1563/1563 [03:19<00:00,  7.82it/s, loss=22.451, sbp=12.049, dbp=10.402]
[BPIndex-ppg] Val epoch 39: 100%|████████████████| 391/391 [00:36<00:00, 10.83it/s, loss=33.628, sbp=16.009, dbp=17.618]


PPG BPIndex Epoch 39: train_loss=22.4506, train_sbp_mae=12.0489, train_dbp_mae=10.4017, val_loss=33.6277, val_sbp_mae=16.0092, val_dbp_mae=17.6185

===== BPIndex PPG - Epoch 40, lr=0.0005 =====


[BPIndex-ppg] Train epoch 40: 100%|████████████| 1563/1563 [03:20<00:00,  7.80it/s, loss=22.297, sbp=11.961, dbp=10.336]
[BPIndex-ppg] Val epoch 40: 100%|████████████████| 391/391 [00:36<00:00, 10.83it/s, loss=34.364, sbp=16.586, dbp=17.778]


PPG BPIndex Epoch 40: train_loss=22.2973, train_sbp_mae=11.9609, train_dbp_mae=10.3364, val_loss=34.3644, val_sbp_mae=16.5860, val_dbp_mae=17.7784

===== BPIndex PPG - Epoch 41, lr=0.0005 =====


[BPIndex-ppg] Train epoch 41: 100%|████████████| 1563/1563 [03:20<00:00,  7.81it/s, loss=22.159, sbp=11.883, dbp=10.276]
[BPIndex-ppg] Val epoch 41: 100%|████████████████| 391/391 [00:36<00:00, 10.85it/s, loss=34.787, sbp=16.711, dbp=18.076]


PPG BPIndex Epoch 41: train_loss=22.1588, train_sbp_mae=11.8831, train_dbp_mae=10.2756, val_loss=34.7870, val_sbp_mae=16.7114, val_dbp_mae=18.0756

===== BPIndex PPG - Epoch 42, lr=0.00025 =====


[BPIndex-ppg] Train epoch 42: 100%|█████████████| 1563/1563 [03:20<00:00,  7.81it/s, loss=21.197, sbp=11.310, dbp=9.887]
[BPIndex-ppg] Val epoch 42: 100%|████████████████| 391/391 [00:36<00:00, 10.82it/s, loss=33.657, sbp=16.081, dbp=17.576]


PPG BPIndex Epoch 42: train_loss=21.1974, train_sbp_mae=11.3102, train_dbp_mae=9.8871, val_loss=33.6567, val_sbp_mae=16.0810, val_dbp_mae=17.5757

===== BPIndex PPG - Epoch 43, lr=0.00025 =====


[BPIndex-ppg] Train epoch 43: 100%|█████████████| 1563/1563 [03:20<00:00,  7.80it/s, loss=20.923, sbp=11.154, dbp=9.770]
[BPIndex-ppg] Val epoch 43: 100%|████████████████| 391/391 [00:36<00:00, 10.82it/s, loss=32.668, sbp=15.410, dbp=17.259]


PPG BPIndex Epoch 43: train_loss=20.9232, train_sbp_mae=11.1535, train_dbp_mae=9.7697, val_loss=32.6684, val_sbp_mae=15.4099, val_dbp_mae=17.2586
  ↳ Saved best PPG BPIndex model (val_loss = 32.6684)

===== BPIndex PPG - Epoch 44, lr=0.00025 =====


[BPIndex-ppg] Train epoch 44: 100%|█████████████| 1563/1563 [03:20<00:00,  7.80it/s, loss=20.727, sbp=11.037, dbp=9.690]
[BPIndex-ppg] Val epoch 44: 100%|████████████████| 391/391 [00:36<00:00, 10.81it/s, loss=32.900, sbp=15.588, dbp=17.311]


PPG BPIndex Epoch 44: train_loss=20.7272, train_sbp_mae=11.0374, train_dbp_mae=9.6898, val_loss=32.8996, val_sbp_mae=15.5883, val_dbp_mae=17.3113

===== BPIndex PPG - Epoch 45, lr=0.00025 =====


[BPIndex-ppg] Train epoch 45: 100%|█████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=20.559, sbp=10.939, dbp=9.620]
[BPIndex-ppg] Val epoch 45: 100%|████████████████| 391/391 [00:36<00:00, 10.77it/s, loss=33.025, sbp=15.529, dbp=17.496]


PPG BPIndex Epoch 45: train_loss=20.5590, train_sbp_mae=10.9388, train_dbp_mae=9.6202, val_loss=33.0252, val_sbp_mae=15.5292, val_dbp_mae=17.4960

===== BPIndex PPG - Epoch 46, lr=0.00025 =====


[BPIndex-ppg] Train epoch 46: 100%|█████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=20.419, sbp=10.860, dbp=9.559]
[BPIndex-ppg] Val epoch 46: 100%|████████████████| 391/391 [00:36<00:00, 10.79it/s, loss=32.746, sbp=15.519, dbp=17.228]


PPG BPIndex Epoch 46: train_loss=20.4191, train_sbp_mae=10.8603, train_dbp_mae=9.5587, val_loss=32.7465, val_sbp_mae=15.5189, val_dbp_mae=17.2276

===== BPIndex PPG - Epoch 47, lr=0.00025 =====


[BPIndex-ppg] Train epoch 47: 100%|█████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=20.272, sbp=10.788, dbp=9.484]
[BPIndex-ppg] Val epoch 47: 100%|████████████████| 391/391 [00:36<00:00, 10.77it/s, loss=33.850, sbp=16.084, dbp=17.766]


PPG BPIndex Epoch 47: train_loss=20.2721, train_sbp_mae=10.7883, train_dbp_mae=9.4837, val_loss=33.8503, val_sbp_mae=16.0842, val_dbp_mae=17.7660

===== BPIndex PPG - Epoch 48, lr=0.00025 =====


[BPIndex-ppg] Train epoch 48: 100%|█████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=20.105, sbp=10.685, dbp=9.420]
[BPIndex-ppg] Val epoch 48: 100%|████████████████| 391/391 [00:36<00:00, 10.80it/s, loss=33.741, sbp=16.166, dbp=17.574]


PPG BPIndex Epoch 48: train_loss=20.1054, train_sbp_mae=10.6853, train_dbp_mae=9.4201, val_loss=33.7408, val_sbp_mae=16.1663, val_dbp_mae=17.5745

===== BPIndex PPG - Epoch 49, lr=0.000125 =====


[BPIndex-ppg] Train epoch 49: 100%|█████████████| 1563/1563 [03:21<00:00,  7.77it/s, loss=19.450, sbp=10.309, dbp=9.141]
[BPIndex-ppg] Val epoch 49: 100%|████████████████| 391/391 [00:36<00:00, 10.79it/s, loss=33.567, sbp=16.092, dbp=17.475]


PPG BPIndex Epoch 49: train_loss=19.4496, train_sbp_mae=10.3086, train_dbp_mae=9.1410, val_loss=33.5672, val_sbp_mae=16.0918, val_dbp_mae=17.4754

===== BPIndex PPG - Epoch 50, lr=0.000125 =====


[BPIndex-ppg] Train epoch 50: 100%|█████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=19.262, sbp=10.211, dbp=9.051]
[BPIndex-ppg] Val epoch 50: 100%|████████████████| 391/391 [00:36<00:00, 10.80it/s, loss=33.117, sbp=15.895, dbp=17.222]


PPG BPIndex Epoch 50: train_loss=19.2621, train_sbp_mae=10.2110, train_dbp_mae=9.0511, val_loss=33.1171, val_sbp_mae=15.8949, val_dbp_mae=17.2222
[Freq2SBPDBPModel] Trainable params: 12872706 (12.873 M)

===== BPIndex ECG - Epoch 1, lr=0.0005 =====


[BPIndex-ecg] Train epoch 1: 100%|█████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=35.326, sbp=20.000, dbp=15.326]
[BPIndex-ecg] Val epoch 1: 100%|█████████████████| 391/391 [00:36<00:00, 10.80it/s, loss=41.471, sbp=20.974, dbp=20.497]


ECG BPIndex Epoch 1: train_loss=35.3258, train_sbp_mae=20.0001, train_dbp_mae=15.3257, val_loss=41.4705, val_sbp_mae=20.9740, val_dbp_mae=20.4965
  ↳ Saved best ECG BPIndex model (val_loss = 41.4705)

===== BPIndex ECG - Epoch 2, lr=0.0005 =====


[BPIndex-ecg] Train epoch 2: 100%|█████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=30.843, sbp=17.386, dbp=13.457]
[BPIndex-ecg] Val epoch 2: 100%|█████████████████| 391/391 [00:36<00:00, 10.79it/s, loss=39.586, sbp=19.872, dbp=19.714]


ECG BPIndex Epoch 2: train_loss=30.8432, train_sbp_mae=17.3862, train_dbp_mae=13.4570, val_loss=39.5858, val_sbp_mae=19.8717, val_dbp_mae=19.7142
  ↳ Saved best ECG BPIndex model (val_loss = 39.5858)

===== BPIndex ECG - Epoch 3, lr=0.0005 =====


[BPIndex-ecg] Train epoch 3: 100%|█████████████| 1563/1563 [03:21<00:00,  7.77it/s, loss=29.247, sbp=16.390, dbp=12.856]
[BPIndex-ecg] Val epoch 3: 100%|█████████████████| 391/391 [00:36<00:00, 10.81it/s, loss=37.844, sbp=18.852, dbp=18.992]


ECG BPIndex Epoch 3: train_loss=29.2468, train_sbp_mae=16.3903, train_dbp_mae=12.8564, val_loss=37.8442, val_sbp_mae=18.8518, val_dbp_mae=18.9924
  ↳ Saved best ECG BPIndex model (val_loss = 37.8442)

===== BPIndex ECG - Epoch 4, lr=0.0005 =====


[BPIndex-ecg] Train epoch 4: 100%|█████████████| 1563/1563 [03:21<00:00,  7.77it/s, loss=28.144, sbp=15.704, dbp=12.440]
[BPIndex-ecg] Val epoch 4: 100%|█████████████████| 391/391 [00:36<00:00, 10.84it/s, loss=37.151, sbp=18.494, dbp=18.657]


ECG BPIndex Epoch 4: train_loss=28.1436, train_sbp_mae=15.7036, train_dbp_mae=12.4400, val_loss=37.1507, val_sbp_mae=18.4937, val_dbp_mae=18.6570
  ↳ Saved best ECG BPIndex model (val_loss = 37.1507)

===== BPIndex ECG - Epoch 5, lr=0.0005 =====


[BPIndex-ecg] Train epoch 5: 100%|█████████████| 1563/1563 [03:21<00:00,  7.77it/s, loss=27.412, sbp=15.247, dbp=12.166]
[BPIndex-ecg] Val epoch 5: 100%|█████████████████| 391/391 [00:36<00:00, 10.79it/s, loss=36.162, sbp=17.856, dbp=18.306]


ECG BPIndex Epoch 5: train_loss=27.4125, train_sbp_mae=15.2466, train_dbp_mae=12.1659, val_loss=36.1622, val_sbp_mae=17.8558, val_dbp_mae=18.3064
  ↳ Saved best ECG BPIndex model (val_loss = 36.1622)

===== BPIndex ECG - Epoch 6, lr=0.0005 =====


[BPIndex-ecg] Train epoch 6: 100%|█████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=26.804, sbp=14.882, dbp=11.922]
[BPIndex-ecg] Val epoch 6: 100%|█████████████████| 391/391 [00:36<00:00, 10.79it/s, loss=36.060, sbp=18.044, dbp=18.016]


ECG BPIndex Epoch 6: train_loss=26.8039, train_sbp_mae=14.8816, train_dbp_mae=11.9222, val_loss=36.0596, val_sbp_mae=18.0437, val_dbp_mae=18.0159
  ↳ Saved best ECG BPIndex model (val_loss = 36.0596)

===== BPIndex ECG - Epoch 7, lr=0.0005 =====


[BPIndex-ecg] Train epoch 7: 100%|█████████████| 1563/1563 [03:20<00:00,  7.80it/s, loss=26.278, sbp=14.561, dbp=11.717]
[BPIndex-ecg] Val epoch 7: 100%|█████████████████| 391/391 [00:36<00:00, 10.82it/s, loss=35.410, sbp=17.329, dbp=18.082]


ECG BPIndex Epoch 7: train_loss=26.2776, train_sbp_mae=14.5610, train_dbp_mae=11.7166, val_loss=35.4104, val_sbp_mae=17.3286, val_dbp_mae=18.0817
  ↳ Saved best ECG BPIndex model (val_loss = 35.4104)

===== BPIndex ECG - Epoch 8, lr=0.0005 =====


[BPIndex-ecg] Train epoch 8: 100%|█████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=25.840, sbp=14.291, dbp=11.549]
[BPIndex-ecg] Val epoch 8: 100%|█████████████████| 391/391 [00:36<00:00, 10.80it/s, loss=34.832, sbp=17.088, dbp=17.744]


ECG BPIndex Epoch 8: train_loss=25.8398, train_sbp_mae=14.2909, train_dbp_mae=11.5489, val_loss=34.8324, val_sbp_mae=17.0879, val_dbp_mae=17.7445
  ↳ Saved best ECG BPIndex model (val_loss = 34.8324)

===== BPIndex ECG - Epoch 9, lr=0.0005 =====


[BPIndex-ecg] Train epoch 9: 100%|█████████████| 1563/1563 [03:20<00:00,  7.80it/s, loss=25.433, sbp=14.042, dbp=11.391]
[BPIndex-ecg] Val epoch 9: 100%|█████████████████| 391/391 [00:36<00:00, 10.84it/s, loss=34.708, sbp=17.044, dbp=17.664]


ECG BPIndex Epoch 9: train_loss=25.4329, train_sbp_mae=14.0423, train_dbp_mae=11.3906, val_loss=34.7083, val_sbp_mae=17.0440, val_dbp_mae=17.6643
  ↳ Saved best ECG BPIndex model (val_loss = 34.7083)

===== BPIndex ECG - Epoch 10, lr=0.0005 =====


[BPIndex-ecg] Train epoch 10: 100%|████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=25.104, sbp=13.843, dbp=11.261]
[BPIndex-ecg] Val epoch 10: 100%|████████████████| 391/391 [00:36<00:00, 10.80it/s, loss=33.983, sbp=16.556, dbp=17.427]


ECG BPIndex Epoch 10: train_loss=25.1040, train_sbp_mae=13.8433, train_dbp_mae=11.2608, val_loss=33.9825, val_sbp_mae=16.5556, val_dbp_mae=17.4269
  ↳ Saved best ECG BPIndex model (val_loss = 33.9825)

===== BPIndex ECG - Epoch 11, lr=0.0005 =====


[BPIndex-ecg] Train epoch 11: 100%|████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=24.788, sbp=13.657, dbp=11.131]
[BPIndex-ecg] Val epoch 11: 100%|████████████████| 391/391 [00:36<00:00, 10.83it/s, loss=33.349, sbp=16.142, dbp=17.207]


ECG BPIndex Epoch 11: train_loss=24.7879, train_sbp_mae=13.6566, train_dbp_mae=11.1313, val_loss=33.3489, val_sbp_mae=16.1421, val_dbp_mae=17.2069
  ↳ Saved best ECG BPIndex model (val_loss = 33.3489)

===== BPIndex ECG - Epoch 12, lr=0.0005 =====


[BPIndex-ecg] Train epoch 12: 100%|████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=24.484, sbp=13.474, dbp=11.010]
[BPIndex-ecg] Val epoch 12: 100%|████████████████| 391/391 [00:36<00:00, 10.75it/s, loss=33.734, sbp=16.551, dbp=17.183]


ECG BPIndex Epoch 12: train_loss=24.4837, train_sbp_mae=13.4735, train_dbp_mae=11.0102, val_loss=33.7341, val_sbp_mae=16.5512, val_dbp_mae=17.1829

===== BPIndex ECG - Epoch 13, lr=0.0005 =====


[BPIndex-ecg] Train epoch 13: 100%|████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=24.252, sbp=13.330, dbp=10.922]
[BPIndex-ecg] Val epoch 13: 100%|████████████████| 391/391 [00:36<00:00, 10.81it/s, loss=33.533, sbp=16.267, dbp=17.266]


ECG BPIndex Epoch 13: train_loss=24.2519, train_sbp_mae=13.3299, train_dbp_mae=10.9220, val_loss=33.5332, val_sbp_mae=16.2672, val_dbp_mae=17.2660

===== BPIndex ECG - Epoch 14, lr=0.0005 =====


[BPIndex-ecg] Train epoch 14: 100%|████████████| 1563/1563 [03:20<00:00,  7.80it/s, loss=23.958, sbp=13.155, dbp=10.803]
[BPIndex-ecg] Val epoch 14: 100%|████████████████| 391/391 [00:36<00:00, 10.84it/s, loss=32.746, sbp=15.712, dbp=17.034]


ECG BPIndex Epoch 14: train_loss=23.9580, train_sbp_mae=13.1546, train_dbp_mae=10.8033, val_loss=32.7459, val_sbp_mae=15.7118, val_dbp_mae=17.0341
  ↳ Saved best ECG BPIndex model (val_loss = 32.7459)

===== BPIndex ECG - Epoch 15, lr=0.0005 =====


[BPIndex-ecg] Train epoch 15: 100%|████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=23.731, sbp=13.013, dbp=10.718]
[BPIndex-ecg] Val epoch 15: 100%|████████████████| 391/391 [00:36<00:00, 10.82it/s, loss=33.403, sbp=16.212, dbp=17.190]


ECG BPIndex Epoch 15: train_loss=23.7313, train_sbp_mae=13.0132, train_dbp_mae=10.7181, val_loss=33.4026, val_sbp_mae=16.2124, val_dbp_mae=17.1902

===== BPIndex ECG - Epoch 16, lr=0.0005 =====


[BPIndex-ecg] Train epoch 16: 100%|████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=23.485, sbp=12.871, dbp=10.614]
[BPIndex-ecg] Val epoch 16: 100%|████████████████| 391/391 [00:36<00:00, 10.83it/s, loss=33.509, sbp=16.078, dbp=17.431]


ECG BPIndex Epoch 16: train_loss=23.4846, train_sbp_mae=12.8709, train_dbp_mae=10.6137, val_loss=33.5094, val_sbp_mae=16.0782, val_dbp_mae=17.4312

===== BPIndex ECG - Epoch 17, lr=0.0005 =====


[BPIndex-ecg] Train epoch 17: 100%|████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=23.264, sbp=12.728, dbp=10.536]
[BPIndex-ecg] Val epoch 17: 100%|████████████████| 391/391 [00:36<00:00, 10.83it/s, loss=32.639, sbp=15.523, dbp=17.116]


ECG BPIndex Epoch 17: train_loss=23.2638, train_sbp_mae=12.7282, train_dbp_mae=10.5357, val_loss=32.6390, val_sbp_mae=15.5228, val_dbp_mae=17.1163
  ↳ Saved best ECG BPIndex model (val_loss = 32.6390)

===== BPIndex ECG - Epoch 18, lr=0.0005 =====


[BPIndex-ecg] Train epoch 18: 100%|████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=23.059, sbp=12.607, dbp=10.452]
[BPIndex-ecg] Val epoch 18: 100%|████████████████| 391/391 [00:36<00:00, 10.83it/s, loss=32.250, sbp=15.489, dbp=16.762]


ECG BPIndex Epoch 18: train_loss=23.0589, train_sbp_mae=12.6073, train_dbp_mae=10.4515, val_loss=32.2504, val_sbp_mae=15.4888, val_dbp_mae=16.7616
  ↳ Saved best ECG BPIndex model (val_loss = 32.2504)

===== BPIndex ECG - Epoch 19, lr=0.0005 =====


[BPIndex-ecg] Train epoch 19: 100%|████████████| 1563/1563 [03:20<00:00,  7.81it/s, loss=22.818, sbp=12.462, dbp=10.356]
[BPIndex-ecg] Val epoch 19: 100%|████████████████| 391/391 [00:36<00:00, 10.83it/s, loss=32.724, sbp=15.696, dbp=17.028]


ECG BPIndex Epoch 19: train_loss=22.8180, train_sbp_mae=12.4616, train_dbp_mae=10.3564, val_loss=32.7244, val_sbp_mae=15.6961, val_dbp_mae=17.0283

===== BPIndex ECG - Epoch 20, lr=0.0005 =====


[BPIndex-ecg] Train epoch 20: 100%|████████████| 1563/1563 [03:20<00:00,  7.81it/s, loss=22.598, sbp=12.324, dbp=10.274]
[BPIndex-ecg] Val epoch 20: 100%|████████████████| 391/391 [00:36<00:00, 10.83it/s, loss=32.119, sbp=15.352, dbp=16.766]


ECG BPIndex Epoch 20: train_loss=22.5984, train_sbp_mae=12.3240, train_dbp_mae=10.2744, val_loss=32.1186, val_sbp_mae=15.3523, val_dbp_mae=16.7663
  ↳ Saved best ECG BPIndex model (val_loss = 32.1186)

===== BPIndex ECG - Epoch 21, lr=0.0005 =====


[BPIndex-ecg] Train epoch 21: 100%|████████████| 1563/1563 [03:20<00:00,  7.80it/s, loss=22.375, sbp=12.203, dbp=10.172]
[BPIndex-ecg] Val epoch 21: 100%|████████████████| 391/391 [00:36<00:00, 10.83it/s, loss=31.702, sbp=15.114, dbp=16.588]


ECG BPIndex Epoch 21: train_loss=22.3751, train_sbp_mae=12.2026, train_dbp_mae=10.1725, val_loss=31.7020, val_sbp_mae=15.1142, val_dbp_mae=16.5878
  ↳ Saved best ECG BPIndex model (val_loss = 31.7020)

===== BPIndex ECG - Epoch 22, lr=0.0005 =====


[BPIndex-ecg] Train epoch 22: 100%|████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=22.175, sbp=12.082, dbp=10.093]
[BPIndex-ecg] Val epoch 22: 100%|████████████████| 391/391 [00:36<00:00, 10.80it/s, loss=32.096, sbp=15.410, dbp=16.687]


ECG BPIndex Epoch 22: train_loss=22.1755, train_sbp_mae=12.0824, train_dbp_mae=10.0931, val_loss=32.0965, val_sbp_mae=15.4099, val_dbp_mae=16.6865

===== BPIndex ECG - Epoch 23, lr=0.0005 =====


[BPIndex-ecg] Train epoch 23: 100%|████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=21.964, sbp=11.964, dbp=10.000]
[BPIndex-ecg] Val epoch 23: 100%|████████████████| 391/391 [00:36<00:00, 10.79it/s, loss=31.575, sbp=15.032, dbp=16.543]


ECG BPIndex Epoch 23: train_loss=21.9642, train_sbp_mae=11.9638, train_dbp_mae=10.0004, val_loss=31.5747, val_sbp_mae=15.0318, val_dbp_mae=16.5429
  ↳ Saved best ECG BPIndex model (val_loss = 31.5747)

===== BPIndex ECG - Epoch 24, lr=0.0005 =====


[BPIndex-ecg] Train epoch 24: 100%|█████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=21.766, sbp=11.851, dbp=9.916]
[BPIndex-ecg] Val epoch 24: 100%|████████████████| 391/391 [00:36<00:00, 10.76it/s, loss=31.277, sbp=14.824, dbp=16.453]


ECG BPIndex Epoch 24: train_loss=21.7664, train_sbp_mae=11.8506, train_dbp_mae=9.9158, val_loss=31.2768, val_sbp_mae=14.8242, val_dbp_mae=16.4527
  ↳ Saved best ECG BPIndex model (val_loss = 31.2768)

===== BPIndex ECG - Epoch 25, lr=0.0005 =====


[BPIndex-ecg] Train epoch 25: 100%|█████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=21.520, sbp=11.710, dbp=9.810]
[BPIndex-ecg] Val epoch 25: 100%|████████████████| 391/391 [00:36<00:00, 10.82it/s, loss=31.883, sbp=15.380, dbp=16.503]


ECG BPIndex Epoch 25: train_loss=21.5202, train_sbp_mae=11.7100, train_dbp_mae=9.8102, val_loss=31.8830, val_sbp_mae=15.3796, val_dbp_mae=16.5034

===== BPIndex ECG - Epoch 26, lr=0.0005 =====


[BPIndex-ecg] Train epoch 26: 100%|█████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=21.365, sbp=11.624, dbp=9.741]
[BPIndex-ecg] Val epoch 26: 100%|████████████████| 391/391 [00:36<00:00, 10.79it/s, loss=31.067, sbp=14.814, dbp=16.253]


ECG BPIndex Epoch 26: train_loss=21.3654, train_sbp_mae=11.6241, train_dbp_mae=9.7413, val_loss=31.0669, val_sbp_mae=14.8135, val_dbp_mae=16.2534
  ↳ Saved best ECG BPIndex model (val_loss = 31.0669)

===== BPIndex ECG - Epoch 27, lr=0.0005 =====


[BPIndex-ecg] Train epoch 27: 100%|█████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=21.196, sbp=11.551, dbp=9.645]
[BPIndex-ecg] Val epoch 27: 100%|████████████████| 391/391 [00:36<00:00, 10.79it/s, loss=30.999, sbp=14.691, dbp=16.308]


ECG BPIndex Epoch 27: train_loss=21.1957, train_sbp_mae=11.5509, train_dbp_mae=9.6448, val_loss=30.9989, val_sbp_mae=14.6908, val_dbp_mae=16.3081
  ↳ Saved best ECG BPIndex model (val_loss = 30.9989)

===== BPIndex ECG - Epoch 28, lr=0.0005 =====


[BPIndex-ecg] Train epoch 28: 100%|█████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=20.950, sbp=11.412, dbp=9.539]
[BPIndex-ecg] Val epoch 28: 100%|████████████████| 391/391 [00:36<00:00, 10.81it/s, loss=30.709, sbp=14.739, dbp=15.970]


ECG BPIndex Epoch 28: train_loss=20.9505, train_sbp_mae=11.4115, train_dbp_mae=9.5390, val_loss=30.7095, val_sbp_mae=14.7394, val_dbp_mae=15.9701
  ↳ Saved best ECG BPIndex model (val_loss = 30.7095)

===== BPIndex ECG - Epoch 29, lr=0.0005 =====


[BPIndex-ecg] Train epoch 29: 100%|█████████████| 1563/1563 [03:21<00:00,  7.78it/s, loss=20.780, sbp=11.320, dbp=9.460]
[BPIndex-ecg] Val epoch 29: 100%|████████████████| 391/391 [00:36<00:00, 10.73it/s, loss=30.977, sbp=14.657, dbp=16.321]


ECG BPIndex Epoch 29: train_loss=20.7799, train_sbp_mae=11.3199, train_dbp_mae=9.4599, val_loss=30.9772, val_sbp_mae=14.6566, val_dbp_mae=16.3206

===== BPIndex ECG - Epoch 30, lr=0.0005 =====


[BPIndex-ecg] Train epoch 30: 100%|█████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=20.610, sbp=11.224, dbp=9.385]
[BPIndex-ecg] Val epoch 30: 100%|████████████████| 391/391 [00:36<00:00, 10.81it/s, loss=31.602, sbp=14.867, dbp=16.735]


ECG BPIndex Epoch 30: train_loss=20.6096, train_sbp_mae=11.2243, train_dbp_mae=9.3853, val_loss=31.6019, val_sbp_mae=14.8666, val_dbp_mae=16.7353

===== BPIndex ECG - Epoch 31, lr=0.0005 =====


[BPIndex-ecg] Train epoch 31: 100%|█████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=20.438, sbp=11.139, dbp=9.299]
[BPIndex-ecg] Val epoch 31: 100%|████████████████| 391/391 [00:36<00:00, 10.82it/s, loss=31.646, sbp=14.998, dbp=16.649]


ECG BPIndex Epoch 31: train_loss=20.4381, train_sbp_mae=11.1387, train_dbp_mae=9.2994, val_loss=31.6464, val_sbp_mae=14.9977, val_dbp_mae=16.6487

===== BPIndex ECG - Epoch 32, lr=0.0005 =====


[BPIndex-ecg] Train epoch 32: 100%|█████████████| 1563/1563 [03:20<00:00,  7.80it/s, loss=20.252, sbp=11.040, dbp=9.211]
[BPIndex-ecg] Val epoch 32: 100%|████████████████| 391/391 [00:36<00:00, 10.77it/s, loss=31.033, sbp=14.706, dbp=16.328]


ECG BPIndex Epoch 32: train_loss=20.2518, train_sbp_mae=11.0405, train_dbp_mae=9.2114, val_loss=31.0333, val_sbp_mae=14.7057, val_dbp_mae=16.3275

===== BPIndex ECG - Epoch 33, lr=0.0005 =====


[BPIndex-ecg] Train epoch 33: 100%|█████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=20.122, sbp=10.976, dbp=9.147]
[BPIndex-ecg] Val epoch 33: 100%|████████████████| 391/391 [00:36<00:00, 10.80it/s, loss=30.067, sbp=14.238, dbp=15.829]


ECG BPIndex Epoch 33: train_loss=20.1223, train_sbp_mae=10.9756, train_dbp_mae=9.1466, val_loss=30.0670, val_sbp_mae=14.2382, val_dbp_mae=15.8289
  ↳ Saved best ECG BPIndex model (val_loss = 30.0670)

===== BPIndex ECG - Epoch 34, lr=0.0005 =====


[BPIndex-ecg] Train epoch 34: 100%|█████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=19.944, sbp=10.889, dbp=9.055]
[BPIndex-ecg] Val epoch 34: 100%|████████████████| 391/391 [00:36<00:00, 10.80it/s, loss=30.565, sbp=14.670, dbp=15.895]


ECG BPIndex Epoch 34: train_loss=19.9443, train_sbp_mae=10.8894, train_dbp_mae=9.0549, val_loss=30.5651, val_sbp_mae=14.6698, val_dbp_mae=15.8953

===== BPIndex ECG - Epoch 35, lr=0.0005 =====


[BPIndex-ecg] Train epoch 35: 100%|█████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=19.793, sbp=10.813, dbp=8.981]
[BPIndex-ecg] Val epoch 35: 100%|████████████████| 391/391 [00:36<00:00, 10.79it/s, loss=30.412, sbp=14.787, dbp=15.625]


ECG BPIndex Epoch 35: train_loss=19.7932, train_sbp_mae=10.8125, train_dbp_mae=8.9807, val_loss=30.4120, val_sbp_mae=14.7868, val_dbp_mae=15.6252

===== BPIndex ECG - Epoch 36, lr=0.0005 =====


[BPIndex-ecg] Train epoch 36: 100%|█████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=19.590, sbp=10.705, dbp=8.886]
[BPIndex-ecg] Val epoch 36: 100%|████████████████| 391/391 [00:36<00:00, 10.81it/s, loss=30.624, sbp=14.450, dbp=16.174]


ECG BPIndex Epoch 36: train_loss=19.5905, train_sbp_mae=10.7048, train_dbp_mae=8.8856, val_loss=30.6240, val_sbp_mae=14.4496, val_dbp_mae=16.1744

===== BPIndex ECG - Epoch 37, lr=0.0005 =====


[BPIndex-ecg] Train epoch 37: 100%|█████████████| 1563/1563 [03:20<00:00,  7.80it/s, loss=19.437, sbp=10.623, dbp=8.814]
[BPIndex-ecg] Val epoch 37: 100%|████████████████| 391/391 [00:36<00:00, 10.83it/s, loss=29.712, sbp=14.143, dbp=15.568]


ECG BPIndex Epoch 37: train_loss=19.4374, train_sbp_mae=10.6234, train_dbp_mae=8.8140, val_loss=29.7115, val_sbp_mae=14.1433, val_dbp_mae=15.5682
  ↳ Saved best ECG BPIndex model (val_loss = 29.7115)

===== BPIndex ECG - Epoch 38, lr=0.0005 =====


[BPIndex-ecg] Train epoch 38: 100%|█████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=19.294, sbp=10.550, dbp=8.744]
[BPIndex-ecg] Val epoch 38: 100%|████████████████| 391/391 [00:36<00:00, 10.78it/s, loss=30.409, sbp=14.467, dbp=15.942]


ECG BPIndex Epoch 38: train_loss=19.2939, train_sbp_mae=10.5497, train_dbp_mae=8.7442, val_loss=30.4094, val_sbp_mae=14.4675, val_dbp_mae=15.9420

===== BPIndex ECG - Epoch 39, lr=0.0005 =====


[BPIndex-ecg] Train epoch 39: 100%|█████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=19.125, sbp=10.471, dbp=8.654]
[BPIndex-ecg] Val epoch 39: 100%|████████████████| 391/391 [00:36<00:00, 10.80it/s, loss=29.727, sbp=14.160, dbp=15.567]


ECG BPIndex Epoch 39: train_loss=19.1248, train_sbp_mae=10.4707, train_dbp_mae=8.6541, val_loss=29.7271, val_sbp_mae=14.1596, val_dbp_mae=15.5674

===== BPIndex ECG - Epoch 40, lr=0.0005 =====


[BPIndex-ecg] Train epoch 40: 100%|█████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=18.942, sbp=10.395, dbp=8.547]
[BPIndex-ecg] Val epoch 40: 100%|████████████████| 391/391 [00:36<00:00, 10.82it/s, loss=30.050, sbp=14.324, dbp=15.726]


ECG BPIndex Epoch 40: train_loss=18.9418, train_sbp_mae=10.3947, train_dbp_mae=8.5471, val_loss=30.0504, val_sbp_mae=14.3243, val_dbp_mae=15.7261

===== BPIndex ECG - Epoch 41, lr=0.0005 =====


[BPIndex-ecg] Train epoch 41: 100%|█████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=18.809, sbp=10.321, dbp=8.488]
[BPIndex-ecg] Val epoch 41: 100%|████████████████| 391/391 [00:36<00:00, 10.82it/s, loss=30.133, sbp=14.512, dbp=15.622]


ECG BPIndex Epoch 41: train_loss=18.8086, train_sbp_mae=10.3210, train_dbp_mae=8.4876, val_loss=30.1332, val_sbp_mae=14.5116, val_dbp_mae=15.6216

===== BPIndex ECG - Epoch 42, lr=0.0005 =====


[BPIndex-ecg] Train epoch 42: 100%|█████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=18.593, sbp=10.200, dbp=8.393]
[BPIndex-ecg] Val epoch 42: 100%|████████████████| 391/391 [00:36<00:00, 10.76it/s, loss=29.857, sbp=14.274, dbp=15.583]


ECG BPIndex Epoch 42: train_loss=18.5933, train_sbp_mae=10.2005, train_dbp_mae=8.3928, val_loss=29.8571, val_sbp_mae=14.2738, val_dbp_mae=15.5832

===== BPIndex ECG - Epoch 43, lr=0.00025 =====


[BPIndex-ecg] Train epoch 43: 100%|██████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=17.524, sbp=9.640, dbp=7.885]
[BPIndex-ecg] Val epoch 43: 100%|████████████████| 391/391 [00:36<00:00, 10.81it/s, loss=29.611, sbp=14.080, dbp=15.532]


ECG BPIndex Epoch 43: train_loss=17.5241, train_sbp_mae=9.6395, train_dbp_mae=7.8846, val_loss=29.6115, val_sbp_mae=14.0799, val_dbp_mae=15.5316
  ↳ Saved best ECG BPIndex model (val_loss = 29.6115)

===== BPIndex ECG - Epoch 44, lr=0.00025 =====


[BPIndex-ecg] Train epoch 44: 100%|██████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=17.151, sbp=9.449, dbp=7.702]
[BPIndex-ecg] Val epoch 44: 100%|████████████████| 391/391 [00:36<00:00, 10.80it/s, loss=29.741, sbp=14.153, dbp=15.588]


ECG BPIndex Epoch 44: train_loss=17.1506, train_sbp_mae=9.4491, train_dbp_mae=7.7015, val_loss=29.7411, val_sbp_mae=14.1531, val_dbp_mae=15.5880

===== BPIndex ECG - Epoch 45, lr=0.00025 =====


[BPIndex-ecg] Train epoch 45: 100%|██████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=16.990, sbp=9.378, dbp=7.612]
[BPIndex-ecg] Val epoch 45: 100%|████████████████| 391/391 [00:36<00:00, 10.74it/s, loss=29.548, sbp=14.049, dbp=15.499]


ECG BPIndex Epoch 45: train_loss=16.9901, train_sbp_mae=9.3780, train_dbp_mae=7.6121, val_loss=29.5477, val_sbp_mae=14.0489, val_dbp_mae=15.4988
  ↳ Saved best ECG BPIndex model (val_loss = 29.5477)

===== BPIndex ECG - Epoch 46, lr=0.00025 =====


[BPIndex-ecg] Train epoch 46: 100%|██████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=16.809, sbp=9.288, dbp=7.521]
[BPIndex-ecg] Val epoch 46: 100%|████████████████| 391/391 [00:36<00:00, 10.80it/s, loss=29.746, sbp=14.049, dbp=15.697]


ECG BPIndex Epoch 46: train_loss=16.8092, train_sbp_mae=9.2881, train_dbp_mae=7.5211, val_loss=29.7458, val_sbp_mae=14.0487, val_dbp_mae=15.6972

===== BPIndex ECG - Epoch 47, lr=0.00025 =====


[BPIndex-ecg] Train epoch 47: 100%|██████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=16.630, sbp=9.197, dbp=7.432]
[BPIndex-ecg] Val epoch 47: 100%|████████████████| 391/391 [00:36<00:00, 10.80it/s, loss=30.101, sbp=14.236, dbp=15.865]


ECG BPIndex Epoch 47: train_loss=16.6298, train_sbp_mae=9.1973, train_dbp_mae=7.4325, val_loss=30.1014, val_sbp_mae=14.2360, val_dbp_mae=15.8655

===== BPIndex ECG - Epoch 48, lr=0.00025 =====


[BPIndex-ecg] Train epoch 48: 100%|██████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=16.487, sbp=9.131, dbp=7.356]
[BPIndex-ecg] Val epoch 48: 100%|████████████████| 391/391 [00:36<00:00, 10.77it/s, loss=29.774, sbp=14.111, dbp=15.663]


ECG BPIndex Epoch 48: train_loss=16.4870, train_sbp_mae=9.1308, train_dbp_mae=7.3562, val_loss=29.7737, val_sbp_mae=14.1112, val_dbp_mae=15.6625

===== BPIndex ECG - Epoch 49, lr=0.00025 =====


[BPIndex-ecg] Train epoch 49: 100%|██████████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=16.366, sbp=9.073, dbp=7.292]
[BPIndex-ecg] Val epoch 49: 100%|████████████████| 391/391 [00:36<00:00, 10.81it/s, loss=29.423, sbp=13.926, dbp=15.497]


ECG BPIndex Epoch 49: train_loss=16.3656, train_sbp_mae=9.0733, train_dbp_mae=7.2922, val_loss=29.4234, val_sbp_mae=13.9260, val_dbp_mae=15.4974
  ↳ Saved best ECG BPIndex model (val_loss = 29.4234)

===== BPIndex ECG - Epoch 50, lr=0.00025 =====


[BPIndex-ecg] Train epoch 50: 100%|██████████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=16.240, sbp=9.015, dbp=7.225]
[BPIndex-ecg] Val epoch 50: 100%|████████████████| 391/391 [00:36<00:00, 10.81it/s, loss=30.096, sbp=14.239, dbp=15.857]


ECG BPIndex Epoch 50: train_loss=16.2396, train_sbp_mae=9.0147, train_dbp_mae=7.2249, val_loss=30.0960, val_sbp_mae=14.2395, val_dbp_mae=15.8565
[Freq2SBPDBPModel] Trainable params: 12873730 (12.874 M)

===== BPIndex PPG+ECG - Epoch 1, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 1: 100%|█████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=34.728, sbp=19.652, dbp=15.076]
[BPIndex-ppg_ecg] Val epoch 1: 100%|█████████████| 391/391 [00:36<00:00, 10.80it/s, loss=42.450, sbp=21.027, dbp=21.423]


PPG+ECG BPIndex Epoch 1: train_loss=34.7280, train_sbp_mae=19.6520, train_dbp_mae=15.0760, val_loss=42.4500, val_sbp_mae=21.0272, val_dbp_mae=21.4227
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 42.4500)

===== BPIndex PPG+ECG - Epoch 2, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 2: 100%|█████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=29.599, sbp=16.529, dbp=13.070]
[BPIndex-ppg_ecg] Val epoch 2: 100%|█████████████| 391/391 [00:36<00:00, 10.79it/s, loss=38.306, sbp=19.294, dbp=19.013]


PPG+ECG BPIndex Epoch 2: train_loss=29.5993, train_sbp_mae=16.5292, train_dbp_mae=13.0701, val_loss=38.3064, val_sbp_mae=19.2935, val_dbp_mae=19.0129
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 38.3064)

===== BPIndex PPG+ECG - Epoch 3, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 3: 100%|█████████| 1563/1563 [03:21<00:00,  7.77it/s, loss=27.046, sbp=14.873, dbp=12.173]
[BPIndex-ppg_ecg] Val epoch 3: 100%|█████████████| 391/391 [00:36<00:00, 10.79it/s, loss=35.105, sbp=16.808, dbp=18.297]


PPG+ECG BPIndex Epoch 3: train_loss=27.0458, train_sbp_mae=14.8725, train_dbp_mae=12.1733, val_loss=35.1052, val_sbp_mae=16.8083, val_dbp_mae=18.2970
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 35.1052)

===== BPIndex PPG+ECG - Epoch 4, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 4: 100%|█████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=25.564, sbp=13.942, dbp=11.622]
[BPIndex-ppg_ecg] Val epoch 4: 100%|█████████████| 391/391 [00:36<00:00, 10.79it/s, loss=33.945, sbp=16.269, dbp=17.676]


PPG+ECG BPIndex Epoch 4: train_loss=25.5638, train_sbp_mae=13.9421, train_dbp_mae=11.6217, val_loss=33.9453, val_sbp_mae=16.2692, val_dbp_mae=17.6760
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 33.9453)

===== BPIndex PPG+ECG - Epoch 5, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 5: 100%|█████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=24.604, sbp=13.337, dbp=11.266]
[BPIndex-ppg_ecg] Val epoch 5: 100%|█████████████| 391/391 [00:36<00:00, 10.79it/s, loss=34.720, sbp=15.963, dbp=18.756]


PPG+ECG BPIndex Epoch 5: train_loss=24.6039, train_sbp_mae=13.3374, train_dbp_mae=11.2665, val_loss=34.7196, val_sbp_mae=15.9634, val_dbp_mae=18.7562

===== BPIndex PPG+ECG - Epoch 6, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 6: 100%|█████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=23.816, sbp=12.868, dbp=10.948]
[BPIndex-ppg_ecg] Val epoch 6: 100%|█████████████| 391/391 [00:36<00:00, 10.78it/s, loss=32.109, sbp=15.036, dbp=17.074]


PPG+ECG BPIndex Epoch 6: train_loss=23.8156, train_sbp_mae=12.8677, train_dbp_mae=10.9480, val_loss=32.1095, val_sbp_mae=15.0358, val_dbp_mae=17.0736
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 32.1095)

===== BPIndex PPG+ECG - Epoch 7, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 7: 100%|█████████| 1563/1563 [03:21<00:00,  7.77it/s, loss=23.228, sbp=12.515, dbp=10.713]
[BPIndex-ppg_ecg] Val epoch 7: 100%|█████████████| 391/391 [00:36<00:00, 10.81it/s, loss=31.173, sbp=14.512, dbp=16.661]


PPG+ECG BPIndex Epoch 7: train_loss=23.2277, train_sbp_mae=12.5152, train_dbp_mae=10.7125, val_loss=31.1734, val_sbp_mae=14.5123, val_dbp_mae=16.6611
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 31.1734)

===== BPIndex PPG+ECG - Epoch 8, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 8: 100%|█████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=22.763, sbp=12.239, dbp=10.524]
[BPIndex-ppg_ecg] Val epoch 8: 100%|█████████████| 391/391 [00:36<00:00, 10.83it/s, loss=30.911, sbp=14.348, dbp=16.563]


PPG+ECG BPIndex Epoch 8: train_loss=22.7631, train_sbp_mae=12.2392, train_dbp_mae=10.5239, val_loss=30.9113, val_sbp_mae=14.3481, val_dbp_mae=16.5631
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 30.9113)

===== BPIndex PPG+ECG - Epoch 9, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 9: 100%|█████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=22.296, sbp=11.966, dbp=10.330]
[BPIndex-ppg_ecg] Val epoch 9: 100%|█████████████| 391/391 [00:36<00:00, 10.77it/s, loss=30.718, sbp=14.463, dbp=16.256]


PPG+ECG BPIndex Epoch 9: train_loss=22.2960, train_sbp_mae=11.9656, train_dbp_mae=10.3304, val_loss=30.7182, val_sbp_mae=14.4626, val_dbp_mae=16.2555
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 30.7182)

===== BPIndex PPG+ECG - Epoch 10, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 10: 100%|████████| 1563/1563 [03:21<00:00,  7.76it/s, loss=21.903, sbp=11.738, dbp=10.165]
[BPIndex-ppg_ecg] Val epoch 10: 100%|████████████| 391/391 [00:36<00:00, 10.78it/s, loss=31.211, sbp=14.493, dbp=16.718]


PPG+ECG BPIndex Epoch 10: train_loss=21.9026, train_sbp_mae=11.7379, train_dbp_mae=10.1646, val_loss=31.2114, val_sbp_mae=14.4935, val_dbp_mae=16.7179

===== BPIndex PPG+ECG - Epoch 11, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 11: 100%|████████| 1563/1563 [03:21<00:00,  7.77it/s, loss=21.587, sbp=11.562, dbp=10.026]
[BPIndex-ppg_ecg] Val epoch 11: 100%|████████████| 391/391 [00:36<00:00, 10.76it/s, loss=30.267, sbp=13.816, dbp=16.452]


PPG+ECG BPIndex Epoch 11: train_loss=21.5874, train_sbp_mae=11.5618, train_dbp_mae=10.0256, val_loss=30.2673, val_sbp_mae=13.8158, val_dbp_mae=16.4516
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 30.2673)

===== BPIndex PPG+ECG - Epoch 12, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 12: 100%|█████████| 1563/1563 [03:21<00:00,  7.77it/s, loss=21.276, sbp=11.382, dbp=9.894]
[BPIndex-ppg_ecg] Val epoch 12: 100%|████████████| 391/391 [00:36<00:00, 10.80it/s, loss=29.610, sbp=13.597, dbp=16.013]


PPG+ECG BPIndex Epoch 12: train_loss=21.2756, train_sbp_mae=11.3821, train_dbp_mae=9.8936, val_loss=29.6101, val_sbp_mae=13.5972, val_dbp_mae=16.0129
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 29.6101)

===== BPIndex PPG+ECG - Epoch 13, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 13: 100%|█████████| 1563/1563 [03:21<00:00,  7.77it/s, loss=20.924, sbp=11.179, dbp=9.745]
[BPIndex-ppg_ecg] Val epoch 13: 100%|████████████| 391/391 [00:36<00:00, 10.79it/s, loss=29.789, sbp=13.537, dbp=16.252]


PPG+ECG BPIndex Epoch 13: train_loss=20.9238, train_sbp_mae=11.1789, train_dbp_mae=9.7449, val_loss=29.7891, val_sbp_mae=13.5371, val_dbp_mae=16.2520

===== BPIndex PPG+ECG - Epoch 14, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 14: 100%|█████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=20.567, sbp=10.970, dbp=9.597]
[BPIndex-ppg_ecg] Val epoch 14: 100%|████████████| 391/391 [00:36<00:00, 10.77it/s, loss=29.364, sbp=13.460, dbp=15.903]


PPG+ECG BPIndex Epoch 14: train_loss=20.5665, train_sbp_mae=10.9697, train_dbp_mae=9.5969, val_loss=29.3636, val_sbp_mae=13.4601, val_dbp_mae=15.9034
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 29.3636)

===== BPIndex PPG+ECG - Epoch 15, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 15: 100%|█████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=20.307, sbp=10.821, dbp=9.486]
[BPIndex-ppg_ecg] Val epoch 15: 100%|████████████| 391/391 [00:36<00:00, 10.79it/s, loss=29.283, sbp=13.271, dbp=16.012]


PPG+ECG BPIndex Epoch 15: train_loss=20.3067, train_sbp_mae=10.8210, train_dbp_mae=9.4856, val_loss=29.2833, val_sbp_mae=13.2712, val_dbp_mae=16.0121
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 29.2833)

===== BPIndex PPG+ECG - Epoch 16, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 16: 100%|█████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=20.020, sbp=10.667, dbp=9.353]
[BPIndex-ppg_ecg] Val epoch 16: 100%|████████████| 391/391 [00:36<00:00, 10.79it/s, loss=28.523, sbp=13.098, dbp=15.424]


PPG+ECG BPIndex Epoch 16: train_loss=20.0199, train_sbp_mae=10.6667, train_dbp_mae=9.3532, val_loss=28.5225, val_sbp_mae=13.0985, val_dbp_mae=15.4240
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 28.5225)

===== BPIndex PPG+ECG - Epoch 17, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 17: 100%|█████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=19.819, sbp=10.554, dbp=9.265]
[BPIndex-ppg_ecg] Val epoch 17: 100%|████████████| 391/391 [00:36<00:00, 10.77it/s, loss=28.740, sbp=13.366, dbp=15.374]


PPG+ECG BPIndex Epoch 17: train_loss=19.8191, train_sbp_mae=10.5538, train_dbp_mae=9.2653, val_loss=28.7398, val_sbp_mae=13.3658, val_dbp_mae=15.3741

===== BPIndex PPG+ECG - Epoch 18, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 18: 100%|█████████| 1563/1563 [03:21<00:00,  7.77it/s, loss=19.554, sbp=10.419, dbp=9.136]
[BPIndex-ppg_ecg] Val epoch 18: 100%|████████████| 391/391 [00:36<00:00, 10.80it/s, loss=28.208, sbp=12.815, dbp=15.393]


PPG+ECG BPIndex Epoch 18: train_loss=19.5542, train_sbp_mae=10.4187, train_dbp_mae=9.1355, val_loss=28.2078, val_sbp_mae=12.8151, val_dbp_mae=15.3926
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 28.2078)

===== BPIndex PPG+ECG - Epoch 19, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 19: 100%|█████████| 1563/1563 [03:21<00:00,  7.77it/s, loss=19.338, sbp=10.297, dbp=9.040]
[BPIndex-ppg_ecg] Val epoch 19: 100%|████████████| 391/391 [00:36<00:00, 10.77it/s, loss=28.283, sbp=13.157, dbp=15.126]


PPG+ECG BPIndex Epoch 19: train_loss=19.3376, train_sbp_mae=10.2972, train_dbp_mae=9.0405, val_loss=28.2835, val_sbp_mae=13.1571, val_dbp_mae=15.1263

===== BPIndex PPG+ECG - Epoch 20, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 20: 100%|█████████| 1563/1563 [03:21<00:00,  7.77it/s, loss=19.058, sbp=10.160, dbp=8.898]
[BPIndex-ppg_ecg] Val epoch 20: 100%|████████████| 391/391 [00:36<00:00, 10.77it/s, loss=28.497, sbp=13.169, dbp=15.328]


PPG+ECG BPIndex Epoch 20: train_loss=19.0581, train_sbp_mae=10.1603, train_dbp_mae=8.8978, val_loss=28.4969, val_sbp_mae=13.1691, val_dbp_mae=15.3278

===== BPIndex PPG+ECG - Epoch 21, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 21: 100%|█████████| 1563/1563 [03:22<00:00,  7.72it/s, loss=18.856, sbp=10.064, dbp=8.791]
[BPIndex-ppg_ecg] Val epoch 21: 100%|████████████| 391/391 [00:36<00:00, 10.66it/s, loss=27.390, sbp=12.629, dbp=14.761]


PPG+ECG BPIndex Epoch 21: train_loss=18.8557, train_sbp_mae=10.0643, train_dbp_mae=8.7914, val_loss=27.3896, val_sbp_mae=12.6285, val_dbp_mae=14.7611
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 27.3896)

===== BPIndex PPG+ECG - Epoch 22, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 22: 100%|██████████| 1563/1563 [03:22<00:00,  7.70it/s, loss=18.612, sbp=9.944, dbp=8.669]
[BPIndex-ppg_ecg] Val epoch 22: 100%|████████████| 391/391 [00:36<00:00, 10.66it/s, loss=26.834, sbp=12.300, dbp=14.534]


PPG+ECG BPIndex Epoch 22: train_loss=18.6122, train_sbp_mae=9.9437, train_dbp_mae=8.6685, val_loss=26.8339, val_sbp_mae=12.3004, val_dbp_mae=14.5335
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 26.8339)

===== BPIndex PPG+ECG - Epoch 23, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 23: 100%|██████████| 1563/1563 [03:23<00:00,  7.70it/s, loss=18.450, sbp=9.861, dbp=8.590]
[BPIndex-ppg_ecg] Val epoch 23: 100%|████████████| 391/391 [00:36<00:00, 10.81it/s, loss=27.187, sbp=12.546, dbp=14.641]


PPG+ECG BPIndex Epoch 23: train_loss=18.4505, train_sbp_mae=9.8609, train_dbp_mae=8.5896, val_loss=27.1871, val_sbp_mae=12.5461, val_dbp_mae=14.6410

===== BPIndex PPG+ECG - Epoch 24, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 24: 100%|██████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=18.171, sbp=9.719, dbp=8.453]
[BPIndex-ppg_ecg] Val epoch 24: 100%|████████████| 391/391 [00:36<00:00, 10.76it/s, loss=26.380, sbp=12.198, dbp=14.182]


PPG+ECG BPIndex Epoch 24: train_loss=18.1714, train_sbp_mae=9.7188, train_dbp_mae=8.4526, val_loss=26.3803, val_sbp_mae=12.1982, val_dbp_mae=14.1822
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 26.3803)

===== BPIndex PPG+ECG - Epoch 25, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 25: 100%|██████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=17.999, sbp=9.630, dbp=8.369]
[BPIndex-ppg_ecg] Val epoch 25: 100%|████████████| 391/391 [00:36<00:00, 10.79it/s, loss=27.798, sbp=12.760, dbp=15.037]


PPG+ECG BPIndex Epoch 25: train_loss=17.9988, train_sbp_mae=9.6299, train_dbp_mae=8.3689, val_loss=27.7979, val_sbp_mae=12.7605, val_dbp_mae=15.0374

===== BPIndex PPG+ECG - Epoch 26, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 26: 100%|██████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=17.790, sbp=9.525, dbp=8.265]
[BPIndex-ppg_ecg] Val epoch 26: 100%|████████████| 391/391 [00:36<00:00, 10.77it/s, loss=26.826, sbp=12.386, dbp=14.441]


PPG+ECG BPIndex Epoch 26: train_loss=17.7900, train_sbp_mae=9.5248, train_dbp_mae=8.2652, val_loss=26.8263, val_sbp_mae=12.3855, val_dbp_mae=14.4407

===== BPIndex PPG+ECG - Epoch 27, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 27: 100%|██████████| 1563/1563 [03:23<00:00,  7.69it/s, loss=17.642, sbp=9.446, dbp=8.196]
[BPIndex-ppg_ecg] Val epoch 27: 100%|████████████| 391/391 [00:36<00:00, 10.65it/s, loss=25.952, sbp=12.144, dbp=13.808]


PPG+ECG BPIndex Epoch 27: train_loss=17.6419, train_sbp_mae=9.4459, train_dbp_mae=8.1960, val_loss=25.9525, val_sbp_mae=12.1442, val_dbp_mae=13.8083
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 25.9525)

===== BPIndex PPG+ECG - Epoch 28, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 28: 100%|██████████| 1563/1563 [03:23<00:00,  7.69it/s, loss=17.403, sbp=9.326, dbp=8.077]
[BPIndex-ppg_ecg] Val epoch 28: 100%|████████████| 391/391 [00:36<00:00, 10.66it/s, loss=26.470, sbp=12.334, dbp=14.136]


PPG+ECG BPIndex Epoch 28: train_loss=17.4027, train_sbp_mae=9.3256, train_dbp_mae=8.0772, val_loss=26.4700, val_sbp_mae=12.3340, val_dbp_mae=14.1360

===== BPIndex PPG+ECG - Epoch 29, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 29: 100%|██████████| 1563/1563 [03:22<00:00,  7.72it/s, loss=17.236, sbp=9.235, dbp=8.000]
[BPIndex-ppg_ecg] Val epoch 29: 100%|████████████| 391/391 [00:36<00:00, 10.81it/s, loss=25.849, sbp=11.810, dbp=14.040]


PPG+ECG BPIndex Epoch 29: train_loss=17.2356, train_sbp_mae=9.2354, train_dbp_mae=8.0002, val_loss=25.8492, val_sbp_mae=11.8096, val_dbp_mae=14.0396
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 25.8492)

===== BPIndex PPG+ECG - Epoch 30, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 30: 100%|██████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=17.050, sbp=9.148, dbp=7.902]
[BPIndex-ppg_ecg] Val epoch 30: 100%|████████████| 391/391 [00:36<00:00, 10.81it/s, loss=26.414, sbp=12.281, dbp=14.133]


PPG+ECG BPIndex Epoch 30: train_loss=17.0497, train_sbp_mae=9.1480, train_dbp_mae=7.9017, val_loss=26.4141, val_sbp_mae=12.2807, val_dbp_mae=14.1334

===== BPIndex PPG+ECG - Epoch 31, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 31: 100%|██████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=16.965, sbp=9.107, dbp=7.858]
[BPIndex-ppg_ecg] Val epoch 31: 100%|████████████| 391/391 [00:36<00:00, 10.81it/s, loss=25.984, sbp=12.025, dbp=13.959]


PPG+ECG BPIndex Epoch 31: train_loss=16.9649, train_sbp_mae=9.1068, train_dbp_mae=7.8581, val_loss=25.9838, val_sbp_mae=12.0248, val_dbp_mae=13.9589

===== BPIndex PPG+ECG - Epoch 32, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 32: 100%|██████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=16.743, sbp=8.997, dbp=7.747]
[BPIndex-ppg_ecg] Val epoch 32: 100%|████████████| 391/391 [00:36<00:00, 10.82it/s, loss=25.366, sbp=11.707, dbp=13.659]


PPG+ECG BPIndex Epoch 32: train_loss=16.7434, train_sbp_mae=8.9965, train_dbp_mae=7.7469, val_loss=25.3661, val_sbp_mae=11.7069, val_dbp_mae=13.6591
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 25.3661)

===== BPIndex PPG+ECG - Epoch 33, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 33: 100%|██████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=16.553, sbp=8.905, dbp=7.648]
[BPIndex-ppg_ecg] Val epoch 33: 100%|████████████| 391/391 [00:36<00:00, 10.83it/s, loss=24.781, sbp=11.430, dbp=13.351]


PPG+ECG BPIndex Epoch 33: train_loss=16.5529, train_sbp_mae=8.9048, train_dbp_mae=7.6480, val_loss=24.7812, val_sbp_mae=11.4299, val_dbp_mae=13.3513
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 24.7812)

===== BPIndex PPG+ECG - Epoch 34, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 34: 100%|██████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=16.433, sbp=8.846, dbp=7.587]
[BPIndex-ppg_ecg] Val epoch 34: 100%|████████████| 391/391 [00:36<00:00, 10.78it/s, loss=25.035, sbp=11.509, dbp=13.526]


PPG+ECG BPIndex Epoch 34: train_loss=16.4333, train_sbp_mae=8.8463, train_dbp_mae=7.5869, val_loss=25.0345, val_sbp_mae=11.5089, val_dbp_mae=13.5257

===== BPIndex PPG+ECG - Epoch 35, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 35: 100%|██████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=16.267, sbp=8.759, dbp=7.507]
[BPIndex-ppg_ecg] Val epoch 35: 100%|████████████| 391/391 [00:36<00:00, 10.73it/s, loss=24.491, sbp=11.248, dbp=13.242]


PPG+ECG BPIndex Epoch 35: train_loss=16.2667, train_sbp_mae=8.7593, train_dbp_mae=7.5073, val_loss=24.4907, val_sbp_mae=11.2483, val_dbp_mae=13.2424
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 24.4907)

===== BPIndex PPG+ECG - Epoch 36, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 36: 100%|██████████| 1563/1563 [03:21<00:00,  7.77it/s, loss=16.090, sbp=8.671, dbp=7.418]
[BPIndex-ppg_ecg] Val epoch 36: 100%|████████████| 391/391 [00:36<00:00, 10.80it/s, loss=25.787, sbp=11.912, dbp=13.875]


PPG+ECG BPIndex Epoch 36: train_loss=16.0895, train_sbp_mae=8.6715, train_dbp_mae=7.4180, val_loss=25.7873, val_sbp_mae=11.9124, val_dbp_mae=13.8749

===== BPIndex PPG+ECG - Epoch 37, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 37: 100%|██████████| 1563/1563 [03:22<00:00,  7.72it/s, loss=15.971, sbp=8.615, dbp=7.357]
[BPIndex-ppg_ecg] Val epoch 37: 100%|████████████| 391/391 [00:36<00:00, 10.60it/s, loss=28.106, sbp=12.780, dbp=15.326]


PPG+ECG BPIndex Epoch 37: train_loss=15.9713, train_sbp_mae=8.6147, train_dbp_mae=7.3567, val_loss=28.1063, val_sbp_mae=12.7801, val_dbp_mae=15.3262

===== BPIndex PPG+ECG - Epoch 38, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 38: 100%|██████████| 1563/1563 [03:23<00:00,  7.70it/s, loss=15.910, sbp=8.593, dbp=7.317]
[BPIndex-ppg_ecg] Val epoch 38: 100%|████████████| 391/391 [00:36<00:00, 10.64it/s, loss=27.060, sbp=12.607, dbp=14.452]


PPG+ECG BPIndex Epoch 38: train_loss=15.9101, train_sbp_mae=8.5932, train_dbp_mae=7.3169, val_loss=27.0599, val_sbp_mae=12.6074, val_dbp_mae=14.4525

===== BPIndex PPG+ECG - Epoch 39, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 39: 100%|██████████| 1563/1563 [03:23<00:00,  7.69it/s, loss=15.714, sbp=8.493, dbp=7.220]
[BPIndex-ppg_ecg] Val epoch 39: 100%|████████████| 391/391 [00:36<00:00, 10.83it/s, loss=25.426, sbp=11.593, dbp=13.833]


PPG+ECG BPIndex Epoch 39: train_loss=15.7137, train_sbp_mae=8.4933, train_dbp_mae=7.2203, val_loss=25.4259, val_sbp_mae=11.5928, val_dbp_mae=13.8332

===== BPIndex PPG+ECG - Epoch 40, lr=0.0005 =====


[BPIndex-ppg_ecg] Train epoch 40: 100%|██████████| 1563/1563 [03:21<00:00,  7.77it/s, loss=15.590, sbp=8.425, dbp=7.165]
[BPIndex-ppg_ecg] Val epoch 40: 100%|████████████| 391/391 [00:36<00:00, 10.78it/s, loss=25.073, sbp=11.484, dbp=13.589]


PPG+ECG BPIndex Epoch 40: train_loss=15.5904, train_sbp_mae=8.4252, train_dbp_mae=7.1653, val_loss=25.0733, val_sbp_mae=11.4840, val_dbp_mae=13.5893

===== BPIndex PPG+ECG - Epoch 41, lr=0.00025 =====


[BPIndex-ppg_ecg] Train epoch 41: 100%|██████████| 1563/1563 [03:21<00:00,  7.77it/s, loss=14.474, sbp=7.852, dbp=6.622]
[BPIndex-ppg_ecg] Val epoch 41: 100%|████████████| 391/391 [00:36<00:00, 10.65it/s, loss=24.030, sbp=11.069, dbp=12.962]


PPG+ECG BPIndex Epoch 41: train_loss=14.4743, train_sbp_mae=7.8524, train_dbp_mae=6.6219, val_loss=24.0302, val_sbp_mae=11.0687, val_dbp_mae=12.9615
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 24.0302)

===== BPIndex PPG+ECG - Epoch 42, lr=0.00025 =====


[BPIndex-ppg_ecg] Train epoch 42: 100%|██████████| 1563/1563 [03:23<00:00,  7.69it/s, loss=14.158, sbp=7.690, dbp=6.469]
[BPIndex-ppg_ecg] Val epoch 42: 100%|████████████| 391/391 [00:36<00:00, 10.64it/s, loss=25.048, sbp=11.454, dbp=13.594]


PPG+ECG BPIndex Epoch 42: train_loss=14.1583, train_sbp_mae=7.6896, train_dbp_mae=6.4687, val_loss=25.0482, val_sbp_mae=11.4545, val_dbp_mae=13.5937

===== BPIndex PPG+ECG - Epoch 43, lr=0.00025 =====


[BPIndex-ppg_ecg] Train epoch 43: 100%|██████████| 1563/1563 [03:23<00:00,  7.68it/s, loss=14.020, sbp=7.614, dbp=6.405]
[BPIndex-ppg_ecg] Val epoch 43: 100%|████████████| 391/391 [00:36<00:00, 10.64it/s, loss=24.120, sbp=10.962, dbp=13.158]


PPG+ECG BPIndex Epoch 43: train_loss=14.0199, train_sbp_mae=7.6145, train_dbp_mae=6.4054, val_loss=24.1198, val_sbp_mae=10.9621, val_dbp_mae=13.1578

===== BPIndex PPG+ECG - Epoch 44, lr=0.00025 =====


[BPIndex-ppg_ecg] Train epoch 44: 100%|██████████| 1563/1563 [03:22<00:00,  7.73it/s, loss=13.847, sbp=7.540, dbp=6.307]
[BPIndex-ppg_ecg] Val epoch 44: 100%|████████████| 391/391 [00:36<00:00, 10.78it/s, loss=24.212, sbp=10.934, dbp=13.278]


PPG+ECG BPIndex Epoch 44: train_loss=13.8466, train_sbp_mae=7.5397, train_dbp_mae=6.3069, val_loss=24.2124, val_sbp_mae=10.9342, val_dbp_mae=13.2782

===== BPIndex PPG+ECG - Epoch 45, lr=0.00025 =====


[BPIndex-ppg_ecg] Train epoch 45: 100%|██████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=13.719, sbp=7.474, dbp=6.245]
[BPIndex-ppg_ecg] Val epoch 45: 100%|████████████| 391/391 [00:36<00:00, 10.80it/s, loss=24.489, sbp=11.097, dbp=13.393]


PPG+ECG BPIndex Epoch 45: train_loss=13.7193, train_sbp_mae=7.4744, train_dbp_mae=6.2449, val_loss=24.4894, val_sbp_mae=11.0966, val_dbp_mae=13.3928

===== BPIndex PPG+ECG - Epoch 46, lr=0.00025 =====


[BPIndex-ppg_ecg] Train epoch 46: 100%|██████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=13.590, sbp=7.419, dbp=6.171]
[BPIndex-ppg_ecg] Val epoch 46: 100%|████████████| 391/391 [00:36<00:00, 10.81it/s, loss=24.297, sbp=11.100, dbp=13.197]


PPG+ECG BPIndex Epoch 46: train_loss=13.5899, train_sbp_mae=7.4192, train_dbp_mae=6.1707, val_loss=24.2971, val_sbp_mae=11.0998, val_dbp_mae=13.1973

===== BPIndex PPG+ECG - Epoch 47, lr=0.000125 =====


[BPIndex-ppg_ecg] Train epoch 47: 100%|██████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=12.975, sbp=7.097, dbp=5.878]
[BPIndex-ppg_ecg] Val epoch 47: 100%|████████████| 391/391 [00:36<00:00, 10.78it/s, loss=23.928, sbp=10.887, dbp=13.041]


PPG+ECG BPIndex Epoch 47: train_loss=12.9753, train_sbp_mae=7.0974, train_dbp_mae=5.8780, val_loss=23.9278, val_sbp_mae=10.8873, val_dbp_mae=13.0406
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 23.9278)

===== BPIndex PPG+ECG - Epoch 48, lr=0.000125 =====


[BPIndex-ppg_ecg] Train epoch 48: 100%|██████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=12.796, sbp=7.008, dbp=5.789]
[BPIndex-ppg_ecg] Val epoch 48: 100%|████████████| 391/391 [00:36<00:00, 10.82it/s, loss=23.861, sbp=10.789, dbp=13.072]


PPG+ECG BPIndex Epoch 48: train_loss=12.7964, train_sbp_mae=7.0077, train_dbp_mae=5.7887, val_loss=23.8607, val_sbp_mae=10.7887, val_dbp_mae=13.0720
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 23.8607)

===== BPIndex PPG+ECG - Epoch 49, lr=0.000125 =====


[BPIndex-ppg_ecg] Train epoch 49: 100%|██████████| 1563/1563 [03:20<00:00,  7.79it/s, loss=12.679, sbp=6.942, dbp=5.736]
[BPIndex-ppg_ecg] Val epoch 49: 100%|████████████| 391/391 [00:36<00:00, 10.79it/s, loss=23.547, sbp=10.698, dbp=12.850]


PPG+ECG BPIndex Epoch 49: train_loss=12.6786, train_sbp_mae=6.9424, train_dbp_mae=5.7362, val_loss=23.5474, val_sbp_mae=10.6977, val_dbp_mae=12.8496
  ↳ Saved best PPG+ECG BPIndex model (val_loss = 23.5474)

===== BPIndex PPG+ECG - Epoch 50, lr=0.000125 =====


[BPIndex-ppg_ecg] Train epoch 50: 100%|██████████| 1563/1563 [03:20<00:00,  7.78it/s, loss=12.576, sbp=6.901, dbp=5.676]
[BPIndex-ppg_ecg] Val epoch 50: 100%|████████████| 391/391 [00:36<00:00, 10.78it/s, loss=23.690, sbp=10.688, dbp=13.002]

PPG+ECG BPIndex Epoch 50: train_loss=12.5764, train_sbp_mae=6.9008, train_dbp_mae=5.6755, val_loss=23.6901, val_sbp_mae=10.6881, val_dbp_mae=13.0020





In [9]:
#保存前面三个baseline model
import os
from pathlib import Path
import shutil

# 1) 确保已经挂载好 Google Drive（如果之前已经 mount 过可以跳过）
try:
    from google.colab import drive
    drive.mount("/content/gdrive", force_remount=False)
except Exception as e:
    print("Drive mount skipped or already mounted:", e)

# 2) 本地模型所在目录（一般就是当前工作目录）
SRC_DIR = Path(".")   # 如果你明确知道在别的地方，可以改这里

# 3) 目标保存目录（你的 Google Drive 目录）
DST_DIR = Path("/content/gdrive/MyDrive/11785Project")
DST_DIR.mkdir(parents=True, exist_ok=True)

print("Source dir :", SRC_DIR.resolve())
print("Target dir :", DST_DIR.resolve())

# 4) 要备份的模型文件名列表
# 按你之前的命名来填，如果你有两套名字，可以都列上：
candidate_model_files = [
    # 旧的频谱 → ABP baseline 名字（如果你有用）
    "baseline_ppg_cnn.pt",
    "baseline_ecg_cnn.pt",
    "baseline_ppg_ecg_cnn.pt",

    # 新的 SBP/DBP baseline 名字（如果你按我之前建议的叫法来）
    "bpindex_ppg_cnn.pt",
    "bpindex_ecg_cnn.pt",
    "bpindex_ppg_ecg_cnn.pt",
]

# 5) 逐个检查并复制
for fname in candidate_model_files:
    src = SRC_DIR / fname
    if src.exists():
        dst = DST_DIR / fname
        shutil.copy2(src, dst)
        print(f"Copied {src}  ->  {dst}")
    else:
        print(f"[Skip] {src} 不存在，跳过")


Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
Source dir : /content
Target dir : /content/gdrive/MyDrive/11785Project
[Skip] baseline_ppg_cnn.pt 不存在，跳过
[Skip] baseline_ecg_cnn.pt 不存在，跳过
[Skip] baseline_ppg_ecg_cnn.pt 不存在，跳过
Copied bpindex_ppg_cnn.pt  ->  /content/gdrive/MyDrive/11785Project/bpindex_ppg_cnn.pt
Copied bpindex_ecg_cnn.pt  ->  /content/gdrive/MyDrive/11785Project/bpindex_ecg_cnn.pt
Copied bpindex_ppg_ecg_cnn.pt  ->  /content/gdrive/MyDrive/11785Project/bpindex_ppg_ecg_cnn.pt


In [None]:
# ==== Block B-1: Diffusion 模型定义（和你上面的大号版一致） ====

class ResConvBlock1D(nn.Module):
    def __init__(self, ch, k=3, p=1, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=k, padding=p),
            nn.BatchNorm1d(ch),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv1d(ch, ch, kernel_size=k, padding=p),
            nn.BatchNorm1d(ch),
            nn.SiLU(),
        )

    def forward(self, x):
        return x + self.net(x)


class FreqConditionalEpsNet1D(nn.Module):
    def __init__(self, base_ch=512, time_emb_dim=512, dropout=0.1):
        super().__init__()
        self.time_emb_dim = time_emb_dim
        self.base_ch = base_ch

        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, base_ch),
            nn.ReLU(),
        )
        self.to_t = nn.Linear(base_ch, base_ch)

        in_ch = 4  # 2(ECG) + 2(PPG)

        self.conv_in = nn.Sequential(
            nn.Conv1d(in_ch, base_ch, kernel_size=3, padding=1),
            nn.BatchNorm1d(base_ch),
            nn.SiLU(),
        )

        self.blocks = nn.ModuleList([
            ResConvBlock1D(base_ch, k=3, p=1, dropout=dropout)
            for _ in range(6)
        ])

        self.out_conv = nn.Conv1d(base_ch, 2, kernel_size=1)

    def sinusoidal_embedding(self, t, dim):
        device = t.device
        half_dim = dim // 2
        emb_factor = torch.exp(
            torch.arange(half_dim, device=device)
            * (-torch.log(torch.tensor(10000.0)) / half_dim)
        )
        emb = t[:, None].float() * emb_factor[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if dim % 2 == 1:
            emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
        return emb

    def forward(self, ecg_t, ppg, t):
        # ecg_t, ppg: [B,2,L]
        x = torch.cat([ecg_t, ppg], dim=1)  # [B,4,L]

        t_sin  = self.sinusoidal_embedding(t, self.time_emb_dim)  # [B,D]
        t_base = self.time_mlp(t_sin)                             # [B,base_ch]
        t_feat = self.to_t(t_base)[:, :, None]                    # [B,base_ch,1]

        x = self.conv_in(x)  # [B,base_ch,L]

        for block in self.blocks:
            x = block(x + t_feat)

        eps = self.out_conv(x)  # [B,2,L]
        return eps


class FreqPPG2ECGDiffusion(nn.Module):
    def __init__(self, eps_model: FreqConditionalEpsNet1D, timesteps: int = 50):
        super().__init__()
        self.eps_model = eps_model
        self.T = timesteps

        beta_start, beta_end = 1e-4, 0.02
        betas = torch.linspace(beta_start, beta_end, timesteps)
        alphas = 1.0 - betas
        alpha_bars = torch.cumprod(alphas, dim=0)

        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alpha_bars', alpha_bars)

    def q_sample(self, x0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x0)
        alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1)
        return torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * noise

    def forward(self, ecg_spec, ppg_spec):
        B = ecg_spec.size(0)
        device = ecg_spec.device
        t = torch.randint(0, self.T, (B,), device=device)

        noise = torch.randn_like(ecg_spec)
        x_t = self.q_sample(ecg_spec, t, noise)

        eps_pred = self.eps_model(x_t, ppg_spec, t)
        diff_loss = F.mse_loss(eps_pred, noise)
        return diff_loss

    @torch.no_grad()
    def p_sample(self, x_t, ppg_spec, t):
        beta_t = self.betas[t].view(-1, 1, 1)
        alpha_t = self.alphas[t].view(-1, 1, 1)
        alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1)

        eps_theta = self.eps_model(x_t, ppg_spec, t)
        mean = (1 / torch.sqrt(alpha_t)) * (
            x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * eps_theta
        )

        if t[0] == 0:
            return mean
        else:
            noise = torch.randn_like(x_t)
            return mean + torch.sqrt(beta_t) * noise

    @torch.no_grad()
    def sample_ecg_spec(self, ppg_spec):
        B, C, L = ppg_spec.shape
        device = ppg_spec.device
        x_t = torch.randn((B, 2, L), device=device)

        for step in reversed(range(self.T)):
            t = torch.full((B,), step, device=device, dtype=torch.long)
            x_t = self.p_sample(x_t, ppg_spec, t)

        return x_t



# ==== Block B-2: 训练 diffusion（只优化 diff_loss） ====

diff_eps_net = FreqConditionalEpsNet1D(
    base_ch=512,
    time_emb_dim=512,
    dropout=0.1,
)
diffusion = FreqPPG2ECGDiffusion(diff_eps_net, timesteps=50).to(device)

optimizer_diff = torch.optim.AdamW(diffusion.parameters(), lr=1e-4, weight_decay=1e-2)
BEST_DIFF = inf
EPOCHS_DIFF = 50   # 可以自己调

def train_one_epoch_diffusion(model, loader, optimizer, device, epoch):
    model.train()
    total_diff = 0.0
    steps = 0
    pbar = tqdm(loader, desc=f"[Diff] Train epoch {epoch}", ncols=120)
    for ppg_b, ecg_b, abp_b, freqs_b, mask_b, paths_b in pbar:
        ppg_b = ppg_b.to(device)
        ecg_b = ecg_b.to(device)

        optimizer.zero_grad()
        diff_l = model(ecg_b, ppg_b)
        diff_l.backward()
        optimizer.step()

        total_diff += diff_l.item()
        steps += 1
        pbar.set_postfix({"diff": f"{total_diff/steps:.4f}"})
    return total_diff / max(1, steps)


@torch.no_grad()
def eval_one_epoch_diffusion(model, loader, device, epoch):
    model.eval()
    total_diff = 0.0
    steps = 0
    pbar = tqdm(loader, desc=f"[Diff] Val epoch {epoch}", ncols=120)
    for ppg_b, ecg_b, abp_b, freqs_b, mask_b, paths_b in pbar:
        ppg_b = ppg_b.to(device)
        ecg_b = ecg_b.to(device)

        diff_l = model(ecg_b, ppg_b)
        total_diff += diff_l.item()
        steps += 1
        pbar.set_postfix({"diff": f"{total_diff/steps:.4f}"})
    return total_diff / max(1, steps)


for epoch in range(1, EPOCHS_DIFF + 1):
    print(f"\n===== Diffusion Epoch {epoch} =====")
    train_diff = train_one_epoch_diffusion(diffusion, train_loader, optimizer_diff, device, epoch)
    val_diff = eval_one_epoch_diffusion(diffusion, val_loader, device, epoch)
    print(f"Epoch {epoch}: train_diff={train_diff:.4f}, val_diff={val_diff:.4f}")

    if val_diff < BEST_DIFF:
        BEST_DIFF = val_diff
        torch.save(diffusion.state_dict(), "diffusion_ppg2ecg.pt")
        print(f"  ↳ Saved best diffusion (val_diff = {BEST_DIFF:.4f})")



===== Diffusion Epoch 1 =====


[Diff] Train epoch 1: 100%|████████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.00it/s, diff=0.4590]
[Diff] Val epoch 1: 100%|███████████████████████████████████████████████| 782/782 [00:05<00:00, 148.10it/s, diff=0.3293]


Epoch 1: train_diff=0.4590, val_diff=0.3293
  ↳ Saved best diffusion (val_diff = 0.3293)

===== Diffusion Epoch 2 =====


[Diff] Train epoch 2: 100%|████████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.32it/s, diff=0.3936]
[Diff] Val epoch 2: 100%|███████████████████████████████████████████████| 782/782 [00:05<00:00, 145.45it/s, diff=0.3122]


Epoch 2: train_diff=0.3936, val_diff=0.3122
  ↳ Saved best diffusion (val_diff = 0.3122)

===== Diffusion Epoch 3 =====


[Diff] Train epoch 3: 100%|████████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.81it/s, diff=0.3765]
[Diff] Val epoch 3: 100%|███████████████████████████████████████████████| 782/782 [00:05<00:00, 147.77it/s, diff=0.3045]


Epoch 3: train_diff=0.3765, val_diff=0.3045
  ↳ Saved best diffusion (val_diff = 0.3045)

===== Diffusion Epoch 4 =====


[Diff] Train epoch 4: 100%|████████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.93it/s, diff=0.3666]
[Diff] Val epoch 4: 100%|███████████████████████████████████████████████| 782/782 [00:05<00:00, 149.36it/s, diff=0.2963]


Epoch 4: train_diff=0.3666, val_diff=0.2963
  ↳ Saved best diffusion (val_diff = 0.2963)

===== Diffusion Epoch 5 =====


[Diff] Train epoch 5: 100%|████████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.42it/s, diff=0.3602]
[Diff] Val epoch 5: 100%|███████████████████████████████████████████████| 782/782 [00:05<00:00, 147.55it/s, diff=0.2937]


Epoch 5: train_diff=0.3602, val_diff=0.2937
  ↳ Saved best diffusion (val_diff = 0.2937)

===== Diffusion Epoch 6 =====


[Diff] Train epoch 6: 100%|████████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.09it/s, diff=0.3554]
[Diff] Val epoch 6: 100%|███████████████████████████████████████████████| 782/782 [00:05<00:00, 147.07it/s, diff=0.2899]


Epoch 6: train_diff=0.3554, val_diff=0.2899
  ↳ Saved best diffusion (val_diff = 0.2899)

===== Diffusion Epoch 7 =====


[Diff] Train epoch 7: 100%|████████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.29it/s, diff=0.3519]
[Diff] Val epoch 7: 100%|███████████████████████████████████████████████| 782/782 [00:05<00:00, 147.21it/s, diff=0.2871]


Epoch 7: train_diff=0.3519, val_diff=0.2871
  ↳ Saved best diffusion (val_diff = 0.2871)

===== Diffusion Epoch 8 =====


[Diff] Train epoch 8: 100%|████████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.00it/s, diff=0.3480]
[Diff] Val epoch 8: 100%|███████████████████████████████████████████████| 782/782 [00:05<00:00, 146.53it/s, diff=0.2858]


Epoch 8: train_diff=0.3480, val_diff=0.2858
  ↳ Saved best diffusion (val_diff = 0.2858)

===== Diffusion Epoch 9 =====


[Diff] Train epoch 9: 100%|████████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.41it/s, diff=0.3451]
[Diff] Val epoch 9: 100%|███████████████████████████████████████████████| 782/782 [00:05<00:00, 147.16it/s, diff=0.2822]


Epoch 9: train_diff=0.3451, val_diff=0.2822
  ↳ Saved best diffusion (val_diff = 0.2822)

===== Diffusion Epoch 10 =====


[Diff] Train epoch 10: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.21it/s, diff=0.3416]
[Diff] Val epoch 10: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 148.50it/s, diff=0.2813]


Epoch 10: train_diff=0.3416, val_diff=0.2813
  ↳ Saved best diffusion (val_diff = 0.2813)

===== Diffusion Epoch 11 =====


[Diff] Train epoch 11: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.20it/s, diff=0.3396]
[Diff] Val epoch 11: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 149.00it/s, diff=0.2804]


Epoch 11: train_diff=0.3396, val_diff=0.2804
  ↳ Saved best diffusion (val_diff = 0.2804)

===== Diffusion Epoch 12 =====


[Diff] Train epoch 12: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.38it/s, diff=0.3375]
[Diff] Val epoch 12: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 149.00it/s, diff=0.2796]


Epoch 12: train_diff=0.3375, val_diff=0.2796
  ↳ Saved best diffusion (val_diff = 0.2796)

===== Diffusion Epoch 13 =====


[Diff] Train epoch 13: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.46it/s, diff=0.3357]
[Diff] Val epoch 13: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 148.42it/s, diff=0.2772]


Epoch 13: train_diff=0.3357, val_diff=0.2772
  ↳ Saved best diffusion (val_diff = 0.2772)

===== Diffusion Epoch 14 =====


[Diff] Train epoch 14: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.86it/s, diff=0.3340]
[Diff] Val epoch 14: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 146.87it/s, diff=0.2762]


Epoch 14: train_diff=0.3340, val_diff=0.2762
  ↳ Saved best diffusion (val_diff = 0.2762)

===== Diffusion Epoch 15 =====


[Diff] Train epoch 15: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.23it/s, diff=0.3317]
[Diff] Val epoch 15: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 146.58it/s, diff=0.2725]


Epoch 15: train_diff=0.3317, val_diff=0.2725
  ↳ Saved best diffusion (val_diff = 0.2725)

===== Diffusion Epoch 16 =====


[Diff] Train epoch 16: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.12it/s, diff=0.3306]
[Diff] Val epoch 16: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 146.29it/s, diff=0.2721]


Epoch 16: train_diff=0.3306, val_diff=0.2721
  ↳ Saved best diffusion (val_diff = 0.2721)

===== Diffusion Epoch 17 =====


[Diff] Train epoch 17: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.17it/s, diff=0.3290]
[Diff] Val epoch 17: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 145.71it/s, diff=0.2716]


Epoch 17: train_diff=0.3290, val_diff=0.2716
  ↳ Saved best diffusion (val_diff = 0.2716)

===== Diffusion Epoch 18 =====


[Diff] Train epoch 18: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.21it/s, diff=0.3275]
[Diff] Val epoch 18: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 148.56it/s, diff=0.2691]


Epoch 18: train_diff=0.3275, val_diff=0.2691
  ↳ Saved best diffusion (val_diff = 0.2691)

===== Diffusion Epoch 19 =====


[Diff] Train epoch 19: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.31it/s, diff=0.3262]
[Diff] Val epoch 19: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 148.00it/s, diff=0.2695]


Epoch 19: train_diff=0.3262, val_diff=0.2695

===== Diffusion Epoch 20 =====


[Diff] Train epoch 20: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.13it/s, diff=0.3251]
[Diff] Val epoch 20: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 148.07it/s, diff=0.2674]


Epoch 20: train_diff=0.3251, val_diff=0.2674
  ↳ Saved best diffusion (val_diff = 0.2674)

===== Diffusion Epoch 21 =====


[Diff] Train epoch 21: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.93it/s, diff=0.3234]
[Diff] Val epoch 21: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 146.94it/s, diff=0.2674]


Epoch 21: train_diff=0.3234, val_diff=0.2674
  ↳ Saved best diffusion (val_diff = 0.2674)

===== Diffusion Epoch 22 =====


[Diff] Train epoch 22: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.17it/s, diff=0.3226]
[Diff] Val epoch 22: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 145.05it/s, diff=0.2677]


Epoch 22: train_diff=0.3226, val_diff=0.2677

===== Diffusion Epoch 23 =====


[Diff] Train epoch 23: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.41it/s, diff=0.3216]
[Diff] Val epoch 23: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 146.06it/s, diff=0.2671]


Epoch 23: train_diff=0.3216, val_diff=0.2671
  ↳ Saved best diffusion (val_diff = 0.2671)

===== Diffusion Epoch 24 =====


[Diff] Train epoch 24: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.46it/s, diff=0.3209]
[Diff] Val epoch 24: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 147.61it/s, diff=0.2658]


Epoch 24: train_diff=0.3209, val_diff=0.2658
  ↳ Saved best diffusion (val_diff = 0.2658)

===== Diffusion Epoch 25 =====


[Diff] Train epoch 25: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.88it/s, diff=0.3205]
[Diff] Val epoch 25: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 147.09it/s, diff=0.2639]


Epoch 25: train_diff=0.3205, val_diff=0.2639
  ↳ Saved best diffusion (val_diff = 0.2639)

===== Diffusion Epoch 26 =====


[Diff] Train epoch 26: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.14it/s, diff=0.3197]
[Diff] Val epoch 26: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 148.21it/s, diff=0.2648]


Epoch 26: train_diff=0.3197, val_diff=0.2648

===== Diffusion Epoch 27 =====


[Diff] Train epoch 27: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.29it/s, diff=0.3185]
[Diff] Val epoch 27: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 147.48it/s, diff=0.2635]


Epoch 27: train_diff=0.3185, val_diff=0.2635
  ↳ Saved best diffusion (val_diff = 0.2635)

===== Diffusion Epoch 28 =====


[Diff] Train epoch 28: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.19it/s, diff=0.3184]
[Diff] Val epoch 28: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 146.77it/s, diff=0.2633]


Epoch 28: train_diff=0.3184, val_diff=0.2633
  ↳ Saved best diffusion (val_diff = 0.2633)

===== Diffusion Epoch 29 =====


[Diff] Train epoch 29: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.11it/s, diff=0.3174]
[Diff] Val epoch 29: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 148.22it/s, diff=0.2638]


Epoch 29: train_diff=0.3174, val_diff=0.2638

===== Diffusion Epoch 30 =====


[Diff] Train epoch 30: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.40it/s, diff=0.3166]
[Diff] Val epoch 30: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 147.43it/s, diff=0.2634]


Epoch 30: train_diff=0.3166, val_diff=0.2634

===== Diffusion Epoch 31 =====


[Diff] Train epoch 31: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.37it/s, diff=0.3163]
[Diff] Val epoch 31: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 146.72it/s, diff=0.2602]


Epoch 31: train_diff=0.3163, val_diff=0.2602
  ↳ Saved best diffusion (val_diff = 0.2602)

===== Diffusion Epoch 32 =====


[Diff] Train epoch 32: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.36it/s, diff=0.3153]
[Diff] Val epoch 32: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 146.27it/s, diff=0.2626]


Epoch 32: train_diff=0.3153, val_diff=0.2626

===== Diffusion Epoch 33 =====


[Diff] Train epoch 33: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.30it/s, diff=0.3150]
[Diff] Val epoch 33: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 147.59it/s, diff=0.2610]


Epoch 33: train_diff=0.3150, val_diff=0.2610

===== Diffusion Epoch 34 =====


[Diff] Train epoch 34: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 59.32it/s, diff=0.3142]
[Diff] Val epoch 34: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 148.06it/s, diff=0.2613]


Epoch 34: train_diff=0.3142, val_diff=0.2613

===== Diffusion Epoch 35 =====


[Diff] Train epoch 35: 100%|███████████████████████████████████████████| 6250/6250 [01:45<00:00, 58.99it/s, diff=0.3145]
[Diff] Val epoch 35: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 146.45it/s, diff=0.2606]


Epoch 35: train_diff=0.3145, val_diff=0.2606

===== Diffusion Epoch 36 =====


[Diff] Train epoch 36: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.84it/s, diff=0.3131]
[Diff] Val epoch 36: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 147.00it/s, diff=0.2606]


Epoch 36: train_diff=0.3131, val_diff=0.2606

===== Diffusion Epoch 37 =====


[Diff] Train epoch 37: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.73it/s, diff=0.3131]
[Diff] Val epoch 37: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 145.91it/s, diff=0.2617]


Epoch 37: train_diff=0.3131, val_diff=0.2617

===== Diffusion Epoch 38 =====


[Diff] Train epoch 38: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.92it/s, diff=0.3119]
[Diff] Val epoch 38: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 144.35it/s, diff=0.2601]


Epoch 38: train_diff=0.3119, val_diff=0.2601
  ↳ Saved best diffusion (val_diff = 0.2601)

===== Diffusion Epoch 39 =====


[Diff] Train epoch 39: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.80it/s, diff=0.3124]
[Diff] Val epoch 39: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 146.31it/s, diff=0.2605]


Epoch 39: train_diff=0.3124, val_diff=0.2605

===== Diffusion Epoch 40 =====


[Diff] Train epoch 40: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.81it/s, diff=0.3120]
[Diff] Val epoch 40: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 146.78it/s, diff=0.2592]


Epoch 40: train_diff=0.3120, val_diff=0.2592
  ↳ Saved best diffusion (val_diff = 0.2592)

===== Diffusion Epoch 41 =====


[Diff] Train epoch 41: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.78it/s, diff=0.3111]
[Diff] Val epoch 41: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 147.51it/s, diff=0.2585]


Epoch 41: train_diff=0.3111, val_diff=0.2585
  ↳ Saved best diffusion (val_diff = 0.2585)

===== Diffusion Epoch 42 =====


[Diff] Train epoch 42: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.71it/s, diff=0.3110]
[Diff] Val epoch 42: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 144.85it/s, diff=0.2570]


Epoch 42: train_diff=0.3110, val_diff=0.2570
  ↳ Saved best diffusion (val_diff = 0.2570)

===== Diffusion Epoch 43 =====


[Diff] Train epoch 43: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.91it/s, diff=0.3107]
[Diff] Val epoch 43: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 145.50it/s, diff=0.2602]


Epoch 43: train_diff=0.3107, val_diff=0.2602

===== Diffusion Epoch 44 =====


[Diff] Train epoch 44: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.63it/s, diff=0.3101]
[Diff] Val epoch 44: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 144.96it/s, diff=0.2592]


Epoch 44: train_diff=0.3101, val_diff=0.2592

===== Diffusion Epoch 45 =====


[Diff] Train epoch 45: 100%|███████████████████████████████████████████| 6250/6250 [01:48<00:00, 57.73it/s, diff=0.3100]
[Diff] Val epoch 45: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 147.29it/s, diff=0.2573]


Epoch 45: train_diff=0.3100, val_diff=0.2573

===== Diffusion Epoch 46 =====


[Diff] Train epoch 46: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.62it/s, diff=0.3095]
[Diff] Val epoch 46: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 147.09it/s, diff=0.2579]


Epoch 46: train_diff=0.3095, val_diff=0.2579

===== Diffusion Epoch 47 =====


[Diff] Train epoch 47: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.93it/s, diff=0.3087]
[Diff] Val epoch 47: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 146.03it/s, diff=0.2584]


Epoch 47: train_diff=0.3087, val_diff=0.2584

===== Diffusion Epoch 48 =====


[Diff] Train epoch 48: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.78it/s, diff=0.3088]
[Diff] Val epoch 48: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 146.05it/s, diff=0.2562]


Epoch 48: train_diff=0.3088, val_diff=0.2562
  ↳ Saved best diffusion (val_diff = 0.2562)

===== Diffusion Epoch 49 =====


[Diff] Train epoch 49: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.70it/s, diff=0.3085]
[Diff] Val epoch 49: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 146.76it/s, diff=0.2566]


Epoch 49: train_diff=0.3085, val_diff=0.2566

===== Diffusion Epoch 50 =====


[Diff] Train epoch 50: 100%|███████████████████████████████████████████| 6250/6250 [01:46<00:00, 58.58it/s, diff=0.3083]
[Diff] Val epoch 50: 100%|██████████████████████████████████████████████| 782/782 [00:05<00:00, 145.90it/s, diff=0.2585]

Epoch 50: train_diff=0.3083, val_diff=0.2585





In [None]:
# ==== Block C-1: 加载冻结 diffusion + 定义 2-stage CNN 模型 ====

# 1) 先构建同样结构的 diffusion，并加载 Block B 训练好的权重
diff_eps_net_stage2 = FreqConditionalEpsNet1D(
    base_ch=512,
    time_emb_dim=512,
    dropout=0.1,
)
diffusion_stage2 = FreqPPG2ECGDiffusion(diff_eps_net_stage2, timesteps=50).to(device)
diffusion_stage2.load_state_dict(torch.load("diffusion_ppg2ecg.pt", map_location=device))

# 冻结参数
for p in diffusion_stage2.parameters():
    p.requires_grad = False
diffusion_stage2.eval()   # 训练 CNN 时就一直 eval

# 2) 定义 CNN：输入 4 通道 (PPG+生成 ECG)，输出 ABP 频谱
two_stage_cnn = BigFreqBPModel(
    in_channels=4,    # [PPG_real, PPG_imag, ECG_gen_real, ECG_gen_imag]
    hidden_dim=256,
    num_blocks=8,
    out_channels=2,
    kernel_size=5,
    dropout=0.1,
).to(device)

optimizer_two_stage = torch.optim.AdamW(two_stage_cnn.parameters(), lr=1e-4, weight_decay=1e-2)
BEST_2STAGE = inf
EPOCHS_2STAGE = 40   # 看时间/效果自己调


# ==== Block C-2: 训练 2-stage 模型（冻结 diffusion，只训 CNN） ====

def train_one_epoch_two_stage(diffusion, cnn, loader, optimizer, device, epoch):
    diffusion.eval()   # 保持冻结
    cnn.train()

    total_bp_f = 0.0
    total_bp_t = 0.0
    steps = 0

    pbar = tqdm(loader, desc=f"[2-Stage] Train epoch {epoch}", ncols=120)
    for ppg_b, ecg_b, abp_b, freqs_b, mask_b, paths_b in pbar:
        ppg_b = ppg_b.to(device)
        abp_b = abp_b.to(device)
        freqs_b = freqs_b.to(device)
        mask_b = mask_b.to(device)

        # 1) 用 frozen diffusion 生成 ECG 频谱（不需要梯度）
        with torch.no_grad():
            ecg_gen = diffusion.sample_ecg_spec(ppg_b)  # [B,2,L]

        # 2) 拼出 CNN 的输入：PPG + 生成 ECG
        x_spec = torch.cat([ppg_b, ecg_gen], dim=1)  # [B,4,L]
        x = x_spec.transpose(1, 2)                  # [B,L,4]

        optimizer.zero_grad()
        abp_hat = cnn(x)                  # [B,L,2]
        abp_hat = abp_hat.transpose(1, 2) # [B,2,L]

        bp_freq = masked_l1_loss(abp_hat, abp_b, mask_b)
        bp_time = batch_time_mae_from_spec(freqs_b, abp_hat, abp_b, mask_b)

        loss = bp_freq   # 这里只用频域 MAE 来优化；也可以加一部分 time MAE

        loss.backward()
        optimizer.step()

        total_bp_f += bp_freq.item()
        total_bp_t += bp_time
        steps += 1
        pbar.set_postfix({
            "bp_f": f"{total_bp_f/steps:.3f}",
            "bp_t": f"{total_bp_t/steps:.3f}",
        })

    return (total_bp_f / max(1, steps),
            total_bp_t / max(1, steps))


@torch.no_grad()
def eval_one_epoch_two_stage(diffusion, cnn, loader, device, epoch):
    diffusion.eval()
    cnn.eval()

    total_bp_f = 0.0
    total_bp_t = 0.0
    steps = 0

    pbar = tqdm(loader, desc=f"[2-Stage] Val epoch {epoch}", ncols=120)
    for ppg_b, ecg_b, abp_b, freqs_b, mask_b, paths_b in pbar:
        ppg_b = ppg_b.to(device)
        abp_b = abp_b.to(device)
        freqs_b = freqs_b.to(device)
        mask_b = mask_b.to(device)

        ecg_gen = diffusion.sample_ecg_spec(ppg_b)  # [B,2,L]

        x_spec = torch.cat([ppg_b, ecg_gen], dim=1)  # [B,4,L]
        x = x_spec.transpose(1, 2)                   # [B,L,4]

        abp_hat = cnn(x)                  # [B,L,2]
        abp_hat = abp_hat.transpose(1, 2) # [B,2,L]

        bp_freq = masked_l1_loss(abp_hat, abp_b, mask_b)
        bp_time = batch_time_mae_from_spec(freqs_b, abp_hat, abp_b, mask_b)

        total_bp_f += bp_freq
        total_bp_t += bp_time
        steps += 1

        pbar.set_postfix({
            "bp_f": f"{total_bp_f/steps:.3f}",
            "bp_t": f"{total_bp_t/steps:.3f}",
        })

    return (total_bp_f / max(1, steps),
            total_bp_t / max(1, steps))


# ==== Block C-3: 主训练循环（2-stage） ====

for epoch in range(1, EPOCHS_2STAGE + 1):
    print(f"\n===== 2-Stage CNN Epoch {epoch} =====")
    train_bp_f, train_bp_t = train_one_epoch_two_stage(
        diffusion_stage2, two_stage_cnn, train_loader, optimizer_two_stage, device, epoch
    )
    val_bp_f, val_bp_t = eval_one_epoch_two_stage(
        diffusion_stage2, two_stage_cnn, val_loader, device, epoch
    )

    print(f"2-Stage Epoch {epoch}: "
          f"train_bp_freq={train_bp_f:.4f}, train_bp_time={train_bp_t:.4f}, "
          f"val_bp_freq={val_bp_f:.4f}, val_bp_time={val_bp_t:.4f}")

    if val_bp_t < BEST_2STAGE:
        BEST_2STAGE = val_bp_t
        torch.save(two_stage_cnn.state_dict(), "two_stage_ppg_ecg_gen_cnn.pt")
        print(f"  ↳ Saved best 2-stage CNN (val_bp_time = {BEST_2STAGE:.4f})")



===== 2-Stage CNN Epoch 1 =====


[2-Stage] Train epoch 1: 100%|███████████████████████████| 6250/6250 [21:52<00:00,  4.76it/s, bp_f=984.277, bp_t=90.349]
[2-Stage] Val epoch 1: 100%|███████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=926.752, bp_t=93.272]


2-Stage Epoch 1: train_bp_freq=984.2772, train_bp_time=90.3487, val_bp_freq=926.7523, val_bp_time=93.2716
  ↳ Saved best 2-stage CNN (val_bp_time = 93.2716)

===== 2-Stage CNN Epoch 2 =====


[2-Stage] Train epoch 2: 100%|███████████████████████████| 6250/6250 [21:52<00:00,  4.76it/s, bp_f=911.065, bp_t=88.123]
[2-Stage] Val epoch 2: 100%|███████████████████████████████| 782/782 [02:39<00:00,  4.92it/s, bp_f=858.508, bp_t=89.759]


2-Stage Epoch 2: train_bp_freq=911.0653, train_bp_time=88.1228, val_bp_freq=858.5085, val_bp_time=89.7590
  ↳ Saved best 2-stage CNN (val_bp_time = 89.7590)

===== 2-Stage CNN Epoch 3 =====


[2-Stage] Train epoch 3: 100%|███████████████████████████| 6250/6250 [21:52<00:00,  4.76it/s, bp_f=851.362, bp_t=83.747]
[2-Stage] Val epoch 3: 100%|███████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=801.740, bp_t=84.206]


2-Stage Epoch 3: train_bp_freq=851.3624, train_bp_time=83.7466, val_bp_freq=801.7403, val_bp_time=84.2061
  ↳ Saved best 2-stage CNN (val_bp_time = 84.2061)

===== 2-Stage CNN Epoch 4 =====


[2-Stage] Train epoch 4: 100%|███████████████████████████| 6250/6250 [21:52<00:00,  4.76it/s, bp_f=800.070, bp_t=77.447]
[2-Stage] Val epoch 4: 100%|███████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=746.965, bp_t=76.927]


2-Stage Epoch 4: train_bp_freq=800.0702, train_bp_time=77.4468, val_bp_freq=746.9650, val_bp_time=76.9272
  ↳ Saved best 2-stage CNN (val_bp_time = 76.9272)

===== 2-Stage CNN Epoch 5 =====


[2-Stage] Train epoch 5: 100%|███████████████████████████| 6250/6250 [21:51<00:00,  4.76it/s, bp_f=746.122, bp_t=69.667]
[2-Stage] Val epoch 5: 100%|███████████████████████████████| 782/782 [02:39<00:00,  4.92it/s, bp_f=690.309, bp_t=69.156]


2-Stage Epoch 5: train_bp_freq=746.1219, train_bp_time=69.6667, val_bp_freq=690.3088, val_bp_time=69.1555
  ↳ Saved best 2-stage CNN (val_bp_time = 69.1555)

===== 2-Stage CNN Epoch 6 =====


[2-Stage] Train epoch 6: 100%|███████████████████████████| 6250/6250 [21:52<00:00,  4.76it/s, bp_f=685.761, bp_t=60.566]
[2-Stage] Val epoch 6: 100%|███████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=620.337, bp_t=58.593]


2-Stage Epoch 6: train_bp_freq=685.7610, train_bp_time=60.5656, val_bp_freq=620.3366, val_bp_time=58.5933
  ↳ Saved best 2-stage CNN (val_bp_time = 58.5933)

===== 2-Stage CNN Epoch 7 =====


[2-Stage] Train epoch 7: 100%|███████████████████████████| 6250/6250 [21:52<00:00,  4.76it/s, bp_f=618.338, bp_t=50.208]
[2-Stage] Val epoch 7: 100%|███████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=548.793, bp_t=46.936]


2-Stage Epoch 7: train_bp_freq=618.3378, train_bp_time=50.2078, val_bp_freq=548.7933, val_bp_time=46.9358
  ↳ Saved best 2-stage CNN (val_bp_time = 46.9358)

===== 2-Stage CNN Epoch 8 =====


[2-Stage] Train epoch 8: 100%|███████████████████████████| 6250/6250 [21:53<00:00,  4.76it/s, bp_f=543.861, bp_t=38.858]
[2-Stage] Val epoch 8: 100%|███████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=466.235, bp_t=34.285]


2-Stage Epoch 8: train_bp_freq=543.8608, train_bp_time=38.8576, val_bp_freq=466.2347, val_bp_time=34.2846
  ↳ Saved best 2-stage CNN (val_bp_time = 34.2846)

===== 2-Stage CNN Epoch 9 =====


[2-Stage] Train epoch 9: 100%|███████████████████████████| 6250/6250 [21:52<00:00,  4.76it/s, bp_f=469.114, bp_t=27.935]
[2-Stage] Val epoch 9: 100%|███████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=364.790, bp_t=19.325]


2-Stage Epoch 9: train_bp_freq=469.1139, train_bp_time=27.9353, val_bp_freq=364.7899, val_bp_time=19.3249
  ↳ Saved best 2-stage CNN (val_bp_time = 19.3249)

===== 2-Stage CNN Epoch 10 =====


[2-Stage] Train epoch 10: 100%|██████████████████████████| 6250/6250 [21:54<00:00,  4.76it/s, bp_f=414.486, bp_t=20.482]
[2-Stage] Val epoch 10: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.90it/s, bp_f=317.205, bp_t=13.162]


2-Stage Epoch 10: train_bp_freq=414.4856, train_bp_time=20.4817, val_bp_freq=317.2050, val_bp_time=13.1622
  ↳ Saved best 2-stage CNN (val_bp_time = 13.1622)

===== 2-Stage CNN Epoch 11 =====


[2-Stage] Train epoch 11: 100%|██████████████████████████| 6250/6250 [21:52<00:00,  4.76it/s, bp_f=397.316, bp_t=18.459]
[2-Stage] Val epoch 11: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=313.630, bp_t=12.640]


2-Stage Epoch 11: train_bp_freq=397.3162, train_bp_time=18.4587, val_bp_freq=313.6296, val_bp_time=12.6395
  ↳ Saved best 2-stage CNN (val_bp_time = 12.6395)

===== 2-Stage CNN Epoch 12 =====


[2-Stage] Train epoch 12: 100%|██████████████████████████| 6250/6250 [21:53<00:00,  4.76it/s, bp_f=393.878, bp_t=18.194]
[2-Stage] Val epoch 12: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=309.829, bp_t=12.463]


2-Stage Epoch 12: train_bp_freq=393.8781, train_bp_time=18.1938, val_bp_freq=309.8287, val_bp_time=12.4628
  ↳ Saved best 2-stage CNN (val_bp_time = 12.4628)

===== 2-Stage CNN Epoch 13 =====


[2-Stage] Train epoch 13: 100%|██████████████████████████| 6250/6250 [21:53<00:00,  4.76it/s, bp_f=391.532, bp_t=18.036]
[2-Stage] Val epoch 13: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=311.041, bp_t=12.758]


2-Stage Epoch 13: train_bp_freq=391.5319, train_bp_time=18.0355, val_bp_freq=311.0407, val_bp_time=12.7584

===== 2-Stage CNN Epoch 14 =====


[2-Stage] Train epoch 14: 100%|██████████████████████████| 6250/6250 [21:53<00:00,  4.76it/s, bp_f=389.487, bp_t=17.900]
[2-Stage] Val epoch 14: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.90it/s, bp_f=302.573, bp_t=12.219]


2-Stage Epoch 14: train_bp_freq=389.4871, train_bp_time=17.8997, val_bp_freq=302.5732, val_bp_time=12.2190
  ↳ Saved best 2-stage CNN (val_bp_time = 12.2190)

===== 2-Stage CNN Epoch 15 =====


[2-Stage] Train epoch 15: 100%|██████████████████████████| 6250/6250 [21:52<00:00,  4.76it/s, bp_f=387.677, bp_t=17.784]
[2-Stage] Val epoch 15: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.90it/s, bp_f=301.623, bp_t=12.065]


2-Stage Epoch 15: train_bp_freq=387.6769, train_bp_time=17.7840, val_bp_freq=301.6232, val_bp_time=12.0653
  ↳ Saved best 2-stage CNN (val_bp_time = 12.0653)

===== 2-Stage CNN Epoch 16 =====


[2-Stage] Train epoch 16: 100%|██████████████████████████| 6250/6250 [21:52<00:00,  4.76it/s, bp_f=385.957, bp_t=17.686]
[2-Stage] Val epoch 16: 100%|██████████████████████████████| 782/782 [02:38<00:00,  4.92it/s, bp_f=306.774, bp_t=12.545]


2-Stage Epoch 16: train_bp_freq=385.9572, train_bp_time=17.6855, val_bp_freq=306.7737, val_bp_time=12.5454

===== 2-Stage CNN Epoch 17 =====


[2-Stage] Train epoch 17: 100%|██████████████████████████| 6250/6250 [21:54<00:00,  4.76it/s, bp_f=384.428, bp_t=17.589]
[2-Stage] Val epoch 17: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=302.829, bp_t=12.233]


2-Stage Epoch 17: train_bp_freq=384.4280, train_bp_time=17.5889, val_bp_freq=302.8291, val_bp_time=12.2329

===== 2-Stage CNN Epoch 18 =====


[2-Stage] Train epoch 18: 100%|██████████████████████████| 6250/6250 [21:53<00:00,  4.76it/s, bp_f=383.159, bp_t=17.515]
[2-Stage] Val epoch 18: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=297.627, bp_t=11.798]


2-Stage Epoch 18: train_bp_freq=383.1589, train_bp_time=17.5150, val_bp_freq=297.6271, val_bp_time=11.7975
  ↳ Saved best 2-stage CNN (val_bp_time = 11.7975)

===== 2-Stage CNN Epoch 19 =====


[2-Stage] Train epoch 19: 100%|██████████████████████████| 6250/6250 [21:52<00:00,  4.76it/s, bp_f=381.781, bp_t=17.436]
[2-Stage] Val epoch 19: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=298.252, bp_t=11.974]


2-Stage Epoch 19: train_bp_freq=381.7808, train_bp_time=17.4363, val_bp_freq=298.2516, val_bp_time=11.9744

===== 2-Stage CNN Epoch 20 =====


[2-Stage] Train epoch 20: 100%|██████████████████████████| 6250/6250 [21:52<00:00,  4.76it/s, bp_f=380.727, bp_t=17.383]
[2-Stage] Val epoch 20: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.92it/s, bp_f=298.693, bp_t=11.956]


2-Stage Epoch 20: train_bp_freq=380.7272, train_bp_time=17.3829, val_bp_freq=298.6927, val_bp_time=11.9556

===== 2-Stage CNN Epoch 21 =====


[2-Stage] Train epoch 21: 100%|██████████████████████████| 6250/6250 [21:52<00:00,  4.76it/s, bp_f=379.513, bp_t=17.302]
[2-Stage] Val epoch 21: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=302.160, bp_t=12.122]


2-Stage Epoch 21: train_bp_freq=379.5131, train_bp_time=17.3023, val_bp_freq=302.1604, val_bp_time=12.1224

===== 2-Stage CNN Epoch 22 =====


[2-Stage] Train epoch 22: 100%|██████████████████████████| 6250/6250 [21:51<00:00,  4.77it/s, bp_f=378.372, bp_t=17.241]
[2-Stage] Val epoch 22: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=300.228, bp_t=12.245]


2-Stage Epoch 22: train_bp_freq=378.3717, train_bp_time=17.2413, val_bp_freq=300.2276, val_bp_time=12.2455

===== 2-Stage CNN Epoch 23 =====


[2-Stage] Train epoch 23: 100%|██████████████████████████| 6250/6250 [22:16<00:00,  4.68it/s, bp_f=377.333, bp_t=17.184]
[2-Stage] Val epoch 23: 100%|██████████████████████████████| 782/782 [02:48<00:00,  4.65it/s, bp_f=295.164, bp_t=11.923]


2-Stage Epoch 23: train_bp_freq=377.3331, train_bp_time=17.1840, val_bp_freq=295.1638, val_bp_time=11.9231

===== 2-Stage CNN Epoch 24 =====


[2-Stage] Train epoch 24: 100%|██████████████████████████| 6250/6250 [22:43<00:00,  4.59it/s, bp_f=376.435, bp_t=17.140]
[2-Stage] Val epoch 24: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=295.083, bp_t=11.805]


2-Stage Epoch 24: train_bp_freq=376.4352, train_bp_time=17.1398, val_bp_freq=295.0834, val_bp_time=11.8051

===== 2-Stage CNN Epoch 25 =====


[2-Stage] Train epoch 25: 100%|██████████████████████████| 6250/6250 [21:52<00:00,  4.76it/s, bp_f=375.573, bp_t=17.088]
[2-Stage] Val epoch 25: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=296.170, bp_t=11.691]


2-Stage Epoch 25: train_bp_freq=375.5728, train_bp_time=17.0878, val_bp_freq=296.1698, val_bp_time=11.6905
  ↳ Saved best 2-stage CNN (val_bp_time = 11.6905)

===== 2-Stage CNN Epoch 26 =====


[2-Stage] Train epoch 26: 100%|██████████████████████████| 6250/6250 [21:53<00:00,  4.76it/s, bp_f=374.854, bp_t=17.042]
[2-Stage] Val epoch 26: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.90it/s, bp_f=296.384, bp_t=11.888]


2-Stage Epoch 26: train_bp_freq=374.8540, train_bp_time=17.0424, val_bp_freq=296.3840, val_bp_time=11.8878

===== 2-Stage CNN Epoch 27 =====


[2-Stage] Train epoch 27: 100%|██████████████████████████| 6250/6250 [21:54<00:00,  4.75it/s, bp_f=374.102, bp_t=16.992]
[2-Stage] Val epoch 27: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.90it/s, bp_f=294.722, bp_t=11.953]


2-Stage Epoch 27: train_bp_freq=374.1023, train_bp_time=16.9924, val_bp_freq=294.7224, val_bp_time=11.9534

===== 2-Stage CNN Epoch 28 =====


[2-Stage] Train epoch 28: 100%|██████████████████████████| 6250/6250 [21:54<00:00,  4.76it/s, bp_f=373.653, bp_t=16.961]
[2-Stage] Val epoch 28: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.90it/s, bp_f=292.400, bp_t=11.734]


2-Stage Epoch 28: train_bp_freq=373.6526, train_bp_time=16.9614, val_bp_freq=292.3996, val_bp_time=11.7340

===== 2-Stage CNN Epoch 29 =====


[2-Stage] Train epoch 29: 100%|██████████████████████████| 6250/6250 [21:54<00:00,  4.75it/s, bp_f=373.025, bp_t=16.928]
[2-Stage] Val epoch 29: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.90it/s, bp_f=291.620, bp_t=11.627]


2-Stage Epoch 29: train_bp_freq=373.0251, train_bp_time=16.9278, val_bp_freq=291.6199, val_bp_time=11.6274
  ↳ Saved best 2-stage CNN (val_bp_time = 11.6274)

===== 2-Stage CNN Epoch 30 =====


[2-Stage] Train epoch 30: 100%|██████████████████████████| 6250/6250 [21:54<00:00,  4.75it/s, bp_f=372.511, bp_t=16.892]
[2-Stage] Val epoch 30: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.90it/s, bp_f=292.174, bp_t=11.694]


2-Stage Epoch 30: train_bp_freq=372.5109, train_bp_time=16.8916, val_bp_freq=292.1736, val_bp_time=11.6935

===== 2-Stage CNN Epoch 31 =====


[2-Stage] Train epoch 31: 100%|██████████████████████████| 6250/6250 [21:54<00:00,  4.76it/s, bp_f=371.748, bp_t=16.843]
[2-Stage] Val epoch 31: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=292.835, bp_t=11.633]


2-Stage Epoch 31: train_bp_freq=371.7479, train_bp_time=16.8430, val_bp_freq=292.8347, val_bp_time=11.6327

===== 2-Stage CNN Epoch 32 =====


[2-Stage] Train epoch 32: 100%|██████████████████████████| 6250/6250 [21:55<00:00,  4.75it/s, bp_f=371.290, bp_t=16.803]
[2-Stage] Val epoch 32: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.90it/s, bp_f=291.774, bp_t=11.651]


2-Stage Epoch 32: train_bp_freq=371.2896, train_bp_time=16.8029, val_bp_freq=291.7743, val_bp_time=11.6506

===== 2-Stage CNN Epoch 33 =====


[2-Stage] Train epoch 33: 100%|██████████████████████████| 6250/6250 [21:53<00:00,  4.76it/s, bp_f=370.818, bp_t=16.776]
[2-Stage] Val epoch 33: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=290.406, bp_t=11.752]


2-Stage Epoch 33: train_bp_freq=370.8177, train_bp_time=16.7757, val_bp_freq=290.4062, val_bp_time=11.7522

===== 2-Stage CNN Epoch 34 =====


[2-Stage] Train epoch 34: 100%|██████████████████████████| 6250/6250 [21:51<00:00,  4.76it/s, bp_f=370.265, bp_t=16.737]
[2-Stage] Val epoch 34: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=289.524, bp_t=11.451]


2-Stage Epoch 34: train_bp_freq=370.2648, train_bp_time=16.7372, val_bp_freq=289.5244, val_bp_time=11.4507
  ↳ Saved best 2-stage CNN (val_bp_time = 11.4507)

===== 2-Stage CNN Epoch 35 =====


[2-Stage] Train epoch 35: 100%|██████████████████████████| 6250/6250 [21:52<00:00,  4.76it/s, bp_f=369.785, bp_t=16.708]
[2-Stage] Val epoch 35: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=290.689, bp_t=11.659]


2-Stage Epoch 35: train_bp_freq=369.7852, train_bp_time=16.7076, val_bp_freq=290.6888, val_bp_time=11.6588

===== 2-Stage CNN Epoch 36 =====


[2-Stage] Train epoch 36: 100%|██████████████████████████| 6250/6250 [21:53<00:00,  4.76it/s, bp_f=369.490, bp_t=16.691]
[2-Stage] Val epoch 36: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=292.507, bp_t=11.853]


2-Stage Epoch 36: train_bp_freq=369.4902, train_bp_time=16.6909, val_bp_freq=292.5070, val_bp_time=11.8533

===== 2-Stage CNN Epoch 37 =====


[2-Stage] Train epoch 37: 100%|██████████████████████████| 6250/6250 [21:53<00:00,  4.76it/s, bp_f=368.892, bp_t=16.651]
[2-Stage] Val epoch 37: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.90it/s, bp_f=291.586, bp_t=11.842]


2-Stage Epoch 37: train_bp_freq=368.8921, train_bp_time=16.6506, val_bp_freq=291.5860, val_bp_time=11.8420

===== 2-Stage CNN Epoch 38 =====


[2-Stage] Train epoch 38: 100%|██████████████████████████| 6250/6250 [21:53<00:00,  4.76it/s, bp_f=368.509, bp_t=16.621]
[2-Stage] Val epoch 38: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=289.912, bp_t=11.537]


2-Stage Epoch 38: train_bp_freq=368.5092, train_bp_time=16.6207, val_bp_freq=289.9122, val_bp_time=11.5367

===== 2-Stage CNN Epoch 39 =====


[2-Stage] Train epoch 39: 100%|██████████████████████████| 6250/6250 [21:53<00:00,  4.76it/s, bp_f=368.173, bp_t=16.602]
[2-Stage] Val epoch 39: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=291.765, bp_t=11.582]


2-Stage Epoch 39: train_bp_freq=368.1730, train_bp_time=16.6022, val_bp_freq=291.7650, val_bp_time=11.5819

===== 2-Stage CNN Epoch 40 =====


[2-Stage] Train epoch 40: 100%|██████████████████████████| 6250/6250 [21:53<00:00,  4.76it/s, bp_f=367.693, bp_t=16.571]
[2-Stage] Val epoch 40: 100%|██████████████████████████████| 782/782 [02:39<00:00,  4.91it/s, bp_f=291.029, bp_t=11.585]

2-Stage Epoch 40: train_bp_freq=367.6930, train_bp_time=16.5707, val_bp_freq=291.0290, val_bp_time=11.5855





In [None]:
# ==== Block D-1: 频域 → 时域 + SBP / DBP 提取工具函数 ====

import torch
import torch.nn.functional as F

@torch.no_grad()
def spec_to_time_abp(freqs, abp_spec, mask=None, eps=1e-8):
    """
    简单版：从双通道频谱 [real, imag] 恢复时域 ABP 波形。
    注意：
      - 假设 abp_spec 是 RFFT 的 one-sided 结果：[B, 2, L]
      - 如果你在 batch_time_mae_from_spec 里有更精细的 inverse / 去归一化逻辑，
        建议直接把那段 copy 到这里，替换掉这个函数内部。

    freqs   : [B, L]  (目前只用长度 L 来推回时域长度)
    abp_spec: [B, 2, L]  (real, imag)
    mask    : [B, L] / [B, T]，这里仅用于 broadcast，不改变值
    返回:
      abp_time: [B, T]，T ≈ 2*(L-1)
    """
    B, C, L = abp_spec.shape
    assert C == 2, "abp_spec 应该是 [B, 2, L] (real, imag)"

    # 如果你知道精确的原始时域长度 T，可以用它来代替 2*(L-1)
    T = 2 * (L - 1)  # RFFT 的典型关系 n_fft = 2*(L-1)

    # 组装复数频谱
    real = abp_spec[:, 0, :]  # [B, L]
    imag = abp_spec[:, 1, :]
    complex_spec = torch.complex(real, imag)  # [B, L]

    # irfft: 得到时域波形 [B, T]
    abp_time = torch.fft.irfft(complex_spec, n=T, dim=-1)

    # 如果训练时对 ABP 做过归一化（均值/方差、min-max 等），
    # 在这里需要做一次 inverse normalization。
    # 举例：
    # abp_time = abp_time * std_abp + mean_abp

    if mask is not None:
        # 将 mask 从频域长度 L 映射到时域长度 T 时，
        # 这里先作简单 nearest neighbor 的缩放（也可以直接忽略）。
        # 为了不引入额外误差，这里默认不处理，直接返回 abp_time。
        pass

    return abp_time  # [B, T]


def extract_sbp_dbp_from_wave(abp_time, mask_time=None):
    """
    从时域 ABP 波形里提取 SBP / DBP。
    简化版：直接在一个样本的有效时间段上取 max / min：
      SBP = max_t ABP(t)
      DBP = min_t ABP(t)

    abp_time : [B, T]
    mask_time: [B, T] (bool)，True 表示有效位置；如果没有就用全 True。

    返回:
      sbp: [B]
      dbp: [B]
    """
    B, T = abp_time.shape
    if mask_time is None:
        mask_time = torch.ones_like(abp_time, dtype=torch.bool, device=abp_time.device)

    # 把无效位置置成很极端的值，避免影响 max/min
    very_neg = -1e9
    very_pos = 1e9

    valid_abp_for_max = torch.where(mask_time, abp_time, torch.full_like(abp_time, very_neg))
    valid_abp_for_min = torch.where(mask_time, abp_time, torch.full_like(abp_time, very_pos))

    sbp, _ = valid_abp_for_max.max(dim=-1)  # [B]
    dbp, _ = valid_abp_for_min.min(dim=-1)  # [B]
    return sbp, dbp


@torch.no_grad()
def sbp_dbp_mae(pred_sbp, pred_dbp, gt_sbp, gt_dbp):
    """
    计算 SBP / DBP 的 MAE。
    输入都是 [B] 或 [N] 向量。
    返回:
      mae_sbp, mae_dbp (标量 float)
    """
    mae_sbp = F.l1_loss(pred_sbp, gt_sbp).item()
    mae_dbp = F.l1_loss(pred_dbp, gt_dbp).item()
    return mae_sbp, mae_dbp


# ==== Block D-2: Two-Stage 模型的 SBP / DBP Evaluation ====

@torch.no_grad()
def eval_two_stage_sbp_dbp(
    diffusion,
    cnn,
    loader,
    device,
    desc: str = "[2-Stage] SBP/DBP Eval"
):
    """
    流程：
      对每个 batch：
        1) 用 PPG 频谱通过 diffusion 生成 ECG 频谱
        2) (PPG, ECG_gen) 拼在一起送入 two_stage CNN 得到 ABP_pred 频谱
        3) 频域 ABP_pred / ABP_gt -> 时域波形 (spec_to_time_abp)
        4) 各自提取 SBP / DBP (extract_sbp_dbp_from_wave)
        5) 累积 SBP / DBP 的 MAE

    最后打印整体:
      - SBP MAE (生成路径)
      - DBP MAE (生成路径)
    """
    diffusion.eval()
    cnn.eval()

    all_pred_sbp = []
    all_pred_dbp = []
    all_gt_sbp   = []
    all_gt_dbp   = []

    pbar = tqdm(loader, desc=desc, ncols=120)
    for ppg_b, ecg_b, abp_b, freqs_b, mask_b, paths_b in pbar:
        ppg_b   = ppg_b.to(device)    # [B,2,L]
        abp_b   = abp_b.to(device)    # [B,2,L]
        freqs_b = freqs_b.to(device)  # [B,L]
        mask_b  = mask_b.to(device)   # [B,L]  (频域 mask)

        B, _, L = abp_b.shape

        # ---- 1) 用 diffusion 生成 ECG 频谱 ----
        ecg_gen = diffusion.sample_ecg_spec(ppg_b)  # [B,2,L]

        # ---- 2) Two-Stage CNN 生成 ABP 频谱 ----
        x_spec = torch.cat([ppg_b, ecg_gen], dim=1)  # [B,4,L]
        x = x_spec.transpose(1, 2)                  # [B,L,4]
        abp_hat_spec = cnn(x).transpose(1, 2)       # [B,2,L]

        # ---- 3) 频域 -> 时域 ABP 波形 ----
        abp_pred_time = spec_to_time_abp(freqs_b, abp_hat_spec, mask_b)  # [B,T]
        abp_gt_time   = spec_to_time_abp(freqs_b, abp_b,       mask_b)   # [B,T]

        # 如果你有专门的时域 mask（比如对应原始时域长度），可以在这里构造；
        # 目前先用全 True，或者简单根据 mask_b 进行插值映射（上面留了接口）。
        mask_time = None  # 先用 None，表示整段都算

        # ---- 4) 提取每个样本的 SBP / DBP ----
        pred_sbp, pred_dbp = extract_sbp_dbp_from_wave(abp_pred_time, mask_time)  # [B], [B]
        gt_sbp,   gt_dbp   = extract_sbp_dbp_from_wave(abp_gt_time,   mask_time)  # [B], [B]

        all_pred_sbp.append(pred_sbp.cpu())
        all_pred_dbp.append(pred_dbp.cpu())
        all_gt_sbp.append(gt_sbp.cpu())
        all_gt_dbp.append(gt_dbp.cpu())

        # 临时计算当前为止的 MAE，放到 tqdm 上，心里有数
        cur_pred_sbp = torch.cat(all_pred_sbp, dim=0)
        cur_pred_dbp = torch.cat(all_pred_dbp, dim=0)
        cur_gt_sbp   = torch.cat(all_gt_sbp,   dim=0)
        cur_gt_dbp   = torch.cat(all_gt_dbp,   dim=0)

        mae_sbp, mae_dbp = sbp_dbp_mae(cur_pred_sbp, cur_pred_dbp, cur_gt_sbp, cur_gt_dbp)
        pbar.set_postfix({
            "SBP_MAE": f"{mae_sbp:.3f}",
            "DBP_MAE": f"{mae_dbp:.3f}",
        })

    # ---- 5) 全局 MAE ----
    all_pred_sbp = torch.cat(all_pred_sbp, dim=0)
    all_pred_dbp = torch.cat(all_pred_dbp, dim=0)
    all_gt_sbp   = torch.cat(all_gt_sbp,   dim=0)
    all_gt_dbp   = torch.cat(all_gt_dbp,   dim=0)

    mae_sbp, mae_dbp = sbp_dbp_mae(all_pred_sbp, all_pred_dbp, all_gt_sbp, all_gt_dbp)

    print("\n========== SBP / DBP Evaluation (Two-Stage, Generated ECG) ==========")
    print(f"SBP MAE (generated path): {mae_sbp:.4f}")
    print(f"DBP MAE (generated path): {mae_dbp:.4f}")
    print("=================================================================\n")

    return mae_sbp, mae_dbp



# ==== Block D-3: 实际跑一遍（用 val_loader 或 test_loader） ====

# 假设你现在 Notebook 里已经有:
#   diffusion_stage2, two_stage_cnn, val_loader, device
#   并且已经 load 好 two_stage_cnn 最优权重:
#   two_stage_cnn.load_state_dict(torch.load("two_stage_ppg_ecg_gen_cnn.pt"))

mae_sbp, mae_dbp = eval_two_stage_sbp_dbp(
    diffusion_stage2,
    two_stage_cnn,
    val_loader,   # 你也可以换成 test_loader
    device,
    desc="[2-Stage] SBP/DBP Eval on VAL"
)



# No Use Blow

In [None]:
!pwd
#%cd /content/gdrive/MyDrive/Colab Notebooks/IDL-HW4/
!cp -r "/content/gdrive/MyDrive/Colab Notebooks/IDL-HW4" /content/
%cd /content/IDL-HW4/
!ls

In [None]:
#需要重启，colab可以重启的，重启后只是变量，import没了（用于重新import，不然直接重跑import代码是没有用的，colab会自动检测是否已经import过，有的话就不会刷新）
#，之类的的没了，gdrive还有安装的包都在
%pip install --no-deps -r requirements.txt
import os
# os.kill(os.getpid(), 9) # NOTE: This will restart the your colab Python runtime (required)!

In [None]:
#会有一些包冲突，但是只要不是project用到的就不管，有用的就重装
%pip install "tokenizers>=0.22.0,<0.24.0" \
             "huggingface-hub>=0.34.0,<1.0" \
             "torchmetrics>=1.5.0,<2.0"

Collecting tokenizers<0.24.0,>=0.22.0
  Downloading tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting huggingface-hub<1.0,>=0.34.0
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->torchmetrics<2.0,>=1.5.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->torchmetrics<2.0,>=1.5.0)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->torchmetrics<2.0,>=1.5.0)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->torchmetrics<2.0,>=1.5.0)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB

### Step 3: Obtain Data

- `NOTE`: This process will automatically download and unzip data for both `HW4P1` and `HW4P2`.  


In [None]:
!curl -L -o /content/f25-hw4-data.zip https://www.kaggle.com/api/v1/datasets/download/cmu11785/f25-11785-hw4-data
!unzip -q -o /content/f25-hw4-data.zip -d /content/IDL-HW4/hw4_data
!rm -rf /content/f25-hw4-data.zip
#!tar -xf "/content/gdrive/MyDrive/Colab Notebooks/IDL-HW4/hw4p1_data.tar" -C "/content/gdrive/MyDrive/Colab Notebooks/IDL-HW4/hw4_data/hw4p1_data"
!du -h --max-depth=2 /content/hw4_data

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 11.9G  100 11.9G    0     0  21.6M      0  0:09:25  0:09:25 --:--:-- 20.8M
du: cannot access '/content/hw4_data': No such file or directory


### Step 4: Move to Handout Directory
You must be within the handout directory for the library imports to work!

- `NOTE`: You may have to repeat running this command anytime you restart your runtime.
- `NOTE`: You can do a `pwd` to check if you are in the right directory.
- `NOTE`: The way it is setup currently, Your data directory should be one level up from your project directory. Keep this in mind when you are setting your `root` in the config file.

If everything was done correctly, You should see atleast the following files in your current working directory after running `!ls`:
```
.
├── README.md
├── requirements.txt
├── hw4lib/
├── mytorch/
├── tests/
└── hw4_data_subset/

```

In [None]:
# import os
# os.chdir('IDL-HW4')
# !ls

## PSC

### 1️⃣ **Step 1 Setting Up Your Environment on Bridges2**

❗️⚠️ For this homework, we are **providing shared Datasets and a shared Conda environment** for the entire class.

❗️⚠️ So for PSC users, **do not download the data yourself** and **do not need to manually install the packages**!


Follow these steps to set up the environment and start a Jupyter notebook on Bridges2:

To run your notebook more efficiently on PSC, we need to use a **Jupyter Server** hosted on a compute node.

You can use your prefered way of connecting to the Jupyter Server. Your options should be covered in the docs linked in post 558 @ piazza.

**The recommended way of connecting is:**

#### **Connect in VSCode**
SSH into Bridges2 and navigate to your **Jet directory** (`Jet/home/<your_psc_username>`). Upload your notebook there, and then connect to the Jupyter Server from that directory.

#### **1. SSH into Bridges2**
1）Open VS Code and click on the `Extensions` icon in the left sidebar. Make sure the "**Remote - SSH**" extension is installed.

2）Open the command palette (**Shift+Command+P** on Mac, **Ctrl+Shift+P** on Windows). A search box will appear at the top center. Choose `"Remote-SSH: Add New SSH Host"`, then enter:

```bash
ssh <your_username>@bridges2.psc.edu #change <your_username> to your username
```

Next, choose `"/Users/<your_username>/.ssh/config"` as the config file. A dialog will appear in the bottom right saying "Host Added". Click `"Connect"`, and then enter your password.

(Note: After adding the host once, you can later use `"Remote-SSH: Connect to Host"` and select "bridges2.psc.edu" from the list.)

3）Once connected, click `"Explorer"` in the left sidebar > "Open Folder", and navigate to your home directory under the project grant:
```bash
/jet/home/<your_username>  #change <your_username> to your username
```

4）You can now drag your notebook files directly into the right-hand pane (your remote home directory), or upload them using `scp` into your folder.

> ❗️⚠️ The following steps should be executed in the **VSCode integrated terminal**.

#### **2. Navigate to Your Directory**
Make sure to use this `/jet/home/<your_username>` as your working directory, since all subsequent operations (up to submission) are based on this path.
```bash
cd /jet/home/<your_username>  #change <your_username> to your username
```

#### **3. Request a Compute Node**
```bash
interact -p GPU-shared --gres=gpu:v100-32:1 -t 8:00:00 -A cis250019p
```

#### **4. Load the Anaconda Module**
```bash
module load anaconda3
```

#### **5. Activate the provided HW4 Environment**
```bash
conda deactivate # First, deactivate any existing Conda environment
######## [need to be updated] conda activate /ocean/projects/cis240101p/mzhang23/TA/HW4/envs/hw4_env && export PYTHONNOUSERSITE=1
```

#### **6. Start Jupyter Notebook**
Launch Jupyter Notebook:
```bash
jupyter notebook --no-browser --ip=0.0.0.0
```

Go to **Kernel** → **Select Another Kernel** → **Existing Jupyter Server**
   Enter the URL of the Jupyter Server:```http://{hostname}:{port}/tree?token={token}```
   
   *(Usually, this URL appears in the terminal output after you run `jupyter notebook --no-browser --ip=0.0.0.0`, in a line like:  “Jupyter Server is running at: http://...”)*

   - eg: `http://v011.ib.bridges2.psc.edu:8888/tree?token=e4b302434e68990f28bc2b4ae8d216eb87eecb7090526249`

> **Note**: Replace `{hostname}`, `{port}` and `{token}` with your actual values from the Jupyter output.

After launching the Jupyter notebook, you can run the cells directly inside the notebook — no need to use the terminal for the remaining steps.

### 2️⃣ Step 2: Get Repo

In [None]:
#Make sure you are in your directory
!pwd #should be /jet/home/<your_username>, if not, uncomment the following line and replace with your actual username:
# %cd /jet/home/<your_username>
#TODO: replace the "<your_username>" to yours

In [None]:
# Example: My preferred approach
import os
# Settings -> Developer Settings -> Personal Access Tokens -> Token (classic)
os.environ['GITHUB_TOKEN'] = "your_github_token_here"

GITHUB_USERNAME = "your_github_username_here"
REPO_NAME       = "your_github_repo_name_here"
TOKEN = os.environ.get("GITHUB_TOKEN")
repo_url        = f"https://{TOKEN}@github.com/{GITHUB_USERNAME}/{REPO_NAME}.git"
!git clone {repo_url}

In [None]:
# To pull latest changes (Must be in the repo dir, use pwd/ls to verify)
!cd {REPO_NAME} && git pull

#### **Move to Project Directory**
- `NOTE`: You may have to repeat this on anytime you restart your runtime. You can do a `pwd` or `ls` to check if you are in the right directory.

In [None]:
import os
os.chdir('IDL-HW4')
!ls

### 3️⃣ **Step 3: Set up Kaggle API Authentication**

In [None]:
# TODO: Use the same Kaggle code from HW3P2
!mkdir /jet/home/<your_username>/.kaggle #TODO: replace the "<your_username>" to yours

with open("/jet/home/<your_username>/.kaggle/kaggle.json", "w+") as f: #TODO: replace the "<your_username>" to yours
    f.write('{"username":"<your_username>","key":"<your_key>"}')
    # TODO: Put your kaggle username & key here

!chmod 600 /jet/home/<your_username>/.kaggle/kaggle.json #TODO: replace the "<your_username>" to yours

### 4️⃣ **Step 4: Get Data**

❗️⚠️ The data used in this assignment is **already stored in a shared, read-only folder, so you do not need to manually download anything**.

Instead, just make sure to replace the dataset path in your notebook code with the correct path from the shared directory.

You can run the following block to explore the shared directory structure:

In [None]:
import os
data_path = "/ocean/projects/cis240101p/mzhang23/TA/HW4/hw4_data/hw4p2_data" #Shared data path, do not need to change the username to yours
print("Files in shared hw4p2 dataset:", os.listdir(data_path))

In [None]:
!apt-get install tree
!tree -L 2 /ocean/projects/cis240101p/mzhang23/TA/HW4/hw4_data/hw4p2_data

# Imports
- If your setup was done correctly, you should be able to run the following cell without any issues.

In [None]:
#重启后需要先跑第一二个框（link gdrive and copy），再run 这个框，不然hwlib不存在
from hw4lib.data import (
    H4Tokenizer,
    ASRDataset,
    verify_dataloader
)
from hw4lib.model import (
    DecoderOnlyTransformer,
    EncoderDecoderTransformer
)
from hw4lib.utils import (
    create_scheduler,
    create_optimizer,
    plot_lr_schedule
)
from hw4lib.trainers import (
    ASRTrainer,
    ProgressiveTrainer
)
from torch.utils.data import DataLoader
import yaml
import gc
import torch
from torchinfo import summary
import os
import json
import wandb
import pandas as pd
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


# Implementations
- `NOTE`: All of these implementations have detailed specification, implementation details, and hints in their respective source files. Make sure to read all of them in their entirety to understand the implementation details!

## Dataset Implementation
- Implement the `ASRDataset` class in `hw4lib/data/asr_dataset.py`.
- You will have to implement parts of `__init__` and completely implement the `__len__`, `__getitem__` and `collate_fn` methods.
- Run the cell below to check your implementation.


In [None]:
!python -m tests.test_dataset_asr

Loading data for train-clean-100 partition...
  0% 0/28 [00:00<?, ?it/s]100% 28/28 [00:00<00:00, 698.79it/s]
Loading data for test-clean partition...
  0% 0/2 [00:00<?, ?it/s]100% 2/2 [00:00<00:00, 3806.08it/s]

Running tests for category: ASRDataset Train
--------------------------------------------------------------------------------[0m

[94m[01/01]    Running:  Test a Train instance of ASRDataset class[0m
Testing __init__ method ...
Test Passed: Dataset length matches FBANK files.
Test Passed: Dataset length matches TRANSCRIPT files.
Test Passed: Order alignment between FBANK files and TRANSCRIPT files is correct.
Test Passed: Alignment between features and transcripts is correct.
Test Passed: All features have the correct number of dimensions (num_feats).
Test Passed: All transcripts are decoded correctly after removing SOS and EOS tokens.
Testing __getitem__ method ...
Test Passed: All samples have correct feature dimensions and transcript alignment.
Testing collate_fn meth

## Model Implementations

Overview:

- Implement the `CrossAttentionLayer` class in `hw4lib/model/sublayers.py`.
- Implement the `CrossAttentionDecoderLayer` class in `hw4lib/model/decoder_layers.py`.
- Implement the `SelfAttentionEncoderLayer` class in `hw4lib/model/encoder_layers.py`. This will be mostly a copy-paste of the `SelfAttentionDecoderLayer` class in `hw4lib/model/decoder_layers.py` with one minor diffrence: it can attend to all positions in the input sequence.
- Implement the `EncoderDecoderTransformer` class in `hw4lib/model/transformers.py`.

### Transformer Sublayers
- Now, Implement the `CrossAttentionLayer` class in `hw4lib/model/sublayers.py`.
- `NOTE`: You should have already implemented the `SelfAttentionLayer`, and `FeedForwardLayer` classes in `hw4lib/model/sublayers.py`.
- Run the cell below to check your implementation.

In [None]:
!python -m tests.test_sublayer_crossattention


Running tests for category: CrossAttentionLayer
--------------------------------------------------------------------------------[0m

[94m[01/01]    Running:  Test the cross-attention sublayer[0m
Testing initialization ...
Test Passed: All layers exist and are instantiated correctly
Testing forward shapes ...
Test Passed: Forward pass returns the correct shapes
Testing padding mask behaviour ...
Test Passed: Padding mask is applied correctly
Testing cross-attention behaviour ...
Test Passed: Cross-attention behavior is correct
Testing residual connection ...
Test Passed: Residual connection is applied correctly
[92m[01/01]    PASSED:   Test the cross-attention sublayer[0m


                                  Test Summary                                  
[93mCategory:    CrossAttentionLayer           
Results:     1/1 tests passed (100.0%)[0m


### Transformer Cross-Attention Decoder Layer
- Implement the `CrossAttentionDecoderLayer` class in `hw4lib/model/decoder_layers.py`.
- Then run the cell below to check your implementation.


In [None]:
!python -m tests.test_decoderlayer_crossattention


Running tests for category: CrossAttentionDecoderLayer
--------------------------------------------------------------------------------[0m

[94m[01/01]    Running:  Test the cross-attention decoder layer[0m
Testing initialization ...
Test Passed: All sublayers exist and are initialized correctly
Testing forward shapes ...
Test Passed: Forward shapes are as expected
Testing sublayer integration ...
Test Passed: Sublayers interact correctly
Testing cross-attention behavior ...
Test Passed: Cross-attention behaves correctly
[92m[01/01]    PASSED:   Test the cross-attention decoder layer[0m


                                  Test Summary                                  
[93mCategory:    CrossAttentionDecoderLayer    
Results:     1/1 tests passed (100.0%)[0m


### Transformer Self-Attention Encoder Layer
- Implement the `SelfAttentionEncoderLayer` class in `hw4lib/model/encoder_layers.py`.
- Then run the cell below to check your implementation.




In [None]:
!python -m tests.test_encoderlayer_selfattention


Running tests for category: SelfAttentionEncoderLayer
--------------------------------------------------------------------------------[0m

[94m[01/01]    Running:  Test the self-attention encoder layer[0m
Testing initialization ...
Test Passed: All sublayers exist and are initialized correctly
Testing forward shapes ...
Test Passed: Forward shapes are as expected
Testing sublayer interaction ...
Test Passed: Sublayers interact correctly
Testing bidirectional attention ...
Test Passed: Bidirectional attention is working correctly
[92m[01/01]    PASSED:   Test the self-attention encoder layer[0m


                                  Test Summary                                  
[93mCategory:    SelfAttentionEncoderLayer     
Results:     1/1 tests passed (100.0%)[0m


### Encoder-Decoder Transformer

- Implement the  `EncoderDecoderTransformer` class in `hw4lib/model/transformers.py`.
- Then run the cell below to check your implementation.

In [None]:
!python -m tests.test_transformer_encoder_decoder


Running tests for category: EncoderDecoderTransformer
--------------------------------------------------------------------------------[0m

[94m[01/01]    Running:  Test the encoder-decoder transformer[0m
Testing initialization...
Test Passed: All components initialized correctly
Testing encode method...
Test Passed: Encode method works correctly
Testing decode method...
Test Passed: Decode method works correctly
Testing forward pass...
Test Passed: Forward pass works correctly
Testing encoder-decoder integration...
Test Passed: Encoder-decoder integration works correctly
Testing CTC integration...
Test Passed: CTC integration works correctly
Testing forward propagation order...
Test Passed: Forward propagation order is correct
[92m[01/01]    PASSED:   Test the encoder-decoder transformer[0m


                                  Test Summary                                  
[93mCategory:    EncoderDecoderTransformer     
Results:     1/1 tests passed (100.0%)[0m


## Decoding Implementation
- We highly recommend you to implement the `generate_beam` method of the `SequenceGenerator` class in `hw4lib/decoding/sequence_generator.py`.
- Then run the cell below to check your implementation.
- `NOTE`: This is an optional but highly recommended task for `HW4P2` to ease the journey to high cutoffs!

In [None]:
!python -m tests.test_decoding --mode beam


Running tests for category: Decoding
--------------------------------------------------------------------------------[0m

[94m[01/01]    Running:  Test beam decoding[0m
Testing Single Batch Beam Search ...
Beam 0  : generated: HELLO WORLD  | expected: HELLO WORLD 
Beam 1  : generated: YELLOW WORLD | expected: YELLOW WORLD
Beam 2  : generated: MELLOW WORLD | expected: MELLOW WORLD
Testing Multi Batch Beam Search ...
Batch 0  : Beam 0  : generated: HELLO WORLD  | expected: HELLO WORLD 
Batch 0  : Beam 1  : generated: YELLOW WORLD | expected: YELLOW WORLD
Batch 0  : Beam 2  : generated: MELLOW WORLD | expected: MELLOW WORLD
Batch 1  : Beam 0  : generated: GOOD BYE     | expected: GOOD BYE    
Batch 1  : Beam 1  : generated: GREAT DAY    | expected: GREAT DAY   
Batch 1  : Beam 2  : generated: GUD NIGHT    | expected: GUD NIGHT   
[92m[01/01]    PASSED:   Test beam decoding[0m


                                  Test Summary                                  
[93mCategory:    Decodin

## Trainer Implementation
You will have to do some minor in-filling for the `ASRTrainer` class in `hw4lib/trainers/asr_trainer.py` before you can use it.
- Fill in the `TODO`s in the `__init__`.
- Fill in the `TODO`s in the `_train_epoch`.
- Fill in the `TODO`s in the `recognize` method.
- Fill in the `TODO`s in the `_validate_epoch`.
- Fill in the `TODO`s in the `train` method.
- Fill in the `TODO`s in the `evaluate` method.

`WARNING`: There are no test's for this. Implement carefully!

# Experiments
From this point onwards you may want to switch to a `GPU` runtime.
- `OBJECTIVE`: Optimize your model for `CER` on the test set.

## Config
- You can use the `config.yaml` file to set your config for your ablation study.

---
### Notes:

- Set `tokenization: token_type:` to specify your desired tokenization strategy
- You will need to set the root path to your `hw4p1_data` folder in `data: root:`. This will depend on your setup. For eg. if you are following out setup instruction:
  - `PSC`: `"/local/hw4_data/hw4p1_data"`
  - `Colab:`: `"/content/hw4_data/hw4p1_data"`
- There's extra configurations in the `optimizer` section which will only be relevant if you decide to use the `create_optimizer` function we've provided in `hw4lib/utils/create_optimizer.py`.
- `BE CAREFUL` while setting numeric values. Eg. `1e-4` will get serialized to a `str` while `1.0e-4` gets serialized to float.

In [None]:
%%writefile config.yaml

Name                      : "HW4P2_encoder8_decoder4_d256"

###### Tokenization ------------------------------------------------------------
tokenization:
  token_type                : "5k"       # [char, 1k, 5k, 10k]
  token_map :
      'char': 'hw4lib/data/tokenizer_jsons/tokenizer_char.json'
      '1k'  : 'hw4lib/data/tokenizer_jsons/tokenizer_1000.json'
      '5k'  : 'hw4lib/data/tokenizer_jsons/tokenizer_5000.json'
      '10k' : 'hw4lib/data/tokenizer_jsons/tokenizer_10000.json'

###### Dataset -----------------------------------------------------------------
data:
  root                 : "hw4_data/hw4p2_data"  # 数据路径不变
  train_partition      : "train-clean-100"
  val_partition        : "dev-clean"
  test_partition       : "test-clean"
  subset               : 1.0                # 全量训练
  batch_size           : 32                 # A100 可以再往上调，但保持 32 更稳
  NUM_WORKERS          : 4
  norm                 : 'global_mvn'       # 默认即可
  num_feats            : 80

  ###### SpecAugment -----------------------------------------------------------
  specaug                   : True          # ✅ 开启 SpecAugment
  specaug_conf:
    apply_freq_mask         : True
    freq_mask_width_range   : 15           # 适中频率掩码
    num_freq_mask           : 2
    apply_time_mask         : True
    time_mask_width_range   : 80           # 更长时间掩码，增强鲁棒性
    num_time_mask           : 4

###### Network Specs -------------------------------------------------------------
model: # Encoder-Decoder Transformer (HW4P2)
  # Speech embedding parameters
  input_dim: 80              # Speech feature dimension
  time_reduction: 2          # 保持 2，下采样不太激进
  reduction_method: 'conv'   # conv 下采样

  # Architecture parameters
  d_model: 256               # 与 baseline 相同，方便训练稳定
  num_encoder_layers: 8      # ✅ 增强 encoder（从 6 -> 8）
  num_decoder_layers: 4      # ✅ 减少 decoder（从 6 -> 4）
  num_encoder_heads: 8
  num_decoder_heads: 8
  d_ff_encoder: 1536         # ✅ 放大香农网络宽度（原来 1024，1536）， （G：一般是d_model 的4倍，参数放d_model更有效率）
  d_ff_decoder: 1536
  skip_encoder_pe: False
  skip_decoder_pe: False

  # Common parameters
  dropout: 0.10              # ✅ 稍微加大 dropout，配合更大模型和 SpecAugment
  layer_drop_rate: 0.05      # 保留原来的 layer drop
  weight_tying: True         # ✅ 开启 weight tying，省参数、易收敛

###### Common Training Parameters ------------------------------------------------
training:
  use_wandb                   : True    # 你如果不想用 wandb 可以改成 False
  wandb_run_id                : "none"
  resume                      : False
  gradient_accumulation_steps : 1       # A100 + bs=32 足够
  wandb_project               : "HW4part2"

###### Loss ----------------------------------------------------------------------
loss:
  label_smoothing: 0.1        # ✅ 适度 label smoothing，减轻过拟合
  ctc_weight: 0.3             # ✅ 稍微提高 CTC 权重，辅助对齐

###### Optimizer -----------------------------------------------------------------
optimizer:
  name: "adamw"
  lr: 0.0002                  # ✅ 比原来 1e-4 略大，配合 Cosine + warmup

  # Common parameters
  weight_decay: 0.00005

  param_groups:
    - name: self_attn
      patterns: []
      lr: 0.0002
      layer_decay:
        enabled: False
        decay_rate: 0.8

    - name: ffn
      patterns: []
      lr: 0.0002
      layer_decay:
        enabled: False
        decay_rate: 0.8

  layer_decay:
    enabled: False
    decay_rate: 0.75

  sgd:
    momentum: 0.9
    nesterov: True
    dampening: 0

  adam:
    betas: [0.9, 0.999]
    eps: 1.0e-8
    amsgrad: False

  adamw:
    betas: [0.9, 0.999]
    eps: 1.0e-8
    amsgrad: False

###### Scheduler -----------------------------------------------------------------
scheduler:
  name: "cosine"  # 继续用 cosine

  reduce_lr:
    mode: "min"
    factor: 0.1
    patience: 10
    threshold: 0.0001
    threshold_mode: "rel"
    cooldown: 0
    min_lr: 0.0000001
    eps: 1e-8

  cosine:
    T_max: 80                # ✅ 调整到与训练轮数同量级，而不是 15
    eta_min: 0.0000001
    last_epoch: -1

  cosine_warm:
    T_0: 10
    T_mult: 10
    eta_min: 0.0000001
    last_epoch: -1

  warmup:
    enabled: True
    type: "linear"           # ✅ 对 transformer 更自然
    epochs: 8                # 比原来略长一点
    start_factor: 0.1
    end_factor: 1.0


Overwriting config.yaml


In [None]:
with open('config.yaml', 'r') as file:
    config = yaml.safe_load(file)

## Tokenizer

In [None]:
Tokenizer = H4Tokenizer(
    token_map  = config['tokenization']['token_map'],
    token_type = config['tokenization']['token_type']
)

                          Tokenizer Configuration (5k)                          
--------------------------------------------------------------------------------
Vocabulary size:     5000

Special Tokens:
PAD:              0
UNK:              1
MASK:             2
SOS:              3
EOS:              4
BLANK:            5

Validation Example:
--------------------------------------------------------------------------------
Input text:  [SOS]HI DEEP LEARNERS[EOS]
Tokens:      ['[SOS]', 'H', 'I', 'ĠDEEP', 'ĠLEARN', 'ERS', '[EOS]']
Token IDs:   [3, 14, 15, 1169, 2545, 214, 4]
Decoded:     [SOS]HI DEEP LEARNERS[EOS]


## Datasets

In [None]:

train_dataset = ASRDataset(
    partition=config['data']['train_partition'],
    config=config['data'],
    tokenizer=Tokenizer,
    isTrainPartition=True,
    global_stats=None  # Will compute stats from training data
)

# TODO: Get the computed global stats from training set
global_stats = None
if config['data']['norm'] == 'global_mvn':
    global_stats = (train_dataset.global_mean, train_dataset.global_std)
    print(f"Global stats computed from training set.")

val_dataset = ASRDataset(
    partition=config['data']['val_partition'],
    config=config['data'],
    tokenizer=Tokenizer,
    isTrainPartition=False,
    global_stats=global_stats
)

test_dataset = ASRDataset(
    partition=config['data']['test_partition'],
    config=config['data'],
    tokenizer=Tokenizer,
    isTrainPartition=False,
    global_stats=global_stats
)

gc.collect()

Loading data for train-clean-100 partition...


100%|██████████| 28539/28539 [00:34<00:00, 823.56it/s]


Global stats computed from training set.
Loading data for dev-clean partition...


100%|██████████| 2703/2703 [00:01<00:00, 1507.97it/s]


Loading data for test-clean partition...


100%|██████████| 2620/2620 [00:00<00:00, 2882.59it/s]


2253588

## Dataloaders

In [None]:
train_loader    = DataLoader(
    dataset     = train_dataset,
    batch_size  = config['data']['batch_size'],
    shuffle     = True,
    num_workers = config['data']['NUM_WORKERS'] if device == 'cuda' else 0,
    pin_memory  = True,
    collate_fn  = train_dataset.collate_fn
)

val_loader      = DataLoader(
    dataset     = val_dataset,
    batch_size  = config['data']['batch_size'],
    shuffle     = False,
    num_workers = config['data']['NUM_WORKERS'] if device == 'cuda' else 0,
    pin_memory  = True,
    collate_fn  = val_dataset.collate_fn
)

test_loader     = DataLoader(
    dataset     = test_dataset,
    batch_size  = config['data']['batch_size'],
    shuffle     = False,
    num_workers = config['data']['NUM_WORKERS'] if device == 'cuda' else 0,
    pin_memory  = True,
    collate_fn  = test_dataset.collate_fn
)

torch.cuda.empty_cache()
gc.collect()

0

### Dataloader Verification

In [None]:
verify_dataloader(train_loader)

             Dataloader Verification              
Dataloader Partition     : train-clean-100
--------------------------------------------------
Number of Batches        : 892
Batch Size               : 32
--------------------------------------------------
Checking shapes of the data...                    

Feature Shape            : [32, 2059, 80]
Shifted Transcript Shape : [32, 73]
Golden Transcript Shape  : [32, 73]
Feature Lengths Shape    : [32]
Transcript Lengths Shape : [32]
--------------------------------------------------
Max Feature Length       : 3066
Max Transcript Length    : 100
Avg. Chars per Token     : 4.24


In [None]:
verify_dataloader(val_loader)

             Dataloader Verification              
Dataloader Partition     : dev-clean
--------------------------------------------------
Number of Batches        : 85
Batch Size               : 32
--------------------------------------------------
Checking shapes of the data...                    

Feature Shape            : [32, 3676, 80]
Shifted Transcript Shape : [32, 104]
Golden Transcript Shape  : [32, 104]
Feature Lengths Shape    : [32]
Transcript Lengths Shape : [32]
--------------------------------------------------
Max Feature Length       : 4081
Max Transcript Length    : 138
Avg. Chars per Token     : 4.17


In [None]:
verify_dataloader(test_loader)

             Dataloader Verification              
Dataloader Partition     : test-clean
--------------------------------------------------
Number of Batches        : 82
Batch Size               : 32
--------------------------------------------------
Checking shapes of the data...                    

Feature Shape            : [32, 2099, 80]
Feature Lengths Shape    : [32]
--------------------------------------------------
Max Feature Length       : 4370
Max Transcript Length    : 0
Avg. Chars per Token     : 0.00


## Calculate Max Lengths
Calculating the maximum transcript length across your dataset is a crucial step when working with certain transformer models.
-  We'll use sinusoidal positional encodings that must be precomputed up to a fixed maximum length.
- This maximum length is a hyperparameter that determines:
  - How long of a sequence your model can process
  - The size of your positional encoding matrix
  - Memory requirements during training and inference
- `Requirements`: For this assignment, ensure your positional encodings can accommodate at least the longest sequence in your dataset to prevent truncation. However, you can set this value higher if you anticipate using your languagemodel to work with longer sequences in future tasks (hint: this might be useful for P2! 😉).
- `NOTE`: We'll be using the same positional encoding matrix for all sequences in your dataset. Take this into account when setting your maximum length.

In [None]:
max_feat_len       = max(train_dataset.feat_max_len, val_dataset.feat_max_len, test_dataset.feat_max_len)
max_transcript_len = max(train_dataset.text_max_len, val_dataset.text_max_len, test_dataset.text_max_len)
max_len            = max(max_feat_len, max_transcript_len)

print("="*50)
print(f"{'Max Feature Length':<30} : {max_feat_len}")
print(f"{'Max Transcript Length':<30} : {max_transcript_len}")
print(f"{'Overall Max Length':<30} : {max_len}")
print("="*50)

Max Feature Length             : 4370
Max Transcript Length          : 138
Overall Max Length             : 4370


## Wandb

In [None]:
wandb.login(key="5c102bb14269774e676d3197d391926898a84841")

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

## Training

Every time you run the trainer, it will create a new directory in the `expts` folder with the following structure:
```
expts/
    └── {run_name}/
        ├── config.yaml
        ├── model_arch.txt
        ├── checkpoints/
        │   ├── checkpoint-best-metric-model.pth
        │   └── checkpoint-last-epoch-model.pth
        ├── attn/
        │   └── {attention visualizations}
        └── text/
            └── {generated text outputs}
```


In [None]:
torch.cuda.empty_cache()
gc.collect()

163

### Training Strategy 1: Cold-Start Trainer

#### Model Load (Default)

In [None]:
model_config = config['model'].copy()
model_config.update({
    'max_len': max_len,
    'num_classes': Tokenizer.vocab_size
})

model = EncoderDecoderTransformer(**model_config)

# Get some inputs from the train dataloader
for batch in train_loader:
    padded_feats, padded_shifted, padded_golden, feat_lengths, transcript_lengths = batch
    break

total_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
assert total_param < 30_000_000, f"Total trainable parameters ({total_param}) exceeds 30 million."

model_stats = summary(model, input_data=[padded_feats, padded_shifted, feat_lengths, transcript_lengths])
print(model_stats)

AssertionError: Total trainable parameters (56860432) exceeds 30 million.

#### Initialize Trainer

If you need to reload the model from a checkpoint, you can do so by calling the `load_checkpoint` method.

```python
checkpoint_path = "path/to/checkpoint.pth"
trainer.load_checkpoint(checkpoint_path)
```


In [None]:
trainer = ASRTrainer(
    model=model,
    tokenizer=Tokenizer,
    config=config,
    run_name="HW4Part2Run01",
    config_file="config.yaml",
    device=device
)

### Setup Optimizer and Scheduler

You can set your own optimizer and scheduler by setting the class members in the `LMTrainer` class.
Eg:
```python
trainer.optimizer = optim.AdamW(model.parameters(), lr=config['optimizer']['lr'], weight_decay=config['optimizer']['weight_decay'])
trainer.scheduler = optim.lr_scheduler.CosineAnnealingLR(trainer.optimizer, T_max=config['training']['epochs'])
```

We also provide a utility function to create your own optimizer and scheduler with the congig and some extra bells and whistles. You are free to use it or not. Do read their code and documentation to understand how it works (`hw4lib/utils/*`).


#### Setting up the optimizer

In [None]:
trainer.optimizer = create_optimizer(
    model=model,
    opt_config=config['optimizer']
)

#### Creating a test scheduler and plotting the learning rate schedule

In [None]:
test_scheduler = create_scheduler(
    optimizer=trainer.optimizer,
    scheduler_config=config['scheduler'],
    train_loader=train_loader,
    gradient_accumulation_steps=config['training']['gradient_accumulation_steps']
)

plot_lr_schedule(
    scheduler=test_scheduler,
    num_epochs=60,
    train_loader=train_loader,
    gradient_accumulation_steps=config['training']['gradient_accumulation_steps']
)

#### Setting up the scheduler

In [None]:
trainer.scheduler = create_scheduler(
    optimizer=trainer.optimizer,
    scheduler_config=config['scheduler'],
    train_loader=train_loader,
    gradient_accumulation_steps=config['training']['gradient_accumulation_steps']
)

#### Train
- Set your epochs and start training!
- `NOTE`: A `scheduler` gets initialized in this call based on the config.

In [None]:
trainer.train(train_loader, val_loader, epochs=100)

#### Inference



In [None]:
# Define the recognition config: Greedy search
recognition_config = {
    'num_batches': None,
    'temperature': 1.0,
    'repeat_penalty': 1.0,
    'lm_weight': None,
    'lm_model': None,
    'beam_width': 1, # Beam width of 1 reverts to greedy
}

# Recognize with the shallow fusion config
config_name = "test"
print(f"Evaluating with {config_name} config")
results = trainer.recognize(test_loader, recognition_config, config_name=config_name, max_length=max_transcript_len)


# Calculate metrics on full batch
generated = [r['generated'] for r in results]
results_df = pd.DataFrame(
    {
        'id': range(len(generated)),
        'transcription': generated
    }
)

# Cleanup (Will end wandb run)
trainer.cleanup()

## Submit to Kaggle

### Authenticate Kaggle
In order to use the Kaggle’s public API, you must first authenticate using an API token. Go to the 'Account' tab of your user profile and select 'Create New Token'. This will trigger the download of kaggle.json, a file containing your API credentials.
- `TODO`: Set your kaggle username and api key here based on the API credentials listed in the kaggle.json




In [None]:
import os
os.environ["KAGGLE_USERNAME"] = "bowu1224"
os.environ["KAGGLE_KEY"] = "8503ee53b6b0907c223a1d085a7bb28e"

In [None]:
results_df.head()

Unnamed: 0,id,transcription
0,0,HE HOPED THERE WOULD BE STOO FOR DINNER TURNIP...
1,1,STUFFET INTO YOU HIS BELLY COUNSELLED HIM
2,2,AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L...
3,3,HOW BERTY ANY GOOD IN YOUR MIND
4,4,NAME BERTAN FRESH NELLIE IS WAITING ON YOU COU...


### Submit

In [None]:
results_df.to_csv("results.csv", index=False)
!kaggle competitions submit -c 11-785-hw-4-p-2-automatic-speech-recognition-f-25 -f results.csv -m "My Submission"

100% 288k/288k [00:02<00:00, 132kB/s]
Successfully submitted to 11-785 HW4P2: Automatic Speech Recognition -F25

#### TODO: Generate a model_metadata.json file to save your model's data (due 48 hours after Kaggle submission deadline OR the day of slack submission)

In [None]:
import json, os, sys, torch, datetime
################################
# TODO: Keep the model_metadata.json
# file safe for submission ater.
################################
def is_colab():
    return "google.colab" in sys.modules and "COLAB_GPU" in os.environ

def is_kaggle():
    return "KAGGLE_KERNEL_RUN_TYPE" in os.environ or "KAGGLE_URL_BASE" in os.environ

def generate_model_submission_file(model):
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
    json_filename = f"model_metadata_{timestamp}.json"

    # Create JSON with parameter count, model architecture, and predictions
    output_json = {
        "parameter_count": sum(p.numel() for p in model.parameters() if p.requires_grad),
        "model_architecture": str(model),
    }

    # Save metadata JSON
    with open(json_filename, "w") as f:
        json.dump(output_json, f, indent=2)

    # Download / display link depending on environment
    if is_colab():
        from google.colab import files
        print(f"OK: Saved as {json_filename}. Downloading in Colab...")
        files.download(json_filename)

    elif is_kaggle():
        from IPython.display import FileLink, display
        print("#" * 100)
        print(f"OK: Your submission file `{json_filename}` has been generated.")
        print("TODO: Click the link below.")
        print("1. The file will open in a new tab.")
        print("2. Right-click anywhere in the new tab and select 'Save As...'")
        print("3. Save the file to your computer with the `.json` extension.")
        print("You MUST submit this file to Autolab if this is your best submission.")
        print("#" * 100 + "\n")
        display(FileLink(json_filename))

    else:
        print(f"OK: saved model data saved to: '{json_filename}'")
        print("REQUIRED to submit to Autolab if these are the best model weights.")

generate_model_submission_file(model)
#### IMPORTANT: Do NOT change the name of the model_metadata_....json file!!

OK: Saved as model_metadata_2025-11-26_23-54.json. Downloading in Colab...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## TODO: fill in your submission requirements

### Notes:

- You will need to set the root path to your submission files (eg. MODEL_METADATA_JSON, NOTEBOOK_PATH, HW4LIB_PATH). This will depend on your setup. For eg. if you are following our setup instruction:
  - `Colab:`: `"/content/..."`In the left file pane, right-click the desired file or folder and select “Copy path”.
  - `PSC`: `"/jet/home/<your_username>/..."` You can check the files in this path by running: ```!ls /jet/home/<your_username>/```

Kindly modify your configurations to suit your ablations and be keen to include your name.

In [None]:
####################################
#             README
####################################

# TODO: Please complete all components of this README
README = """
- **Model**: Model archtiecture description. Anything unique? Any specific architecture shapes or strategies?
- **Training Strategy**: optimizer + scheduler + loss function + any other unique ideas
- **Augmentations**: augmentations if used. If augmentations weren't used, then ignore
- **Notebook Execution**: Any instructions required to run your notebook.
"""

####################################
#       Credentials (Optional)
####################################

# These are not required **IF** you have run the cells to declare these variables above.
# If you would like to paste your credentials here again, feel free to:
# OPTIONAL: Fill these out if you do not want to re-run previous cells to re-initialize these credential variables

KAGGLE_USERNAME = "bowu1224" #TODO
KAGGLE_API_KEY = "8503ee53b6b0907c223a1d085a7bb28e" #TODO
WANDB_API_KEY = "5c102bb14269774e676d3197d391926898a84841" #TODO


####################################
#             Wandb Logs
####################################

# TODO: Your wandb project url should look like https://wandb.ai/username-or-team-name/project-name
#(Take these parameters and put them in the variables below)

WANDB_USERNAME_OR_TEAMNAME = "wubo0004-nangyang-technological-university" # TODO: Put your username-or-team-name here
WANDB_PROJECT = "HW4part2" # TODO: Put your project-name

####################################
#         Notebook & Files
####################################

# TODO: Download HW4P2 Notebook (if on colab or kaggle) and upload both your HW4P2 notebook + model_metadata_*.json to your file system.
# TODO: For each file, obtain the file paths and put them below.

# TODO: COLAB INSTRUCTIONS:
# * With Colab, upload your desired file (notebook or model_metadata.json) to "Files"
# * Right-click the file, click "Copy Path,"
# * Paste the path below.

# TODO: KAGGLE INSTRUCTIONS:
# * First download a copy of your notebook with "File > Download Notebook"
# Then...
# * Click "File" in the top left of the screen
# * Go to "Upload Input > Upload Model"
# * Upload your notebook file.
# * For "Model Name" put HW4P2_Final_Submission
# * For "Framework" put "Other"
# * For "License" put "Other"
# * Click "Upload another file" and upload your model_metadata####.json file as well.
# * Now, on your right in your "Models" section, you should see a new folder with your submission files.
# * Click on the "Copy File Path" buttons for the notebook and json file and paste them below.

# TODO: Linux system:
# * Simply upload or find the path of your notebook file and model_metadata###.json file, and paste them here.

NOTEBOOK_PATH = "/content/gdrive/MyDrive/Colab Notebooks/IDL-HW4/HW4P2_Student_Starter_Notebook.ipynb" # TODO: Put your HW4P2 notebook path here
MODEL_METADATA_JSON = "/content/gdrive/MyDrive/Colab Notebooks/IDL-HW4/model_metadata_2025-11-26_23-54.json" # TODO: Put your Model Metadata path json file here (see end of HW4P2 Code Notebook to get this file)
HW4LIB_PATH = "/content/gdrive/MyDrive/Colab Notebooks/IDL-HW4/hw4lib" # TODO: Put your hw4lib path here

####################################
#         Additional Files
####################################

ADDITIONAL_FILES = [ # TODO: Upload any files and add any paths to any additional files you would like to include in your submission, otherwise, leave this empty
]

####################################
#         SLACK SUBMISSION
####################################

ENABLE_SLACK_SUBMISSION = False # TODO: Set this to true if you are submitting to the Slack competition

####################################
#     Creating the Submission
####################################

# TODO: Once the README, wandb information, and file paths are filled in, run this cell,
# run the "Assignment Backend Functions" in the next cells, and generate the final zip file at the end.

SAFE_SUBMISSION = True # TODO: Set this to False if you want to generate a submission.zip even if you are missing files, otherwise it's recommended to keep this as True


# Assignment Backend Submission Functions (DO NOT MODIFY, just run these cells)

In [None]:
from datetime import datetime

######################################
#       Assignment Configs
######################################

WANDB_METRIC = "CER"
WANDB_DIRECTION = "descending"
WANDB_TOP_N = 10
WANDB_OUTPUT_PKL = "wandb_top_runs.pkl"

# Kaggle configuration
COMPETITION_NAME = "11-785-hw-4-p-2-automatic-speech-recognition-f-25"
SLACK_COMPETITION_NAME = "slack-hw-4-p-2-f-25"
FINAL_SUBMISSION_DATETIME = datetime.strptime("2025-12-06 00:00:00", "%Y-%m-%d %H:%M:%S")
SLACK_SUBMISSION_DATETIME = datetime.strptime("2025-12-11 00:00:00", "%Y-%m-%d %H:%M:%S")
GRADING_DIRECTION = "descending"
KAGGLE_OUTPUT_JSON = "kaggle_data.json"

SUBMISSION_OUTPUT = "HW4P2_final_submission.zip"

In [None]:
from datetime import datetime, timezone
import zoneinfo

eastern = zoneinfo.ZoneInfo("America/New_York")
FINAL_DEADLINE_UTC = (
    FINAL_SUBMISSION_DATETIME
    .replace(tzinfo=eastern)
    .astimezone(timezone.utc)
)

SLACK_DEADLINE_UTC = (
    SLACK_SUBMISSION_DATETIME
    .replace(tzinfo=eastern)
    .astimezone(timezone.utc)
)

ACKNOWLEDGEMENT_MESSAGE = """
Submission of this file and assignment indicate the student's agreement to the following Aknowledgement requirements:
Setting the ACNKOWLEDGED flag to True indicates full understanding and acceptance of the following:
1. Slack days may ONLY be used on P2 FINAL (not checkpoint) submission. I.e. you may use slack days to submit final P2 kaggle scores (such as this one) later on the **SLACK KAGGLE COMPETITION** at the expense of your Slack days.
2. The final autolab **code submission is due 48 hours after** the conclusion of the Kaggle Deadline (or, the same day as your final kaggle submission).
3. Course staff will require your kaggle username here, and then will pull your official PRIVATE kaggle leaderboard score. This submission may result in slight variance in scores/code, but we will check for acceptable discrepancies. Any discrepancies related to modifying the submission code (at the bottom of the notebook) will result in an AIV.
4. You are NOT allowed to use any code that will pre-load models (such as those from Hugging Face, etc.).
   You MAY use models described by papers or articles, but you MUST implement them yourself through fundamental PyTorch operations (i.e. Linear, Conv2d, etc.).
5. You are NOT allowed to use any external data/datasets at ANY point of this assignment.
6. You may work with teammates to run ablations/experiments, BUT you must submit your OWN code and your OWN results.
7. Failure to comply with the prior rules will be considered an Academic Integrity Violation (AIV).
8. Late submissions MUST be submitted through the Slack Kaggle (see writeup for details). Any submissions made to the regular Kaggle after the original deadline will NOT be considered, no matter how many slack days remain for the student.
"""
def save_acknowledgment_file():
    if ACKNOWLEDGED:
        with open("acknowledgement.txt", "w") as f:
            f.write(ACKNOWLEDGEMENT_MESSAGE.strip())
        print("Saved acknowledgement.txt")
        return True
    else:
        print("ERROR: Must set ACKNOWLEDGED = True.")
        return False
# Saves README
def save_readme(readme):
    try:
        with open("README.txt", "w") as f:
            f.write(readme.strip())

        print("Saved README.txt")
    except Exception as e:
        print(f"ERROR: Error occured while saving README.txt: {e}")
        return False

    return True

# Saves wandb logs
import wandb, json, pickle

def save_top_wandb_runs():
    wandb.login(key=WANDB_API_KEY)
    if not ACKNOWLEDGED:
        print("ERROR: Must set ACKNOWLEDGED = True.")
        return False

    api = wandb.Api()
    runs = api.runs(
        f"{WANDB_USERNAME_OR_TEAMNAME}/{WANDB_PROJECT}",
        order=f"{'-' if WANDB_DIRECTION == 'descending' else ''}summary_metrics.{WANDB_METRIC}"
    )
    selected_runs = runs[:min(WANDB_TOP_N, len(runs))]

    if not selected_runs:
        print(f"ERROR: No runs found for {WANDB_USERNAME_OR_TEAMNAME}/{WANDB_PROJECT}. Please check that your wandb credentials (Wandb Username/Team Name, API Key, and Project Name) are correct.")
        return False

    all_data = []
    for run in selected_runs:
        run_data = {
            "id": run.id,
            "name": run.name,
            "tags": run.tags,
            "state": run.state,
            "created_at": str(run.created_at),
            "config": run.config,
            "summary": dict(run.summary),
        }
        try:
            run_data["history"] = run.history(samples=1000)
        except Exception as e:
            run_data["history"] = f"Failed to fetch history: {str(e)}"
        all_data.append(run_data)
    with open(WANDB_OUTPUT_PKL, "wb") as f:
        pickle.dump(all_data, f)

    print(f"OK: Exported {len(all_data)} WandB runs to {WANDB_OUTPUT_PKL}")

    return True
# Saves kaggle information

# Install dependencies silently (only if running on Colab)
import sys

from datetime import datetime
import os, json, requests
def kaggle_login(username, key):
    os.makedirs(os.path.expanduser("~/.kaggle"), exist_ok=True)
    with open(os.path.expanduser("~/.kaggle/kaggle.json"), "w") as f:
        json.dump({"username": username, "key": key}, f)
    os.chmod(os.path.expanduser("~/.kaggle/kaggle.json"), 0o600)


def get_active_submission_config():
    if ENABLE_SLACK_SUBMISSION:
        return SLACK_COMPETITION_NAME, SLACK_DEADLINE_UTC
    return COMPETITION_NAME, FINAL_DEADLINE_UTC

def kaggle_user_exists(usernagbme):
    try:
        return requests.get(f"https://www.kaggle.com/{KAGGLE_USERNAME}").status_code == 200
    except Exception as e:
        print(f"ERROR: Error occured while checking Kaggle user: {e}")
        return False

DEFAULT_SCORE=0
if GRADING_DIRECTION == "ascending":
    DEFAULT_SCORE=0
else:
    DEFAULT_SCORE=1.0

def get_best_kaggle_score(subs):
    def extract_score(s): return float(s.private_score or s.public_score or DEFAULT_SCORE)
    if not subs:
        return None, None
    best = max(subs, key=lambda s: extract_score(s) if GRADING_DIRECTION == "ascending" else -extract_score(s))

    score_type = "private" if best.private_score not in [None, ""] else "public"
    return extract_score(best), score_type

def save_kaggle_json(kaggle_username, kaggle_key):

    kaggle_login(kaggle_username, kaggle_key)

    from kaggle.api.kaggle_api_extended import KaggleApi

    if not ACKNOWLEDGED:
        print("ERROR: Must set ACKNOWLEDGED = True.")
        return False

    if not kaggle_user_exists(KAGGLE_USERNAME):
        print(f"ERROR: User '{KAGGLE_USERNAME}' not found.")
        return False

    comp_name, deadline = get_active_submission_config()

    api = KaggleApi()
    api.authenticate()

    # Get competition submissions
    submissions = [s for s in api.competition_submissions(comp_name) if getattr(s, "_submitted_by", None) == KAGGLE_USERNAME]
    if not submissions:
        print(f"ERROR: No valid submissions found for user [{KAGGLE_USERNAME}] for this competition [{comp_name}]. Slack flag set to [{ENABLE_SLACK_SUBMISSION}]")
        print("Please double check your Kaggle username and ensure you've submitted at least once.")
        return False

    score, score_type = get_best_kaggle_score(submissions)
    result = {
        "kaggle_username": KAGGLE_USERNAME,
        "acknowledgement": ACKNOWLEDGED,
        "submitted_slack": ENABLE_SLACK_SUBMISSION,
        "competition_name": comp_name,
        "deadline": deadline.strftime("%Y-%m-%d %H:%M:%S"),
        "raw_score": score * 100.0,
        "score_type": score_type,
    }

    print(f"OK: Projected score (excluding bonuses) saved as {KAGGLE_OUTPUT_JSON}")
    if score:
        print(f"Best score {score}.")
        with open(KAGGLE_OUTPUT_JSON, "w") as f:
            json.dump(result, f, indent=2)
        return True
    return False

import os
import sys
import zipfile


def create_submission_zip(additional_files, safe_flag):
    if not "ACKNOWLEDGED" in globals() or not ACKNOWLEDGED:
        print("ERROR: Make sure to RUN the Acknowledgement cell (at the top of the notebook). Also, must set ACKNOWLEDGED = True.")
        return

    if (not save_acknowledgment_file()):
        print("ERROR: Make sure to RUN the Acknowledgement cell (at the top of the notebook). Also, must set ACKNOWLEDGED = True.")
        return


    if not "ENABLE_SLACK_SUBMISSION" in globals() or ENABLE_SLACK_SUBMISSION is None:
        print("ERROR: \"ENABLE_SLACK_SUBMISSION\" variable is not defined. \nTODO: Make sure to RUN the cell (A few cells up at the beginning of the submission section). \nMake sure to set the ENABLE_SLACK_SUBMISSION checkbox if you're on colab, or set the parameter correctly set on other platforms \n(if you are submitting through the SLACK submission).")
        return

    if not "README" in globals() or not README:
        print("ERROR: Make sure to RUN the README cell(above your credentials cell).")
        return

    if (not save_readme(README)):
        print("ERROR: Error while saving the README file. Make sure to complete and RUN the README cell(above your credentials cell).")
        return

    if (not save_top_wandb_runs()):
        return

    if not "KAGGLE_USERNAME" in globals() or not "KAGGLE_API_KEY" in globals() or not KAGGLE_USERNAME or not KAGGLE_API_KEY:
        print("ERROR: Make sure to set KAGGLE_USERNAME and KAGGLE_API_KEY for this code submission.")
        return

    if (not save_kaggle_json(KAGGLE_USERNAME, KAGGLE_API_KEY)):
        print(f"ERROR: An error occured while retrieve kaggle information from username [{KAGGLE_USERNAME}] from competition [{get_active_submission_config()[0]}] with slack flag set to [{ENABLE_SLACK_SUBMISSION}]. Please check your kaggle username, key, and submission.")
        return

    files_to_zip = [
        "acknowledgement.txt",
        "README.txt",
        KAGGLE_OUTPUT_JSON,
        WANDB_OUTPUT_PKL,
        MODEL_METADATA_JSON,
        NOTEBOOK_PATH,
        HW4LIB_PATH,
    ] + additional_files

    missing_files = False

    with zipfile.ZipFile(SUBMISSION_OUTPUT, "w") as zipf:
        for file_path in files_to_zip:
            if os.path.exists(file_path):
                arcname = os.path.basename(file_path)  # flatten path
                zipf.write(file_path, arcname=arcname)
                print(f"OK: Added {arcname}")
            else:
                missing_files = True
                print(f"ERROR: Missing file: {file_path}")

    if missing_files:
        if safe_flag:
            raise "ERROR: Missing files with safety flag set to True. Please upload any necessary files, ensure you have the correct paths and rerun all cells."
        else:
            print("WARNING: Missing files with safety flag set to False. Submission may be incomplete.")

    if "google.colab" in sys.modules:
        from google.colab import files
        files.download(SUBMISSION_OUTPUT)

    print("Final submission saved as:", SUBMISSION_OUTPUT)

# File Generation (TODO: Check file generation outputs for any errors)

### For Colab and PSC users:

In [None]:
create_submission_zip(ADDITIONAL_FILES, SAFE_SUBMISSION)

#TODO: If the HW4P2_final_submission.zip file does not
# automatically bring up a donwload pop-up
# Then make sure to entire the files and
#manually download the checkpoint_submission.json file.

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Saved acknowledgement.txt
Saved README.txt
OK: Exported 3 WandB runs to wandb_top_runs.pkl
OK: Projected score (excluding bonuses) saved as kaggle_data.json
Best score 9.5173.
OK: Added acknowledgement.txt
OK: Added README.txt
OK: Added kaggle_data.json
OK: Added wandb_top_runs.pkl
OK: Added model_metadata_2025-11-26_23-54.json
OK: Added HW4P2_Student_Starter_Notebook.ipynb
OK: Added hw4lib


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Final submission saved as: HW4P2_final_submission.zip
