In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os
import random
import os

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)  # 你可以换成其它数字
# 数据集类
class ProteinNPYDataset(Dataset):
    def __init__(self, pos_path, neg_path):
        self.pos = np.load(pos_path, mmap_mode='r')
        self.neg = np.load(neg_path, mmap_mode='r')
        self.lengths = [len(self.pos), len(self.neg)]
        self.total_len = self.lengths[0] + self.lengths[1]

    def __len__(self):
        return self.total_len

    def __getitem__(self, idx):
        if idx < self.lengths[0]:
            x = self.pos[idx]
            y = 1
        else:
            x = self.neg[idx - self.lengths[0]]
            y = 0
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)

In [2]:
class MLPExperts(nn.Module):
    def __init__(self, d_model, d_ff, num_experts):
        super().__init__()
        self.num_experts = num_experts
        self.fc1 = nn.Linear(d_model, d_ff * num_experts, bias=True)
        self.fc2 = nn.Linear(d_ff, d_model, bias=True)
        self.d_ff = d_ff
    def forward(self, x, expert_idx):
        # x: [B*L, d_model], expert_idx: [B*L, k]
        all_hidden = self.fc1(x)  # [B*L, d_ff * num_experts]
        all_hidden = all_hidden.view(x.size(0), self.num_experts, self.d_ff)  # [B*L, num_experts, d_ff]
        out = []
        for i in range(expert_idx.size(1)):
            idx = expert_idx[:, i]  # [B*L]
            hidden = all_hidden[torch.arange(x.size(0)), idx]  # [B*L, d_ff]
            hidden = F.gelu(hidden)
            out_i = self.fc2(hidden)  # [B*L, d_model]
            out.append(out_i)
        out = torch.stack(out, dim=1)  # [B*L, k, d_model]
        return out
class NoisyTopKMoE(nn.Module):
    def __init__(self, d_model, d_ff, num_experts=30, k=2, noisy_std=1.0):
        super().__init__()
        self.num_experts = num_experts
        self.k = k
        self.noisy_std = noisy_std
        self.experts = MLPExperts(d_model, d_ff, num_experts)
        self.gate = nn.Linear(d_model, num_experts)
    def forward(self, x):
        # x: [B, L, d_model]
        B, L, D = x.shape
        x_flat = x.reshape(-1, D)  # [B*L, D]
        gate_logits = self.gate(x_flat)  # [B*L, num_experts]
        # Noisy gating
        if self.training and self.noisy_std > 0:
            noise = torch.randn_like(gate_logits) * self.noisy_std
            gate_logits = gate_logits + noise
        gate_scores = F.softmax(gate_logits, dim=-1)  # [B*L, num_experts]

          # 稀疏路由：只选top-k
        topk_val, topk_idx = torch.topk(gate_scores, self.k, dim=-1)  # [B*L, k]
        # 负载均衡损失（新版，防止爆炸）
        meangate = gate_scores.mean(dim=0)  # [num_experts]
        load_balance_loss = (meangate * meangate).sum() * (self.num_experts ** 2)
        # 专家并行输出
        expert_outs = self.experts(x_flat, topk_idx)  # [B*L, k, d_model]
        topk_val = topk_val / (topk_val.sum(dim=-1, keepdim=True) + 1e-9)
        moe_out = (expert_outs * topk_val.unsqueeze(-1)).sum(dim=1)  # [B*L, d_model]
        moe_out = moe_out.view(B, L, D)
        return moe_out, load_balance_loss

In [3]:
class TransformerMoEBlock(nn.Module):
    def __init__(self, d_model, nhead, d_ff, num_experts=30, k=3, dropout=0.1, noisy_std=1.0):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.moe = NoisyTopKMoE(d_model, d_ff, num_experts, k, noisy_std)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.self_attn(x, x, x)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)
        moe_out, load_balance_loss = self.moe(x)
        x = x + self.dropout(moe_out)
        x = self.norm2(x)
        return x, load_balance_loss

class TransformerMoE(nn.Module):
    def __init__(self, d_model=640, nhead=8, d_ff=2048, num_layers=4, num_experts=30, k=2, dropout=0.1, noisy_std=1.0, num_classes=2):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerMoEBlock(d_model, nhead, d_ff, num_experts, k, dropout, noisy_std)
            for _ in range(num_layers)
        ])
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, num_classes)
        )
    def forward(self, x):
        total_load_balance_loss = 0
        for layer in self.layers:
            x, lb_loss = layer(x)
            total_load_balance_loss += lb_loss
        x = x.mean(dim=1)  # 池化
        logits = self.classifier(x)
        return logits, total_load_balance_loss

