In [4]:
import numpy as np
import mat73
import matplotlib.pyplot as plt
import pandas as pd

def plot_signal(signal_data, sampling_rate=400, title="单导联信号",save_path="/11.png"):
    """
    绘制单导联信号数据
    
    参数:
    signal_data: 信号数据数组，形状为(4000,)
    sampling_rate: 采样率，默认250Hz
    title: 图标题
    """
    
    # 转换为numpy数组
    signal_data = np.array(signal_data)
    
    # 生成时间轴
    n_samples = len(signal_data)
    time_axis = np.arange(n_samples) / sampling_rate
    
    # 创建图形
    plt.figure(figsize=(12, 6))
    plt.plot(time_axis, signal_data, 'b-', linewidth=0.8)
    
    # 设置标签和标题
    plt.title(title, fontsize=14)
    plt.xlabel('时间 (秒)')
    plt.ylabel('幅值')
    plt.grid(True, alpha=0.3)
    
    # 显示图形
    plt.tight_layout()
    plt.savefig(save_path)

# 读取数据
data_dict_train = mat73.loadmat('./data/traindata.mat')
data_dict_test = mat73.loadmat('./data/testdata.mat')

train_data = data_dict_train['traindata'] # (20000,4000) 0-500 AF,501-1000为NAF
test_data = data_dict_test['testdata'] # (10000,4000) 前5000为 NAF，后5000为AF数据

# plot_signal(test_data[5004], sampling_rate=400, title="模拟心电信号")


# 1、构建训练集
labels_train = np.full(train_data.shape[0], -1)  # 初始化为-1（未标记）

# 添加标签
labels_train[0:500] = 1      # 0-500为AF（房颤），标签为1
labels_train[500:1000] = 0   # 501-1000为NAF（非房颤），标签为0
# 1001-19999保持为-1（未标记）

print(f"训练集标签分布:")
print(f"AF (房颤, 标签=1): {np.sum(labels_train == 1)} 个样本")
print(f"NAF (非房颤, 标签=0): {np.sum(labels_train == 0)} 个样本") 
print(f"未标记 (标签=-1): {np.sum(labels_train == -1)} 个样本")

# 2、构建测试集
labels_test = np.full(test_data.shape[0], -1)
# 添加标签
# labels_test[0:5000] = 0      # 0-5000为NAF（非房颤），标签为0
# labels_test[5000:10000] = 1   # 5000-10000为AF（房颤），标签为1

# 利用tang师兄给的label
df = pd.read_csv("./2025cspc/true_label.csv")
labels_test = df['Prediction'].values  # 直接获取NumPy数组

print(f"测试集标签分布:")
print(f"AF (房颤, 标签=1): {np.sum(labels_test == 1)} 个样本")
print(f"NAF (非房颤, 标签=0): {np.sum(labels_test == 0)} 个样本") 
print(f"未标记 (标签=-1): {np.sum(labels_test == -1)} 个样本")

# train_data，labels_train
# test_data，labels_test

训练集标签分布:
AF (房颤, 标签=1): 500 个样本
NAF (非房颤, 标签=0): 500 个样本
未标记 (标签=-1): 19000 个样本
测试集标签分布:
AF (房颤, 标签=1): 4797 个样本
NAF (非房颤, 标签=0): 5203 个样本
未标记 (标签=-1): 0 个样本


In [5]:
# 数据过滤，（基线漂移、降噪）
from scipy import signal
import warnings
warnings.filterwarnings('ignore')

def remove_baseline_drift(ecg_signal, sampling_rate=400, cutoff=0.5):
    """
    去除基线漂移 - 高通滤波
    
    参数:
    ecg_signal: 输入心电信号
    sampling_rate: 采样频率
    cutoff: 高通滤波截止频率
    """
    nyquist = sampling_rate / 2
    normal_cutoff = cutoff / nyquist
    b, a = signal.butter(4, normal_cutoff, btype='high', analog=False)
    filtered_signal = signal.filtfilt(b, a, ecg_signal)
    return filtered_signal

def bandpass_filter(ecg_signal, sampling_rate=400, low_cutoff=0.5, high_cutoff=40):
    """
    带通滤波器 - 去除高频和低频噪声
    
    参数:
    ecg_signal: 输入信号
    sampling_rate: 采样频率
    low_cutoff: 低频截止频率 (Hz)
    high_cutoff: 高频截止频率 (Hz)
    """
    nyquist = sampling_rate / 2
    low = low_cutoff / nyquist
    high = high_cutoff / nyquist
    
    b, a = signal.butter(4, [low, high], btype='band')
    filtered_signal = signal.filtfilt(b, a, ecg_signal)
    return filtered_signal

