In [1]:
""" import """

import sys, os
os.chdir(os.path.abspath(os.path.join(os.getcwd(), '..')))
project_root = os.getcwd()
if project_root not in sys.path: sys.path.append(project_root)

import torch
import lightning
import numpy as np
import pandas as pd
import scipy.io
import scipy.stats

import tqdm
import warnings
import dataclasses
import plotly.express
import plotly.subplots
import plotly.graph_objects
import matplotlib.pyplot as plt
import ipywidgets
from IPython.display import display, clear_output

import src

if torch.cuda.is_available(): device = "cuda"
elif torch.backends.mps.is_available(): device = "mps"
else: device = "cpu"

# disable MPS UserWarning: The operator 'aten::col2im' is not currently 
# supported on the MPS backend
warnings.filterwarnings("ignore", message=".*MPS.*fallback.*")

In [2]:
""" config """
config = src.config.ConfigD04()     # config.data and config.model will be used
profile_load_path = "data/wave2wave/sample.csv"
ckpt_load_path = f"ckpt/{config.__class__.__name__}/last.ckpt"
# for prediction save & load
result_fold = f"data/{config.__class__.__name__}/"
result_path = os.path.join(result_fold, "result.pt")
sample_path = os.path.join(result_fold, "sample.csv")

In [3]:
""" prediction """
# load data
x = torch.load(config.data.x_load_path, weights_only=True)
y = torch.load(config.data.y_load_path, weights_only=True)
# normalize y
valid = ~torch.isnan(y).any(dim=1)
std = y[valid].std()
mean = y[valid].mean()
y = (y - mean) / std
# append y as an additional channel to x
x = torch.cat([x, y.unsqueeze(1)], dim=1)
# sample profile
sample = pd.read_csv(profile_load_path)
sample = sample.drop(columns=["split02"])
sample = sample.rename(columns={"split01": "split"})
sample["split"] = sample["split"].map({0: "train", 1: "test", 2: "test"})
sample["system"] = sample["system"].map({False: "old", True: "new"})
# filter valid samples
valid = torch.where(~torch.isnan(x).any(dim=(1, 2)))[0]
x = x[valid]
y = y[valid]
sample = sample.iloc[valid.numpy()].reset_index(drop=True)
# data
dataloader = torch.utils.data.DataLoader(
    src.data.Dataset(x, y), batch_size=config.data.batch_size, shuffle=False, 
)
# model
model = src.model.SCOST(**dataclasses.asdict(config.model))
ckpt = torch.load(
    ckpt_load_path, weights_only=True, map_location=torch.device(device)
)
state_dict = {
    k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()
    if k.startswith("model.")
}
model.load_state_dict(state_dict, strict=False)
model = model.eval().to(device)
# predict
result_b = []
for batch in tqdm.tqdm(dataloader):
    # batch to device
    x, channel_idx, y = batch
    x, channel_idx, y = x.to(device), channel_idx.to(device), y.to(device)
    # forward
    with torch.no_grad(): x_pred, _ = model(
        x, channel_idx, head_type="reconstruction", user_mask=3
    )
    # store result
    result_b.append(torch.cat([
        x.detach().cpu(),                               # (B, 4, T)
        x_pred[:, 3, :].detach().cpu().unsqueeze(1)     # (B, 1, T)
    ], dim=1))
result = torch.cat(result_b, dim=0)            # (N, 5, T)
# map bp in 3nd and 4th channel to waveform before normalization
# and store new waveform in 5th and 6th channel
result = torch.cat([
    result, 
    (result[:, 3, :] * std + mean).unsqueeze(1),
    (result[:, 4, :] * std + mean).unsqueeze(1),
], dim=1).detach().cpu()
# store key features in profile
sample["TrueMinBP"] = result[:, 5].min(dim=1).values.numpy()
sample["TrueMaxBP"] = result[:, 5].max(dim=1).values.numpy()
sample["PredMinBP"] = result[:, 6].min(dim=1).values.numpy()
sample["PredMaxBP"] = result[:, 6].max(dim=1).values.numpy()
sample["(P-T)MinBP"] = sample["PredMinBP"] - sample["TrueMinBP"]
sample["(P-T)MaxBP"] = sample["PredMaxBP"] - sample["TrueMaxBP"]
# save result as .pt and sample as .csv in result_save_fold
# print shape and where saved
os.makedirs(result_fold, exist_ok=True)
torch.save(result, result_path)
sample.to_csv(sample_path, index=False)
print(f"sample: pd.DataFrame > {sample_path}\t{sample.shape}")
print(f"result: torch.Tensor > {result_path}\t\t{tuple(result.shape)}")