In [4]:
def eval_model(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, matthews_corrcoef
    acc = accuracy_score(all_labels, all_preds)
    pre = precision_score(all_labels, all_preds)
    rec = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)
    print(f"Test ACC: {acc:.4f}, PRE: {pre:.4f}, REC: {rec:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}")
    return acc, pre, rec, f1, mcc

In [5]:
from torch.cuda.amp import autocast, GradScaler
# 数据路径
train_pos = '/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/data/positive_train.npy'
train_neg = '/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/data/negative_train_embedding_enhanced.npy'

train_dataset = ProteinNPYDataset(train_pos, train_neg)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerMoE(
    d_model=640, nhead=8, d_ff=2048, num_layers=4, num_experts=30, k=3, dropout=0.1, noisy_std=1.0, num_classes=2
).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

scaler = GradScaler()  # 在训练前初始化

def train_one_epoch(model, loader, optimizer, criterion, device, moe_loss_weight=0.01, scaler=None):
    model.train()
    total_loss = 0
    for x, y in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        with autocast():  # 开启混合精度
            logits, lb_loss = model(x)
            loss = criterion(logits, y) + moe_loss_weight * lb_loss
        scaler.scale(loss).backward()      # 用scaler缩放loss反向传播
        scaler.step(optimizer)             # 用scaler.step更新参数
        scaler.update()                    # 更新scaler状态
        total_loss += loss.item()
    return total_loss / len(loader)

# 初始化scaler
scaler = GradScaler()

# 训练主循环
epochs = 10
best_acc = 0
best_state = None
best_path = "/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/checkpoints/transformer_moe_best.pth"
last_path = "/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/checkpoints/transformer_moe_last.pth"

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler=scaler)
    print(f"Train Loss: {train_loss:.4f}")
    # 保存最后一次模型权重
    torch.save(model.state_dict(), last_path)
    print(f"Last model saved at epoch {epoch+1} ({last_path})")

  scaler = GradScaler()  # 在训练前初始化
  scaler = GradScaler()



Epoch 1/10


  with autocast():  # 开启混合精度
                                                           

Train Loss: 1.6447
Last model saved at epoch 1 (/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/checkpoints/transformer_moe_last.pth)

Epoch 2/10


                                                           

Train Loss: 1.4228
Last model saved at epoch 2 (/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/checkpoints/transformer_moe_last.pth)

Epoch 3/10


                                                           

Train Loss: 1.3999
Last model saved at epoch 3 (/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/checkpoints/transformer_moe_last.pth)

Epoch 4/10


                                                           

Train Loss: 1.3914
Last model saved at epoch 4 (/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/checkpoints/transformer_moe_last.pth)

Epoch 5/10


                                                           

Train Loss: 1.3869
Last model saved at epoch 5 (/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/checkpoints/transformer_moe_last.pth)

Epoch 6/10


                                                           

Train Loss: 1.3860
Last model saved at epoch 6 (/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/checkpoints/transformer_moe_last.pth)

Epoch 7/10


                                                           

Train Loss: 1.3669
Last model saved at epoch 7 (/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/checkpoints/transformer_moe_last.pth)

Epoch 8/10


                                                           

Train Loss: 1.3699
Last model saved at epoch 8 (/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/checkpoints/transformer_moe_last.pth)

Epoch 9/10


                                                           

Train Loss: 1.3542
Last model saved at epoch 9 (/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/checkpoints/transformer_moe_last.pth)

Epoch 10/10


                                                           

Train Loss: 1.3421
Last model saved at epoch 10 (/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/checkpoints/transformer_moe_last.pth)


In [6]:
test_pos = '/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/data/positive_test.npy'
test_neg = '/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/data/negative_test.npy'
test_dataset = ProteinNPYDataset(test_pos, test_neg)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

