In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import sys
from torch.utils.data import DataLoader
import ChannelBasedTransformer_fft
from sklearn.metrics import accuracy_score,precision_score, recall_score, f1_score, average_precision_score, roc_auc_score, precision_recall_curve, auc
import os
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import datetime
from mne.io import concatenate_raws, read_raw_edf
from utils import fft
import seaborn as sns
from scipy.interpolate import interp1d
import patient_information
import ChannelBasedTransformer_fft_PE
from sklearn.model_selection import KFold
import re



class CustomDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        target = self.targets[index]
        return sample, target


def precision_at_k(y_true, k):
    # 计算前k正确率
    return np.sum(y_true[0:k])/k

def recall_at_k(y_true, k):
    # 计算召回率
    return np.sum(y_true[0:k]) / np.sum(y_true)

def f1_at_k(y_true, k):
    prec = precision_at_k(y_true, k)
    rec = recall_at_k(y_true, k)
    # 计算F1分数
    return 2 * (prec * rec) / (prec + rec) if (prec + rec) > 0 else 0

def averaged_precision_at_k(y_true, k):
    total = 0
    for i in range(k):
        total += precision_at_k(y_true, i+1)
    return total/k

In [None]:
patient_name_list = [ ]
data_agu_list = [ ]

In [None]:
for i in range(len(patient_name_list)):
    patient_name = patient_name_list[i]
    print(patient_name)
    # 获得通道名称ch_names
    patient = getattr(patient_information, patient_name)()
    position_encoding = np.load(f'../position_encoding_{patient_name}.npy')
    filename = next(iter(patient.seizure_start_dict.keys()))
    useless_chan = patient_information.useless_chan
    exclude = patient.exclude
    exclude = useless_chan + exclude
    if patient_name == 'ZhangQian' or patient_name == 'XuJunwei':
        raw = read_raw_edf(f'../{patient_name}/S1.edf',preload=False,encoding='latin1',exclude=exclude)
    else:
        raw = read_raw_edf(f'../{patient_name}/{filename}.edf',preload=False,encoding='latin1',exclude=exclude)
    ch_names = raw.ch_names
    print(ch_names)

In [None]:
# 加载可解释数据集
patient = getattr(patient_information, patient_name)()
filename = next(iter(patient.seizure_start_dict.keys()))
useless_chan = patient_information.useless_chan
exclude = patient.exclude
exclude = useless_chan + exclude

ch_names = patient.ch_names
for p in range(len(ch_names)):
    # 移除 'EEG '
    ch_names[p] = ch_names[p].replace('EEG ', '')

    # 如果字符串不以 "POL " 开头，则添加 "POL "
    if not ch_names[p].startswith('POL '):
        ch_names[p] = 'POL ' + ch_names[p]
    # 使用正则表达式移除 '-' 之后的所有字符
    ch_names[p] = re.sub('-.*', '', ch_names[p])


In [None]:
print(ch_names)

In [None]:
len(ch_names)

In [None]:
# 数据准备 发作   ***评价时推荐使用数据增强之前的原始数据，即窗宽1s，步长1s***
dataset_path = f'../{patient_name}/preprocessed_data_1s_1s(Z-Score)/'
data = np.load(os.path.join(dataset_path,'seizure/all_data.npy'))
data = np.transpose(data,(2, 0, 1))
labels = np.load(os.path.join(dataset_path,'seizure/label.npy'))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 32
n_chans = data.shape[1]
num_heads = 4
embed_dim = 64   # 64
mlp_ratio = 2
dropout = 0
num_blocks = 2
num_classes = 2
fs = 1000
fft_points = data.shape[2]*2
freq_used = 100  # 只使用 freq_used Hz以下的信息
fft_dim = int(freq_used*fft_points/fs)

In [None]:
# 定位效果定量评价    注意是不是要放大高频！注意随机种子！
if patient.PE == 1:
    position_encoding = np.load(f'../position_encoding_{patient_name}.npy')

scoring_method = 'mid'

# 定位效果定量评价
# 主文件夹路径
main_folder_path = f'../{patient_name}'
# 子文件夹名称
folds = ['fold1', 'fold2', 'fold3', 'fold4', 'fold5']
dataset_dict = {'dataset1':CustomDataset(data, labels),'dataset2':CustomDataset(data, labels),
                'dataset3':CustomDataset(data, labels),'dataset4':CustomDataset(data, labels),'dataset5':CustomDataset(data, labels)}
frequency = np.zeros([5, len(ch_names)])
average_score = np.zeros([5, len(ch_names)])
i=0

