In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
import numpy as np

RESULTS_CSV_PATH = 'Your/result/folder'

LABEL_COLUMN_NAMES = [
    '右冠状动脉_检查结果', 
    '左冠状动脉主干_检查结果', 
    '左前降支_检查结果', 
    '左回旋支_检查结果'
]

PROB_COLUMN_NAMES = [
    '右冠状动脉_prob', 
    '左冠状动脉主干_prob', 
    '左前降支_prob', 
    '左回旋支_prob'
]

CLASS_NAMES_ENGLISH = [
    'Right Coronary Artery',
    'Left Main Artery',
    'Left Anterior Descending',
    'Left Circumflex'
]

NEGATIVE_CLASS_STRINGS_LIST = ["未见明显狭窄","轻度狭窄","中度狭窄"]


N_BOOTSTRAPS = 1000
CI_ALPHA = 0.95      
SEED = 42         


def calculate_auc_ci(y_true, y_scores, n_bootstraps=1000, alpha=0.95, rng_seed=42):
    bootstrapped_scores = []
    rng = np.random.RandomState(rng_seed)
    
    for i in range(n_bootstraps):
        indices = rng.randint(0, len(y_scores), len(y_scores))
        if len(np.unique(y_true.iloc[indices])) < 2:
            continue
            
        score = auc(*roc_curve(y_true.iloc[indices], y_scores.iloc[indices])[:2])
        bootstrapped_scores.append(score)
        
    sorted_scores = np.array(bootstrapped_scores)
    sorted_scores.sort()

    lower_bound = np.percentile(sorted_scores, (1 - alpha) / 2 * 100)
    upper_bound = np.percentile(sorted_scores, (1 + alpha) / 2 * 100)
    
    return lower_bound, upper_bound


df = pd.read_csv(RESULTS_CSV_PATH)

roc_results = {}

print("Calculating ROC, AUC, and 95% CI for each class...")

for i in range(len(LABEL_COLUMN_NAMES)):
    label_col = LABEL_COLUMN_NAMES[i]
    prob_col = PROB_COLUMN_NAMES[i]
    class_name = CLASS_NAMES_ENGLISH[i]

    if label_col not in df.columns or prob_col not in df.columns:
        continue

    temp_df = df[[label_col, prob_col]].dropna()
    y_true = temp_df[label_col].apply(lambda x: 0 if x in NEGATIVE_CLASS_STRINGS_LIST else 1)
    y_scores = temp_df[prob_col]

    if len(np.unique(y_true)) < 2:
        continue

    fpr, tpr, _ = roc_curve(y_true, y_scores)
    roc_auc = auc(fpr, tpr)
    
    ci_lower, ci_upper = calculate_auc_ci(y_true, y_scores, n_bootstraps=N_BOOTSTRAPS, alpha=CI_ALPHA, rng_seed=SEED)

    roc_results[class_name] = {
        'fpr': fpr, 
        'tpr': tpr, 
        'auc': roc_auc,
        'ci_low': ci_lower,
        'ci_high': ci_upper
    }
    
    print(f"- {class_name}: AUC={roc_auc:.3f} (95% CI: {ci_lower:.3f}-{ci_upper:.3f})")



if not roc_results:
    print("\nNo valid data to plot. Exiting.")
    exit()

plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
fig, ax = plt.subplots(1, 1, figsize=(8, 7), dpi=120)

colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']

for idx, (class_name, results) in enumerate(roc_results.items()):
    auc_val = results['auc']
    ci_low = results['ci_low']
    ci_high = results['ci_high']
    color = colors[idx % len(colors)]
    
    label_str = f"{class_name}\nAUC = {auc_val:.3f} (95% CI: {ci_low:.3f}–{ci_high:.3f})"
    
    ax.plot(
        results['fpr'], 
        results['tpr'], 
        lw=2.5,
        color=color,
        label=label_str
    )

ax.plot([0, 1], [0, 1], 'k--', lw=1.5, alpha=0.6, label='No-Skill (AUC = 0.50)')

ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate (1 - Specificity)', fontsize=12, fontweight='bold')
ax.set_ylabel('True Positive Rate (Sensitivity)', fontsize=12, fontweight='bold')
ax.set_title('ROC Curves with 95% Confidence Intervals', fontsize=14, fontweight='bold', pad=15)

ax.legend(loc="lower right", fontsize=13, frameon=True, framealpha=0.9, edgecolor='gray')

ax.grid(True, linestyle='--', alpha=0.4)
plt.tight_layout()
plt.savefig('Your/save/path',dpi=900)
plt.show()

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.size'] = 11
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['xtick.labelsize'] = 11
plt.rcParams['ytick.labelsize'] = 11
plt.rcParams['legend.fontsize'] = 11
plt.rcParams['axes.linewidth'] = 0.8

RISK_COLORS = {"Low": "#1f77b4", "High": "#d62728"}


CSV_PATH = 'Your/CSV_PATH'  
OUT_DIR = "./Your/OUT_DIR"
os.makedirs(OUT_DIR, exist_ok=True)

df = pd.read_csv(CSV_PATH)


DISCHARGE_COL = "出院日期"
EXAM_TIME_COL = "检查时间"
EVENT_COL = "new_MI"
KM_DAYS_COL = "new_MI_KM（筛选＞28）"


PROB_COLUMN_NAMES = [
    '右冠状动脉主干_prob',
    '左冠状动脉主干_prob',  
    '左前降支_prob',
    '左回旋支_prob'  
]
VESSEL_NAMES = ['RCA', 'LM', 'LAD', 'LCX']


ABS_THR_BY_VESSEL = {"LM": 0.01, "LAD": 0.15, "RCA": 0.15, "LCX": 0.15}
DEFAULT_ABS_THR = 0.15


TIME_POINTS = [30, 100, 200, 300]
TIME_NAMES = ["30d", "100d", "200d", "300d"]

FIXED_FOLLOWUP_DAYS = 700.0


def to_numeric_series(s: pd.Series) -> pd.Series:
    return pd.to_numeric(s, errors="coerce")