def notch_filter(ecg_signal, sampling_rate=400, freq=50, quality=30):
    """
    陷波滤波器 - 去除工频干扰
    
    参数:
    ecg_signal: 输入信号
    sampling_rate: 采样频率
    freq: 陷波频率 (Hz)
    quality: 品质因子
    """
    b, a = signal.iirnotch(freq, quality, sampling_rate)
    filtered_signal = signal.filtfilt(b, a, ecg_signal)
    return filtered_signal

def preprocess_single_signal(ecg_signal, sampling_rate=400):
    """
    对单个信号进行完整预处理
    
    参数:
    ecg_signal: 输入心电信号
    sampling_rate: 采样频率
    """
    # 1. 去除基线漂移
    processed_signal = remove_baseline_drift(ecg_signal, sampling_rate)
    
    # 2. 带通滤波
    processed_signal = bandpass_filter(processed_signal, sampling_rate, 
                                     low_cutoff=0.5, high_cutoff=40)
    
    # 3. 陷波滤波去工频干扰
    processed_signal = notch_filter(processed_signal, sampling_rate, freq=50)
    
    return processed_signal

# plot_signal(test_data[5004], sampling_rate=400, title="row",save_path = "./test.png") # 降噪前的数据

# processed_signal = preprocess_single_signal(test_data[5004],400)

# plot_signal(processed_signal, sampling_rate=400, title="after_process",save_path = "./test_after_process.png") # 降噪前的数据

# 批量处理数据
def preprocess_batch(data_batch, sampling_rate=400):
    """
    批量预处理数据
    
    参数:
    data_batch: 形状为(n_samples, signal_length)的数据
    sampling_rate: 采样频率
    """
    print(f"开始批量预处理，数据形状: {data_batch.shape}")
    
    processed_data = np.zeros_like(data_batch)
    
    for i in range(data_batch.shape[0]):
        if (i + 1) % 1000 == 0:
            print(f"已处理 {i + 1}/{data_batch.shape[0]} 个样本")
        
        processed_data[i] = preprocess_single_signal(data_batch[i], sampling_rate)
    
    print("批量预处理完成！")
    return processed_data

X_train = preprocess_batch(train_data[0:1000,:], sampling_rate=400) # # labels_train
X_test = preprocess_batch(test_data, sampling_rate=400) # labels_test


开始批量预处理，数据形状: (1000, 4000)
已处理 1000/1000 个样本
批量预处理完成！
开始批量预处理，数据形状: (10000, 4000)
已处理 1000/10000 个样本
已处理 2000/10000 个样本
已处理 3000/10000 个样本
已处理 4000/10000 个样本
已处理 5000/10000 个样本
已处理 6000/10000 个样本
已处理 7000/10000 个样本
已处理 8000/10000 个样本
已处理 9000/10000 个样本
已处理 10000/10000 个样本
批量预处理完成！


# 1、尝试了一下计算R-R的变异性来判断AF

In [None]:
from scipy import signal
from scipy.ndimage import uniform_filter1d

def detect_r_peaks(ecg_signal, sampling_rate=400, min_distance=120):
    """
    检测ECG信号中的R峰
    
    参数:
    ecg_signal: ECG信号数据
    sampling_rate: 采样率 (Hz)
    min_distance: R峰之间的最小距离 (样本点数)
    
    返回:
    r_peaks: R峰位置的索引数组
    """
    
    # 1. 信号预处理 - 带通滤波突出QRS波群
    nyquist = sampling_rate / 2
    low = 5 / nyquist   # 5Hz高通
    high = 15 / nyquist # 15Hz低通
    
    b, a = signal.butter(4, [low, high], btype='band')
    filtered_signal = signal.filtfilt(b, a, ecg_signal)
    
    # 2. 信号平方增强QRS复合波
    squared_signal = filtered_signal ** 2
    
    # 3. 移动平均平滑
    window_size = int(0.08 * sampling_rate)  # 80ms窗口
    smoothed_signal = uniform_filter1d(squared_signal, size=window_size)
    
    # 4. 寻找峰值
    # 动态阈值：取信号均值的一定比例
    threshold = np.mean(smoothed_signal) * 0.3
    
    # 寻找高于阈值的峰值
    peaks, _ = signal.find_peaks(smoothed_signal, 
                                height=threshold,
                                distance=min_distance)
    
    # 5. 在原始信号中精确定位R峰
    r_peaks = []
    search_window = int(0.05 * sampling_rate)  # 50ms搜索窗口
    
    for peak in peaks:
        # 在峰值附近寻找原始信号的最大值
        start = max(0, peak - search_window)
        end = min(len(ecg_signal), peak + search_window)
        
        local_max_idx = np.argmax(ecg_signal[start:end])
        r_peak_idx = start + local_max_idx
        r_peaks.append(r_peak_idx)
    
    return np.array(r_peaks)

