## Finetune

In `analysis.py`, I implemented subject-level adaptation. The workflow is to
finetune on data from condition 1, and then test on data from the other
conditions.

There are several parameters involved:

subject = "S001"
config.runner.weight_min = 0.5
config.runner.weight_max = 0.5

Here, `subject` indicates which subject we want to finetune on.
`weight_min` and `weight_max` control the contribution of different targets
to the loss during training.

Now, could you help me write a new script that extends this analysis to all
subjects? Specifically, for each subject:

- Use the following 5 parameter settings to finetune:
  - `weight_min = [1e-6, 0.1, 0.5, 0.9, 1-1e-6]`
  - `weight_max = [1-1e-6, 0.9, 0.5, 0.1, 1e-6]`
  Each pair `(weight_min, weight_max)` defines one finetuning run, so we will
  obtain 5 models per subject.
- Use each model to predict data from this subject with `condition != 1`.
  This yields 5 sets of `(PredMinBP, PredMaxBP)`, together with the ground
  truth `(TrueMinBP, TrueMaxBP)`. Therefore, for each subject we will have
  6 sets of values, i.e. 12 columns in total.
- Repeat this procedure for every subject. Note that each subject is trained
  independently: for each subject and each parameter setting, the checkpoint
  must be reloaded and finetuning must be performed from scratch.

In addition, although I say “all subjects,” we only need to process subjects
with `split = 1` or `split = 2`. The `profile` file contains a column `split`.
Although `split` is assigned at the sample level, in practice all samples of
the same subject belong to the same split. I only care about subjects whose
split is 1 or 2.

After processing all subjects with split equal to 1 or 2, we will obtain a
profile file of shape `N × (K + 2 × 6)`, where `K` is the number of original
columns in `data.profile.copy()` (i.e., subject-level metadata).

If a subject does not have data with `condition = 1`, or does not have any
data with `condition != 1`, then skip this subject entirely. The final profile
file should not include samples from this subject.

Please save the final profile file to the directory
`data/presentation/ResultFintune` (create this directory if it does not
exist).

If possible, please write this script as `script/temp.py`. I will run this
file myself. When running, the script should display three nested `tqdm`
progress bars:
- The outer bar shows progress over subjects.
- The middle bar shows progress over parameter settings for the current
  subject.
- The inner bar shows training progress for the current finetuning run.

In [None]:
from __future__ import annotations

import os
from pathlib import Path
import warnings

import lightning
import numpy as np
import pandas as pd
import torch
import tqdm

import src
import config

# Which config to use
config_name = "Finetune"


def _device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    if torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def _set_seed_and_precision() -> None:
    torch.set_float32_matmul_precision("medium")
    lightning.seed_everything(42, workers=True, verbose=False)
    warnings.filterwarnings("ignore", message=".*MPS.*fallback.*")


def _load_model(cfg: config.Config, device: str) -> src.Model:
    if cfg.trainer.ckpt_load_path is None:
        raise ValueError("config.trainer.ckpt_load_path must be set")
    model = src.Model(**vars(cfg.model))
    src.Trainer.ckptLoader_(model, cfg.trainer.ckpt_load_path).to(device)
    model.freeze()
    return model


def _loss_shape(x: torch.Tensor, y: torch.Tensor, cfg) -> torch.Tensor:
    return torch.nn.functional.smooth_l1_loss(  # (B,)
        input=torch.stack([
            torch.roll(x, shifts=s, dims=-1)
            for s in range(-cfg.runner.K, cfg.runner.K + 1)
        ], dim=1)[:, :, cfg.runner.K:-cfg.runner.K],
        target=y.unsqueeze(1).expand(
            (-1, 2 * cfg.runner.K + 1, -1)
        )[:, :, cfg.runner.K:-cfg.runner.K],
        reduction="none",
    ).mean(dim=-1).min(dim=-1).values.mean()