# 使用 for 循环依次访问这五个子文件夹
for fold in folds:
    i += 1
    # 构造当前 fold 的完整路径
    current_fold_path = os.path.join(main_folder_path, fold)
    if patient.PE == 1:
        model = ChannelBasedTransformer_fft_PE.ChannelBasedTransformer_fft(num_heads, fft_dim, embed_dim, mlp_ratio, dropout, n_chans, num_blocks, num_classes, position_encoding, pool = 'cls')
    else:
        model = ChannelBasedTransformer_fft.ChannelBasedTransformer_fft(num_heads, fft_dim, embed_dim, mlp_ratio, dropout, n_chans, num_blocks, num_classes, pool = 'cls')
    model.load_state_dict(torch.load(current_fold_path + '/train_fft.pth'))
    model.eval()        # 将模型切换到评估模式
    model.to(device)
    dataset = dataset_dict[f'dataset{i}']
    data_loader = DataLoader(dataset = dataset, batch_size=batch_size, shuffle=False, drop_last = False)
    # 初始化空列表以保存预测和真实标签
    all_predictions = []
    all_targets = []
    all_attention_scores_real_testset = []
    all_attention_scores_imag_testset = []
    all_attention_scores_abs_testset = []

    with torch.no_grad():
        for batch in data_loader:
            inputs, labels = batch
            _, x_real, x_imag, x_abs = fft(inputs, fs, fft_points, freq_used)
            x_real = torch.Tensor(x_real).to(device)
            x_imag = torch.Tensor(x_imag).to(device)
            x_abs = torch.Tensor(x_abs).to(device)
            inputs, labels = inputs.to(device), labels.to(device)  # 将数据移动到GPU上
            outputs,attn_weights_real, attn_weights_imag, attn_weights_abs = model(x_real, x_imag, x_abs)
            _, predicted = torch.max(outputs, 1)
            all_predictions.extend(predicted.tolist())
            all_targets.extend(labels.tolist())
            all_attention_scores_real_testset.append(attn_weights_real)
            all_attention_scores_imag_testset.append(attn_weights_imag)
            all_attention_scores_abs_testset.append(attn_weights_abs)

    # 测试集上的模型表现
    accuracy = accuracy_score(all_targets, all_predictions)
    precision = precision_score(all_targets, all_predictions, average='binary')
    recall = recall_score(all_targets, all_predictions, average='binary')
    f1 = f1_score(all_targets, all_predictions, average='binary')
    
    combined_tensor = torch.cat(all_attention_scores_real_testset, dim=0) + torch.cat(all_attention_scores_imag_testset, dim=0)
    combined_tensor = combined_tensor.cpu().numpy()
    combined_tensor_mean = np.mean(combined_tensor, axis=1)
    attention_score = combined_tensor_mean
    min_val = np.min(attention_score)
    max_val = np.max(attention_score)
    attention_score = attention_score.T
    # Min-Max标准化
    attention_score = attention_score[1:,:]   # 去除cls
    scaled_attention_score = (attention_score - min_val) / (max_val - min_val)
    if scoring_method == 'mid':
        #中位数法
        q1 = np.percentile(scaled_attention_score, 50)
        frequency_temp = np.sum(scaled_attention_score > q1, axis=1)
        frequency_temp = (frequency_temp - np.min(frequency_temp)) / (np.max(frequency_temp) - np.min(frequency_temp))
        frequency[i-1,:] = frequency_temp
        result_dict_mid = {}
        for k in range(len(ch_names)):
            result_dict_mid[ch_names[k]] = frequency_temp[k]
    else:
        #平均法
        channel_total = np.mean(scaled_attention_score, axis = 1)
        channel_total = (channel_total - np.min(channel_total)) / (np.max(channel_total) - np.min(channel_total))
        average_score[i-1,:] = channel_total
        result_dict_ave = {}
        for k in range(len(ch_names)):
            result_dict_ave[ch_names[k]] = channel_total[k]
            
if scoring_method == 'mid':
    #中位数法
    stds = np.std(frequency, axis=0)
    frequency = np.mean(frequency, axis=0)
    result_dict_mid = {}
    for k in range(len(ch_names)):
        result_dict_mid[ch_names[k]] = frequency[k]
sorted_dict = dict(sorted(result_dict_mid.items(), key=lambda item: item[1], reverse=True))
sorted_dict

In [None]:
# 计算PLV
PLV_dict = {}

# 遍历 ch_names 列表中的每个字符串和 scaled_attention_score 中的每一行
for name, scores in zip(ch_names, scaled_attention_score):
    # 将每个字符串作为字典的键，相应的 scores 作为值
    PLV_dict[name] = scores.tolist()  # 将 numpy 数组转换为 Python 列表

