In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

# ------------------------------
# 定义 VAE 模型（编码器+解码器）
# ------------------------------
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.input_dim = input_dim
        
        # 编码器部分：将输入映射到隐变量参数
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)      # 均值
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)  # 对数方差
        
        # 解码器部分：将隐变量恢复为输入
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Tanh()  # 假设输入数据归一化到 [-1, 1]
        )
        
    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar, z

# ------------------------------
# 定义隐变量判别器（Latent Discriminator）
# ------------------------------
class LatentDiscriminator(nn.Module):
    def __init__(self, latent_dim, hidden_dim=64):
        super(LatentDiscriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
    def forward(self, z):
        return self.net(z)

# ------------------------------
# 定义各项损失函数
# ------------------------------
def reconstruction_loss(recon_x, x):
    # 使用均方误差作为重构损失
    return nn.MSELoss()(recon_x, x)

def kl_divergence(mu, logvar):
    # KL 散度公式：-0.5 * sum(1 + logvar - mu^2 - exp(logvar))
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / mu.size(0)

def adversarial_loss(discriminator, z, target_label):
    criterion = nn.BCELoss()
    pred = discriminator(z)
    target = torch.full(pred.size(), target_label, device=pred.device, dtype=pred.dtype)
    return criterion(pred, target)

# ------------------------------
# 主训练流程
# ------------------------------
def main():
    # 超参数设定
    input_dim = 384   # 例如使用 Sentence-BERT 生成的嵌入维度
    hidden_dim = 128
    latent_dim = 32
    batch_size = 64
    num_epochs = 10
    learning_rate = 0.001
    beta_kl = 1.0     # KL 损失权重
    gamma_adv = 1.0   # 对抗损失权重
    
    # 这里为示例，我们使用随机生成的数据模拟输入（实际请替换为你自己的数据）
    n_samples = 30000
    # 数据归一化到 [-1, 1]
    data = np.random.uniform(-1, 1, size=(n_samples, input_dim)).astype(np.float32)
    dataset = TensorDataset(torch.from_numpy(data))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 初始化模型
    vae = VAE(input_dim, hidden_dim, latent_dim).to(device)
    latent_disc = LatentDiscriminator(latent_dim).to(device)
    
    # 定义优化器
    optimizer_vae = optim.Adam(vae.parameters(), lr=learning_rate)
    optimizer_disc = optim.Adam(latent_disc.parameters(), lr=learning_rate)
    
    # 训练循环
    for epoch in range(num_epochs):
        total_loss = 0.0
        total_recon_loss = 0.0
        total_kl_loss = 0.0
        total_adv_loss = 0.0
        
        for batch in dataloader:
            x = batch[0].to(device)
            
            # ----- Step 1: 训练判别器 -----
            vae.eval()           # 冻结 VAE
            latent_disc.train()  # 判别器开启训练
            optimizer_disc.zero_grad()
            
            # 真实隐变量样本（来自标准正态分布）
            z_real = torch.randn(x.size(0), latent_dim, device=device)
            # 生成的隐变量样本（来自 VAE 编码器）
            with torch.no_grad():
                _, _, _, z_fake = vae(x)
            
            loss_disc_real = adversarial_loss(latent_disc, z_real, 1)  # 真实标签为 1
            loss_disc_fake = adversarial_loss(latent_disc, z_fake, 0)  # 虚假标签为 0
            loss_disc = loss_disc_real + loss_disc_fake
            loss_disc.backward()
            optimizer_disc.step()
            
            # ----- Step 2: 训练 VAE（编码器+解码器） -----
            vae.train()
            latent_disc.eval()   # 冻结判别器
            optimizer_vae.zero_grad()
            
            recon, mu, logvar, z = vae(x)
            loss_recon = reconstruction_loss(recon, x)
            loss_kl = kl_divergence(mu, logvar)
            # 对抗性损失：希望判别器将编码器生成的 z 判定为“真实”（标签=1）
            loss_adv = adversarial_loss(latent_disc, z, 1)
            
            loss = loss_recon + beta_kl * loss_kl + gamma_adv * loss_adv
            loss.backward()
            optimizer_vae.step()
            
            total_loss += loss.item() * x.size(0)
            total_recon_loss += loss_recon.item() * x.size(0)
            total_kl_loss += loss_kl.item() * x.size(0)
            total_adv_loss += loss_adv.item() * x.size(0)
            
        avg_loss = total_loss / n_samples
        avg_recon_loss = total_recon_loss / n_samples
        avg_kl_loss = total_kl_loss / n_samples
        avg_adv_loss = total_adv_loss / n_samples
        print(f"Epoch [{epoch+1}/{num_epochs}] Total Loss: {avg_loss:.4f}, Recon: {avg_recon_loss:.4f}, KL: {avg_kl_loss:.4f}, Adv: {avg_adv_loss:.4f}")
    
    # 保存训练好的模型
    torch.save(vae.state_dict(), "vae_with_latent_disc.pth")
    print("Training complete, model saved.")

if __name__ == "__main__":
    main()


Epoch [1/10] Total Loss: 1.0624, Recon: 0.3391, KL: 0.0302, Adv: 0.6931
Epoch [2/10] Total Loss: 1.0287, Recon: 0.3341, KL: 0.0012, Adv: 0.6933
Epoch [3/10] Total Loss: 1.0275, Recon: 0.3337, KL: 0.0004, Adv: 0.6933
Epoch [4/10] Total Loss: 1.0271, Recon: 0.3336, KL: 0.0003, Adv: 0.6932
Epoch [5/10] Total Loss: 1.0271, Recon: 0.3336, KL: 0.0001, Adv: 0.6934
Epoch [6/10] Total Loss: 1.0270, Recon: 0.3336, KL: 0.0001, Adv: 0.6933
Epoch [7/10] Total Loss: 1.0270, Recon: 0.3336, KL: 0.0002, Adv: 0.6933
Epoch [8/10] Total Loss: 1.0266, Recon: 0.3336, KL: 0.0002, Adv: 0.6929
Epoch [9/10] Total Loss: 1.0269, Recon: 0.3336, KL: 0.0001, Adv: 0.6932
Epoch [10/10] Total Loss: 1.0269, Recon: 0.3336, KL: 0.0001, Adv: 0.6932
Training complete, model saved.