def eval_model(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []
    all_probs = []  # 存储正类概率用于AUC和AUPRC
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())  # 正类概率

    from sklearn.metrics import (
        accuracy_score, precision_score, recall_score, f1_score, matthews_corrcoef,
        confusion_matrix, roc_auc_score, average_precision_score
    )
    acc = accuracy_score(all_labels, all_preds)
    pre = precision_score(all_labels, all_preds)
    rec = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)

    # 混淆矩阵
    tn, fp, fn, tp = confusion_matrix(all_labels, all_preds).ravel()
    sn = rec  # 敏感性
    sp = tn / (tn + fp) if (tn + fp) > 0 else 0  # 特异性

    # AUC和AUPRC
    auc = roc_auc_score(all_labels, all_probs)
    auprc = average_precision_score(all_labels, all_probs)

    print(f"Test ACC: {acc:.4f}, PRE: {pre:.4f}, REC(Sn): {sn:.4f}, SP: {sp:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}, AUC: {auc:.4f}, AUPRC: {auprc:.4f}")
    return acc, pre, sn, sp, f1, mcc, auc, auprc

# 测试
eval_model(model, test_loader, device)

Test ACC: 0.8903, PRE: 0.9336, REC(Sn): 0.9108, SP: 0.8399, F1: 0.9220, MCC: 0.7380, AUC: 0.9528, AUPRC: 0.9795


(0.8903394255874674,
 0.9335839598997494,
 0.910757946210269,
 0.8398791540785498,
 0.9220297029702971,
 0.7380158184030507,
 0.9528065652723096,
 0.9795031746111332)

In [7]:
# 假设模型结构和ProteinNPYDataset已定义，device已设置
import torch

# 1. 加载模型
model = TransformerMoE(
    d_model=1152, nhead=8, d_ff=2048, num_layers=4, num_experts=30, k=3, dropout=0.1, noisy_std=1.0, num_classes=2
).to(device)
model.load_state_dict(torch.load('/exp_data/sjx/star/main_transformer_moe_weight/best_transformer_moe_last.pth', map_location=device))
model.eval()

# 2. 加载测试集
test_pos = '/exp_data/sjx/star/first_data/ESM-embedding/positive_test_embedding.npy'
test_neg = '/exp_data/sjx/star/first_data/ESM-embedding/negative_test_embedding.npy'
test_dataset = ProteinNPYDataset(test_pos, test_neg)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

# 3. 定义评估函数
def eval_model(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, matthews_corrcoef
    acc = accuracy_score(all_labels, all_preds)
    pre = precision_score(all_labels, all_preds)
    rec = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)
    print(f"Test ACC: {acc:.4f}, PRE: {pre:.4f}, REC: {rec:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}")
    return acc, pre, rec, f1, mcc

# 4. 测试
eval_model(model, test_loader, device)

  model.load_state_dict(torch.load('/exp_data/sjx/star/main_transformer_moe_weight/best_transformer_moe_last.pth', map_location=device))


Test ACC: 0.9225, PRE: 0.9483, REC: 0.9425, F1: 0.9454, MCC: 0.8120


(0.9225413402959095,
 0.948339483394834,
 0.9425427872860636,
 0.9454322501532803,
 0.8120485793877618)

# 记录
第一次参数设置 d_model=1152, nhead=8, d_ff=2048, num_layers=4, num_experts=30, k=2, dropout=0.1, noisy_std=1.0, num_classes=2 epoch=10

结果 Test ACC: 0.9173, PRE: 0.9199, REC: 0.9682, F1: 0.9434, MCC: 0.7939

第二次参数设置 d_model=1152, nhead=8, d_ff=2048, num_layers=5, num_experts=30, k=2, dropout=0.2, noisy_std=1.0, num_classes=2 epoch=20

结果 Test ACC: 0.8973, PRE: 0.9464, REC: 0.9071, F1: 0.9263, MCC: 0.7589

第三次参数设置 d_model=1152, nhead=8, d_ff=2048, num_layers=4, num_experts=30, k=2, dropout=0.1, noisy_std=1.0, num_classes=2 epoch=15
结果Test ACC: 0.9121, PRE: 0.9544, REC: 0.9205, F1: 0.9371, MCC: 0.7926

第四次参数设置d_model=1152, nhead=8, d_ff=2048, num_layers=4, num_experts=30, k=2, dropout=0.1, noisy_std=1.0, num_classes=2 epoch=10
数据集变成不加gan

结果：Test ACC: 0.8808, PRE: 0.9570, REC: 0.8716, F1: 0.9123, MCC: 0.7350

第五次参数设置d_model=1152, nhead=8, d_ff=2048, num_layers=3, num_experts=30, k=2, dropout=0.1, noisy_std=1.0, num_classes=2 epoch=10

