In [9]:
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn import metrics
import torch.nn.functional as F
import numpy as np
import math

# 从 focalLoss.py 导入 FocalLoss_v2
class FocalLoss_v2(nn.Module):
    def __init__(self, num_class=2, gamma=2, alpha=None):
        super(FocalLoss_v2, self).__init__()
        self.gamma = gamma
        self.num_class = num_class
        if alpha is None:
            self.alpha = torch.ones(num_class)
        else:
            self.alpha = alpha

    def forward(self, logit, target):
        target = target.view(-1)
        alpha = self.alpha[target.cpu().long()].to(logit.device)
        logpt = -F.cross_entropy(logit, target, reduction='none')
        pt = torch.exp(logpt)
        focal_loss = -(alpha * (1 - pt) ** self.gamma) * logpt
        return focal_loss.mean()

# 从 utility.py 导入 masked_softmax
def create_src_lengths_mask(batch_size: int, src_lengths, max_src_len=None):
    if max_src_len is None:
        max_src_len = int(src_lengths.max())
    src_indices = torch.arange(0, max_src_len).unsqueeze(0).type_as(src_lengths)
    src_indices = src_indices.expand(batch_size, max_src_len)
    src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_src_len)
    return (src_indices < src_lengths).int()

def masked_softmax(scores, src_lengths, src_length_masking=True):
    if src_length_masking:
        bsz, src_len, max_src_len = scores.size()
        src_mask = create_src_lengths_mask(bsz, src_lengths, max_src_len).to(scores.device)
        src_mask = src_mask.unsqueeze(1)
        scores = scores.masked_fill(src_mask == 0, -np.inf)
    return F.softmax(scores.float(), dim=-1)

# 定义 Contextual_Attention 类
class Contextual_Attention(nn.Module):
    def __init__(self, q_input_dim, v_input_dim=1024, qk_dim=1024, v_dim=1024):
        super(Contextual_Attention, self).__init__()
        self.cn3 = nn.Conv1d(q_input_dim, qk_dim, kernel_size=3, padding=1)
        self.cn5 = nn.Conv1d(q_input_dim, qk_dim, kernel_size=5, padding=2)
        self.k = nn.Linear(v_dim * 2 + q_input_dim, qk_dim)
        self.q = nn.Linear(q_input_dim, qk_dim)
        self.v = nn.Linear(v_input_dim, v_dim)
        self._norm_fact = 1 / math.sqrt(qk_dim)

    def forward(self, plm_embedding, evo_local, seqlengths):
        Q = self.q(evo_local)
        k3 = self.cn3(evo_local.permute(0, 2, 1))
        k5 = self.cn5(evo_local.permute(0, 2, 1))
        evo_local_concat = torch.cat((evo_local, k3.permute(0, 2, 1), k5.permute(0, 2, 1)), dim=2)
        K = self.k(evo_local_concat)
        V = self.v(plm_embedding)
        atten_scores = torch.bmm(Q, K.permute(0, 2, 1)) * self._norm_fact
        atten = masked_softmax(atten_scores, seqlengths)
        output = torch.bmm(atten, V)
        return output + V

# coll_paddding 函数
def coll_paddding(batch_traindata):
    batch_traindata.sort(key=lambda data: len(data[0]), reverse=True)
    feature0 = []
    f0agv = []
    feature_fusion = []
    train_y = []
    for data in batch_traindata:
        feature0.append(data[0])
        f0agv.append(data[1])
        feature_fusion.append(data[2])
        train_y.append(data[3])
    data_length = [len(data) for data in feature0]
    feature0 = torch.nn.utils.rnn.pad_sequence(feature0, batch_first=True, padding_value=0)
    f0agv = torch.nn.utils.rnn.pad_sequence(f0agv, batch_first=True, padding_value=0)
    feature_fusion = torch.nn.utils.rnn.pad_sequence(feature_fusion, batch_first=True, padding_value=0)
    train_y = torch.nn.utils.rnn.pad_sequence(train_y, batch_first=True, padding_value=0)
    return feature0, f0agv, feature_fusion, train_y, torch.tensor(data_length)

# 数据集类
class BioinformaticsDataset(Dataset):
    def __init__(self, X_prot, X_feature_fusion):
        self.X_prot = X_prot
        self.X_feature_fusion = X_feature_fusion

    def __getitem__(self, index):
        filename_prot = self.X_prot[index]
        df_prot = pd.read_csv(filename_prot)
        prot = df_prot.iloc[:, 1:].values
        if prot.dtype == object:
            prot = prot.astype(float)
        prot = torch.tensor(prot, dtype=torch.float)
        agv = torch.mean(prot, dim=0)
        agv = agv.repeat(prot.shape[0], 1)
        feature_fusion = prot  # 如果没有特征融合数据，可以使用prot
        label = df_prot.iloc[:, 0].values
        label = torch.tensor(label, dtype=torch.long)
        return prot, agv, feature_fusion, label

    def __len__(self):
        return len(self.X_prot)