100%|██████████| 85/85 [00:29<00:00,  2.91it/s]


sample: pd.DataFrame > data/ConfigD04/sample.csv	(21592, 13)
result: torch.Tensor > data/ConfigD04/result.pt		(21592, 7, 1000)


In [4]:
""" load """
result = torch.load(result_path, weights_only=True)
sample = pd.read_csv(sample_path)

In [None]:
""" calibration 0: no calibration """
for s in ["train", "test"]: print(
    s, "\tMAE of (min, max) = ({:5.2f}, {:5.2f})".format(
        np.nanmean(np.abs(sample[(sample["split"] == s) & (sample["condition"] != 1)]["(P-T)MinBP"])),
        np.nanmean(np.abs(sample[(sample["split"] == s) & (sample["condition"] != 1)]["(P-T)MaxBP"])),
    )
)

train 	MAE of (min, max) = ( 3.74,  4.68)
test 	MAE of (min, max) = (15.32, 22.36)


In [6]:
""" calibration 1: slope and bias per subject and min/max """
sample["Adj1PredMinBP"]  = np.nan
sample["Adj1PredMaxBP"]  = np.nan
sample["(Adj1P-T)MinBP"] = np.nan
sample["(Adj1P-T)MaxBP"] = np.nan
for subject, group in sample.groupby("subject"):
    # calibration subset: condition == 1
    cond1 = group[group["condition"] == 1]
    # fit linear models if enough calibration points
    if len(cond1) < 2:
        # Not enough calibration points: no correction
        a_min, b_min = 0.0, 0.0   # error ≈ 0 → P_adj = P
        a_max, b_max = 0.0, 0.0
    else:
        # ----- Min BP: fit PredMinBP -> (P-T)MinBP -----
        a_min, b_min, r_value, p_value, std_err = scipy.stats.linregress(
            cond1["PredMinBP"], cond1["(P-T)MinBP"]
        )
        # ----- Max BP: fit PredMaxBP -> (P-T)MaxBP -----
        a_max, b_max, r_value, p_value, std_err = scipy.stats.linregress(
            cond1["PredMaxBP"], cond1["(P-T)MaxBP"]
        )
    # indices of this subject in the original DataFrame
    idx = group.index
    # Apply this subject's correction to all its rows
    pred_min = sample.loc[idx, "PredMinBP"]
    pred_max = sample.loc[idx, "PredMaxBP"]
    true_min = sample.loc[idx, "TrueMinBP"]
    true_max = sample.loc[idx, "TrueMaxBP"]
    # Predicted error from linear model
    err_hat_min = a_min * pred_min + b_min  # type: ignore
    err_hat_max = a_max * pred_max + b_max  # type: ignore
    # Corrected predictions: P_adj = P - (aP + b)
    adj_pred_min = pred_min - err_hat_min
    adj_pred_max = pred_max - err_hat_max
    # stre
    sample.loc[idx, "Adj1PredMinBP"]  = adj_pred_min
    sample.loc[idx, "Adj1PredMaxBP"]  = adj_pred_max
    sample.loc[idx, "(Adj1P-T)MinBP"] = adj_pred_min - true_min
    sample.loc[idx, "(Adj1P-T)MaxBP"] = adj_pred_max - true_max

for s in ["train", "test"]: print(
    s, "\tMAE of (min, max) = ({:5.2f}, {:5.2f})".format(
        np.nanmean(np.abs(sample[(sample["split"] == s) & (sample["condition"] != 1)]["(Adj1P-T)MinBP"])),
        np.nanmean(np.abs(sample[(sample["split"] == s) & (sample["condition"] != 1)]["(Adj1P-T)MaxBP"])),
    )
)

train 	MAE of (min, max) = ( 7.63,  6.76)
test 	MAE of (min, max) = (11.94, 20.47)


In [7]:
""" calibration 2: slope per split and min/max, bias per subject and min/max """

# ---- 0. Prepare columns for second calibration ----
sample["Adj2PredMinBP"]  = np.nan
sample["Adj2PredMaxBP"]  = np.nan
sample["(Adj2P-T)MinBP"] = np.nan
sample["(Adj2P-T)MaxBP"] = np.nan