结果：Test ACC: 0.9017, PRE: 0.9094, REC: 0.9572, F1: 0.9327, MCC: 0.7540

第六次参数设置 d_model=1152, nhead=8, d_ff=2048, num_layers=4, num_experts=30, k=2, dropout=0.1, noisy_std=1.0, num_classes=2 epoch=10 加上了一个kaming初始化

结果：Test ACC: 0.9130, PRE: 0.9367, REC: 0.9413, F1: 0.9390, MCC: 0.7871

第七次参数设置 d_model=1152, nhead=8, d_ff=2048, num_layers=4, num_experts=40, k=2, dropout=0.1, noisy_std=1.0, num_classes=2

结果：Test ACC: 0.8782, PRE: 0.9508, REC: 0.8741, F1: 0.9108, MCC: 0.7260

第八次参数设置 d_model=1152, nhead=8, d_ff=2048, num_layers=4, num_experts=35, k=2, dropout=0.1, noisy_std=1.0, num_classes=2

结果Test ACC: 0.9112, PRE: 0.9409, REC: 0.9340, F1: 0.9374, MCC: 0.7848

第九次参数设置d_model=1152, nhead=8, d_ff=2048, num_layers=4, num_experts=25, k=2, dropout=0.1, noisy_std=1.0, num_classes=2

结果Test ACC: 0.8982, PRE: 0.8899, REC: 0.9780, F1: 0.9319, MCC: 0.7452

第十次参数设置 d_model=1152, nhead=8, d_ff=2048, num_layers=4, num_experts=30, k=1, dropout=0.1, noisy_std=1.0, num_classes=2 epoch=10

结果：Test ACC: 0.9017, PRE: 0.9657, REC: 0.8936, F1: 0.9283, MCC: 0.7786

第十一次参数设置d_model=1152, nhead=8, d_ff=2048, num_layers=4, num_experts=30, k=3, dropout=0.1, noisy_std=1.0, num_classes=2

结果Test ACC: 0.9225, PRE: 0.9483, REC: 0.9425, F1: 0.9454, MCC: 0.8120


### 十折验证集
参数均使用前边的最佳模型的参数

结果：========== 10-Fold CV Results ==========
Mean ACC: 0.9401 ± 0.0058
Mean PRE: 0.9264
Mean REC: 0.9260
Mean F1:  0.9252
Mean MCC: 0.8523

### 十折测试集
结果：========== 10-Fold Test Results ==========
Mean ACC: 0.9155 ± 0.0048
Mean PRE: 0.9338
Mean REC: 0.9488
Mean F1:  0.9411
Mean MCC: 0.7925

### 十折交叉验证

In [8]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from torch.cuda.amp import autocast, GradScaler
import random
import os
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [9]:
def train_one_epoch(model, loader, optimizer, criterion, device, moe_loss_weight=0.01, scaler=None):
    model.train()
    total_loss = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        with autocast():
            logits, lb_loss = model(x)
            loss = criterion(logits, y) + moe_loss_weight * lb_loss
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
    return total_loss / len(loader)

def eval_model(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, matthews_corrcoef
    acc = accuracy_score(all_labels, all_preds)
    pre = precision_score(all_labels, all_preds)
    rec = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)
    print(f"Val ACC: {acc:.4f}, PRE: {pre:.4f}, REC: {rec:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}")
    return acc, pre, rec, f1, mcc

In [None]:
train_pos = '/exp_data/sjx/star/first_data/ESM-embedding/positive_train_embedding.npy'
train_neg = '/exp_data/sjx/star/gan_data/negative_train_all_combined.npy'

# 构建全体索引和标签
pos_len = np.load(train_pos, mmap_mode='r').shape[0]
neg_len = np.load(train_neg, mmap_mode='r').shape[0]
all_indices = np.concatenate([np.arange(pos_len), np.arange(neg_len) + pos_len])
all_labels = np.concatenate([np.ones(pos_len, dtype=int), np.zeros(neg_len, dtype=int)])

# 数据集
full_dataset = ProteinNPYDataset(train_pos, train_neg)

# K折分层
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
all_metrics = []