def to_datetime_series(s: pd.Series) -> pd.Series:
    return pd.to_datetime(s, errors="coerce")

def days_between_abs(a: pd.Series, b: pd.Series) -> pd.Series:
    dt_days = (b - a).dt.total_seconds() / 86400.0
    return dt_days.abs()

def build_time_event_from_columns(df_sub: pd.DataFrame, event_col: str, km_days_col: str,
                                  discharge_col: str, exam_col: str):
    discharge_dt = to_datetime_series(df_sub[discharge_col])
    exam_dt = to_datetime_series(df_sub[exam_col])
    base_days = days_between_abs(discharge_dt, exam_dt)

    ev = to_numeric_series(df_sub[event_col]).fillna(0).astype(int)
    ev = (ev == 1).astype(int)

    km_days = to_numeric_series(df_sub[km_days_col])

    time = base_days.copy()
    event_mask = (ev == 1)
    km_ok = event_mask & km_days.notna()
    time.loc[km_ok] = base_days.loc[km_ok] + km_days.loc[km_ok]

    ev.loc[event_mask & km_days.isna()] = 0
    return time, ev

def apply_fixed_followup(time_raw: pd.Series, event_raw: pd.Series, x_max: float):
    time = time_raw.clip(lower=0, upper=x_max).astype(float)
    event = event_raw.copy().astype(int)
    over = (time_raw > x_max)
    event.loc[over] = 0
    time.loc[over] = x_max
    return time, event

def risk_group_by_vessel_abs(df_all: pd.DataFrame, vessel: str, prob_col: str) -> pd.Series:
    thr = float(ABS_THR_BY_VESSEL.get(vessel, DEFAULT_ABS_THR))
    p = to_numeric_series(df_all[prob_col])
    high = (p >= thr)
    high = high.fillna(False)
    return pd.Series(np.where(high, "High", "Low"), index=df_all.index)

def risk_group_any_vessel_abs(df_all: pd.DataFrame) -> pd.Series:
    high_any = pd.Series(False, index=df_all.index)
    for prob_col, vessel in zip(PROB_COLUMN_NAMES, VESSEL_NAMES):
        if prob_col not in df_all.columns:
            continue
        thr = float(ABS_THR_BY_VESSEL.get(vessel, DEFAULT_ABS_THR))
        p = to_numeric_series(df_all[prob_col])
        high_any = high_any | (p >= thr)
    high_any = high_any.fillna(False)
    return pd.Series(np.where(high_any, "High", "Low"), index=df_all.index)

def prepare_time_event_df(df0: pd.DataFrame):
    need = [DISCHARGE_COL, EXAM_TIME_COL, EVENT_COL, KM_DAYS_COL]
    for c in need:
        if c not in df0.columns:
            raise KeyError(f"缺少必要列: {c}")

    sub = df0.copy()
    time_raw, event_raw = build_time_event_from_columns(
        sub, EVENT_COL, KM_DAYS_COL, DISCHARGE_COL, EXAM_TIME_COL
    )

    neg_event_mask = (event_raw == 1) & (time_raw < 0)
    if int(neg_event_mask.sum()) > 0:
        sub = sub.loc[~neg_event_mask].copy()
        time_raw = time_raw.loc[~neg_event_mask].copy()
        event_raw = event_raw.loc[~neg_event_mask].copy()

    ok = time_raw.notna()
    sub = sub.loc[ok].copy()
    time_raw = time_raw.loc[ok].copy()
    event_raw = event_raw.loc[ok].copy()

    time, event = apply_fixed_followup(time_raw, event_raw, FIXED_FOLLOWUP_DAYS)
    return sub, time, event

def compute_rates_by_timewindow(time: pd.Series, event: pd.Series, risk_group: pd.Series):
    out = {}
    for rg in ["Low", "High"]:
        idx = (risk_group == rg)
        n_total = int(idx.sum())
        rates = []
        for t in TIME_POINTS:
            n_evt = int(((event == 1) & (time <= t) & idx).sum())
            r = (n_evt / n_total * 100.0) if n_total > 0 else 0.0
            rates.append(r)
        out[rg] = rates
        out[f"{rg}_n"] = n_total
    return out

def plot_timewindow_bars(ax, rates_dict, title, ylim_max=None):
    x = np.arange(len(TIME_POINTS))
    width = 0.36

    low = np.array(rates_dict["Low"], dtype=float)
    high = np.array(rates_dict["High"], dtype=float)

    ax.bar(x - width/2, low, width=width, color=RISK_COLORS["Low"],
           edgecolor='black', linewidth=0.6, label='Low Risk')
    
    ax.bar(x + width/2, high, width=width, color=RISK_COLORS["High"],
           edgecolor='black', linewidth=0.6, label='High Risk')

    ax.set_xticks(x)
    ax.set_xticklabels(TIME_NAMES, fontweight='bold')
    ax.set_ylabel("Cumulative Death Rate (%)")
    ax.set_title(title, fontsize=15, fontweight='bold', pad=10)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.grid(axis='y', linestyle='--', alpha=0.3)

    if ylim_max is None:
        m = float(np.nanmax(np.r_[low, high])) if len(low) else 0.0
        ylim_max = max(2, min(100, m * 1.35 + 0.5))
    ax.set_ylim(0, ylim_max)

sub_all, time_all, event_all = prepare_time_event_df(df)


risk_any = risk_group_any_vessel_abs(sub_all)
rates_any = compute_rates_by_timewindow(time_all, event_all, risk_any)

fig1, ax1 = plt.subplots(1, 1, figsize=(8.8, 5.8), dpi=300)
plot_timewindow_bars(ax1, rates_any, title="Combined Risk (Any Vessel ≥ Threshold)")

ax1.legend(frameon=False, loc="upper left")

plt.tight_layout()
out_png1 = os.path.join(OUT_DIR, "results.png")
plt.savefig(out_png1, dpi=600)
plt.show()
print("Saved:", out_png1)