# ---- 1. Compute global slopes for each split (train / test) ----
global_slopes = {}  # keys: (split, "min") / (split, "max")

for split_name in ["train", "test"]:
    df_split = sample[
        (sample["split"] == split_name) & (sample["condition"] == 1)
    ]
    # ---- Min BP: PredMinBP -> (P-T)MinBP ----
    if df_split["PredMinBP"].notna().sum() >= 2:
        a_min, b_min, r, p, se = scipy.stats.linregress(
            df_split["PredMinBP"], df_split["(P-T)MinBP"]
        )
    else:
        a_min = 0.0  # no slope info; treat as pure bias
    global_slopes[(split_name, "min")] = a_min
    # ---- Max BP: PredMaxBP -> (P-T)MaxBP ----
    if df_split["PredMaxBP"].notna().sum() >= 2:
        a_max, b_max, r, p, se = scipy.stats.linregress(
            df_split["PredMaxBP"], df_split["(P-T)MaxBP"]
        )
    else:
        a_max = 0.0
    global_slopes[(split_name, "max")] = a_max

# ---- 2. Per-subject bias estimation using fixed global slope ----
for subject, group in sample.groupby("subject"):
    # Process each split separately, because slope depends on split
    for split_name in ["train", "test"]:
        sub = group[group["split"] == split_name]
        if sub.empty:
            continue  # this subject has no samples in this split

        a_min = global_slopes[(split_name, "min")]
        a_max = global_slopes[(split_name, "max")]

        # calibration subset for bias: condition == 1 within this subject & split
        calib = sub[sub["condition"] == 1]

        # ----- MinBP: bias b_min_subject -----
        if calib["PredMinBP"].notna().sum() >= 1:
            # (P-T)_i ≈ a_global * Pred_i + b_subject
            # => b_subject = mean((P-T)_i - a_global * Pred_i)
            b_min = np.nanmean(
                calib["(P-T)MinBP"] - a_min * calib["PredMinBP"]
            )
        else:
            b_min = 0.0  # no info, fallback to 0

        # ----- MaxBP: bias b_max_subject -----
        if calib["PredMaxBP"].notna().sum() >= 1:
            b_max = np.nanmean(
                calib["(P-T)MaxBP"] - a_max * calib["PredMaxBP"]
            )
        else:
            b_max = 0.0

        # ----- Apply calibration to all rows of this subject & split -----
        idx = sub.index

        pred_min = sample.loc[idx, "PredMinBP"]
        pred_max = sample.loc[idx, "PredMaxBP"]
        true_min = sample.loc[idx, "TrueMinBP"]
        true_max = sample.loc[idx, "TrueMaxBP"]

        # predicted error from fixed-slope + subject-specific bias
        err_hat_min = a_min * pred_min + b_min
        err_hat_max = a_max * pred_max + b_max

        adj_pred_min = pred_min - err_hat_min
        adj_pred_max = pred_max - err_hat_max

        sample.loc[idx, "Adj2PredMinBP"]  = adj_pred_min
        sample.loc[idx, "Adj2PredMaxBP"]  = adj_pred_max
        sample.loc[idx, "(Adj2P-T)MinBP"] = adj_pred_min - true_min
        sample.loc[idx, "(Adj2P-T)MaxBP"] = adj_pred_max - true_max

# ---- 3. Print MAE for train / test after calibration 2 ----
for s in ["train", "test"]: print(
    s, "\tMAE of (min, max) = ({:5.2f}, {:5.2f})".format(
        np.nanmean(np.abs(sample[(sample["split"] == s) & (sample["condition"] != 1)]["(Adj2P-T)MinBP"])),
        np.nanmean(np.abs(sample[(sample["split"] == s) & (sample["condition"] != 1)]["(Adj2P-T)MaxBP"])),
    )
)

train 	MAE of (min, max) = ( 3.44,  4.45)
test 	MAE of (min, max) = (11.67, 18.94)


In [None]:
""" visualization """

# waveform: (21592, 7, 1000) from result.pt
waveforms = result.detach().cpu().numpy()  # shape: (N, 7, 1000)

# ---------- global y range for waveform ----------
# 上 subplot：ch5–6 (BP before normalization)
wave_min_row1 = float(waveforms[:, 5:7, :].min())
wave_max_row1 = float(waveforms[:, 5:7, :].max())