def calculate_rr_intervals(r_peaks, sampling_rate=400):
    """
    计算R-R间期
    
    参数:
    r_peaks: R峰位置数组
    sampling_rate: 采样率
    
    返回:
    rr_intervals: R-R间期数组 (单位: 秒)
    """
    
    if len(r_peaks) < 2:
        return np.array([])
    
    # 计算相邻R峰之间的时间间隔
    rr_samples = np.diff(r_peaks)
    rr_intervals = rr_samples / sampling_rate
    
    return rr_intervals

def analyze_rr_variability_v1(rr_intervals, af_threshold=0.2):
    """
    分析R-R间期变异性来判断是否为房颤
    
    参数:
    rr_intervals: R-R间期数组
    af_threshold: 房颤判断阈值
    
    返回:
    is_af: 是否为房颤 (True/False)
    variability_metrics: 变异性指标字典
    """
    
    if len(rr_intervals) < 3:
        return False, {"reason": "R-R间期过少，无法判断"}
    
    # 计算多个变异性指标
    
    # 1. 变异系数 (CV) - 标准差与均值的比值
    mean_rr = np.mean(rr_intervals)
    std_rr = np.std(rr_intervals)
    cv = std_rr / mean_rr if mean_rr > 0 else 0
    
    # 2. RMSSD - 相邻R-R间期差值的均方根
    rr_diffs = np.diff(rr_intervals)
    rmssd = np.sqrt(np.mean(rr_diffs ** 2))
    
    # 3. pNN50 - 相邻R-R间期差值大于50ms的百分比
    pnn50 = np.sum(np.abs(rr_diffs) > 0.05) / len(rr_diffs) * 100
    
    # 4. 最大最小R-R间期比值
    if np.min(rr_intervals) > 0:
        rr_range_ratio = np.max(rr_intervals) / np.min(rr_intervals)
    else:
        rr_range_ratio = 0
    
    # 5. 连续R-R间期变化的标准差
    successive_diff_std = np.std(rr_diffs)
    
    # 变异性指标
    metrics = {
        "mean_rr": mean_rr,
        "std_rr": std_rr,
        "cv": cv,
        "rmssd": rmssd,
        "pnn50": pnn50,
        "rr_range_ratio": rr_range_ratio,
        "successive_diff_std": successive_diff_std
    }
    
    # 房颤判断逻辑 - 多指标综合判断
    af_indicators = 0
    
    # 指标1: 变异系数过大
    if cv > af_threshold:
        af_indicators += 1
    
    # 指标2: RMSSD过大
    if rmssd > 0.08:  # 80ms
        af_indicators += 1
    
    # 指标3: pNN50过大
    if pnn50 > 20:  # 20%
        af_indicators += 1
    
    # 指标4: R-R间期范围比过大
    if rr_range_ratio > 2.0:
        af_indicators += 1
    
    # 指标5: 连续差值标准差过大
    if successive_diff_std > 0.06:  # 60ms
        af_indicators += 1
    
    # 如果有2个或以上指标异常，判断为房颤
    is_af = af_indicators >= 2
    metrics["af_indicators"] = af_indicators
    
    return is_af, metrics

def analyze_rr_variability_v2(rr_intervals, cv_threshold=0.15):
    """
    简化版R-R间期变异性分析
    
    参数:
    rr_intervals: R-R间期数组
    cv_threshold: 变异系数阈值
    
    返回:
    is_af: 是否为房颤 (True/False)
    metrics: 基本变异性指标
    """
    
    if len(rr_intervals) < 3:
        return False, {"reason": "数据不足"}
    
    # 计算基本统计指标
    mean_rr = np.mean(rr_intervals)
    std_rr = np.std(rr_intervals)
    cv = std_rr / mean_rr if mean_rr > 0 else 0
    
    # 计算相邻R-R间期差值的均方根 (RMSSD)
    rr_diffs = np.diff(rr_intervals)
    rmssd = np.sqrt(np.mean(rr_diffs ** 2))
    
    # 基本指标
    metrics = {
        "mean_rr": round(mean_rr, 3),
        "cv": round(cv, 3),
        "rmssd": round(rmssd, 3)
    }
    
    # 简单判断：变异系数大于阈值且RMSSD异常则判断为房颤
    is_af = cv > cv_threshold and rmssd > 0.08
    
    return is_af, metrics


