In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Dataset Preparation
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10('path/to/dataset', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Model Architecture
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim

        self.gen = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 64 * 64 * 3),
            nn.Tanh()
        )

    def forward(self, noise):
        output = self.gen(noise)
        return output.view(-1, 3, 64, 64)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(64 * 64 * 3, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, images):
        images = images.view(-1, 64 * 64 * 3)
        return self.disc(images)

# Initialize the models
latent_dim = 100
generator = Generator(latent_dim)
discriminator = Discriminator()

# Loss functions and optimizers
criterion = nn.BCELoss()
gen_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)

# Model Training
num_epochs = 100
for epoch in range(num_epochs):
    for real_images, _ in train_loader:
        # Train the discriminator
        disc_optimizer.zero_grad()
        real_labels = torch.ones(real_images.size(0), 1)
        real_outputs = discriminator(real_images)
        real_loss = criterion(real_outputs, real_labels)

        noise = torch.randn(real_images.size(0), latent_dim)
        fake_images = generator(noise)
        fake_labels = torch.zeros(real_images.size(0), 1)
        fake_outputs = discriminator(fake_images.detach())
        fake_loss = criterion(fake_outputs, fake_labels)

        disc_loss = real_loss + fake_loss
        disc_loss.backward()
        disc_optimizer.step()

        # Train the generator
        gen_optimizer.zero_grad()
        noise = torch.randn(real_images.size(0), latent_dim)
        fake_images = generator(noise)
        fake_labels = torch.ones(real_images.size(0), 1)
        fake_outputs = discriminator(fake_images)
        gen_loss = criterion(fake_outputs, fake_labels)
        gen_loss.backward()
        gen_optimizer.step()

    # Evaluation and logging
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Generator Loss: {gen_loss.item():.4f}, Discriminator Loss: {disc_loss.item():.4f}")

        # Generate and visualize sample images
        noise = torch.randn(16, latent_dim)
        sample_images = generator(noise).detach().cpu().permute(0, 2, 3, 1).numpy()
        fig, axes = plt.subplots(4, 4, figsize=(12, 12))
        for i, ax in enumerate(axes.flat):
            ax.imshow(sample_images[i])
            ax.axis('off')
        plt.show()