# 下 subplot：ch0–4 (after normalization)
wave_min_row2 = float(waveforms[:, 0:5, :].min())
wave_max_row2 = float(waveforms[:, 0:5, :].max())

pad1 = 0.05 * (wave_max_row1 - wave_min_row1 + 1e-8)
pad2 = 0.05 * (wave_max_row2 - wave_min_row2 + 1e-8)

wave_range_row1 = [wave_min_row1 - pad1, wave_max_row1 + pad1]
wave_range_row2 = [wave_min_row2 - pad2, wave_max_row2 + pad2]

# ==================== scatter 配置 ====================

numeric_cols = sample.select_dtypes(include=["float64"]).columns.tolist()

filter_cols = [
    "subject",
    "health",
    "system",
    "repeat",
    "condition",
    "split",
]

hover_cols = ["subject", "health", "system", "repeat", "condition", "split"]

multi_filter_cols = ["subject", "condition"]          # 多值：用 SelectMultiple
bool_like_cols   = ["health", "system", "repeat", "split"]  # 少值：用 Dropdown

# ======== 一些全局常量 / 映射 =========
COMMON_FONT = 12
BOOL_STR_MAP = {
    False: "false", True: "true",
    "False": "false", "True": "true",
    0: "false", 1: "true"
}
COLOR_SEQ = plotly.express.colors.qualitative.Plotly
SYMBOL_SEQ = ["circle", "square", "diamond", "x", "cross", "triangle-up"]

# waveform 共有的 t 轴
N_T = waveforms.shape[2]
GLOBAL_T = np.arange(N_T)

# ========= 控件 =========
UNIFIED_WIDTH = "220px"
CONTROL_STYLE = {"description_width": "70px"}
# ---- xy ----
x1_dropdown = ipywidgets.Dropdown(
    options=numeric_cols, value="PredMinBP", description="x1",
    layout=ipywidgets.Layout(width=UNIFIED_WIDTH), style=CONTROL_STYLE,
)
y1_dropdown = ipywidgets.Dropdown(
    options=numeric_cols, value="(P-T)MinBP", description="y1",
    layout=ipywidgets.Layout(width=UNIFIED_WIDTH), style=CONTROL_STYLE,
)
x2_dropdown = ipywidgets.Dropdown(
    options=[None] + numeric_cols, value=None, description="x2",
    layout=ipywidgets.Layout(width=UNIFIED_WIDTH), style=CONTROL_STYLE,
)
y2_dropdown = ipywidgets.Dropdown(
    options=[None] + numeric_cols, value=None, description="y2",
    layout=ipywidgets.Layout(width=UNIFIED_WIDTH), style=CONTROL_STYLE,
)
# ---- color ----
color_dropdown = ipywidgets.Dropdown(
    options=filter_cols, value="condition", description="color",
    layout=ipywidgets.Layout(width=UNIFIED_WIDTH), style=CONTROL_STYLE,
)
# ---- 点大小和透明度 ----
size_input = ipywidgets.BoundedIntText(
    value=4, min=1, max=15, step=1, description="size", 
    layout=ipywidgets.Layout(width=UNIFIED_WIDTH), style=CONTROL_STYLE,
)
opacity_input = ipywidgets.BoundedFloatText(
    value=0.7, min=0.1, max=1.0, step=0.05, description="opacity",
    layout=ipywidgets.Layout(width=UNIFIED_WIDTH), style=CONTROL_STYLE
)
# ---- 多值 filter（subject / condition）: SelectMultiple ----
multi_filter_widgets = {}
for col in multi_filter_cols:
    unique_vals = sorted(sample[col].unique())
    options = [(str(v), v) for v in unique_vals]
    sm = ipywidgets.SelectMultiple(
        options=options,
        value=tuple(v for _, v in options),
        description=col,
        layout=ipywidgets.Layout(width=UNIFIED_WIDTH, height="125px"),
        style=CONTROL_STYLE,
    )
    multi_filter_widgets[col] = sm