# 可视化
def visualize_af_detection(ecg_signal, r_peaks, rr_intervals, is_af, metrics, sampling_rate=400):
    """
    可视化房颤检测结果
    """
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # 1. ECG信号和R峰
    time_axis = np.arange(len(ecg_signal)) / sampling_rate
    ax1.plot(time_axis, ecg_signal, 'b-', linewidth=0.8, alpha=0.7)
    ax1.plot(r_peaks / sampling_rate, ecg_signal[r_peaks], 'ro', markersize=6)
    ax1.set_title(f'ECG信号和R峰检测 (检测到{len(r_peaks)}个R峰)')
    ax1.set_xlabel('时间 (秒)')
    ax1.set_ylabel('幅值')
    ax1.grid(True, alpha=0.3)
    
    # 2. R-R间期序列
    if len(rr_intervals) > 0:
        ax2.plot(rr_intervals * 1000, 'go-', linewidth=1, markersize=4)
        ax2.set_title('R-R间期序列')
        ax2.set_xlabel('R-R间期索引')
        ax2.set_ylabel('R-R间期 (ms)')
        ax2.grid(True, alpha=0.3)
        
        # 添加均值线
        mean_rr = np.mean(rr_intervals) * 1000
        ax2.axhline(y=mean_rr, color='r', linestyle='--', alpha=0.7, label=f'均值: {mean_rr:.1f}ms')
        ax2.legend()
    
    # 3. R-R间期直方图
    if len(rr_intervals) > 0:
        ax3.hist(rr_intervals * 1000, bins=min(20, len(rr_intervals)), alpha=0.7, color='skyblue')
        ax3.set_title('R-R间期分布')
        ax3.set_xlabel('R-R间期 (ms)')
        ax3.set_ylabel('频次')
        ax3.grid(True, alpha=0.3)
    
    # 4. 变异性指标
    if len(rr_intervals) > 0:
        metric_names = ['CV', 'RMSSD(ms)', 'pNN50(%)', 'RR比值', '连续差值STD(ms)']
        metric_values = [
            metrics['cv'],
            metrics['rmssd'] * 1000,
            metrics['pnn50'],
            metrics['rr_range_ratio'],
            metrics['successive_diff_std'] * 1000
        ]
        
        colors = ['red' if is_af else 'green' for _ in metric_values]
        bars = ax4.bar(range(len(metric_names)), metric_values, color=colors, alpha=0.7)
        ax4.set_title(f'变异性指标 (异常指标数: {metrics["af_indicators"]})')
        ax4.set_xticks(range(len(metric_names)))
        ax4.set_xticklabels(metric_names, rotation=45)
        ax4.grid(True, alpha=0.3)
        
        # 添加数值标签
        for bar, value in zip(bars, metric_values):
            ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{value:.2f}', ha='center', va='bottom', fontsize=8)
    
    # 添加总体结果
    result_text = "房颤 (AF)" if is_af else "窦性心律 (NAF)"
    result_color = "red" if is_af else "green"
    fig.suptitle(f'房颤检测结果: {result_text}', fontsize=16, color=result_color, fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # 打印详细指标
    print(f"\n=== 检测结果 ===")
    print(f"结果: {result_text}")
    print(f"R峰数量: {len(r_peaks)}")
    print(f"R-R间期数量: {len(rr_intervals)}")
    if len(rr_intervals) > 0:
        print(f"平均R-R间期: {metrics['mean_rr']*1000:.1f} ms")
        print(f"变异系数: {metrics['cv']:.3f}")
        print(f"RMSSD: {metrics['rmssd']*1000:.1f} ms")
        print(f"pNN50: {metrics['pnn50']:.1f}%")
        print(f"异常指标数: {metrics['af_indicators']}/5")

def detect_af(ecg_signal, sampling_rate=400, af_threshold=0.25, visualize=False):
    """
    房颤检测主函数
    
    参数:
    ecg_signal: 4000维ECG信号数据
    sampling_rate: 采样率 (Hz)
    af_threshold: 房颤判断阈值
    visualize: 是否可视化结果
    
    返回:
    af_result: 0表示NAF，1表示AF
    """
    
    try:
        # 1. R峰检测
        r_peaks = detect_r_peaks(ecg_signal, sampling_rate)
        
        if len(r_peaks) < 3:
            # R峰过少，默认判断为NAF
            if visualize:
                print("警告: 检测到的R峰过少，默认判断为NAF")
            return 0
        
        # 2. 计算R-R间期
        rr_intervals = calculate_rr_intervals(r_peaks, sampling_rate)
        
        # 3. 分析R-R间期变异性
        is_af, metrics = analyze_rr_variability_v1(rr_intervals, af_threshold)
        
        # 4. 可视化结果
        if visualize:
            visualize_af_detection(ecg_signal, r_peaks, rr_intervals, 
                                 is_af, metrics, sampling_rate)
        
        # 5. 返回结果
        return 1 if is_af else 0
        
    except Exception as e:
        print(f"房颤检测出错: {e}")
        return 0  # 出错时默认返回NAF
    
# result_af = detect_af(X_test[5003], sampling_rate=400, visualize=True)
# print(f"房颤检测结果: {result_af} (0=NAF, 1=AF)")

# 测试所有样本

pred_test = []
for i in range(len(X_test)):
    pred_test.append(detect_af(X_test[i], sampling_rate=400, visualize=False))

pred = np.array(pred_test)
print(pred.shape)


from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns

def calculate_metrics(test_label, pred):
    """
    计算二分类评估指标
    
    参数:
    test_label: 真实标签数组，0表示NAF，1表示AF
    pred: 预测标签数组，0表示NAF，1表示AF
    
    返回:
    dict: 包含所有评估指标的字典
    """
    
    # 确保输入为numpy数组
    test_label = np.array(test_label)
    pred = np.array(pred)
    
    # 检查维度是否相同
    if test_label.shape != pred.shape:
        raise ValueError("test_label和pred的维度必须相同")
    
    # 计算各项指标
    acc = accuracy_score(test_label, pred)
    precision = precision_score(test_label, pred, pos_label=1)
    recall = recall_score(test_label, pred, pos_label=1)
    f1 = f1_score(test_label, pred, pos_label=1)
    
    # 计算混淆矩阵
    cm = confusion_matrix(test_label, pred)
    
    # 手动计算验证（可选）
    tn, fp, fn, tp = cm.ravel()
    
    # 验证计算
    acc_manual = (tp + tn) / (tp + tn + fp + fn)
    precision_manual = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall_manual = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1_manual = 2 * (precision_manual * recall_manual) / (precision_manual + recall_manual) if (precision_manual + recall_manual) > 0 else 0
    
    # 结果字典
    results = {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'confusion_matrix': cm,
        'true_negatives': tn,
        'false_positives': fp,
        'false_negatives': fn,
        'true_positives': tp,
        'manual_verification': {
            'accuracy': acc_manual,
            'precision': precision_manual,
            'recall': recall_manual,
            'f1_score': f1_manual
        }
    }
    
    return results

def print_results(results):
    """打印评估结果"""
    print("=" * 50)
    print("二分类评估指标结果")
    print("=" * 50)
    print(f"准确率 (Accuracy):     {results['accuracy']:.4f}")
    print(f"精确率 (Precision):    {results['precision']:.4f}")
    print(f"召回率 (Recall):       {results['recall']:.4f}")
    print(f"F1分数 (F1-Score):     {results['f1_score']:.4f}")
    print("\n混淆矩阵:")
    print(f"真阴性 (TN): {results['true_negatives']}")
    print(f"假阳性 (FP): {results['false_positives']}")
    print(f"假阴性 (FN): {results['false_negatives']}")
    print(f"真阳性 (TP): {results['true_positives']}")
    print("\n混淆矩阵 (行为真实标签，列为预测标签):")
    print("        预测")
    print("      NAF  AF")
    print(f"真实 NAF {results['confusion_matrix'][0,0]:3d} {results['confusion_matrix'][0,1]:3d}")
    print(f"    AF  {results['confusion_matrix'][1,0]:3d} {results['confusion_matrix'][1,1]:3d}")

def plot_confusion_matrix(cm, class_names=['NAF', 'AF']):
    """绘制混淆矩阵热力图"""
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title('混淆矩阵')
    plt.xlabel('预测标签')
    plt.ylabel('真实标签')
    plt.show()

# test_label = [0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0]
# pred =       [0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0]

# 计算指标
results = calculate_metrics(labels_test, pred)

# 打印结果
print_results(results)

# 绘制混淆矩阵
plot_confusion_matrix(results['confusion_matrix'])

# 2、尝试深度学习的方法

## 2.1 Masksub

In [2]:
# 使用模型
from dataset import AF_1lead_cspc_Dataset
from torch.utils.data import Dataset, DataLoader
from finetune_model import ft_12lead_ECGFounder, ft_1lead_ECGFounder
import torch
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, \
    roc_auc_score, f1_score,auc,roc_curve,multilabel_confusion_matrix
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn

def my_eval_with_dynamic_thresh(gt, pred):
    """
    Evaluates the model with dynamically adjusted thresholds for each task.

    Args:
        gt: Ground truth labels (numpy array)
        pred: Prediction probabilities (numpy array)

    Returns:
        - Overall mean of the metrics across all tasks
        - Per-metric mean across all tasks (as a list)
        - All metrics per task in a columnar format
    """
    optimal_thresholds = find_optimal_thresholds(gt, pred)
    n_task = gt.shape[1]
    rocaucs = []
    sensitivities = []
    specificities = []
    f1 = []
    auprcs = []  # Step 2: Initialize list for AUPRC

    for i in range(n_task):
        tmp_gt = np.nan_to_num(gt[:, i], nan=0)
        tmp_pred = np.nan_to_num(pred[:, i], nan=0)

        # ROC-AUC
        try:
            rocaucs.append(roc_auc_score(tmp_gt, tmp_pred))
        except:
            rocaucs.append(0.0)

        # AUPRC  # Step 3: Calculate AUPRC
        try:
            auprc = average_precision_score(tmp_gt, tmp_pred)
            auprcs.append(auprc)
        except:
            auprcs.append(0.0)

        # Sensitivity and Specificity
        pred_labels = (tmp_pred > optimal_thresholds[i]).astype(int)
        # pred_labels = (tmp_pred > 0.5).astype(int)
        cm = confusion_matrix(tmp_gt, pred_labels).ravel()
        
        # Handle different sizes of confusion matrix
        if len(cm) == 1:
            # Only one class present in predictions
            if pred_labels.sum() == 0:  # Only negative class predicted
                tn, fp, fn, tp = cm[0], 0, 0, 0
            else:                       # Only positive class predicted
                tn, fp, fn, tp = 0, 0, 0, cm[0]
        else:
            tn, fp, fn, tp = cm

        # Calculate Sensitivity (True Positive Rate)
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
        sensitivities.append(sensitivity)
        
        # Calculate Specificity (True Negative Rate)
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        specificities.append(specificity)

        f1s = f1_score(tmp_gt, pred_labels)
        f1.append(f1s)
    
    # Convert lists to numpy arrays
    rocaucs = np.array(rocaucs)
    sensitivities = np.array(sensitivities)
    specificities = np.array(specificities)
    f1 = np.array(f1)
    auprcs = np.array(auprcs)  # Step 4: Compute mean AUPRC

    # Calculate means for each metric
    mean_rocauc = np.mean(rocaucs)
    mean_auprc = np.mean(auprcs)  # Step 4: Compute mean AUPRC

    # Step 5: Update return statement
    return mean_rocauc, rocaucs, sensitivities, specificities, f1, auprcs, optimal_thresholds

def calculate_metrics_v1(test_label, pred):
    """
    计算二分类评估指标
    
    参数:
    test_label: 真实标签数组，0表示NAF，1表示AF
    pred: 预测标签数组，0表示NAF，1表示AF
    
    返回:
    dict: 包含所有评估指标的字典
    """
    
    # 确保输入为numpy数组
    test_label = np.array(test_label)
    pred = np.array(pred)
    
    # 检查维度是否相同
    if test_label.shape != pred.shape:
        raise ValueError("test_label和pred的维度必须相同")
    
    # 计算各项指标
    acc = accuracy_score(test_label, pred)
    precision = precision_score(test_label, pred, pos_label=1)
    recall = recall_score(test_label, pred, pos_label=1)
    f1 = f1_score(test_label, pred, pos_label=1)
    
    # 计算混淆矩阵
    cm = confusion_matrix(test_label, pred)
    
    # 手动计算验证（可选）
    tn, fp, fn, tp = cm.ravel()
    
    # 验证计算
    acc_manual = (tp + tn) / (tp + tn + fp + fn)
    precision_manual = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall_manual = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1_manual = 2 * (precision_manual * recall_manual) / (precision_manual + recall_manual) if (precision_manual + recall_manual) > 0 else 0
    
    # 结果字典
    results = {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'confusion_matrix': cm,
        'true_negatives': tn,
        'false_positives': fp,
        'false_negatives': fn,
        'true_positives': tp,
        'manual_verification': {
            'accuracy': acc_manual,
            'precision': precision_manual,
            'recall': recall_manual,
            'f1_score': f1_manual
        }
    }
    
    return results

def print_results(results):
    """打印评估结果"""
    print("=" * 50)
    print("二分类评估指标结果")
    print("=" * 50)
    print(f"准确率 (Accuracy):     {results['accuracy']:.4f}")
    print(f"精确率 (Precision):    {results['precision']:.4f}")
    print(f"召回率 (Recall):       {results['recall']:.4f}")
    print(f"F1分数 (F1-Score):     {results['f1_score']:.4f}")
    print("\n混淆矩阵:")
    print(f"真阴性 (TN): {results['true_negatives']}")
    print(f"假阳性 (FP): {results['false_positives']}")
    print(f"假阴性 (FN): {results['false_negatives']}")
    print(f"真阳性 (TP): {results['true_positives']}")
    print("\n混淆矩阵 (行为真实标签，列为预测标签):")
    print("        预测")
    print("      NAF  AF")
    print(f"真实 NAF {results['confusion_matrix'][0,0]:3d} {results['confusion_matrix'][0,1]:3d}")
    print(f"    AF  {results['confusion_matrix'][1,0]:3d} {results['confusion_matrix'][1,1]:3d}")

# 1、加载数据集
Train_dataset = AF_1lead_cspc_Dataset(ecg_data=X_train,labels=labels_train[0:1000])
Test_dataset = AF_1lead_cspc_Dataset(ecg_data=X_test,labels=labels_test)

testloader = DataLoader(Test_dataset, batch_size=16, drop_last=True,shuffle=False)
trainloader = DataLoader(Train_dataset, batch_size=16, drop_last=True,shuffle=False)

# 2、加载参数
n_classes = 150
gpu_id = 4
lr = 1e-4
weight_decay = 1e-5
early_stop_lr = 1e-5
Epochs = 5

device = torch.device('cuda:{}'.format(gpu_id) if torch.cuda.is_available() else 'cpu')


# 3、加载模型参数
pth = './checkpoint/1_lead_ECGFounder.pth'
model = ft_1lead_ECGFounder(device, pth, n_classes,linear_prob=False)

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.1, mode='max', verbose=True)
criterion = nn.BCEWithLogitsLoss()



