In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import umap
import os
from sklearn.decomposition import PCA
import random

# ========================
# Config
# ========================
# 设置随机种子以确保可重复性
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seq_len = 300
embed_dim = 640  # 修改为960维
noise_dim = 128
batch_size = 16
n_critic = 5
lambda_gp = 10
epochs = 300
checkpoint_dir = "/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/checkpoints/"
os.makedirs(checkpoint_dir, exist_ok=True)

# 创建数据保存目录
data_dir = "/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/data/"
os.makedirs(data_dir, exist_ok=True)

# ========== 判断是否已完成训练 ==========
checkpoint_dir = "/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/checkpoints/"
epoch300_path = os.path.join(checkpoint_dir, "generator_epoch300.pt")
best_path = os.path.join(checkpoint_dir, "best_generator.pt")

if os.path.exists(epoch300_path) and os.path.exists(best_path):
    print("[✓] 已检测到已完成训练的模型，将跳过训练阶段")
    skip_training = True
else:
    skip_training = False

In [6]:
# ========================
# Generator with Conv1d (保持原有架构完全不变)
# ========================
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(noise_dim, seq_len * 256)
        self.conv = nn.Sequential(
            nn.Conv1d(256, 512, 3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv1d(512, embed_dim, 3, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        x = self.fc(z).view(-1, 256, seq_len)  # (B, 256, 300)
        x = self.conv(x)  # (B, embed_dim, 300)
        return x.transpose(1, 2)  # (B, 300, embed_dim)

# ========================
# Discriminator with Conv1d (保持原有架构完全不变)
# ========================
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(embed_dim, 256, 3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv1d(256, 128, 3, padding=1),
            nn.LeakyReLU(0.2),
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * seq_len, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        x = x.transpose(1, 2)  # (B, embed_dim, 300)
        x = self.conv(x)       # (B, 128, 300)
        x = x.reshape(x.size(0), -1)  # (B, 128*300)
        return self.fc(x)

# =======================
# Gradient Penalty (保持原有函数完全不变)
# ========================
def compute_gradient_penalty(D, real_samples, fake_samples):
    alpha = torch.rand(real_samples.size(0), 1, 1, device=device)
    alpha = alpha.expand_as(real_samples)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones(d_interpolates.size(), device=device)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradients = gradients.reshape(gradients.size(0), -1)
    return ((gradients.norm(2, dim=1) - 1) ** 2).mean()

In [7]:
# =======================
# Load Real Data (修改为加载两个数据集并拼接)
# ========================
print("正在加载训练数据...")
train_data = np.load("/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/data/negative_train.npy")
print("训练数据 shape:", train_data.shape)

print("正在加载测试数据...")
test_data = np.load("/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/data/negative_test.npy")
print("测试数据 shape:", test_data.shape)

# 内存优化的数据拼接
print("正在拼接数据...")
real_data = np.concatenate([train_data, test_data], axis=0)
del train_data, test_data  # 释放内存

print("拼接后数据 shape:", real_data.shape)
print("真实数据值域：", real_data.min(), real_data.max())

real_tensor = torch.tensor(real_data, dtype=torch.float32)
dataloader = DataLoader(TensorDataset(real_tensor), batch_size=batch_size, shuffle=True)

In [8]:
# ========================
# Models (保持原有优化器参数完全不变)
# ========================
G = Generator().to(device)
D = Discriminator().to(device)
optimizer_G = optim.Adam(G.parameters(), lr=1e-4, betas=(0.5, 0.9))  # 保持原有学习率和参数
optimizer_D = optim.Adam(D.parameters(), lr=1e-4, betas=(0.5, 0.9))  # 保持原有学习率和参数

# =======================
# Training (保持原有训练逻辑完全不变)
# ========================
best_g_loss = float("inf")  # 初始化最小 G loss
if not skip_training:
    for epoch in range(1, epochs + 1):
        pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{epochs}")
        for i, (real,) in enumerate(pbar):
            real = real.to(device)

            # === Train D ===
            for _ in range(n_critic):
                z = torch.randn(real.size(0), noise_dim).to(device)
                fake = G(z).detach()
                real_score = D(real)
                fake_score = D(fake)
                gp = compute_gradient_penalty(D, real, fake)
                d_loss = -torch.mean(real_score) + torch.mean(fake_score) + lambda_gp * gp
                optimizer_D.zero_grad()
                d_loss.backward()
                optimizer_D.step()

            # === Train G ===
            z = torch.randn(real.size(0), noise_dim).to(device)
            fake = G(z)
            g_loss = -torch.mean(D(fake))
            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

            pbar.set_postfix({
                "D loss": f"{d_loss.item():.2f}",
                "G loss": f"{g_loss.item():.2f}"
            })

        # === 每 10 epoch 或最后一个 epoch保存模型 ===
        if epoch % 100 == 0 or epoch == epochs:
            save_path = os.path.join(checkpoint_dir, f"generator_epoch{epoch}.pt")
            torch.save(G.state_dict(), save_path)
            print(f"[Checkpoint] Saved generator to {save_path}")

        # === 保存表现最好的 G ===
        if g_loss.item() < best_g_loss:
            best_g_loss = g_loss.item()
            best_path = os.path.join(checkpoint_dir, "best_generator.pt")
            torch.save(G.state_dict(), best_path)
            print(f"[BEST] Saved best generator with G loss = {best_g_loss:.4f}")

In [9]:
# =======================
# Save Model & Generate
# ========================
# ========================
# ✅ 使用 best_generator 生成数据
# ========================
print("\nLoading best generator for data generation...")

# 重新加载 best generator 权重
G.load_state_dict(torch.load(os.path.join(checkpoint_dir, "generator_epoch300.pt")))
G.eval()

# 生成 2435 条数据
gen_total = 2435
batch_size = 256
generated = []

with torch.no_grad():
    total = 0
    while total < gen_total:
        current_batch = min(batch_size, gen_total - total)
        z = torch.randn(current_batch, noise_dim).to(device)
        fake = G(z).cpu().numpy()
        generated.append(fake)
        total += current_batch

generated = np.concatenate(generated, axis=0)
print(f"生成了 {generated.shape[0]} 条数据，维度: {generated.shape[1:]}")

# 取前1948条数据
generated_1948 = generated[:1948]
print(f"取前1948条数据，shape: {generated_1948.shape}")

# 加载原始训练数据用于拼接
print("加载原始训练数据用于拼接...")
original_train_data = np.load("/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/data/negative_train.npy")
print(f"原始训练数据 shape: {original_train_data.shape}")

# 拼接生成的数据和原始训练数据
final_train_data = np.concatenate([original_train_data, generated_1948], axis=0)
print(f"最终训练数据 shape: {final_train_data.shape}")

# 保存最终训练集负样本
save_data_path = "/exp_data/sjx/star/experiments/qianruxuanze/esm2_150M/data/negative_train_embedding_enhanced.npy"
np.save(save_data_path, final_train_data)
print(f"已保存增强的训练集负样本到: {save_data_path}")
print(f"数据维度: {final_train_data.shape}")