# ---- bool-like filter：Dropdown ----
bool_filter_radios = {}
for col in bool_like_cols:
    unique_vals = sorted(sample[col].unique())
    if col in ["health", "repeat"]:
        options = [("all", "__ALL__")] + [(str(v).lower(), v) for v in unique_vals]
    else:
        options = [("all", "__ALL__")] + [(str(v), v) for v in unique_vals]
    radio = ipywidgets.Dropdown(
        options=options,
        value="test" if col == "split" else "__ALL__",
        description=col,
        layout=ipywidgets.Layout(width=UNIFIED_WIDTH),
        style=CONTROL_STYLE,
    )
    bool_filter_radios[col] = radio

# ========= 输出区域 =========
plot_output = ipywidgets.Output(
)

# ========= 全局开关 & 全局变量 =========
suppress_update = False
sample_idx_pos = None  # customdata 中 s 的位置（列索引）
wave_trace_indices = {}  # 保存 waveform traces 的 index

# ==================== scatter + waveform（3 subplot）更新函数 ====================

def update_plot(*args):
    global suppress_update, sample_idx_pos, wave_trace_indices
    if suppress_update:
        return

    with plot_output:
        clear_output(wait=True)

        # ---------- 先用 4 个 bool dropdown 过滤 ----------
        df_bool = sample
        for col, w in bool_filter_radios.items():
            val = w.value
            if val == "__ALL__":
                continue
            df_bool = df_bool[df_bool[col] == val]
        if df_bool.empty:
            suppress_update = True
            try:
                for col, w in multi_filter_widgets.items():
                    w.options = []
                    w.value = ()
            finally:
                suppress_update = False
            print("No samples match the current filters.")
            return

        # ---------- 根据 bool 过滤后的数据，更新 subject / condition 的 options ----------
        suppress_update = True
        try:
            for col, w in multi_filter_widgets.items():
                new_vals = sorted(df_bool[col].unique())
                new_options = [(str(v), v) for v in new_vals]

                old_selected = list(w.value)
                new_selected = tuple(v for v in old_selected if v in new_vals)

                if len(new_selected) == 0:
                    new_selected = tuple(new_vals)

                w.options = new_options
                w.value = new_selected
        finally:
            suppress_update = False

        # ---------- 再在 df_bool 基础上应用 subject / condition 过滤 ----------
        sub_df = df_bool
        for col, w in multi_filter_widgets.items():
            selected = list(w.value)
            if len(selected) > 0:
                sub_df = sub_df[sub_df[col].isin(selected)]

        if sub_df.empty:
            print("No samples match the current filters.")
            return

        # 给每个 sample 记一下在 profile 中的 index（0-based）
        sub_df = sub_df.copy()
        sub_df["sample_idx"] = sub_df.index

        # ---------- 画 scatter + 占位 waveform ----------

        # pair 1
        x1 = x1_dropdown.value
        y1 = y1_dropdown.value

        # pair 2
        x2 = x2_dropdown.value
        y2 = y2_dropdown.value
        has_x2 = (x2 is not None) and (y2 is not None)

        color_col = color_dropdown.value
        point_size = size_input.value
        point_opacity = opacity_input.value

        # 数据展开：为 pair1 / pair2 各复制一份
        df_list = []

        df1 = sub_df.copy()
        df1["_x"] = df1[x1]
        df1["_y"] = df1[y1]
        df1["_pair"] = "pair1"
        df_list.append(df1)

        if has_x2:
            df2 = sub_df.copy()
            df2["_x"] = df2[x2]
            df2["_y"] = df2[y2]
            df2["_pair"] = "pair2"
            df_list.append(df2)

        df_long = pd.concat(df_list, ignore_index=True)

        # --- color legend mapping（统一处理） ---
        if color_col:
            if color_col in ["health", "repeat"]:
                df_long["_color_label"] = (
                    df_long[color_col].map(BOOL_STR_MAP).fillna(df_long[color_col]).astype(str)
                )
            else:
                df_long["_color_label"] = df_long[color_col].astype(str)

            color_vals = sorted(df_long["_color_label"].unique())
        else:
            df_long["_color_label"] = "all"
            color_vals = ["all"]

        color_map = {v: COLOR_SEQ[i % len(COLOR_SEQ)] for i, v in enumerate(color_vals)}

        # --- shape 映射 ---
        pair_vals = sorted(df_long["_pair"].unique())  # ["pair1", "pair2"...]
        symbol_map = {p: SYMBOL_SEQ[i] for i, p in enumerate(pair_vals)}
        if "pair1" not in symbol_map:
            symbol_map["pair1"] = SYMBOL_SEQ[0]
        if "pair2" not in symbol_map:
            symbol_map["pair2"] = SYMBOL_SEQ[1]

        # x1 vs y1 / x2 vs y2 标签
        pair_label_map = {
            "pair1": "x1 vs y1",
            "pair2": "x2 vs y2",
        }

        # ====== 组装 hovertemplate（x1,y1,x2,y2,s, metadata） ======
        sample_idx_series = df_long["sample_idx"].astype(int)
        sample_idx_arr = sample_idx_series.to_numpy()

        sub_indexed = sub_df.set_index("sample_idx")

        # x1 / y1
        x1_for_row = sub_indexed.loc[sample_idx_arr, x1].to_numpy()
        y1_for_row = sub_indexed.loc[sample_idx_arr, y1].to_numpy()

        custom_columns = []
        lines = []

        custom_columns.append(x1_for_row)
        custom_columns.append(y1_for_row)
        lines.append("x1 = %{customdata[0]}")
        lines.append("y1 = %{customdata[1]}")
        next_idx = 2

        # x2, y2（如果存在）
        if has_x2:
            x2_for_row = sub_indexed.loc[sample_idx_arr, x2].to_numpy()
            y2_for_row = sub_indexed.loc[sample_idx_arr, y2].to_numpy()
            custom_columns.append(x2_for_row)
            custom_columns.append(y2_for_row)
            lines.append(f"x2 = %{{customdata[{next_idx}]}}")
            lines.append(f"y2 = %{{customdata[{next_idx+1}]}}")
            next_idx += 2

        # s = sample index
        sample_idx_pos = next_idx
        custom_columns.append(sample_idx_arr)
        lines.append(f"s = %{{customdata[{next_idx}]}}")
        next_idx += 1

        # metadata
        for col in hover_cols:
            if col in ["health", "repeat"]:
                vals = df_long[col].map(BOOL_STR_MAP).fillna(df_long[col]).astype(str).to_numpy()
            else:
                vals = df_long[col].astype(str).to_numpy()
            custom_columns.append(vals)
            lines.append(f"{col} = %{{customdata[{next_idx}]}}")
            next_idx += 1

        hover_template = "<br>".join(lines) + "<extra></extra>"
        customdata = np.column_stack(custom_columns)

        # ========= 构建 2×2 subplot: 左边 scatter, 右上/右下 waveform =========
        base_fig = plotly.subplots.make_subplots(
            rows=2, cols=2,
            specs=[
                [{"rowspan": 2}, {"type": "xy"}],
                [None, {"type": "xy"}]
            ],
            column_widths=[0.5, 0.5],
            row_heights=[0.5, 0.5],
            horizontal_spacing=0.035,
            vertical_spacing=0.035,
            shared_xaxes=True,   # 右上、右下共享 x 轴
        )
        fig = plotly.graph_objects.FigureWidget(base_fig)

        # --- 主散点（数据 trace，本身不进 legend） ---
        if color_col:
            for c in color_vals:
                for p in pair_vals:
                    mask = (df_long["_color_label"] == c) & (df_long["_pair"] == p)
                    if not mask.any():
                        continue
                    sub = df_long[mask]
                    fig.add_trace(
                        plotly.graph_objects.Scatter(
                            x=sub["_x"],
                            y=sub["_y"],
                            mode="markers",
                            marker=dict(
                                color=color_map[c],
                                symbol=symbol_map.get(p, SYMBOL_SEQ[0]),
                                size=point_size,
                                opacity=point_opacity,
                            ),
                            customdata=customdata[mask.to_numpy()],
                            hovertemplate=hover_template,
                            showlegend=False,
                        ),
                        row=1, col=1,
                    )
        else:
            for p in pair_vals:
                mask = (df_long["_pair"] == p)
                if not mask.any():
                    continue
                sub = df_long[mask]
                fig.add_trace(
                    plotly.graph_objects.Scatter(
                        x=sub["_x"],
                        y=sub["_y"],
                        mode="markers",
                        marker=dict(
                            color="blue",
                            symbol=symbol_map.get(p, SYMBOL_SEQ[0]),
                            size=point_size,
                            opacity=point_opacity,
                        ),
                        customdata=customdata[mask.to_numpy()],
                        hovertemplate=hover_template,
                        showlegend=False,
                    ),
                    row=1, col=1,
                )

        # ========= 先加 waveform 线条（legend group: waveform） =========
        t = GLOBAL_T
        color_bp_true = "#1f77b4"
        color_bp_pred = "#ff7f0e"

        # hovertemplate for top waveform subplot
        waveform_hover = (
            "T = %{x}<br>"
            "BP True = %{customdata[0]:.2f}<br>"
            "BP Pred = %{customdata[1]:.2f}<br>"
            "|Pred-True| = %{customdata[2]:.2f}<extra></extra>"
        )

        # 占位 customdata（后面点击时再填真实值）
        empty_wave_custom = np.zeros((N_T, 3), dtype=float)

        # row1: before normalization
        start_idx = len(fig.data)
        idx_bp_before_true = start_idx
        fig.add_trace(
            plotly.graph_objects.Scatter(
                x=t,
                y=np.full(N_T, np.nan),
                mode="lines",
                name="BP True",
                line=dict(color=color_bp_true, width=2),
                legendgroup="waveform",
                legendgrouptitle_text="waveform",
                showlegend=True,
                customdata=empty_wave_custom,
                hovertemplate=waveform_hover,
            ),
            row=1, col=2,
        )
        idx_bp_before_pred = start_idx + 1
        fig.add_trace(
            plotly.graph_objects.Scatter(
                x=t,
                y=np.full(N_T, np.nan),
                mode="lines",
                name="BP Pred",
                line=dict(color=color_bp_pred, width=2),
                legendgroup="waveform",
                showlegend=True,
                customdata=empty_wave_custom,
                hovertemplate=waveform_hover,
            ),
            row=1, col=2,
        )

        # row2: after normalization, 3 + 2 traces（保持原来的简单 hover）
        idx_ch0 = len(fig.data)
        for ch in range(3):
            fig.add_trace(
                plotly.graph_objects.Scatter(
                    x=t,
                    y=np.full(N_T, np.nan),
                    mode="lines",
                    name=f"ch{ch}",
                    opacity=0.5,
                    line=dict(width=2.0),
                    legendgroup="waveform",
                    showlegend=True,
                ),
                row=2, col=2,
            )
        idx_bp_after_true = idx_ch0 + 3
        fig.add_trace(
            plotly.graph_objects.Scatter(
                x=t,
                y=np.full(N_T, np.nan),
                mode="lines",
                name="BP True",
                line=dict(color=color_bp_true, width=2),
                legendgroup="waveform",
                showlegend=False,
            ),
            row=2, col=2,
        )
        idx_bp_after_pred = idx_ch0 + 4
        fig.add_trace(
            plotly.graph_objects.Scatter(
                x=t,
                y=np.full(N_T, np.nan),
                mode="lines",
                name="BP Pred",
                line=dict(color=color_bp_pred, width=2),
                legendgroup="waveform",
                showlegend=False,
            ),
            row=2, col=2,
        )

        # 记录 waveform trace index，供点击回调用
        wave_trace_indices = {
            "bp_before_true": idx_bp_before_true,
            "bp_before_pred": idx_bp_before_pred,
            "ch0": idx_ch0,
            "ch1": idx_ch0 + 1,
            "ch2": idx_ch0 + 2,
            "bp_after_true": idx_bp_after_true,
            "bp_after_pred": idx_bp_after_pred,
        }

        # ========= 然后 scatter legend：先 shape 再 color（同一组 scatter） =========
        # 1) shape legend（两条永远都有）
        for i, p in enumerate(["pair1", "pair2"]):
            fig.add_trace(
                plotly.graph_objects.Scatter(
                    x=[None], y=[None],
                    mode="markers",
                    marker=dict(
                        color="black",
                        size=COMMON_FONT,
                        symbol=symbol_map.get(p, SYMBOL_SEQ[i]),
                    ),
                    name=pair_label_map[p],
                    showlegend=True,
                    legendgroup="scatter",
                    legendgrouptitle_text="scatter" if i == 0 else None,
                )
            )

        # 2) color legend
        for c in color_vals:
            fig.add_trace(
                plotly.graph_objects.Scatter(
                    x=[None], y=[None],
                    mode="markers",
                    marker=dict(
                        color=color_map[c],
                        size=COMMON_FONT,
                    ),
                    name=f"{color_col} = {c}",
                    showlegend=True,
                    legendgroup="scatter",
                )
            )

        # ========= 坐标轴 & layout =========
        fig.update_yaxes(row=1, col=2, range=wave_range_row1)
        fig.update_yaxes(row=2, col=2, range=wave_range_row2)

        fig.update_xaxes(
            showticklabels=False,
            row=1, col=2,
        )
        fig.update_xaxes(row=2, col=2,)

        fig.update_layout(
            height=550,
            width=1250,
            # autosize=True,                     # 让 plotly 自动用容器宽度
            font=dict(size=COMMON_FONT),
            margin=dict(l=0, r=0, t=0, b=0),
            legend=dict(orientation="v"),
        )

        # ====== 点击事件：点 waveform（batch_update，一次性刷新） ======
        def handle_click(trace, points, state_click):
            if not points.point_inds:
                return
            idx0 = points.point_inds[0]
            cd_row = trace.customdata[idx0]
            s_idx = int(cd_row[sample_idx_pos])
            if (s_idx < 0) or (s_idx >= waveforms.shape[0]):
                return

            w = waveforms[s_idx]
            if (w.ndim != 2) or (w.shape[0] < 7):
                return

            bp_true = w[5]
            bp_pred = w[6]
            bp_diff = np.abs(bp_true - bp_pred)

            # 整个 waveform 的 max / min 的 MAE
            TrueMaxBP = float(bp_true.max())
            PredMaxBP = float(bp_pred.max())
            TrueMinBP = float(bp_true.min())
            PredMinBP = float(bp_pred.min())
            mae_max = abs(TrueMaxBP - PredMaxBP)
            mae_min = abs(TrueMinBP - PredMinBP)

            # per-x customdata: [bp_true, bp_pred, diff]
            wave_custom = np.stack([bp_true, bp_pred, bp_diff], axis=1)  # (N_T, 3)

            with fig.batch_update():
                # 更新 before normalization
                fig.data[wave_trace_indices["bp_before_true"]].y = bp_true
                fig.data[wave_trace_indices["bp_before_pred"]].y = bp_pred

                # 更新 hover customdata
                fig.data[wave_trace_indices["bp_before_true"]].customdata = wave_custom
                fig.data[wave_trace_indices["bp_before_pred"]].customdata = wave_custom

                # 更新 after normalization: ch0-2 + BP True/Pred
                fig.data[wave_trace_indices["ch0"]].y = w[0]
                fig.data[wave_trace_indices["ch1"]].y = w[1]
                fig.data[wave_trace_indices["ch2"]].y = w[2]
                fig.data[wave_trace_indices["bp_after_true"]].y = w[3]
                fig.data[wave_trace_indices["bp_after_pred"]].y = w[4]

                # 在上面那个 waveform subplot 的右上角加 annotation
                fig.update_layout(
                    annotations=[
                        dict(
                            xref="x2 domain",
                            yref="y2 domain",
                            x=1.0,
                            y=0.0,
                            text=f"MAE of (min, max) = ({mae_min:.2f}, {mae_max:.2f})",
                            showarrow=False,
                            bgcolor="rgba(255,255,255,0.7)",
                            borderwidth=1,
                            font=dict(size=COMMON_FONT),
                        )
                    ]
                )

        # 只给有 customdata 的 scatter trace 绑定点击事件
        for tr in fig.data:
            if (
                getattr(tr, "mode", None) == "markers"
                and getattr(tr, "customdata", None) is not None
            ):
                tr.on_click(handle_click)

        display(fig)

# ========= 事件绑定 =========
widgets_list = [
    x1_dropdown, y1_dropdown,
    x2_dropdown, y2_dropdown,
    color_dropdown,
] + list(multi_filter_widgets.values()) + list(bool_filter_radios.values()) + [size_input, opacity_input]

for w in widgets_list:
    w.observe(update_plot, names="value")

# ========= UI 布局 =========
left_col = ipywidgets.HBox([
    ipywidgets.VBox([
        x1_dropdown, y1_dropdown,
        x2_dropdown, y2_dropdown,
    ]),
    ipywidgets.VBox([
        color_dropdown,
        size_input,
        opacity_input,
    ]),
    multi_filter_widgets["subject"],
    multi_filter_widgets["condition"],
    ipywidgets.VBox([
        bool_filter_radios["health"],
        bool_filter_radios["system"],
        bool_filter_radios["repeat"],
        bool_filter_radios["split"],
    ]),
])

ui = ipywidgets.VBox([left_col, plot_output])

display(ui)
update_plot()