#中位数法
q1 = np.percentile(scaled_attention_score, 50)
# 创建一个空字典用于存储结果
positions = {}

# 遍历 patient.EZ 中的每个通道
for name in patient.EZ:
    # 检查该字符串是否在 result_dict 中
    if name in PLV_dict:
        # 获取该字符串对应的值列表
        values = PLV_dict[name]
        # 初始化一个空列表，用于存储大于 q1 的值在列表中的位置
        greater_than_q1_positions = []
        # 遍历该值列表
        for i, value in enumerate(values):
            # 如果值大于 q1，则将其位置添加到列表中
            if value > q1:
                greater_than_q1_positions.append(i)
        # 将该字符串的大于 q1 的值的位置列表存储到结果字典中
        positions[name] = greater_than_q1_positions

# 创建一个空集合用于存储所有位置列表的并集
key_second = set()

# 遍历 positions 字典中的每个键值对
for positions_list in positions.values():
    # 将当前键对应的位置列表添加到并集中
    key_second.update(positions_list)
key_second = list(key_second)

key_channel_position = []
# 遍历 patient.EZ 中的每个通道
for name in patient.EZ:
    key_channel_position.append(ch_names.index(name))

In [None]:
# 计算PLV数据准备
dataset_path = f'../preprocessed_data_1s_1s(Z-Score)/'
seizure_data = np.load(os.path.join(dataset_path,'seizure/all_data.npy'))
seizure_data = np.transpose(seizure_data,(2, 0, 1))
preseizure_data = np.load(os.path.join(dataset_path,'preseizure/all_data.npy'))
preseizure_data = np.transpose(preseizure_data,(2, 0, 1))

In [None]:
# 计算PLV
from scipy.signal import hilbert

# 相位锁定值（Phase Locking Value, PLV）单个值 无向
def phase_locking_value(x, y):
    # 计算信号的解析信号
    analytic_signal_x = hilbert(x)
    analytic_signal_y = hilbert(y)
    # 获取相位
    phases_x = np.angle(analytic_signal_x)
    phases_y = np.angle(analytic_signal_y)
    complex_phase_diff = np.exp(complex(0,1)*(phases_x - phases_y))
    plv = np.abs(np.sum(complex_phase_diff))/len(phases_x)
    return plv

seizure_data.astype(np.float16)
preseizure_data.astype(np.float16)
sei_PLV = np.zeros([seizure_data.shape[0],len(ch_names),len(ch_names)])
pre_PLV = np.zeros([preseizure_data.shape[0],len(ch_names),len(ch_names)])
for i in range(seizure_data.shape[1]):
    for j in range(len(ch_names)):
        for k in range(j, len(ch_names)):
            plv = phase_locking_value(seizure_data[:,i,j], seizure_data[:,i,k])
            sei_PLV[i,j,k] = plv
            sei_PLV[i,k,j] = plv

for i in range(preseizure_data.shape[1]):
    for j in range(len(ch_names)):
        for k in range(j, len(ch_names)):
            plv = phase_locking_value(preseizure_data[:,i,j], preseizure_data[:,i,k])
            pre_PLV[i,j,k] = plv
            pre_PLV[i,k,j] = plv


In [None]:
noez_channel = [x for x in list(range(sei_PLV.shape[1])) if x not in key_channel_position]
ez_plv_seizure = sei_PLV[key_second][:,key_channel_position][:,:,key_channel_position]
noez_plv_seizure = sei_PLV[key_second][:,noez_channel][:,:,noez_channel]
intra_seizure = sei_PLV[key_second][:,key_channel_position][:,:,noez_channel]
ez_plv_pre = pre_PLV[:,key_channel_position][:,:,key_channel_position]
noez_plv_pre = pre_PLV[:,noez_channel][:,:,noez_channel]
intra_pre = pre_PLV[:,key_channel_position][:,:,noez_channel]
print('ez_plv_seizure:',np.mean(ez_plv_seizure))
print('noez_plv_seizure:',np.mean(noez_plv_seizure))
print('intra_seizure:',np.mean(intra_seizure))
print('ez_plv_pre:',np.mean(ez_plv_pre))
print('noez_plv_pre:',np.mean(noez_plv_pre))
print('intra_pre:',np.mean(intra_pre))