# 4、训练模型、重新训练模型
model.train()
prog_iter_train = tqdm(trainloader, desc="Training", leave=False)

best_val_auroc = 0.
step = 0
current_lr = lr
all_res = []
pos_neg_counts = {}
total_steps_per_epoch = len(trainloader)
eval_steps = total_steps_per_epoch


# for epoch in range(Epochs):
#     ### train
#     for batch in tqdm(trainloader,desc='Training'):
#         input_x, input_y = tuple(t.to(device) for t in batch)
#         outputs = model(input_x)
#         loss = criterion(outputs, input_y)
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         step += 1

#     # test
        
#     model.eval()
    
#     prog_iter_val = tqdm(testloader, desc="Validation", leave=False)
#     all_gt = []
#     all_pred_prob = []
#     with torch.no_grad():
#         for batch_idx, batch in enumerate(prog_iter_val):
#             input_x, input_y = tuple(t.to(device) for t in batch)
#             logits = model(input_x)
#             pred = torch.sigmoid(logits)
#             all_pred_prob.append(pred.cpu().data.numpy())
#             all_gt.append(input_y.cpu().data.numpy())
#     all_pred_prob = np.concatenate(all_pred_prob)
#     all_gt = np.concatenate(all_gt)
#     all_gt = np.array(all_gt)