global_max = 0.0
for prob_col, vessel in zip(PROB_COLUMN_NAMES, VESSEL_NAMES):
    if prob_col not in sub_all.columns:
        continue
    rg = risk_group_by_vessel_abs(sub_all, vessel, prob_col)
    rd = compute_rates_by_timewindow(time_all, event_all, rg)
    global_max = max(global_max, float(np.nanmax(np.r_[rd["Low"], rd["High"]])))

global_ylim = max(2, min(100, global_max * 1.35 + 0.5))

fig2, axes = plt.subplots(2, 2, figsize=(16, 10), dpi=300)
axes = axes.flatten()

for i, (prob_col, vessel) in enumerate(zip(PROB_COLUMN_NAMES, VESSEL_NAMES)):
    ax = axes[i]
    if prob_col not in sub_all.columns:
        ax.axis("off")
        ax.set_title(f"{vessel} (missing prob col)")
        continue

    thr = float(ABS_THR_BY_VESSEL.get(vessel, DEFAULT_ABS_THR))
    rg = risk_group_by_vessel_abs(sub_all, vessel, prob_col)
    rd = compute_rates_by_timewindow(time_all, event_all, rg)

    plot_timewindow_bars(ax, rd, title=f"{vessel} (High = prob ≥ {thr:.2f})", ylim_max=global_ylim)

    if i == 0:
        ax.legend(frameon=False, loc="upper left")
    else:
        if ax.get_legend():
            ax.get_legend().remove()

plt.tight_layout()
out_png2 = os.path.join(OUT_DIR, "results2.png")
plt.savefig(out_png2, dpi=600)
plt.show()
print("Saved:", out_png2)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr
import numpy as np

plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.size'] = 11

CSV_PATH = "Your/result.csv"

LABELS = [
    '右冠状动脉主干_诊断结果',
    '左冠状动脉主干_诊断结果',
    '左前降支_诊断结果',
    '左回旋支_诊断结果'
]

PROBS = [
    '右冠状动脉主干_prob',
    '左冠状动脉主干_prob',
    '左前降支_prob',
    '左回旋支_prob'
]

VESSELS = ['RCA', 'LM', 'LAD', 'LCx']

GRADE_MAP = {
    "未见明显狭窄": ("Normal", 0),
    "轻度狭窄": ("Mild", 1),
    "中度狭窄": ("Moderate", 2),
    "重度狭窄": ("Severe", 3),
    "完全堵塞": ("Severe", 3)
}

ORDER_EN = ["Normal", "Mild", "Moderate", "Severe"]

df = pd.read_csv(CSV_PATH)

fig, axes = plt.subplots(2, 2, figsize=(16, 11))
axes = axes.flatten()

print("\n===== Spearman Trend Test =====\n")

for i in range(4):
    temp = df[[LABELS[i], PROBS[i]]].dropna().copy()
    
    temp = temp[temp[LABELS[i]].isin(GRADE_MAP.keys())]
    
    if temp.empty:
        print(f"Skipping {VESSELS[i]} - No valid labels found.")
        continue
    temp["Grade_EN"] = temp[LABELS[i]].map(lambda x: GRADE_MAP[x][0])
    temp["Grade_Level"] = temp[LABELS[i]].map(lambda x: GRADE_MAP[x][1])

    ax = axes[i]

    sns.boxplot(
        x="Grade_EN",
        y=PROBS[i],
        data=temp,
        order=ORDER_EN,
        ax=ax,
        width=0.5,
        linewidth=1.5,
        fliersize=3,
        palette="Blues"
    )

    ax.set_title(VESSELS[i], fontsize=15, fontweight='bold')
    ax.set_xlabel("Stenosis Grade")
    ax.set_ylabel("Predicted Probability")
    ax.grid(axis='y', alpha=0.3, linestyle='--')

    if not temp.empty:
        data_min = temp[PROBS[i]].min()
        data_max = temp[PROBS[i]].max()
        padding = (data_max - data_min) * 0.1
        if pd.notna(data_min) and pd.notna(data_max):
            ax.set_ylim(data_min - padding, data_max + padding)

    if len(temp) > 2:
        rho, p_value = spearmanr(temp["Grade_Level"], temp[PROBS[i]])
        
        p_text = "< 0.001" if p_value < 0.001 else f"{p_value:.3f}"
        trend_text = f"Spearman trend: ρ = {rho:.3f}, p {p_text if '<' in p_text else '= ' + p_text}"
        
        print(f"{VESSELS[i]} -> {trend_text}")

        ax.text(
            0.5, -0.20,
            trend_text,
            transform=ax.transAxes,
            ha='center',
            va='top',
            fontsize=11,
            fontweight='medium',
            bbox=dict(facecolor='white', alpha=0.8, edgecolor='lightgray', boxstyle='round,pad=0.3')
        )
    else:
        print(f"{VESSELS[i]} -> Not enough data for Spearman test")

plt.subplots_adjust(
    left=0.08,
    right=0.96,
    top=0.94,
    bottom=0.15,
    hspace=0.4,
    wspace=0.25
)

plt.show()

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import brier_score_loss
from statsmodels.nonparametric.smoothers_lowess import lowess
from matplotlib.lines import Line2D

CSV_PATH = "Your/result.csv"

LABEL_COLUMN_NAMES = [
    '右冠状动脉主干_诊断结果',
    '左冠状动脉主干_诊断结果',
    '左前降支_诊断结果',
    '左回旋支_诊断结果'
]

PROB_COLUMN_NAMES = [
    '右冠状动脉主干_prob',
    '左冠状动脉主干_prob',
    '左前降支_prob',
    '左回旋支_prob'
]

VESSEL_NAMES = ['RCA', 'LM', 'LAD', 'LCX']

NON_EVENT_LABELS = ["未见明显狭窄", "轻度狭窄", "中度狭窄"]

GROUP_SIZE = 20
BOOTSTRAP = 500

RISK_COLORS = {
    "Low": "#1f77b4",   
    "High": "#d62728"   
}

VESSEL_THRESHOLDS = {
    "RCA": 0.15,  
    "LM":  0.01,  
    "LAD": 0.15,  
    "LCX": 0.15   
}

