<a href="https://colab.research.google.com/github/staerkjoe/ML_colab/blob/main/script_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils
import matplotlib.pyplot as plt

# Define a simple generator and discriminator for CIFAR-10
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # TODO: Define the generator architecture for CIFAR-10
        # consider that the output must match the size of the images (3*32*32)

    def forward(self, x):
        return self.fc(x).view(x.size(0), 3, 32,32)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # TODO: Define the discriminator architecture for CIFAR-10
        # consider that:
        # the input must match one image (3*32*32)
        # the output must match a number

    def forward(self, x):
        return self.fc(x.view(x.size(0), -1))


# Initialize the models
generator = Generator()
discriminator = Discriminator()

# Define loss function and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

# Lists to store losses for plotting
d_losses = []
g_losses = []

# Data loading and preprocessing (using CIFAR-10 dataset)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Training loop
def train_gan(generator, discriminator, dataloader, num_epochs):
    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader):
            real_images, _ = data
            batch_size = real_images.size(0)
            real_images = real_images.view(batch_size, -1)
            real_labels = torch.ones(batch_size, 1)
            fake_labels = torch.zeros(batch_size, 1)

            # Train the discriminator
            optimizer_D.zero_grad()
            outputs = discriminator(real_images)
            d_loss_real = criterion(outputs, real_labels)
            d_loss_real.backward()

            z = torch.randn(batch_size, 100)
            fake_images = generator(z)
            outputs = discriminator(fake_images.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            d_loss_fake.backward()
            d_loss = d_loss_real + d_loss_fake
            optimizer_D.step()

            # Train the generator
            optimizer_G.zero_grad()
            outputs = discriminator(fake_images)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            optimizer_G.step()

            d_losses.append(d_loss.item())
            g_losses.append(g_loss.item())

            if (i + 1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(dataloader)}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}')

        # Generate and save a sample of fake images
        if (epoch + 1) % 10 == 0:
            with torch.no_grad():
                z = torch.randn(32, 100)
                fake_samples = generator(z)
                vutils.save_image(fake_samples, f'fake_cifar_samples_epoch_{epoch+1}.png', normalize=True)

        # Plot the loss curves
        plt.figure(figsize=(10, 5))
        plt.title("Generator and Discriminator Loss")
        plt.plot(g_losses, label="G Loss")
        plt.plot(d_losses, label="D Loss")
        plt.xlabel("Iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(f'loss_plot_epoch_{epoch+1}.png')
        plt.show()

# Main training loop
train_gan(generator, discriminator, dataloader, num_epochs=50)