# 模型定义
class DeepAIPModule(nn.Module):
    def __init__(self):
        super(DeepAIPModule, self).__init__()
        self.ca = Contextual_Attention(q_input_dim=1024, v_input_dim=1024)
        self.relu = nn.ReLU(True)
        self.protcnn1 = nn.Conv1d(1024, 512, kernel_size=3, padding=1)
        self.protcnn2 = nn.Conv1d(512, 256, kernel_size=3, padding=1)
        self.protcnn3 = nn.Conv1d(256, 128, kernel_size=3, padding=1)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 32)
        self.fc4 = nn.Linear(32, 2)
        self.drop = nn.Dropout(0.5)
        self.batch_norm1 = nn.BatchNorm1d(256)
        self.batch_norm2 = nn.BatchNorm1d(32)
    def forward(self, prot0, f0agv, evo, data_length):
        #evosa = self.ca(prot0, evo, data_length)
        #prot = torch.cat((prot0, f0agv, evosa), dim=2)
        prot = self.protcnn1(prot0.permute(0, 2, 1))
        prot = self.relu(prot)
        prot = self.protcnn2(prot)
        prot = self.relu(prot)
        prot = self.protcnn3(prot)
        prot = self.relu(prot)
        x = self.fc2(prot.permute(0, 2, 1))
        x = x.permute(0, 2, 1)  # 调整维度以适应 BatchNorm1d 的输入
        x = self.batch_norm1(x)
        x = self.relu(x)
        x = x.permute(0, 2, 1)  # 恢复原始的维度顺序
        x = self.drop(x)

        x = self.fc3(x)
        x = x.permute(0, 2, 1)  # 调整维度以适应 BatchNorm1d 的输入
        x = self.batch_norm2(x)
        x = self.relu(x)
        x = x.permute(0, 2, 1)  # 恢复原始的维度顺序
        x = self.drop(x)
        x = self.fc4(x)
        return x

# 训练函数
def train():
    train_set = BioinformaticsDataset(prot_train, fusion_train)
    model = DeepAIPModule()
    epochs = 100
    model = model.to(device)
    train_loader = DataLoader(dataset=train_set, batch_size=256, shuffle=True, num_workers=4,
                              collate_fn=coll_paddding)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    per_cls_weights = torch.FloatTensor([0.6, 0.4]).to(device)
    fcloss = FocalLoss_v2(alpha=per_cls_weights, gamma=2)
    model.train()
    for epoch in range(epochs):
        epoch_loss_train = 0.0
        nb_train = 0
        all_labels = []
        all_preds = []
        for prot_x, f0agv, evo_x, data_y, length in train_loader:
            prot_x = prot_x.to(device)
            f0agv = f0agv.to(device)
            evo_x = evo_x.to(device)
            data_y = data_y.to(device)
            length = length.to(device)
            optimizer.zero_grad()
            y_pred = model(prot_x, f0agv, evo_x, length)
            y_pred_packed = torch.nn.utils.rnn.pack_padded_sequence(y_pred, length.cpu(), batch_first=True, enforce_sorted=False)
            data_y_packed = torch.nn.utils.rnn.pack_padded_sequence(data_y, length.cpu(), batch_first=True, enforce_sorted=False)
            loss = fcloss(y_pred_packed.data, data_y_packed.data)
            loss.backward()
            optimizer.step()
            epoch_loss_train += loss.item()
            nb_train += 1
            # 收集预测结果和真实标签用于计算指标
            y_pred_labels = torch.argmax(y_pred_packed.data, dim=1).cpu()
            all_preds.extend(y_pred_labels.numpy())
            all_labels.extend(data_y_packed.data.cpu().numpy())
        epoch_loss_avg = epoch_loss_train / nb_train
        # 计算指标
        acc = metrics.accuracy_score(all_labels, all_preds)
        mcc = metrics.matthews_corrcoef(all_labels, all_preds)
        tn, fp, fn, tp = metrics.confusion_matrix(all_labels, all_preds, labels=[0,1]).ravel()
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        # 输出训练指标
        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss_avg:.4f}, Acc: {acc:.4f}, MCC: {mcc:.4f}, Sensitivity: {sensitivity:.4f}, Specificity: {specificity:.4f}')
    # 训练结束后对测试集进行评估
    test(model)