def grouped_calibration(y, p, group_size):
    df_gp = pd.DataFrame({"y": y, "p": p}).sort_values("p")
    grouped = []
    for i in range(0, len(df_gp), group_size):
        chunk = df_gp.iloc[i:i+group_size]
        if len(chunk) > 0:
            grouped.append({
                "mean_p": chunk["p"].mean(),
                "obs": chunk["y"].mean(),
                "n": len(chunk)
            })
    return pd.DataFrame(grouped)

def bootstrap_ci_levels(y, p, x_grid, group_size=20, n_boot=500, ci=0.95, seed=42):
    preds = []
    rng = np.random.RandomState(seed)

    y = np.asarray(y)
    p = np.asarray(p)

    for _ in range(n_boot):
        idx = rng.choice(len(y), len(y), replace=True)
        df_group = grouped_calibration(y[idx], p[idx], group_size)
        if len(df_group) < 3:
            continue
        try:
            smoothed = lowess(df_group["obs"], df_group["mean_p"], frac=0.6, return_sorted=True)
            preds.append(np.interp(x_grid, smoothed[:, 0], smoothed[:, 1]))
        except Exception:
            continue

    if len(preds) == 0:
        return np.zeros_like(x_grid), np.zeros_like(x_grid)

    preds = np.array(preds)
    lower = np.percentile(preds, (1-ci)/2*100, axis=0)
    upper = np.percentile(preds, (1+ci)/2*100, axis=0)
    return lower, upper

plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
plt.rcParams['font.size'] = 11
plt.rcParams['axes.linewidth'] = 1.0

df = pd.read_csv(CSV_PATH)

fig, axes = plt.subplots(2, 2, figsize=(14, 14))
axes = axes.flatten()

print("========= Calibration Plot Generation =========")
print("Label rule: Severe stenosis / occlusion = 1; (Normal/Mild/Moderate) = 0")
print("------------------------------------------------")

summary_rows = []

for i, (label_col, prob_col, vessel) in enumerate(zip(
    LABEL_COLUMN_NAMES, PROB_COLUMN_NAMES, VESSEL_NAMES)):
    y_true = df[label_col].apply(lambda x: 0 if x in NON_EVENT_LABELS else 1).astype(int).values
    y_prob = df[prob_col].astype(float).values

    n_total = len(y_true)
    n_pos = int(y_true.sum())
    n_neg = int(n_total - n_pos)
    pos_rate = n_pos / n_total if n_total > 0 else np.nan
    neg_rate = n_neg / n_total if n_total > 0 else np.nan

    calib_df = grouped_calibration(y_true, y_prob, GROUP_SIZE)

    cutoff = VESSEL_THRESHOLDS[vessel]
    calib_df["risk"] = calib_df["mean_p"].apply(lambda pp: "High" if pp >= cutoff else "Low")

    x_grid_full = np.linspace(0, 1, 100)
    ci95_low, ci95_high = bootstrap_ci_levels(
        y_true, y_prob, x_grid_full, group_size=GROUP_SIZE, n_boot=BOOTSTRAP, ci=0.95, seed=42
    )
    ci99_low, ci99_high = bootstrap_ci_levels(
        y_true, y_prob, x_grid_full, group_size=GROUP_SIZE, n_boot=BOOTSTRAP, ci=0.99, seed=42
    )

    brier = brier_score_loss(y_true, y_prob)
    ax = axes[i]

    print(f"[{vessel}] Threshold: {cutoff:.1%} | N={n_total}, Positive={n_pos} ({pos_rate:.2%}), Negative={n_neg} ({neg_rate:.2%})")

    summary_rows.append({
        "Vessel": vessel,
        "N": n_total,
        "Positive_n": n_pos,
        "Positive_rate": pos_rate,
        "Negative_n": n_neg,
        "Negative_rate": neg_rate,
        "Threshold": cutoff,
        "Brier": brier
    })

    if vessel == 'LM':
        display_limit = 0.35
        ax.set_xlim(0, display_limit)
        ax.set_ylim(0, display_limit)

        ax.plot([0, display_limit], [0, display_limit],
                linestyle='--', color='black', linewidth=1, alpha=0.7)

        ax.text(0.95, 0.05, "Axes Zoomed-in\n(0 - 0.35)", transform=ax.transAxes,
                ha='right', va='bottom', fontsize=10, style='italic', color='gray')
    else:
        ax.plot([0, 1], [0, 1], linestyle='--', color='black', linewidth=1, alpha=0.7)
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

    ax.fill_between(x_grid_full, ci99_low, ci99_high, color="gray", alpha=0.15, label='_nolegend_')
    ax.fill_between(x_grid_full, ci95_low, ci95_high, color="gray", alpha=0.35, label='_nolegend_')

    for risk in ["Low", "High"]:
        sub = calib_df[calib_df["risk"] == risk]
        if len(sub) > 0:
            ax.scatter(
                sub["mean_p"], sub["obs"],
                s=50,
                color=RISK_COLORS[risk],
                alpha=0.85,
                edgecolor="black",
                linewidth=0.6,
                zorder=10
            )

    ax.axvline(x=cutoff, color='black', linestyle=':', linewidth=1.2, alpha=0.5)

    ax.set_title(vessel, fontsize=16, fontweight='bold')
    ax.set_xlabel("Predicted Probability")
    ax.set_ylabel("Observed Proportion")
    ax.grid(alpha=0.2, linestyle='--')

    stats_text = (
        f"Brier Score = {brier:.3f}\n"
        f"Positive Rate = {pos_rate:.2%}\n"
        f"Negative Rate = {neg_rate:.2%}\n"
        f"Threshold = {cutoff:.1%}"
    )
    ax.text(0.05, 0.92, stats_text,
            transform=ax.transAxes,
            fontsize=11, fontweight='bold',
            va='top',
            bbox=dict(facecolor='white', alpha=0.9, edgecolor='lightgray', boxstyle='round,pad=0.5'))

print("------------------------------------------------")

summary_df = pd.DataFrame(summary_rows)
print("===== Summary Table (copy-ready) =====")
print(summary_df.to_string(index=False))


