In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os
import random
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix, roc_auc_score, average_precision_score
import random
import os

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(42)  # 你可以换成其它数字

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 数据集类
class ProteinNPYIndexDataset(Dataset):
    def __init__(self, pos_path, neg_path, pos_indices, neg_indices):
        self.pos_path = pos_path
        self.neg_path = neg_path
        self.pos_indices = pos_indices
        self.neg_indices = neg_indices
        self.pos_len = len(pos_indices)
        self.neg_len = len(neg_indices)
        self.total_len = self.pos_len + self.neg_len
        self.pos_npy = None
        self.neg_npy = None

    def __len__(self):
        return self.total_len

    def __getitem__(self, idx):
        if self.pos_npy is None:
            self.pos_npy = np.load(self.pos_path, mmap_mode='r')
        if self.neg_npy is None:
            self.neg_npy = np.load(self.neg_path, mmap_mode='r')
        if idx < self.pos_len:
            x = self.pos_npy[self.pos_indices[idx]]
            y = 1
        else:
            x = self.neg_npy[self.neg_indices[idx - self.pos_len]]
            y = 0
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)


class MLPExperts(nn.Module):
    def __init__(self, d_model, d_ff, num_experts):
        super().__init__()
        self.num_experts = num_experts
        self.fc1 = nn.Linear(d_model, d_ff * num_experts, bias=True)
        self.fc2 = nn.Linear(d_ff, d_model, bias=True)
        self.d_ff = d_ff

    def forward(self, x, expert_idx):
        # x: [B*L, d_model], expert_idx: [B*L, k]
        all_hidden = self.fc1(x)  # [B*L, d_ff * num_experts]
        all_hidden = all_hidden.view(x.size(0), self.num_experts, self.d_ff)  # [B*L, num_experts, d_ff]
        out = []
        for i in range(expert_idx.size(1)):
            idx = expert_idx[:, i]  # [B*L]
            hidden = all_hidden[torch.arange(x.size(0)), idx]  # [B*L, d_ff]
            hidden = F.gelu(hidden)
            out_i = self.fc2(hidden)  # [B*L, d_model]
            out.append(out_i)
        out = torch.stack(out, dim=1)  # [B*L, k, d_model]
        return out


class NoisyTopKMoE(nn.Module):
    def __init__(self, d_model, d_ff, num_experts=30, k=2, noisy_std=1.0):
        super().__init__()
        self.num_experts = num_experts
        self.k = k
        self.noisy_std = noisy_std
        self.experts = MLPExperts(d_model, d_ff, num_experts)
        self.gate = nn.Linear(d_model, num_experts)

    def forward(self, x):
        # x: [B, L, d_model]
        B, L, D = x.shape
        x_flat = x.reshape(-1, D)  # [B*L, D]
        gate_logits = self.gate(x_flat)  # [B*L, num_experts]
        # Noisy gating
        if self.training and self.noisy_std > 0:
            noise = torch.randn_like(gate_logits) * self.noisy_std
            gate_logits = gate_logits + noise
        gate_scores = F.softmax(gate_logits, dim=-1)  # [B*L, num_experts]

        # 稀疏路由：只选top-k
        topk_val, topk_idx = torch.topk(gate_scores, self.k, dim=-1)  # [B*L, k]
        # 负载均衡损失（新版，防止爆炸）
        meangate = gate_scores.mean(dim=0)  # [num_experts]
        load_balance_loss = (meangate * meangate).sum() * (self.num_experts ** 2)
        # 专家并行输出
        expert_outs = self.experts(x_flat, topk_idx)  # [B*L, k, d_model]
        topk_val = topk_val / (topk_val.sum(dim=-1, keepdim=True) + 1e-9)
        moe_out = (expert_outs * topk_val.unsqueeze(-1)).sum(dim=1)  # [B*L, d_model]
        moe_out = moe_out.view(B, L, D)
        return moe_out, load_balance_loss


class TransformerMoEBlock(nn.Module):
    def __init__(self, d_model, nhead, d_ff, num_experts=30, k=2, dropout=0.1, noisy_std=1.0):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.moe = NoisyTopKMoE(d_model, d_ff, num_experts, k, noisy_std)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, _ = self.self_attn(x, x, x)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)
        moe_out, load_balance_loss = self.moe(x)
        x = x + self.dropout(moe_out)
        x = self.norm2(x)
        return x, load_balance_loss


class TransformerMoE(nn.Module):
    def __init__(self, d_model=1152, nhead=8, d_ff=2048, num_layers=4, num_experts=30, k=2, dropout=0.1, noisy_std=1.0,
                 num_classes=2):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerMoEBlock(d_model, nhead, d_ff, num_experts, k, dropout, noisy_std)
            for _ in range(num_layers)
        ])
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, num_classes)
        )

    def forward(self, x):
        total_load_balance_loss = 0
        for layer in self.layers:
            x, lb_loss = layer(x)
            total_load_balance_loss += lb_loss
        x = x.mean(dim=1)  # 池化
        logits = self.classifier(x)
        return logits, total_load_balance_loss