# 修改后的测试函数
def test(model):
    # Dataset and DataLoader Initialization
    test_set = BioinformaticsDataset(prot_test, fusion_test)
    test_loader = DataLoader(dataset=test_set, batch_size=256, num_workers=4, collate_fn=coll_paddding)
    model.eval()
    arr_probs = []
    arr_labels = []
    arr_labels_hyps = []
    with torch.no_grad():
        for prot_x, f0agv, evo_x, data_y, length in test_loader:
            prot_x = prot_x.to(device)
            f0agv = f0agv.to(device)
            evo_x = evo_x.to(device)
            data_y = data_y.to(device)
            length = length.to(device)
            y_pred = model(prot_x, f0agv, evo_x, length)
            y_pred_packed = torch.nn.utils.rnn.pack_padded_sequence(y_pred, length.cpu(), batch_first=True)
            y_pred_data = y_pred_packed.data
            y_pred_softmax = torch.nn.functional.softmax(y_pred_data, dim=1)
            arr_probs.extend(y_pred_softmax[:, 1].cpu().numpy())
            y_pred_labels = torch.argmax(y_pred_softmax, dim=1).cpu()
            data_y_packed = torch.nn.utils.rnn.pack_padded_sequence(data_y, length.cpu(), batch_first=True)
            arr_labels.extend(data_y_packed.data.cpu())
            arr_labels_hyps.extend(y_pred_labels.cpu())
    auc = metrics.roc_auc_score(arr_labels, arr_probs)
    acc = metrics.accuracy_score(arr_labels, arr_labels_hyps)
    mcc = metrics.matthews_corrcoef(arr_labels, arr_labels_hyps)
    tn, fp, fn, tp = metrics.confusion_matrix(arr_labels, arr_labels_hyps, labels=[0,1]).ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1score = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0
    youden = sensitivity + specificity - 1
    # Results Storage and Output
    metrics_dict = {
        'accuracy': acc,
        'balanced_accuracy': metrics.balanced_accuracy_score(arr_labels, arr_labels_hyps),
        'MCC': mcc,
        'AUC': auc,
        'AP': metrics.average_precision_score(arr_labels, arr_probs),
        'TN': tn,
        'FP': fp,
        'FN': fn,
        'TP': tp,
        'Sensitivity': sensitivity,
        'Specificity': specificity,
        'Precision': precision,
        'Recall': recall,
        'F1 Score': f1score,
        'Youden Index': youden
    }
    for key, value in metrics_dict.items():
        print(f'{key}: {value}')

    print('<----------------Testing Complete---------------->')
    return acc, mcc

if __name__ == "__main__":
    # 检查 CUDA 是否可用
    cuda = torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")
    if cuda:
        torch.cuda.set_device(0)
    print("Use device:", device)
    # 数据集文件名
    fusion_train = ['/root/ACE/train_data.csv']
    prot_train = ['/root/ACE/train_data.csv']
    fusion_test = ['/root/ACE/test_data.csv']
    prot_test = ['/root/ACE/test_data.csv']
    # 开始训练
    train()


Use device: cuda
Epoch 1/100, Loss: 0.1424, Acc: 0.5331, MCC: 0.0015, Sensitivity: 0.6427, Specificity: 0.3587
Epoch 2/100, Loss: 0.1200, Acc: 0.5539, MCC: 0.0567, Sensitivity: 0.6407, Specificity: 0.4159
Epoch 3/100, Loss: 0.1140, Acc: 0.6091, MCC: 0.1662, Sensitivity: 0.7006, Specificity: 0.4635
Epoch 4/100, Loss: 0.1018, Acc: 0.6091, MCC: 0.1671, Sensitivity: 0.6986, Specificity: 0.4667
Epoch 5/100, Loss: 0.0995, Acc: 0.6360, MCC: 0.2255, Sensitivity: 0.7186, Specificity: 0.5048
Epoch 6/100, Loss: 0.0946, Acc: 0.6507, MCC: 0.2628, Sensitivity: 0.7166, Specificity: 0.5460
Epoch 7/100, Loss: 0.0960, Acc: 0.6434, MCC: 0.2464, Sensitivity: 0.7126, Specificity: 0.5333
Epoch 8/100, Loss: 0.0937, Acc: 0.6593, MCC: 0.2763, Sensitivity: 0.7345, Specificity: 0.5397
Epoch 9/100, Loss: 0.0904, Acc: 0.6618, MCC: 0.2881, Sensitivity: 0.7206, Specificity: 0.5683
Epoch 10/100, Loss: 0.0854, Acc: 0.6777, MCC: 0.3205, Sensitivity: 0.7365, Specificity: 0.5841
Epoch 11/100, Loss: 0.0779, Acc: 0.6912, M