legend_handles = [
    Line2D([0], [0], marker='o', color='w', markerfacecolor=RISK_COLORS['Low'],
           markersize=10, markeredgecolor='k', label='Low Risk (< Threshold)'),
    Line2D([0], [0], marker='o', color='w', markerfacecolor=RISK_COLORS['High'],
           markersize=10, markeredgecolor='k', label='High Risk (≥ Threshold)'),
    Line2D([0], [0], color='gray', linewidth=4, alpha=0.35, label='95% CI'),
    Line2D([0], [0], color='gray', linewidth=4, alpha=0.15, label='99% CI'),
]

fig.legend(handles=legend_handles, loc="lower center", ncol=4, frameon=False,
           fontsize=12, bbox_to_anchor=(0.5, 0.03))

plt.figtext(
    0.5, 0.01,
    "Note: Axes for LM are zoomed in (0-0.35) due to the extremely low prevalence.\n"
    "Risk Thresholds: 1.0% for LM; 15% (ESC Guidelines) for others. "
    "Positive/Negative rates are computed with: severe stenosis/occlusion=1; normal/mild/moderate=0.",
    ha="center",
    fontsize=10,
    style='italic',
    color='#444'
)

plt.tight_layout()
plt.subplots_adjust(bottom=0.14) 
plt.show()


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from sklearn.metrics import confusion_matrix


CSV_PATH = "Your/result.csv"
OUT_DIR  = "Your/output_dir"
os.makedirs(OUT_DIR, exist_ok=True)

LABEL_COLUMN_NAMES = [
    "右冠状动脉主干_诊断结果",
    "左冠状动脉主干_诊断结果",
    "左前降支_诊断结果",
    "左回旋支_诊断结果",
]
PROB_COLUMN_NAMES = [
    "右冠状动脉主干_prob",
    "左冠状动脉主干_prob",
    "左前降支_prob",
    "左回旋支_prob",
]
VESSEL_NAMES = ["RCA", "LM", "LAD", "LCX"]

NON_EVENT_STRINGS = ["未见明显狭窄", "轻度狭窄", "中度狭窄"]

VESSEL_THRESHOLDS = {"RCA": 0.15, "LM": 0.01, "LAD": 0.15, "LCX": 0.15}

RANGE_DEFAULT = (0.01, 0.60)
RANGE_LM      = (0.001, 0.10)

N_TH   = 200
N_PT   = 300


plt.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans"]
plt.rcParams["font.size"] = 11
plt.rcParams["axes.linewidth"] = 1.0

COLOR_MODEL = "#1f77b4"
COLOR_ALL   = "#7f7f7f"
COLOR_ZERO  = "black"

COLOR_OPPORT = "black"

OPPORT_LABEL = "Opportunistic screening"

SPEC_TARGET_DEFAULT = 0.95
SPEC_TARGET_LM      = 0.87

def _safe_to_numeric(s: pd.Series) -> pd.Series:
    return pd.to_numeric(s, errors="coerce")

def _label_to_binary(label_series: pd.Series) -> np.ndarray:
    def is_nonevent(x) -> bool:
        x = "" if pd.isna(x) else str(x).strip()
        return any(k in x for k in NON_EVENT_STRINGS)
    return label_series.apply(lambda x: 0 if is_nonevent(x) else 1).to_numpy(dtype=int)

def _confusion(y_true: np.ndarray, y_pred: np.ndarray):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    return tn, fp, fn, tp

def _sens_spec(y_true: np.ndarray, y_prob: np.ndarray, thr: float):
    y_pred = (y_prob >= thr).astype(int)
    tn, fp, fn, tp = _confusion(y_true, y_pred)
    sens = tp / (tp + fn) if (tp + fn) > 0 else np.nan
    spec = tn / (tn + fp) if (tn + fp) > 0 else np.nan
    return sens, spec

def threshold_sweep_sens_spec(vessel: str, y_true: np.ndarray, y_prob: np.ndarray):
    lo, hi = RANGE_LM if vessel == "LM" else RANGE_DEFAULT
    ths = np.linspace(lo, hi, N_TH)
    sens = np.zeros_like(ths, dtype=float)
    spec = np.zeros_like(ths, dtype=float)
    for i, t in enumerate(ths):
        s1, s2 = _sens_spec(y_true, y_prob, float(t))
        sens[i] = s1
        spec[i] = s2
    out = pd.DataFrame({"threshold": ths, "sensitivity": sens, "specificity": spec})
    out.to_csv(os.path.join(OUT_DIR, f"{vessel}_threshold_sens_spec.csv"),
               index=False, encoding="utf-8-sig")
    return out

def _spec_target_for_vessel(vessel: str) -> float:
    return SPEC_TARGET_LM if vessel == "LM" else SPEC_TARGET_DEFAULT

def _find_first_threshold_meeting_spec(dfm: pd.DataFrame, base_thr: float, target_spec: float):
    sub = dfm[dfm["threshold"] >= base_thr].copy()
    sub = sub.dropna(subset=["specificity"])
    hit = sub[sub["specificity"] >= target_spec]
    if len(hit) == 0:
        return np.nan
    return float(hit.iloc[0]["threshold"])

def _select_6_thresholds_from_df(dfm: pd.DataFrame, base_thr: float, t_end: float):
    th = dfm["threshold"].values
    sens = dfm["sensitivity"].values
    spec = dfm["specificity"].values

    if not np.isfinite(t_end) or t_end < base_thr:
        t_end = th[-1]

    targets = np.linspace(base_thr, t_end, 6)
    idx = np.searchsorted(th, targets, side="left")
    idx = np.clip(idx, 0, len(th) - 1)

    seen = set()
    idx_unique = []
    for j in idx:
        j = int(j)
        if j not in seen:
            idx_unique.append(j)
            seen.add(j)

    k = idx_unique[-1] if len(idx_unique) > 0 else 0
    while len(idx_unique) < 6:
        k = min(k + 1, len(th) - 1)
        if k not in seen:
            idx_unique.append(k)
            seen.add(k)
        if k == len(th) - 1 and len(idx_unique) < 6:
            idx_unique.append(len(th) - 1)

    idx_unique = idx_unique[:6]

    out = pd.DataFrame({
        "Thr": th[idx_unique],
        "Sensitivity": sens[idx_unique],
        "Specificity": spec[idx_unique],
    })
    return out