#     af_pred_test = (all_pred_prob >= 0.5)
#     result_test = calculate_metrics_v1(all_gt,af_pred_test)
#     print_results(result_test)

    # print(all_pred_prob.shape) # 
    # print(all_gt.shape) # 
    

    

    # print('Epoch {} step {}, val: {:.4f}'.format(epoch, step, res_val))


model.eval()
prog_iter_test = tqdm(testloader, desc="Testing", leave=False)
all_gt = [] # label
all_pred_prob = [] # pred score
with torch.no_grad():
    for batch_idx, batch in enumerate(prog_iter_test):
        input_x, input_y = tuple(t.to(device) for t in batch)
        pred = model(input_x)

        pred = torch.sigmoid(pred) # sigmoid
        all_pred_prob.append(pred.cpu().data.numpy())
        all_gt.append(input_y.cpu().data.numpy())

all_pred_prob = np.concatenate(all_pred_prob) # score
all_gt = np.concatenate(all_gt) # label
all_gt = np.array(all_gt)

# print result
# fpr, tpr, th = roc_curve(all_gt, all_pred_prob[:, 5])
# roc_auc = auc(fpr, tpr)
# print("AF auc is :{}".format(roc_auc))
# plt.figure()
# plt.plot(fpr ,tpr, label=f'AVB (AUC = {roc_auc:.2f})')
af_pred = (all_pred_prob[:, 5] >= 0.45) # 大于阈值即为这个 (400,1)



