In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
def train_gan(batch_size=32, num_epochs=100, device='cpu'):
    # Set device
    device = torch.device(device)

    # Load MNIST dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

    # Initialize Generator and Discriminator
    generator = Generator().to(device)
    discriminator = Discriminator().to(device)

    # Define loss function and optimizer
    criterion = nn.BCELoss()
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    # Training loop
    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(train_loader):
            # Train Discriminator
            optimizer_D.zero_grad()

            # Generate fake images
            z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
            fake_images = generator(z)

            # Real images
            real_images = real_images.to(device)
            real_labels = torch.ones(batch_size, 1).to(device)

            # Fake images
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Discriminator loss for real images
            d_real_loss = criterion(discriminator(real_images), real_labels)
            # Discriminator loss for fake images
            d_fake_loss = criterion(discriminator(fake_images.detach()), fake_labels)

            # Total Discriminator loss
            d_loss = d_real_loss + d_fake_loss
            d_loss.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
            fake_images = generator(z)
            g_loss = criterion(discriminator(fake_images), real_labels)
            g_loss.backward()
            optimizer_G.step()

        # Print progress
        print(f'Epoch [{epoch + 1}/{num_epochs}], Discriminator Loss: {d_loss.item():.4f}, Generator Loss: {g_loss.item():.4f}')