def net_benefit_model(y_true: np.ndarray, y_prob: np.ndarray, pts: np.ndarray) -> np.ndarray:
    n = len(y_true)
    nb = np.zeros_like(pts, dtype=float)
    for i, pt in enumerate(pts):
        y_pred = (y_prob >= pt).astype(int)
        tn, fp, fn, tp = _confusion(y_true, y_pred)
        w = pt / (1.0 - pt)
        nb[i] = (tp / n) - (fp / n) * w
    return nb

def net_benefit_treat_all(y_true: np.ndarray, pts: np.ndarray) -> np.ndarray:
    n = len(y_true)
    event = np.sum(y_true == 1)
    nonevent = n - event
    nb = np.zeros_like(pts, dtype=float)
    for i, pt in enumerate(pts):
        w = pt / (1.0 - pt)
        nb[i] = (event / n) - (nonevent / n) * w
    return nb

def dca_curve(vessel: str, y_true: np.ndarray, y_prob: np.ndarray):
    lo, hi = RANGE_LM if vessel == "LM" else RANGE_DEFAULT
    pts = np.linspace(lo, hi, N_PT)
    nb_m = net_benefit_model(y_true, y_prob, pts)
    nb_a = net_benefit_treat_all(y_true, pts)
    nb_0 = np.zeros_like(pts, dtype=float)
    out = pd.DataFrame({"pt": pts, "NB_model": nb_m, "NB_all": nb_a, "NB_none": nb_0})
    out.to_csv(os.path.join(OUT_DIR, f"{vessel}_dca.csv"),
               index=False, encoding="utf-8-sig")
    return out

def _panel_title(ax, vessel: str):
    ax.set_title(vessel, fontsize=14, fontweight="bold", pad=8)

def plot_threshold_sensspec_2x2(vessel_to_df: dict, vessel_thresholds: dict):
    fig, axes = plt.subplots(2, 2, figsize=(10.5, 8.5), dpi=300)
    axes = axes.flatten()

    summary_rows = []

    for i, vessel in enumerate(VESSEL_NAMES):
        ax = axes[i]
        dfm = vessel_to_df[vessel]
        base_thr = float(vessel_thresholds[vessel])

        lo, hi = RANGE_LM if vessel == "LM" else RANGE_DEFAULT

        ax.plot(dfm["threshold"], dfm["sensitivity"], lw=2.3, color="#d62728")
        ax.plot(dfm["threshold"], dfm["specificity"], lw=2.3, color="#1f77b4")

        ax.axvline(base_thr, color="black", lw=1.2, ls=":")

        target_spec = _spec_target_for_vessel(vessel)
        t_op = _find_first_threshold_meeting_spec(dfm, base_thr=base_thr, target_spec=target_spec)
        t_end = t_op if np.isfinite(t_op) else hi

        t6 = _select_6_thresholds_from_df(dfm, base_thr=base_thr, t_end=t_end)
        for _, r in t6.iterrows():
            summary_rows.append({
                "Vessel": vessel,
                "Base_thr": base_thr,
                "Spec_target": target_spec,
                "Thr": float(r["Thr"]),
                "Sensitivity": float(r["Sensitivity"]),
                "Specificity": float(r["Specificity"]),
                "Opportunistic_thr": float(t_op) if np.isfinite(t_op) else np.nan,
            })

        if np.isfinite(t_op):
            ax.axvline(t_op, color=COLOR_OPPORT, lw=2.0, ls="--", alpha=0.95)

        ax.set_xlim(lo, hi)
        ax.set_ylim(0, 1.0)

        _panel_title(ax, vessel)
        ax.set_xlabel("Threshold probability")
        ax.set_ylabel("Metric value")
        ax.grid(False)

        ax.text(
            0.02, 0.06,
            f"Base = {base_thr:.2f}",
            transform=ax.transAxes,
            fontsize=10,
            color="black",
            bbox=dict(facecolor="white", edgecolor="none", alpha=0.75, pad=2.5)
        )

    handles = [
        Line2D([0], [0], color="#d62728", lw=2.3, label="Sensitivity"),
        Line2D([0], [0], color="#1f77b4", lw=2.3, label="Specificity"),
        Line2D([0], [0], color="black",   lw=1.2, ls=":",  label="Base threshold"),
        Line2D([0], [0], color=COLOR_OPPORT, lw=2.0, ls="--", label=OPPORT_LABEL),
    ]
    fig.legend(handles=handles, loc="lower center", ncol=2, frameon=False,
               bbox_to_anchor=(0.5, 0.01), fontsize=11)

    plt.tight_layout()
    plt.subplots_adjust(bottom=0.14)

    out_png = os.path.join(OUT_DIR, "Figure_Threshold_SensSpec_2x2.png")
    plt.savefig(out_png, dpi=600, bbox_inches="tight")
    plt.close()
    print("Saved:", out_png)

    summary_df = pd.DataFrame(summary_rows)
    out_csv = os.path.join(OUT_DIR, "Table_Threshold_SensSpec_6points.csv")
    summary_df.to_csv(out_csv, index=False, encoding="utf-8-sig")
    print("Saved:", out_csv)

