In [7]:
import numpy as np
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, \
    roc_auc_score, f1_score,auc,roc_curve,multilabel_confusion_matrix
import os
from shutil import copyfile
import torch
import h5py
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import json
from util import save_checkpoint, save_reg_checkpoint, my_eval_with_dynamic_thresh
from finetune_model import ft_12lead_ECGFounder, ft_1lead_ECGFounder
from sklearn.model_selection import train_test_split
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

def evaluate_predictions(scores, th, labels):
    """
    输入:
    scores - 分数数组，形状[400,1]
    th - 阈值，大于th的预测为1，否则为0
    labels - 真实标签数组，形状[400,1]，前200为1，后200为0
    
    输出:
    predictions - 预测标签数组
    confusion_matrix - 混淆矩阵
    metrics - 评估指标
    """
    predictions = (scores > th).astype(int)
    labels = np.array(labels).reshape(-1, 1)
    
    TP = np.sum((predictions == 1) & (labels == 1))
    FP = np.sum((predictions == 1) & (labels == 0))
    FN = np.sum((predictions == 0) & (labels == 1))
    TN = np.sum((predictions == 0) & (labels == 0))
    
    confusion_mat = np.array([[TP, FP], [FN, TN]])
    
    accuracy = (TP + TN) / (TP + FP + FN + TN)
    precision = TP / (TP + FP) if (TP + FP) != 0 else 0
    recall = TP / (TP + FN) if (TP + FN) != 0 else 0
    specificity = TN / (TN + FP) if (TN + FP) != 0 else 0  # 特异度
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0
    
    metrics = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'specificity': specificity,  # 特异度
        'f1_score': f1_score,
        'confusion_matrix': confusion_mat
    }
    print(f"TP is {TP}")
    print(f"FP is {FP}")
    print(f"FN is {FN}")
    print(f"TN is {TN}")

    for k, v in metrics.items():
        if k != 'confusion_matrix':
            print(f"{k}: {v:.4f}")

    return metrics

## Load data

In [11]:
from dataset import Chaoyang_Dataset

num_lead = 12 # 12-lead ECG or 1-lead ECG 

gpu_id = 4
batch_size = 1
lr = 1e-4
weight_decay = 1e-5
early_stop_lr = 1e-5
Epochs = 5

device = torch.device('cuda:{}'.format(gpu_id) if torch.cuda.is_available() else 'cpu')

tasks = ['class']
n_classes = 150

# read chaoyang data
with h5py.File('./data/chaoyan_AF_6x2.h5', 'r') as f:
    data1 = f['data'][:]
    label1 = f['labels'][:]

with h5py.File('./data/chaoyan_AF_12x1.h5', 'r') as f:
    data2 = f['data'][:]
    label2 = f['labels'][:]

with h5py.File('./data/chaoyan_NAF.h5', 'r') as f:
    data3 = f['data'][:]
    label3 = f['labels'][:]

data = np.concatenate((data1,data2,data3),axis=0)
data = np.moveaxis(data, 1, 2)
#label = np.concatenate((label1,label2,label3),axis=0)

# AF 是第6个数据，在对应的表中,只关注1和5即可
# 正常窦性心律是第2个数据
label = np.zeros((400,150))
label[0:200,5]=1
label[200:400,1]=1

# print(label[0])
# print(label[200])
tensor_label = torch.from_numpy(label)

print(data.shape)
print(label.shape)

ECGdataset = Chaoyang_Dataset(data=data,label=tensor_label)
pth = './checkpoint/12_lead_ECGFounder.pth'
model = ft_12lead_ECGFounder(device, pth, n_classes,linear_prob=False)

testloader = DataLoader(ECGdataset, batch_size=16, drop_last=True,shuffle=False)

(400, 12, 1000)
(400, 150)


  checkpoint = torch.load(pth, map_location=device)


## test chaoyang hospital

In [18]:
  
# test 
# 400 x 1
model.eval()
prog_iter_test = tqdm(testloader, desc="Testing", leave=False)
all_gt = [] # label
all_pred_prob = [] # pred score
with torch.no_grad():
    for batch_idx, batch in enumerate(prog_iter_test):
        input_x, input_y = tuple(t.to(device) for t in batch)
        pred = model(input_x)
        pred = torch.sigmoid(pred) # sigmoid
        all_pred_prob.append(pred.cpu().data.numpy())
        all_gt.append(input_y.cpu().data.numpy())

all_pred_prob = np.concatenate(all_pred_prob) # score
all_gt = np.concatenate(all_gt) # label

all_gt = np.array(all_gt)

# print result
fpr, tpr, th = roc_curve(all_gt[:, 5], all_pred_prob[:, 5])
roc_auc = auc(fpr, tpr)
print("AF auc is :{}".format(roc_auc))
# plt.figure()
# plt.plot(fpr ,tpr, label=f'AVB (AUC = {roc_auc:.2f})')
af_pred = (all_pred_prob[:, 5] >= 0.5) # 大于阈值即为这个 (400,1)

ecg_founder_400 = evaluate_predictions(all_pred_prob[:, 5].reshape(-1,1), th=0.62, labels=all_gt[:,5].reshape(400,1))
    
# np.savetxt("./chaoyang.csv",af_pred, delimiter=",", fmt="%.6f")

  return (torch.tensor(signal, dtype=torch.float), torch.tensor(self.label[index], dtype=torch.float))
                                                        

AF auc is :0.6606000000000001
TP is 144
FP is 94
FN is 56
TN is 106
accuracy: 0.6250
precision: 0.6050
recall: 0.7200
specificity: 0.5300
f1_score: 0.6575




In [21]:
# 300 12 x 1
  
# test
model.eval()
prog_iter_test = tqdm(testloader, desc="Testing", leave=False)
all_gt = [] # label
all_pred_prob = [] # pred score
with torch.no_grad():
    for batch_idx, batch in enumerate(prog_iter_test):
        input_x, input_y = tuple(t.to(device) for t in batch)
        pred = model(input_x)
        pred = torch.sigmoid(pred) # sigmoid
        all_pred_prob.append(pred.cpu().data.numpy())
        all_gt.append(input_y.cpu().data.numpy())

all_pred_prob = np.concatenate(all_pred_prob) # score
all_gt = np.concatenate(all_gt) # label

all_gt = np.array(all_gt)
print(all_gt.shape)
# print result
fpr, tpr, th = roc_curve(all_gt[100:, 5], all_pred_prob[100:, 5])
roc_auc = auc(fpr, tpr)
print("AF auc is :{}".format(roc_auc))
# plt.figure()
# plt.plot(fpr ,tpr, label=f'AVB (AUC = {roc_auc:.2f})')
af_pred = (all_pred_prob[100:, 5] >= 0.5) # 大于阈值即为这个 (400,1)

ecg_founder_300 = evaluate_predictions(all_pred_prob[100:, 5].reshape(-1,1), th=0.6, labels=all_gt[100:,5].reshape(300,1))

# np.savetxt("./chaoyang.csv",af_pred, delimiter=",", fmt="%.6f")

  return (torch.tensor(signal, dtype=torch.float), torch.tensor(self.label[index], dtype=torch.float))
                                                        

(400, 150)
AF auc is :0.70285
TP is 86
FP is 114
FN is 14
TN is 86
accuracy: 0.5733
precision: 0.4300
recall: 0.8600
specificity: 0.4300
f1_score: 0.5733