NameError: name 'X_train' is not defined

In [None]:
# 评估
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns

def calculate_metrics(test_label, pred):
    """
    计算二分类评估指标
    
    参数:
    test_label: 真实标签数组，0表示NAF，1表示AF
    pred: 预测标签数组，0表示NAF，1表示AF
    
    返回:
    dict: 包含所有评估指标的字典
    """
    
    # 确保输入为numpy数组
    test_label = np.array(test_label)
    pred = np.array(pred)
    
    # 检查维度是否相同
    if test_label.shape != pred.shape:
        raise ValueError("test_label和pred的维度必须相同")
    
    # 计算各项指标
    acc = accuracy_score(test_label, pred)
    precision = precision_score(test_label, pred, pos_label=1)
    recall = recall_score(test_label, pred, pos_label=1)
    f1 = f1_score(test_label, pred, pos_label=1)
    
    # 计算混淆矩阵
    cm = confusion_matrix(test_label, pred)
    
    # 手动计算验证（可选）
    tn, fp, fn, tp = cm.ravel()
    
    # 验证计算
    acc_manual = (tp + tn) / (tp + tn + fp + fn)
    precision_manual = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall_manual = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1_manual = 2 * (precision_manual * recall_manual) / (precision_manual + recall_manual) if (precision_manual + recall_manual) > 0 else 0
    
    # 结果字典
    results = {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'confusion_matrix': cm,
        'true_negatives': tn,
        'false_positives': fp,
        'false_negatives': fn,
        'true_positives': tp,
        'manual_verification': {
            'accuracy': acc_manual,
            'precision': precision_manual,
            'recall': recall_manual,
            'f1_score': f1_manual
        }
    }
    
    return results