def plot_dca_2x2(vessel_to_dca: dict, vessel_thresholds: dict):
    fig, axes = plt.subplots(2, 2, figsize=(10.5, 8.5), dpi=300)
    axes = axes.flatten()

    for i, vessel in enumerate(VESSEL_NAMES):
        ax = axes[i]
        dfd = vessel_to_dca[vessel]
        base_thr = float(vessel_thresholds[vessel])

        ax.plot(dfd["pt"], dfd["NB_model"], lw=2.5, color=COLOR_MODEL)
        ax.plot(dfd["pt"], dfd["NB_all"],   lw=2.0, color=COLOR_ALL,   ls="--")
        ax.axhline(0, lw=1.2, color=COLOR_ZERO, ls=":")

        lo, hi = RANGE_LM if vessel == "LM" else RANGE_DEFAULT
        ax.set_xlim(lo, hi)

        _panel_title(ax, vessel)
        ax.set_xlabel("Threshold probability")
        ax.set_ylabel("Net benefit")
        ax.grid(False)

    handles = [
        Line2D([0], [0], color=COLOR_MODEL, lw=2.5, label="Model"),
        Line2D([0], [0], color=COLOR_ALL,   lw=2.0, ls="--", label="Treat-all"),
        Line2D([0], [0], color=COLOR_ZERO,  lw=1.2, ls=":",  label="Treat-none (NB=0)"),
    ]
    fig.legend(handles=handles, loc="lower center", ncol=2, frameon=False,
               bbox_to_anchor=(0.5, 0.01), fontsize=11)

    plt.tight_layout()
    plt.subplots_adjust(bottom=0.14)

    out_png = os.path.join(OUT_DIR, "Figure_DCA_2x2.png")
    plt.savefig(out_png, dpi=600, bbox_inches="tight")
    plt.close()
    print("Saved:", out_png)


def load_and_prepare(df: pd.DataFrame, label_col: str, prob_col: str):
    if label_col not in df.columns or prob_col not in df.columns:
        raise KeyError(f"Missing columns: {label_col} or {prob_col}")

    tmp = df[[label_col, prob_col]].copy()
    tmp[prob_col] = _safe_to_numeric(tmp[prob_col])
    tmp = tmp.dropna(subset=[label_col, prob_col])

    y_true = _label_to_binary(tmp[label_col])
    y_prob = tmp[prob_col].to_numpy(dtype=float)

    if len(np.unique(y_true)) < 2:
        print(f"Warning: only one class in {label_col}. Curves may be meaningless.")
    return y_true, y_prob

def main():
    df = pd.read_csv(CSV_PATH)

    vessel_to_th = {}
    vessel_to_dca_df = {}

    for label_col, prob_col, vessel in zip(LABEL_COLUMN_NAMES, PROB_COLUMN_NAMES, VESSEL_NAMES):
        y_true, y_prob = load_and_prepare(df, label_col, prob_col)
        vessel_to_th[vessel] = threshold_sweep_sens_spec(vessel, y_true, y_prob)
        vessel_to_dca_df[vessel] = dca_curve(vessel, y_true, y_prob)

    plot_threshold_sensspec_2x2(vessel_to_th, VESSEL_THRESHOLDS)
    plot_dca_2x2(vessel_to_dca_df, VESSEL_THRESHOLDS)

if __name__ == "__main__":
    main()


In [None]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import sys
from scipy.signal import find_peaks

CONFIG = {
    'CODE_DIR': './',
    'CSV_PATH': 'Your/result.csv',
    
    'TARGETS': [
        {'name': 'RCA', 'prob_col': '右冠状动脉主干_prob', 'threshold': 0.15}, 
        {'name': 'LM',  'prob_col': '左冠状动脉主干_prob', 'threshold': 0.01}, 
        {'name': 'LAD', 'prob_col': '左前降支_prob',       'threshold': 0.15}, 
        {'name': 'LCX', 'prob_col': '左回旋支_prob',       'threshold': 0.15}  
    ],
    
    'SAMPLE_LIMIT': 500, 
    'FS': 500,
    'BEFORE_R': 0.25,   
    'AFTER_R': 0.65,    
}

sys.path.append(CONFIG['CODE_DIR'])
try:
    from ECGdataset import ECGDataset
except ImportError:
    print("无法导入 ECGdataset")

def extract_aligned_beats(signal, fs=500, before_sec=0.1, after_sec=0.4):
    lead_data = signal[1]
    peaks, _ = find_peaks(lead_data, height=np.max(lead_data)*0.3, distance=int(fs*0.4))
    before_pts = int(before_sec * fs)
    after_pts = int(after_sec * fs)
    beats = []
    
    for r in peaks:
        if r - before_pts < 0 or r + after_pts > signal.shape[1]: continue
        beat_raw = signal[:, r - before_pts: r + after_pts]
        
        mean = np.mean(beat_raw, axis=1, keepdims=True)
        std = np.std(beat_raw, axis=1, keepdims=True)
        beat_norm = (beat_raw - mean) / (std + 1e-8)
        beats.append(beat_norm)
        
    return np.stack(beats) if len(beats) > 0 else None

def get_risk_group_stats(df, prob_col, threshold):
    df_low = df[df[prob_col] < threshold]
    df_high = df[df[prob_col] >= threshold]
    
    print(f"Threshold: {threshold:.2%} | Low Risk: {len(df_low)} | High Risk: {len(df_high)}")
    
    if len(df_low) > CONFIG['SAMPLE_LIMIT']: df_low = df_low.sample(n=CONFIG['SAMPLE_LIMIT'], random_state=42)
    if len(df_high) > CONFIG['SAMPLE_LIMIT']: df_high = df_high.sample(n=CONFIG['SAMPLE_LIMIT'], random_state=42)

    def collect_beats(sub_df):
        try: dataset = ECGDataset(sub_df, use_augment=False)
        except NameError: return None
        all_beats = []
        for i in range(len(dataset)):
            try:
                signal, _, _, _, _ = dataset[i]
                signal_np = signal.numpy() if isinstance(signal, torch.Tensor) else np.asarray(signal)
                beats = extract_aligned_beats(signal_np, CONFIG['FS'], CONFIG['BEFORE_R'], CONFIG['AFTER_R'])
                if beats is not None: all_beats.append(beats)
            except: continue
        return np.concatenate(all_beats, axis=0) if all_beats else None

    beats_low = collect_beats(df_low)
    beats_high = collect_beats(df_high)
    
    if beats_low is None or beats_high is None: 
        print("警告: 某一组样本不足，无法提取心拍")
        return None

    return {
        'low_mean': beats_low.mean(axis=0),
        'low_std': beats_low.std(axis=0),
        'high_mean': beats_high.mean(axis=0),
        'high_std': beats_high.std(axis=0)
    }