for fold, (train_idx, val_idx) in enumerate(skf.split(all_indices, all_labels), 1):
    print(f"\n========== Fold {fold}/10 ==========")
    train_loader = DataLoader(Subset(full_dataset, train_idx), batch_size=64, shuffle=True, num_workers=2)
    val_loader = DataLoader(Subset(full_dataset, val_idx), batch_size=64, shuffle=False, num_workers=2)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TransformerMoE(
        d_model=1152, nhead=8, d_ff=2048, num_layers=4, num_experts=30, k=3, dropout=0.1, noisy_std=1.0, num_classes=2
    ).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    scaler = GradScaler()

    best_acc = 0
    best_state = None
    epochs = 10

    for epoch in range(epochs):
        print(f"\n[Fold {fold}] Epoch {epoch+1}/{epochs}")
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            with autocast():
                logits, lb_loss = model(x)
                loss = criterion(logits, y) + 0.01 * lb_loss
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()
        print(f"Train Loss: {total_loss / len(train_loader):.4f}")

        # 验证
        model.eval()
        all_preds, all_labels_fold = [], []
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                logits, _ = model(x)
                preds = torch.argmax(logits, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels_fold.extend(y.cpu().numpy())
        from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, matthews_corrcoef
        acc = accuracy_score(all_labels_fold, all_preds)
        pre = precision_score(all_labels_fold, all_preds)
        rec = recall_score(all_labels_fold, all_preds)
        f1 = f1_score(all_labels_fold, all_preds)
        mcc = matthews_corrcoef(all_labels_fold, all_preds)
        print(f"Val ACC: {acc:.4f}, PRE: {pre:.4f}, REC: {rec:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}")

        if acc > best_acc:
            best_acc = acc
            best_state = model.state_dict()
            torch.save(best_state, f"/exp_data/sjx/star/main_transformer_moe_weight/cv_point/best_fold{fold}.pth")
            print(f"Best model saved for fold {fold} at epoch {epoch+1}")

    all_metrics.append((best_acc, pre, rec, f1, mcc))
    print(f"[Fold {fold}] Best ACC: {best_acc:.4f}")

# 汇总结果
all_metrics = np.array(all_metrics)
print("\n========== 10-Fold CV Results ==========")
print(f"Mean ACC: {all_metrics[:,0].mean():.4f} ± {all_metrics[:,0].std():.4f}")
print(f"Mean PRE: {all_metrics[:,1].mean():.4f}")
print(f"Mean REC: {all_metrics[:,2].mean():.4f}")
print(f"Mean F1:  {all_metrics[:,3].mean():.4f}")
print(f"Mean MCC: {all_metrics[:,4].mean():.4f}")



[Fold 1] Epoch 1/10


  scaler = GradScaler()
  with autocast():


### 十折测试结果

In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, matthews_corrcoef

# 测试集
test_pos = '/exp_data/sjx/star/first_data/ESM-embedding/positive_test_embedding.npy'
test_neg = '/exp_data/sjx/star/first_data/ESM-embedding/negative_test_embedding.npy'
test_dataset = ProteinNPYDataset(test_pos, test_neg)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def eval_model(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
    acc = accuracy_score(all_labels, all_preds)
    pre = precision_score(all_labels, all_preds)
    rec = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)
    return acc, pre, rec, f1, mcc

# 评估每一折
all_metrics = []
for fold in range(1, 11):
    print(f"\n========== Test Fold {fold}/10 ==========")
    model = TransformerMoE(
        d_model=1152, nhead=8, d_ff=2048, num_layers=4, num_experts=30, k=3, dropout=0.1, noisy_std=1.0, num_classes=2
    ).to(device)
    model.load_state_dict(torch.load(f"/exp_data/sjx/star/main_transformer_moe_weight/cv_point/best_fold{fold}.pth", map_location=device))
    acc, pre, rec, f1, mcc = eval_model(model, test_loader, device)
    print(f"Test ACC: {acc:.4f}, PRE: {pre:.4f}, REC: {rec:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}")
    all_metrics.append((acc, pre, rec, f1, mcc))

# 汇总平均
all_metrics = np.array(all_metrics)
print("\n========== 10-Fold Test Results ==========")
print(f"Mean ACC: {all_metrics[:,0].mean():.4f} ± {all_metrics[:,0].std():.4f}")
print(f"Mean PRE: {all_metrics[:,1].mean():.4f}")
print(f"Mean REC: {all_metrics[:,2].mean():.4f}")
print(f"Mean F1:  {all_metrics[:,3].mean():.4f}")
print(f"Mean MCC: {all_metrics[:,4].mean():.4f}")