def _loss_min(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.smooth_l1_loss(
        x.min(dim=-1).values, y.min(dim=-1).values
    )


def _loss_max(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.smooth_l1_loss(
        x.max(dim=-1).values, y.max(dim=-1).values
    )


def _weight_tag(w_min: float, w_max: float) -> str:
    def fmt(v: float) -> str:
        txt = f"{v:.6f}"
        txt = txt.rstrip("0").rstrip(".") if "." in txt else txt
        return txt.replace(".", "p")

    return f"wmin{fmt(w_min)}_wmax{fmt(w_max)}"


def main() -> None:
    _set_seed_and_precision()
    device = _device()

    cfg: config.Config = getattr(cfg, config_name)()
    data = src.DataModule(**vars(cfg.data))
    data.setup()

    profile = data.profile.copy()
    split_mask = profile["split"].isin([1, 2])
    subjects = sorted(profile.loc[split_mask, "subject"].unique())

    weight_pairs: list[tuple[float, float]] = [
        (1e-6, 1 - 1e-6),
        (0.1, 0.9),
        (0.5, 0.5),
        (0.9, 0.1),
        (1 - 1e-6, 1e-6),
    ]

    out_rows: list[pd.DataFrame] = []
    subjects_bar = tqdm.tqdm(subjects, desc="Subjects", position=0)
    for subject in subjects_bar:
        subjects_bar.set_postfix_str(subject)
        subject_profile = profile.loc[profile["subject"] == subject]
        cond_eq1_mask = subject_profile["condition"] == 1
        cond_ne1_mask = subject_profile["condition"] != 1
        if (not cond_eq1_mask.any()) or (not cond_ne1_mask.any()):
            subjects_bar.write(
                f"Skip {subject}: missing condition 1 or condition != 1 data"
            )
            continue

        base_out = subject_profile.loc[cond_ne1_mask].copy().reset_index(drop=True)
        true_filled = False

        weight_bar = tqdm.tqdm(
            weight_pairs, desc="Weight sets", position=1, leave=False
        )
        for w_min, w_max in weight_bar:
            weight_bar.set_postfix({"w_min": w_min, "w_max": w_max})
            cfg.objective.weight_min = w_min
            cfg.objective.weight_max = w_max

            model = _load_model(cfg, device)

            # backbone outputs for finetuning
            x_train, y_train = [], []
            for batch in data.train_dataloader(subject=subject):
                batch = [b.to(device) for b in batch]
                with torch.no_grad():
                    x_train.append(model.forward(batch[0], batch[1], pool_dim=1))
                y_train.append(batch[2])
            x_train = torch.cat(x_train, dim=0)
            y_train = torch.cat(y_train, dim=0)

            x_valid, y_valid = [], []
            for batch in data.val_dataloader(subject=subject):
                batch = [b.to(device) for b in batch]
                with torch.no_grad():
                    x_valid.append(model.forward(batch[0], batch[1], pool_dim=1))
                y_valid.append(batch[2])
            x_valid = torch.cat(x_valid, dim=0)
            y_valid = torch.cat(y_valid, dim=0)

            optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.objective.lr)

            epoch_bar = tqdm.trange(
                cfg.trainer.max_epochs, desc="Epochs", position=2, leave=False
            )
            for _ in epoch_bar:
                model.train()
                x = model.forwardAdapter(x_train)
                train_loss_shape = _loss_shape(x, y_train, cfg)
                train_loss_min = _loss_min(x, y_train)
                train_loss_max = _loss_max(x, y_train)
                train_loss = (
                    cfg.objective.weight_shape * train_loss_shape
                    + cfg.objective.weight_min * train_loss_min
                    + cfg.objective.weight_max * train_loss_max
                )
                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()

                model.eval()
                with torch.no_grad():
                    x_val = model.forwardAdapter(x_valid)
                    valid_loss_shape = _loss_shape(x_val, y_valid, cfg)
                    valid_loss_min = _loss_min(x_val, y_valid)
                    valid_loss_max = _loss_max(x_val, y_valid)
                    valid_loss = (
                        cfg.objective.weight_shape * valid_loss_shape
                        + cfg.objective.weight_min * valid_loss_min
                        + cfg.objective.weight_max * valid_loss_max
                    )
                epoch_bar.set_postfix(
                    {
                        "train": float(train_loss),
                        "valid": float(valid_loss),
                    }
                )

            # prediction on this subject (all conditions), then keep condition != 1
            model.eval()
            result_batches = []
            for batch in data.test_dataloader(subject=subject):
                x, channel_idx, y = batch
                x, channel_idx, y = x.to(device), channel_idx.to(device), y.to(device)
                with torch.no_grad():
                    x_pred = model.forwardRegression(
                        x, channel_idx, adapter=True
                    )
                result_batches.append(torch.cat([
                    x.detach().cpu(),                       # (B, 4, T)
                    y.detach().cpu().unsqueeze(1),          # (B, 1, T)
                    x_pred.detach().cpu().unsqueeze(1),     # (B, 1, T)
                ], dim=1))
            result = torch.cat(result_batches, dim=0)       # (N, 5, T)
            result = torch.cat([
                result,
                data.denormalize(result[:, 3, :]).unsqueeze(1),
                data.denormalize(result[:, 4, :]).unsqueeze(1),
            ], dim=1).detach().cpu()                        # (N, 7, T)

            true_min = result[:, 5].min(dim=1).values.numpy()
            true_max = result[:, 5].max(dim=1).values.numpy()
            pred_min = result[:, 6].min(dim=1).values.numpy()
            pred_max = result[:, 6].max(dim=1).values.numpy()

            cond_ne1_array = cond_ne1_mask.to_numpy()
            if not true_filled:
                base_out["TrueMinBP"] = true_min[cond_ne1_array]
                base_out["TrueMaxBP"] = true_max[cond_ne1_array]
                true_filled = True

            tag = _weight_tag(w_min, w_max)
            base_out[f"PredMinBP_{tag}"] = pred_min[cond_ne1_array]
            base_out[f"PredMaxBP_{tag}"] = pred_max[cond_ne1_array]

            if device == "cuda":
                torch.cuda.empty_cache()

        out_rows.append(base_out)

    if not out_rows:
        print("No subjects processed; nothing saved.")
        return

    out_df = pd.concat(out_rows, axis=0).reset_index(drop=True)
    out_dir = Path("data/presentation/ResultFintune")
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / "profile.csv"
    out_df.to_csv(out_path, index=False)
    print(f"Saved aggregated profile to {out_path} (rows={len(out_df)})")

# Ensure the working directory is the repo root so relative paths resolve
os.chdir(Path(__file__).resolve().parents[1])
main()

## Analysis

While the code is running, please help me write a script to analyze this `profile`.

I am very interested in the final overall MAE. There are two evaluation strategies.

### Strategy 1: Single parameter set for all samples
In the first strategy, all samples use the same parameter set to compute the final MAE.  
Note that **min and max are treated separately**:  
- one parameter set is chosen to minimize the MAE of **min**,  
- another (possibly different) parameter set is chosen to minimize the MAE of **max**.  

The requirement is that, within each case (min or max), the **same parameter set is applied to all subjects and all samples**.

Please report:
- which parameter set achieves the best MAE for **min**, and what that best MAE is;
- which parameter set achieves the best MAE for **max**, and what that best MAE is.

Then, please visualize:
- the **per-subject MAE**;
- all samples plotted together, with the x-axis being the prediction and the y-axis being the ground truth, colored by subject.

---

### Strategy 2: Best parameter set per subject
In the second strategy, for each subject, the prediction values for **min** and **max** are taken from the parameter set that performs best for that subject.

The remaining analysis is the same as in Strategy 1. Please also:
- report the best overall MAE;
- visualize the per-subject MAE;
- visualize all samples plotted together, with the x-axis being the predicted BP and the y-axis being the ground truth, colored by subject.

Note that in the visualizations, **min and max should be handled and plotted separately**.

---

Please write everything in `script/temp.ipynb` so that I can easily inspect all figures.

In [None]:
from __future__ import annotations

import dataclasses
import os
from pathlib import Path
import warnings

import lightning
import numpy as np
import pandas as pd
import torch
import tqdm

import src

# Which config to use
config_name = "Finetune"


def _device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    if torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def _set_seed_and_precision() -> None:
    torch.set_float32_matmul_precision("medium")
    lightning.seed_everything(42, workers=True, verbose=False)
    warnings.filterwarnings("ignore", message=".*MPS.*fallback.*")


def _load_model(cfg: config.Config, device: str) -> src.Model:
    if cfg.trainer.ckpt_load_path is None:
        raise ValueError("config.trainer.ckpt_load_path must be set")
    model = src.Model(**vars(cfg.model))
    src.Trainer.ckptLoader_(model, cfg.trainer.ckpt_load_path).to(device)
    model.freeze()
    return model


def _loss_shape(x: torch.Tensor, y: torch.Tensor, cfg) -> torch.Tensor:
    return torch.nn.functional.smooth_l1_loss(  # (B,)
        input=torch.stack([
            torch.roll(x, shifts=s, dims=-1)
            for s in range(-cfg.runner.K, cfg.runner.K + 1)
        ], dim=1)[:, :, cfg.runner.K:-cfg.runner.K],
        target=y.unsqueeze(1).expand(
            (-1, 2 * cfg.runner.K + 1, -1)
        )[:, :, cfg.runner.K:-cfg.runner.K],
        reduction="none",
    ).mean(dim=-1).min(dim=-1).values.mean()


def _loss_min(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.smooth_l1_loss(
        x.min(dim=-1).values, y.min(dim=-1).values
    )


def _loss_max(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.smooth_l1_loss(
        x.max(dim=-1).values, y.max(dim=-1).values
    )


def _weight_tag(w_min: float, w_max: float) -> str:
    def fmt(v: float) -> str:
        txt = f"{v:.6f}"
        txt = txt.rstrip("0").rstrip(".") if "." in txt else txt
        return txt.replace(".", "p")

    return f"wmin{fmt(w_min)}_wmax{fmt(w_max)}"


def main() -> None:
    _set_seed_and_precision()
    device = _device()

    config: src.config.Config = getattr(src.config, config_name)()
    data = src.data.DataModule(**vars(config.data))
    data.setup()

    profile = data.profile.copy()
    split_mask = profile["split"].isin([1, 2])
    subjects = sorted(profile.loc[split_mask, "subject"].unique())

    weight_pairs: list[tuple[float, float]] = [
        (1e-6, 1 - 1e-6),
        (0.1, 0.9),
        (0.5, 0.5),
        (0.9, 0.1),
        (1 - 1e-6, 1e-6),
    ]

    out_rows: list[pd.DataFrame] = []
    subjects_bar = tqdm.tqdm(subjects, desc="Subjects", position=0)
    for subject in subjects_bar:
        subjects_bar.set_postfix_str(subject)
        subject_profile = profile.loc[profile["subject"] == subject]
        cond_eq1_mask = subject_profile["condition"] == 1
        cond_ne1_mask = subject_profile["condition"] != 1
        if (not cond_eq1_mask.any()) or (not cond_ne1_mask.any()):
            subjects_bar.write(
                f"Skip {subject}: missing condition 1 or condition != 1 data"
            )
            continue

        base_out = subject_profile.loc[cond_ne1_mask].copy().reset_index(drop=True)
        true_filled = False

        weight_bar = tqdm.tqdm(
            weight_pairs, desc="Weight sets", position=1, leave=False
        )
        for w_min, w_max in weight_bar:
            weight_bar.set_postfix({"w_min": w_min, "w_max": w_max})
            config.objective.weight_min = w_min
            config.objective.weight_max = w_max

            model = _load_model(config, device)

            # backbone outputs for finetuning
            x_train, y_train = [], []
            for batch in data.train_dataloader(subject=subject):
                batch = [b.to(device) for b in batch]
                with torch.no_grad():
                    x_train.append(model.forward(batch[0], batch[1], pool_dim=1))
                y_train.append(batch[2])
            x_train = torch.cat(x_train, dim=0)
            y_train = torch.cat(y_train, dim=0)

            x_valid, y_valid = [], []
            for batch in data.val_dataloader(subject=subject):
                batch = [b.to(device) for b in batch]
                with torch.no_grad():
                    x_valid.append(model.forward(batch[0], batch[1], pool_dim=1))
                y_valid.append(batch[2])
            x_valid = torch.cat(x_valid, dim=0)
            y_valid = torch.cat(y_valid, dim=0)

            optimizer = torch.optim.AdamW(model.parameters(), lr=config.objective.lr)

            epoch_bar = tqdm.trange(
                config.trainer.max_epochs, desc="Epochs", position=2, leave=False
            )
            for _ in epoch_bar:
                model.train()
                x = model.forwardAdapter(x_train)
                train_loss_shape = _loss_shape(x, y_train, config)
                train_loss_min = _loss_min(x, y_train)
                train_loss_max = _loss_max(x, y_train)
                train_loss = (
                    config.objective.weight_shape * train_loss_shape
                    + config.objective.weight_min * train_loss_min
                    + config.objective.weight_max * train_loss_max
                )
                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()

                model.eval()
                with torch.no_grad():
                    x_val = model.forwardAdapter(x_valid)
                    valid_loss_shape = _loss_shape(x_val, y_valid, config)
                    valid_loss_min = _loss_min(x_val, y_valid)
                    valid_loss_max = _loss_max(x_val, y_valid)
                    valid_loss = (
                        config.objective.weight_shape * valid_loss_shape
                        + config.objective.weight_min * valid_loss_min
                        + config.objective.weight_max * valid_loss_max
                    )
                epoch_bar.set_postfix(
                    {
                        "train": float(train_loss),
                        "valid": float(valid_loss),
                    }
                )

            # prediction on this subject (all conditions), then keep condition != 1
            model.eval()
            result_batches = []
            for batch in data.test_dataloader(subject=subject):
                x, channel_idx, y = batch
                x, channel_idx, y = x.to(device), channel_idx.to(device), y.to(device)
                with torch.no_grad():
                    x_pred = model.forwardRegressionAdapter(
                        x, channel_idx
                    )
                result_batches.append(torch.cat([
                    x.detach().cpu(),                       # (B, 4, T)
                    y.detach().cpu().unsqueeze(1),          # (B, 1, T)
                    x_pred.detach().cpu().unsqueeze(1),     # (B, 1, T)
                ], dim=1))
            result = torch.cat(result_batches, dim=0)       # (N, 5, T)
            result = torch.cat([
                result,
                data.denormalize(result[:, 3, :]).unsqueeze(1),
                data.denormalize(result[:, 4, :]).unsqueeze(1),
            ], dim=1).detach().cpu()                        # (N, 7, T)

            true_min = result[:, 5].min(dim=1).values.numpy()
            true_max = result[:, 5].max(dim=1).values.numpy()
            pred_min = result[:, 6].min(dim=1).values.numpy()
            pred_max = result[:, 6].max(dim=1).values.numpy()

            cond_ne1_array = cond_ne1_mask.to_numpy()
            if not true_filled:
                base_out["TrueMinBP"] = true_min[cond_ne1_array]
                base_out["TrueMaxBP"] = true_max[cond_ne1_array]
                true_filled = True

            tag = _weight_tag(w_min, w_max)
            base_out[f"PredMinBP_{tag}"] = pred_min[cond_ne1_array]
            base_out[f"PredMaxBP_{tag}"] = pred_max[cond_ne1_array]

            if device == "cuda":
                torch.cuda.empty_cache()

        out_rows.append(base_out)

    if not out_rows:
        print("No subjects processed; nothing saved.")
        return

    out_df = pd.concat(out_rows, axis=0).reset_index(drop=True)
    out_dir = Path("data/presentation/ResultFintune")
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / "profile.csv"
    out_df.to_csv(out_path, index=False)
    print(f"Saved aggregated profile to {out_path} (rows={len(out_df)})")


# Ensure the working directory is the repo root so relative paths resolve
os.chdir(Path(__file__).resolve().parents[1])
main()


In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import plotly.express as px

px.defaults.template = "plotly_white"
px.defaults.width = 1250
px.defaults.height = 600

def mae(y_true, y_pred) -> float:
    """Nan-tolerant MAE."""
    diff = np.abs(np.asarray(y_true) - np.asarray(y_pred))
    return float(np.nanmean(diff))

output_dir = Path("data/presentation/ResultFintune")
output_dir.mkdir(parents=True, exist_ok=True)

data_path = output_dir / "profile.csv"
profile = pd.read_csv(data_path)


In [None]:
# Detect available parameter sets (tags)
pred_min_cols = [c for c in profile.columns if c.startswith("PredMinBP_")]
pred_max_cols = [c for c in profile.columns if c.startswith("PredMaxBP_")]
tags_min = [c.removeprefix("PredMinBP_") for c in pred_min_cols]
tags_max = [c.removeprefix("PredMaxBP_") for c in pred_max_cols]
assert tags_min, "No PredMinBP_* columns found"
assert tags_max, "No PredMaxBP_* columns found"
tags_min, tags_max

In [None]:
# --- Strategy 1: global best parameter set (separate best for min vs. max) ---
def best_global(true_col: str, prefix: str, tags: list[str]):
    scores = {
        tag: mae(profile[true_col], profile[f"{prefix}{tag}"])
        for tag in tags
    }
    best_tag = min(scores, key=scores.get)
    return best_tag, scores

best_min_tag, min_scores = best_global("TrueMinBP", "PredMinBP_", tags_min)
best_max_tag, max_scores = best_global("TrueMaxBP", "PredMaxBP_", tags_max)
best_min_mae = min_scores[best_min_tag]
best_max_mae = max_scores[best_max_tag]

metrics_text = (
    f"Global best (min): {best_min_tag} -> MAE = {best_min_mae:.4f}\n"
    f"Global best (max): {best_max_tag} -> MAE = {best_max_mae:.4f}"
)
metrics_path = output_dir / "01GlobalBestParas.txt"
metrics_path.write_text(metrics_text + "\n")
print(f"Wrote global metrics to {metrics_path}")

# Per-sample errors for box plots
err_rows = []
for subj, df_sub in profile.groupby("subject"):
    err_rows.append(pd.DataFrame({
        "subject": subj,
        "metric": "diastolic",
        "error": np.abs(df_sub["TrueMinBP"] - df_sub[f"PredMinBP_{best_min_tag}"])
    }))
    err_rows.append(pd.DataFrame({
        "subject": subj,
        "metric": "systolic",
        "error": np.abs(df_sub["TrueMaxBP"] - df_sub[f"PredMaxBP_{best_max_tag}"])
    }))
sample_err_global = pd.concat(err_rows, ignore_index=True)

fig_box_global = px.box(
    sample_err_global, x="subject", y="error", color="metric",
    points="suspectedoutliers", title="Per-sample error (global best tags)",
    labels={"error": "|Pred-True|", "subject": "Subject"}
)

fig_box_global.update_traces(marker=dict(size=3), whiskerwidth=0)
fig_box_global.update_yaxes(type="log", range=[0, 2])
fig_box_global.write_image(output_dir / "01GlobalBestParas.png")


In [None]:
# --- Strategy 2: per-subject best parameter set (separate for min vs. max) ---
def best_tags_per_subject(true_col: str, prefix: str, tags: list[str]):
    out = {}
    for subj, df_sub in profile.groupby("subject"):
        scores = {
            tag: mae(df_sub[true_col], df_sub[f"{prefix}{tag}"])
            for tag in tags
        }
        out[subj] = min(scores, key=scores.get)
    return out

best_min_by_subj = best_tags_per_subject("TrueMinBP", "PredMinBP_", tags_min)
best_max_by_subj = best_tags_per_subject("TrueMaxBP", "PredMaxBP_", tags_max)

# Build columns with per-subject best predictions
profile = profile.copy()
profile["PredMinBP_best_subject"] = np.nan
profile["PredMaxBP_best_subject"] = np.nan

for subj, tag in best_min_by_subj.items():
    mask = profile["subject"] == subj
    profile.loc[mask, "PredMinBP_best_subject"] = profile.loc[mask, f"PredMinBP_{tag}"]
for subj, tag in best_max_by_subj.items():
    mask = profile["subject"] == subj
    profile.loc[mask, "PredMaxBP_best_subject"] = profile.loc[mask, f"PredMaxBP_{tag}"]

overall_min_mae = mae(profile["TrueMinBP"], profile["PredMinBP_best_subject"])
overall_max_mae = mae(profile["TrueMaxBP"], profile["PredMaxBP_best_subject"])
metrics_text = (
    f"Per-subject best (min): MAE = {overall_min_mae:.4f}\n"
    f"Per-subject best (max): MAE = {overall_max_mae:.4f}"
)
metrics_path = output_dir / "02SubjectBestParas.txt"
metrics_path.write_text(metrics_text + "\n")
print(f"Wrote subject-level metrics to {metrics_path}")

# Per-sample errors for box plots (subject-specific best)
err_rows = []
for subj, df_sub in profile.groupby("subject"):
    err_rows.append(pd.DataFrame({
        "subject": subj,
        "metric": "diastolic",
        "error": np.abs(df_sub["TrueMinBP"] - df_sub["PredMinBP_best_subject"]),
    }))
    err_rows.append(pd.DataFrame({
        "subject": subj,
        "metric": "systolic",
        "error": np.abs(df_sub["TrueMaxBP"] - df_sub["PredMaxBP_best_subject"]),
    }))
sample_err_subject = pd.concat(err_rows, ignore_index=True)

fig_box_subj = px.box(
    sample_err_subject, x="subject", y="error", color="metric",
    points="suspectedoutliers", title="Per-sample error (subject-specific best tags)",
    labels={"error": "|Pred-True|", "subject": "Subject"}
)

fig_box_subj.update_traces(marker=dict(size=3), whiskerwidth=0)
fig_box_subj.update_yaxes(type="log", range=[0, 2])
fig_box_subj.write_image(output_dir / "02SubjectBestParas.png")

# Quick look at which tag won per subject (optional table)
pd.DataFrame({
    "subject": list(best_min_by_subj.keys()),
    "best_min_tag": list(best_min_by_subj.values()),
    "best_max_tag": [best_max_by_subj[s] for s in best_min_by_subj.keys()],
}).sort_values("subject")

In [None]:
# Optional: remove extreme errors (possible bad ground truth) and recompute
pred_min_col = "PredMinBP_best_subject"
pred_max_col = "PredMaxBP_best_subject"

df_err = profile.copy()
df_err["err_diastolic"] = np.abs(df_err["TrueMinBP"] - df_err[pred_min_col])
df_err["err_systolic"] = np.abs(df_err["TrueMaxBP"] - df_err[pred_max_col])

def filter_outliers_iqr(df: pd.DataFrame, cols: list[str], k: float = 1.5, max_pct: float = 0.99):
    """IQR + percentile cap: keep rows where each col <= min(q3 + k*IQR, pct(max_pct))."""
    mask = pd.Series(True, index=df.index)
    thresholds: dict[str, float] = {}
    for col in cols:
        s = df[col].dropna()
        if s.empty:
            thresholds[col] = np.nan
            continue
        q1, q3 = s.quantile([0.25, 0.75])
        iqr = q3 - q1
        upper_iqr = q3 + k * iqr
        upper_pct = s.quantile(max_pct)
        upper = float(min(upper_iqr, upper_pct))
        thresholds[col] = upper
        mask &= df[col] <= upper
    return df[mask].copy(), thresholds, mask

filtered, thresholds, keep_mask = filter_outliers_iqr(
    df_err, ["err_diastolic", "err_systolic"], k=1.5, max_pct=0.99
)

summary_lines = [
    f"Thresholds (abs error): {thresholds}",
    (
        f"Kept {len(filtered)} of {len(df_err)} samples ("
        f"{keep_mask.mean()*100:.1f}%); removed {len(df_err) - len(filtered)}"
    ),
]

filtered_min_mae = mae(filtered["TrueMinBP"], filtered[pred_min_col])
filtered_max_mae = mae(filtered["TrueMaxBP"], filtered[pred_max_col])
summary_lines.append(
    "Filtered overall MAE (diastolic, systolic): "
    f"({filtered_min_mae:.4f}, {filtered_max_mae:.4f})"
)

metrics_path = output_dir / "03SubjectBestParasThresholds.txt"
metrics_path.write_text("\n".join(summary_lines) + "\n")
print(f"Wrote filtered metrics to {metrics_path}")

# Box plot on filtered data (box + suspected outliers on same line)
err_long = pd.concat([
    filtered[["subject", "err_diastolic"]].rename(columns={"err_diastolic": "error"}).assign(metric="diastolic"),
    filtered[["subject", "err_systolic"]].rename(columns={"err_systolic": "error"}).assign(metric="systolic"),
])
fig_box_filtered = px.box(
    err_long, x="subject", y="error", color="metric",
    points="suspectedoutliers",
    title="Per-sample error after outlier removal (subject-specific best)",
    labels={"error": "|Pred-True|", "subject": "Subject"}
)
fig_box_filtered.update_traces(marker=dict(size=3), whiskerwidth=0)
fig_box_filtered.update_yaxes(type="log", range=[0, 2])
fig_box_filtered.write_image(output_dir / "03SubjectBestParasThresholds.png")