In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

# Cấu hình mô hình
noise_dim = 100
embed_dim = 128
label_dim = 5
batch_size = 64
num_epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generator
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim + label_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, embed_dim)
        )
    def forward(self, z, label_onehot):
        return self.fc(torch.cat([z, label_onehot], dim=1))

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(embed_dim + label_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    def forward(self, embed, label_onehot):
        return self.fc(torch.cat([embed, label_onehot], dim=1))

# One-hot encode
def one_hot(labels, num_classes=5):
    return torch.eye(num_classes)[labels].to(device)

# Khởi tạo model
G = Generator().to(device)
D = Discriminator().to(device)
criterion = nn.BCELoss()
g_opt = optim.Adam(G.parameters(), lr=0.0002)
d_opt = optim.Adam(D.parameters(), lr=0.0002)

# Dữ liệu giả để test trước khi dùng Token2Vec thật
real_token_vecs = torch.randn(1000, embed_dim).to(device)
real_labels = torch.randint(0, label_dim, (1000,)).to(device)

# Huấn luyện GAN
for epoch in range(num_epochs):
    for i in range(0, len(real_token_vecs), batch_size):
        real_embed = real_token_vecs[i:i+batch_size]
        labels = real_labels[i:i+batch_size]
        bsz = real_embed.size(0)

        real = torch.ones(bsz, 1).to(device)
        fake = torch.zeros(bsz, 1).to(device)
        label_onehot = one_hot(labels)

        # Train Discriminator
        z = torch.randn(bsz, noise_dim).to(device)
        fake_embed = G(z, label_onehot)

        d_real = D(real_embed, label_onehot)
        d_fake = D(fake_embed.detach(), label_onehot)
        d_loss = criterion(d_real, real) + criterion(d_fake, fake)
        d_opt.zero_grad()
        d_loss.backward()
        d_opt.step()

        # Train Generator
        z = torch.randn(bsz, noise_dim).to(device)
        fake_embed = G(z, label_onehot)
        d_pred = D(fake_embed, label_onehot)
        g_loss = criterion(d_pred, real)
        g_opt.zero_grad()
        g_loss.backward()
        g_opt.step()

    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f"Epoch {epoch+1}/{num_epochs} | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")


  cpu = _conversion_method_template(device=torch.device("cpu"))


Epoch 1/50 | D Loss: 1.2696 | G Loss: 0.6184
Epoch 10/50 | D Loss: 0.9492 | G Loss: 1.3521
Epoch 20/50 | D Loss: 1.1142 | G Loss: 1.9039
Epoch 30/50 | D Loss: 1.1614 | G Loss: 0.6390
Epoch 40/50 | D Loss: 0.6237 | G Loss: 1.1354
Epoch 50/50 | D Loss: 0.4800 | G Loss: 1.3073