def plot_risk_stratified_grid(stats_dict):
    
    all_y = []
    for key in stats_dict:
        s = stats_dict[key]
        y_vals = np.concatenate([
            s['low_mean'].flatten() + s['low_std'].flatten(),
            s['low_mean'].flatten() - s['low_std'].flatten(),
            s['high_mean'].flatten() + s['high_std'].flatten(),
            s['high_mean'].flatten() - s['high_std'].flatten()
        ])
        all_y.append(y_vals)
    
    all_y = np.concatenate(all_y)
    g_min, g_max = np.percentile(all_y, 0.5), np.percentile(all_y, 99.5)
    
    center = (g_max + g_min) / 2
    span = max(g_max - g_min, 4.0) 
    ylim_range = (center - span/2 - 0.2, center + span/2 + 0.2)

    width_ratios = [10, 10, 2, 10, 10, 2, 10, 10, 2, 10, 10]
    
    fig, axes = plt.subplots(6, 11, figsize=(36, 18), dpi=120, 
                             gridspec_kw={'width_ratios': width_ratios})
    
    plt.subplots_adjust(wspace=0.1, hspace=0.15, top=0.92, bottom=0.08, left=0.03, right=0.97)
    
    spacer_indices = [2, 5, 8]
    for row in range(6):
        for col in spacer_indices:
            axes[row, col].axis('off')

    lead_indices_map = [
        (0, 0, 'I'),   (1, 0, 'II'),  (2, 0, 'III'), (3, 0, 'aVR'), (4, 0, 'aVL'), (5, 0, 'aVF'), 
        (0, 1, 'V1'),  (1, 1, 'V2'),  (2, 1, 'V3'),  (3, 1, 'V4'),  (4, 1, 'V5'),  (5, 1, 'V6')  
    ]

    t = np.linspace(-CONFIG['BEFORE_R'], CONFIG['AFTER_R'], int((CONFIG['BEFORE_R']+CONFIG['AFTER_R'])*CONFIG['FS']))

    for v_idx, vessel_cfg in enumerate(CONFIG['TARGETS']):
        v_name = vessel_cfg['name']
        stats = stats_dict.get(v_name)
        if stats is None: continue

        start_col = v_idx * 3 
        
        for lead_real_idx in range(12):
            row_offset, col_offset, label = lead_indices_map[lead_real_idx]
            ax = axes[row_offset, start_col + col_offset]

            ax.fill_between(t, 
                            stats['low_mean'][lead_real_idx]-0.5*stats['low_std'][lead_real_idx],
                            stats['low_mean'][lead_real_idx]+0.5*stats['low_std'][lead_real_idx], 
                            alpha=0.1, color='gray')
            ax.plot(t, stats['low_mean'][lead_real_idx], color='#404040', linewidth=2.0, alpha=0.9)
            
            ax.fill_between(t, 
                            stats['high_mean'][lead_real_idx]-0.5*stats['high_std'][lead_real_idx],
                            stats['high_mean'][lead_real_idx]+0.5*stats['high_std'][lead_real_idx], 
                            alpha=0.1, color='#d62728')
            ax.plot(t, stats['high_mean'][lead_real_idx], color='#d62728', linewidth=2.5, alpha=0.9)

            ax.set_ylim(ylim_range)
            ax.set_xlim(-CONFIG['BEFORE_R'], CONFIG['AFTER_R'])
            ax.xaxis.set_major_locator(ticker.MultipleLocator(0.2))
            ax.yaxis.set_major_locator(ticker.MultipleLocator(1.0)) 
            ax.grid(True, which='major', color='#d0d0d0', linewidth=0.8, linestyle='-')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.tick_params(left=False, bottom=False)
            
            for spine in ax.spines.values():
                spine.set_linewidth(1.0)
                spine.set_color('black')

            ax.text(0.04, 0.90, label, transform=ax.transAxes, 
                    fontsize=14, fontweight='bold', va='top', ha='left',
                    bbox=dict(boxstyle="square,pad=0.1", fc="white", ec="none", alpha=0.7))

    fig.canvas.draw()
    header_y = 0.95
    
    for i, cfg in enumerate(CONFIG['TARGETS']):
        start_col = i * 3
        ax0 = axes[0, start_col]
        ax1 = axes[0, start_col + 1]
        
        bbox0 = ax0.get_position()
        bbox1 = ax1.get_position()
        center_x = (bbox0.x0 + bbox1.x1) / 2
        
        title_str = f"{cfg['name']}"
        plt.figtext(center_x, header_y, title_str, ha='center', va='bottom', 
                    fontsize=28, fontweight='bold', color='black')

    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], color='#404040', lw=3.0, label='Low Risk (< Threshold)'),
        Line2D([0], [0], color='#d62728', lw=3.0, label='High Risk (≥ Threshold)')
    ]
    fig.legend(handles=legend_elements, loc='lower center', 
               bbox_to_anchor=(0.5, 0.02), ncol=2, frameon=False, fontsize=18)
    
    plt.figtext(0.5, 0.01, 
                "Risk Thresholds: 1.0% (10x Prevalence) for LM; 15% (ESC Guidelines) for RCA/LAD/LCX.", 
                ha='center', fontsize=12, style='italic', color='gray')

    plt.show()

def main():
    if not pd.io.common.file_exists(CONFIG['CSV_PATH']):
        print(f"找不到文件: {CONFIG['CSV_PATH']}")
        return

    df = pd.read_csv(CONFIG['CSV_PATH'])
    
    all_stats = {}
    print("正在计算基于风险分层的波形统计...")
    for target in CONFIG['TARGETS']:
        print(f"... Processing {target['name']} (Threshold: {target['threshold']:.1%})")
        stats = get_risk_group_stats(df, target['prob_col'], target['threshold'])
        all_stats[target['name']] = stats

    print("正在绘图...")
    plot_risk_stratified_grid(all_stats)

if __name__ == "__main__":
    main()