def print_results(results):
    """打印评估结果"""
    print("=" * 50)
    print("二分类评估指标结果")
    print("=" * 50)
    print(f"准确率 (Accuracy):     {results['accuracy']:.4f}")
    print(f"精确率 (Precision):    {results['precision']:.4f}")
    print(f"召回率 (Recall):       {results['recall']:.4f}")
    print(f"F1分数 (F1-Score):     {results['f1_score']:.4f}")
    print("\n混淆矩阵:")
    print(f"真阴性 (TN): {results['true_negatives']}")
    print(f"假阳性 (FP): {results['false_positives']}")
    print(f"假阴性 (FN): {results['false_negatives']}")
    print(f"真阳性 (TP): {results['true_positives']}")
    print("\n混淆矩阵 (行为真实标签，列为预测标签):")
    print("        预测")
    print("      NAF  AF")
    print(f"真实 NAF {results['confusion_matrix'][0,0]:3d} {results['confusion_matrix'][0,1]:3d}")
    print(f"    AF  {results['confusion_matrix'][1,0]:3d} {results['confusion_matrix'][1,1]:3d}")

def plot_confusion_matrix(cm, class_names=['NAF', 'AF']):
    """绘制混淆矩阵热力图"""
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title('混淆矩阵')
    plt.xlabel('预测标签')
    plt.ylabel('真实标签')
    plt.show()

# test_label = [0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0]
# pred =       [0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0]

# 计算指标
results = calculate_metrics(labels_test, pred)

# 打印结果
print_results(results)

# 绘制混淆矩阵
plot_confusion_matrix(results['confusion_matrix'])



In [None]:
def calculate_metrics(test_label, pred):
    """
    计算二分类评估指标
    
    参数:
    test_label: 真实标签数组，0表示NAF，1表示AF
    pred: 预测标签数组，0表示NAF，1表示AF
    
    返回:
    dict: 包含所有评估指标的字典
    """
    
    # 确保输入为numpy数组
    test_label = np.array(test_label)
    pred = np.array(pred)
    
    # 检查维度是否相同
    if test_label.shape != pred.shape:
        raise ValueError("test_label和pred的维度必须相同")
    
    # 计算各项指标
    acc = accuracy_score(test_label, pred)
    precision = precision_score(test_label, pred, pos_label=1)
    recall = recall_score(test_label, pred, pos_label=1)
    f1 = f1_score(test_label, pred, pos_label=1)
    
    # 计算混淆矩阵
    cm = confusion_matrix(test_label, pred)
    
    # 手动计算验证（可选）
    tn, fp, fn, tp = cm.ravel()
    
    # 验证计算
    acc_manual = (tp + tn) / (tp + tn + fp + fn)
    precision_manual = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall_manual = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1_manual = 2 * (precision_manual * recall_manual) / (precision_manual + recall_manual) if (precision_manual + recall_manual) > 0 else 0
    
    # 结果字典
    results = {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'confusion_matrix': cm,
        'true_negatives': tn,
        'false_positives': fp,
        'false_negatives': fn,
        'true_positives': tp,
        'manual_verification': {
            'accuracy': acc_manual,
            'precision': precision_manual,
            'recall': recall_manual,
            'f1_score': f1_manual
        }
    }
    
    return results

def print_results(results):
    """打印评估结果"""
    print("=" * 50)
    print("二分类评估指标结果")
    print("=" * 50)
    print(f"准确率 (Accuracy):     {results['accuracy']:.4f}")
    print(f"精确率 (Precision):    {results['precision']:.4f}")
    print(f"召回率 (Recall):       {results['recall']:.4f}")
    print(f"F1分数 (F1-Score):     {results['f1_score']:.4f}")
    print("\n混淆矩阵:")
    print(f"真阴性 (TN): {results['true_negatives']}")
    print(f"假阳性 (FP): {results['false_positives']}")
    print(f"假阴性 (FN): {results['false_negatives']}")
    print(f"真阳性 (TP): {results['true_positives']}")
    print("\n混淆矩阵 (行为真实标签，列为预测标签):")
    print("        预测")
    print("      NAF  AF")
    print(f"真实 NAF {results['confusion_matrix'][0,0]:3d} {results['confusion_matrix'][0,1]:3d}")
    print(f"    AF  {results['confusion_matrix'][1,0]:3d} {results['confusion_matrix'][1,1]:3d}")

def plot_confusion_matrix(cm, class_names=['NAF', 'AF']):
    """绘制混淆矩阵热力图"""
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title('混淆矩阵')
    plt.xlabel('预测标签')
    plt.ylabel('真实标签')
    plt.show()


print(all_gt.shape)
print(af_pred.shape)
# 打印结果
calculate_metrics(all_gt.reshape(10000), af_pred)