In [None]:
from gan import GAN
import torch
from torch import nn

In [None]:
noise_dim = 100  # Dimensionality of the noise vector
data_dim = 784  # Output/input dimensionality (e.g., 28x28 image flattened)
hidden_dim = 128

# Create GAN instance
gan = GAN(noise_dim, data_dim, hidden_dim)

# Generate a batch of noise vectors
batch_size = 16
noise = torch.randn(batch_size, noise_dim)

# Forward pass through the GAN
classification = gan(noise)
print(classification)

In [None]:
# Training loop
num_epochs = 1000
lr = 0.0002

# Optimizers
optimizer_g = torch.optim.Adam(gan.generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_d = torch.optim.Adam(gan.discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# Loss function
criterion = nn.BCELoss()

for epoch in range(num_epochs):
    for _ in range(5):  # Train discriminator more frequently
        # Train discriminator on real data
        real_data = torch.randn(batch_size, data_dim)  # Replace with actual dataset
        real_labels = torch.ones(batch_size, 1)

        optimizer_d.zero_grad()

        output_real = gan.discriminator(real_data)
        loss_real = criterion(output_real, real_labels)

        # Train discriminator on fake data
        noise = torch.randn(batch_size, noise_dim)
        fake_data = gan.generator(noise).detach()
        fake_labels = torch.zeros(batch_size, 1)

        output_fake = gan.discriminator(fake_data)
        loss_fake = criterion(output_fake, fake_labels)

        loss_d = loss_real + loss_fake
        loss_d.backward()
        optimizer_d.step()

    # Train generator
    noise = torch.randn(batch_size, noise_dim)
    fake_labels = torch.ones(batch_size, 1)  # Generator tries to fool discriminator

    optimizer_g.zero_grad()

    fake_data = gan.generator(noise)
    output_fake = gan.discriminator(fake_data)
    loss_g = criterion(output_fake, fake_labels)

    loss_g.backward()
    optimizer_g.step()

    # Print progress
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss D: {loss_d.item():.4f}, Loss G: {loss_g.item():.4f}")