In [None]:
# 热力值显示
def heatmap_visualization(scaled_attention_score, data, start_time, end_time, start_channel, end_channel, ch_names, fs = 1000, downsample_scale = 5):
    # original_data 是矩阵
    original_data = scaled_attention_score[:,start_time:end_time]

    # 创建新的 x 轴坐标（列插值）
    x_old = np.linspace(0, 1, original_data.shape[1])
    x_new = np.linspace(0, 1, original_data.shape[1]*1000)
    # 对每一行进行列插值
    expanded_data_cols = np.array([interp1d(x_old, row, kind='linear')(x_new) for row in original_data])

    # 创建新的 y 轴坐标（行插值）
    y_old = np.linspace(0, 1, original_data.shape[0])
    y_new = np.linspace(0, 1, original_data.shape[0]*10)
    # 对整个数据集进行行插值
    expanded_data1 = np.array([interp1d(y_old, expanded_data_cols[:, i], kind='linear')(y_new) for i in range(expanded_data_cols.shape[1])]).T
    # 取出前五行 下面是对前五行做一个对称填充以便通道和热力图可以对齐
    top_five_rows = expanded_data1[:5, :]
    # 将这五行上下颠倒（沿着第一个轴）
    flipped_top_five_rows = np.flipud(top_five_rows)
    # 将颠倒后的五行拼接到原矩阵的顶部
    expanded_data1 = np.vstack((flipped_top_five_rows, expanded_data1))

    # 可视化
    data = data[start_time:end_time,:,:]
    # 首先，转置数组，使其形状变为 (181, start_time:end_time, 1000)
    data_transposed = np.transpose(data, (1, 0, 2))
    # 然后，使用reshape方法将其转换为 (181, 1000*(end_time-start_time))
    reshaped_data1 = data_transposed.reshape(data_transposed.shape[0], -1)

    plt.figure(figsize=(20*(end_time-start_time)/20, 8*(end_channel-start_channel)/30))  # 设置图的大小
    # 降采样
    downsample_scale = 5
    reshaped_data = reshaped_data1[start_channel:end_channel+1, ::downsample_scale]
    fs = fs/downsample_scale
    # 取相应的热力图
    expanded_data = expanded_data1[start_channel*10:end_channel*10+10,::downsample_scale]
    # 绘制热力图
    ax = sns.heatmap(expanded_data, cmap="OrRd",shading='gouraud', vmin=0, vmax=1, cbar_kws={'label': 'Attention Score', 'orientation': 'horizontal', "pad":0.02,'shrink':0.5})   # , vmin=0, vmax=1
    # cbar_kws={'label': 'ColorbarName', #color bar的名称
    #                        'orientation': 'horizontal',#color bar的方向设置，默认为'vertical'，可水平显示'horizontal'
    #                        "ticks":np.arange(4.5,8,0.5),#color bar中刻度值范围和间隔
    #                        "format":"%.3f",#格式化输出color bar中刻度值
    #                        "pad":0.15,#color bar与热图之间距离，距离变大热图会被压缩}
    cbar = ax.collections[0].colorbar
    # 设置colorbar的刻度值和对应的标签
    cbar.set_ticks([0,1])
    cbar.set_ticklabels(['Low', 'High'])                                               

    # 绘制每一行数据，使它们在y轴上垂直偏移
    for i in range(reshaped_data.shape[0]):
        plt.plot(reshaped_data[i, :]*1.25 + 5 + i*10, color='black', linewidth=0.5)  # 偏移量可以根据需要调整    
    # 设置图表的边界
    plt.margins(x=0, y=0)
    # 设置X轴标签
    plt.xlabel("time(s)")
    plt.title("Time-Attention Heatmap")
    # 设置Y轴刻度和标签
    y_labels = ch_names[start_channel:end_channel+1]
    y_ticks = [5+i*10 for i in range(reshaped_data.shape[0])]   # np.linspace(0,expanded_data.shape[0]-10, reshaped_data.shape[0])+5  
    plt.yticks(y_ticks, labels=y_labels)
    x_ticks = np.linspace(0, int(reshaped_data.shape[1]), int(reshaped_data.shape[1]/fs)+1)
    x_labels = [str(i) for i in range(0, int(reshaped_data.shape[1]/fs)+1)]  
    plt.xticks(x_ticks, labels=x_labels, rotation=0)
    # 设置刻度的凸起朝外
    plt.tick_params(direction='out')
    # 显示图表
    plt.show()
start_time = 0
end_time = 20
start_channel=0
end_channel=167
ch_names = ch_names
heatmap_visualization(scaled_attention_score, data, start_time, end_time, start_channel, end_channel, ch_names, fs = 1000, downsample_scale = 5)