def eval_model(model, loader, device):
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
    acc = accuracy_score(all_labels, all_preds)
    pre = precision_score(all_labels, all_preds)
    rec = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)
    tn, fp, fn, tp = confusion_matrix(all_labels, all_preds).ravel()
    sn = rec
    sp = tn / (tn + fp) if (tn + fp) > 0 else 0
    auc = roc_auc_score(all_labels, all_probs)
    auprc = average_precision_score(all_labels, all_probs)
    print(f"ACC: {acc:.4f}, F1: {f1:.4f}, Recall(Sn): {sn:.4f}, MCC: {mcc:.4f}, Precision: {pre:.4f}, Sp: {sp:.4f}, AUC: {auc:.4f}, AUPRC: {auprc:.4f}")
    return acc, f1, sn, mcc, pre, sp, auc, auprc

In [2]:
train_pos = '/exp_data/sjx/star/first_data/ESM-embedding/positive_train_embedding.npy'
train_neg = '/exp_data/sjx/star/first_data/ESM-embedding/negative_train_embedding.npy'
test_pos = '/exp_data/sjx/star/first_data/ESM-embedding/positive_test_embedding.npy'
test_neg = '/exp_data/sjx/star/first_data/ESM-embedding/negative_test_embedding.npy'

pos_len = np.load(train_pos, mmap_mode='r').shape[0]
neg_len = np.load(train_neg, mmap_mode='r').shape[0]
test_pos_len = np.load(test_pos, mmap_mode='r').shape[0]
test_neg_len = np.load(test_neg, mmap_mode='r').shape[0]

print(f"训练集正样本: {pos_len}, 负样本: {neg_len}")
print(f"测试集正样本: {test_pos_len}, 负样本: {test_neg_len}")

### 上采样方法

In [3]:
# 上采样负样本到与正样本一样多
pos_indices = np.arange(pos_len)
neg_indices = np.random.choice(neg_len, size=pos_len, replace=True)

train_dataset = ProteinNPYIndexDataset(train_pos, train_neg, pos_indices, neg_indices)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)

### 权重采样方法

In [4]:
# 权重采样
all_indices = np.concatenate([np.arange(pos_len), np.arange(neg_len)])
all_labels = np.concatenate([np.ones(pos_len, dtype=int), np.zeros(neg_len, dtype=int)])
weights = np.zeros_like(all_labels, dtype=np.float32)
weights[all_labels == 1] = 1. / pos_len
weights[all_labels == 0] = 1. / neg_len

sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
train_dataset_w = ProteinNPYIndexDataset(train_pos, train_neg, np.arange(pos_len), np.arange(neg_len))
train_loader_w = DataLoader(train_dataset_w, batch_size=32, sampler=sampler, num_workers=2)

In [5]:
test_dataset = ProteinNPYIndexDataset(test_pos, test_neg, np.arange(test_pos_len), np.arange(test_neg_len))
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [6]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits, _ = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

In [8]:
model = TransformerMoE(
    d_model=1152, nhead=8, d_ff=2048, num_layers=4, num_experts=30, k=3, dropout=0.1, noisy_std=1.0, num_classes=2
).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

### 使用train_loader使用上采样方法

In [None]:
save_dir = '/exp_data/sjx/star/experiments/gan_anlysis/gan_sampling/weight_points/'
os.makedirs(save_dir, exist_ok=True)
epochs = 10
for epoch in range(epochs):
    loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss:.4f}")

    # 每个epoch都保存最后权重
    last_path = os.path.join(save_dir, 'moe_last.pth')
    torch.save(model.state_dict(), last_path)

In [10]:
print("测试集评估：")
save_dir = '/exp_data/sjx/star/experiments/gan_anlysis/gan_sampling/weight_points/'
model.load_state_dict(torch.load(os.path.join(save_dir, 'moe_last.pth')))
eval_model(model, test_loader, device)

### 使用train_loader_w使用权重采样

In [None]:
save_dir = '/exp_data/sjx/star/experiments/gan_anlysis/gan_sampling/weight_points/'
os.makedirs(save_dir, exist_ok=True)
epochs = 10
for epoch in range(epochs):
    loss = train_one_epoch(model, train_loader_w, optimizer, criterion, device)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss:.4f}")

    # 每个epoch都保存最后权重
    last_path = os.path.join(save_dir, 'moe_last_w.pth')
    torch.save(model.state_dict(), last_path)

In [11]:
print("测试集评估：")
save_dir = '/exp_data/sjx/star/experiments/gan_anlysis/gan_sampling/weight_points/'
model.load_state_dict(torch.load(os.path.join(save_dir, 'moe_last_w.pth')))
eval_model(model, test_loader, device)