In [None]:
# 批量绘制热力图并保存
formatted_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
for i in range(data.shape[0]//20):
    start_time = i*20
    end_time = (i+1)*20
    start_channel=0
    end_channel=167        # 这里要改，通道总数
    ch_names = ch_names
    # original_data 是矩阵
    original_data = scaled_attention_score[:,start_time:end_time]

    # 创建新的 x 轴坐标（列插值）
    x_old = np.linspace(0, 1, original_data.shape[1])
    x_new = np.linspace(0, 1, original_data.shape[1]*1000)
    # 对每一行进行列插值
    expanded_data_cols = np.array([interp1d(x_old, row, kind='linear')(x_new) for row in original_data])

    # 创建新的 y 轴坐标（行插值）
    y_old = np.linspace(0, 1, original_data.shape[0])
    y_new = np.linspace(0, 1, original_data.shape[0]*10)
    # 对整个数据集进行行插值
    expanded_data1 = np.array([interp1d(y_old, expanded_data_cols[:, i], kind='linear')(y_new) for i in range(expanded_data_cols.shape[1])]).T
    # 取出前五行 下面是对前五行做一个对称填充以便通道和热力图可以对齐
    top_five_rows = expanded_data1[:5, :]
    # 将这五行上下颠倒（沿着第一个轴）
    flipped_top_five_rows = np.flipud(top_five_rows)
    # 将颠倒后的五行拼接到原矩阵的顶部
    expanded_data1 = np.vstack((flipped_top_five_rows, expanded_data1))

    # 可视化
    data_temp = data[start_time:end_time,:,:]
    # 首先，转置数组，使其形状变为 (181, start_time:end_time, 1000)
    data_transposed = np.transpose(data_temp, (1, 0, 2))
    # 然后，使用reshape方法将其转换为 (181, 1000*(end_time-start_time))
    reshaped_data1 = data_transposed.reshape(data_transposed.shape[0], -1)
    plt.figure(figsize=(20*(end_time-start_time)/20, 8*(end_channel-start_channel)/30))  # 设置图的大小
    # 降采样
    downsample_scale = 5
    reshaped_data = reshaped_data1[start_channel:end_channel+1, ::downsample_scale]
    fs1 = fs/downsample_scale
    # 取相应的热力图
    expanded_data = expanded_data1[start_channel*10:end_channel*10+10,::downsample_scale]
    # 绘制热力图
    ax = sns.heatmap(expanded_data, cmap="OrRd",shading='gouraud', vmin=0, vmax=1, cbar_kws={'label': 'Attention Score', 'orientation': 'horizontal', "pad":0.02,'shrink':0.5})   # , vmin=0, vmax=1
    # cbar_kws={'label': 'ColorbarName', #color bar的名称
    #                        'orientation': 'horizontal',#color bar的方向设置，默认为'vertical'，可水平显示'horizontal'
    #                        "ticks":np.arange(4.5,8,0.5),#color bar中刻度值范围和间隔
    #                        "format":"%.3f",#格式化输出color bar中刻度值
    #                        "pad":0.15,#color bar与热图之间距离，距离变大热图会被压缩}
    cbar = ax.collections[0].colorbar
    # 设置colorbar的刻度值和对应的标签
    cbar.set_ticks([0,1])
    cbar.set_ticklabels(['Low', 'High'])                                               
    # 绘制每一行数据，使它们在y轴上垂直偏移
    for k in range(reshaped_data.shape[0]):
        plt.plot(reshaped_data[k, :]*1.25 + 5 + k*10, color='black', linewidth=0.5)  # 偏移量可以根据需要调整    
    # 设置图表的边界
    plt.margins(x=0, y=0)
    # 设置X轴标签
    plt.xlabel("time(s)")
    plt.title("Time-Attention Heatmap")
    # 设置Y轴刻度和标签
    y_labels = ch_names[start_channel:end_channel+1]
    y_ticks = [5+k*10 for k in range(reshaped_data.shape[0])]   # np.linspace(0,expanded_data.shape[0]-10, reshaped_data.shape[0])+5  
    plt.yticks(y_ticks, labels=y_labels)
    x_ticks = np.linspace(0, int(reshaped_data.shape[1]), int(reshaped_data.shape[1]/fs1)+1)
    x_labels = [str(k) for k in range(start_time, start_time+int(reshaped_data.shape[1]/fs1)+1)]  
    plt.xticks(x_ticks, labels=x_labels, rotation=0)
    # 设置刻度的凸起朝外
    plt.tick_params(direction='out')
    plt.tight_layout()

    plt.savefig(f'../Time_attention_map/{patient_name}/{start_time}~{end_time}s {formatted_time}.png', bbox_inches='tight')
    plt.close()