In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt

In [None]:
class GeneratorCNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GeneratorCNN, self).__init__()
        self.conv1 = nn.ConvTranspose2d(input_size, hidden_size*4, kernel_size=4, stride=1, padding=0)
        self.conv2 = nn.ConvTranspose2d(hidden_size*4, hidden_size*2, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.ConvTranspose2d(hidden_size*2, hidden_size, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.ConvTranspose2d(hidden_size, output_size, kernel_size=4, stride=2, padding=1)
        self.activation = nn.Tanh()

    def forward(self, x):
        x = x.view(x.size(0), x.size(1), 1, 1)
        x = self.activation(self.conv1(x))
        x = self.activation(self.conv2(x))
        x = self.activation(self.conv3(x))
        x = torch.sigmoid(self.conv4(x))
        return x

In [None]:
class DiscriminatorCNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(DiscriminatorCNN, self).__init__()
        self.conv1 = nn.Conv2d(input_size, hidden_size, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(hidden_size, hidden_size*2, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(hidden_size*2, hidden_size*4, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(hidden_size*4, output_size, kernel_size=3, stride=1, padding=0)
        self.activation = nn.LeakyReLU(0.2)
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.activation(self.conv1(x))
        x = self.activation(self.conv2(x))
        x = self.activation(self.conv3(x))
        x = self.conv4(x)
        x = self.flatten(x)
        x = torch.sigmoid(x)
        return x

In [None]:
torch.manual_seed(42)
np.random.seed(42)

In [None]:
input_size = 100 
hidden_size = 256
output_size = 1
num_epochs = 10
batch_size = 128
learning_rate = 0.0002

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

In [None]:
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
generator = GeneratorCNN(input_size, hidden_size, output_size)
discriminator = DiscriminatorCNN(1, hidden_size, output_size)  # Input size is 1 for grayscale images

In [None]:
criterion = nn.BCELoss()
generator_optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

In [None]:
writer = SummaryWriter()

In [None]:
def generate_noise(n_samples, input_size):
    return torch.randn(n_samples, input_size, 1, 1)

In [None]:
images, _ = next(iter(train_loader))
images = images.view(-1, 1, 28, 28)
noise_for_graph = generate_noise(5, input_size)
writer.add_graph(generator, noise_for_graph)
writer.add_graph(discriminator, images)

In [None]:
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(train_loader):
        real_images = real_images.view(-1, 1, 28, 28)
        batch_size = real_images.size(0)

        discriminator_optimizer.zero_grad()

        real_outputs = discriminator(real_images)
        real_loss = criterion(real_outputs, torch.ones_like(real_outputs))  # Label real images as 1

        noise = generate_noise(batch_size, input_size)
        fake_images = generator(noise)
        fake_outputs = discriminator(fake_images.detach())
        fake_loss = criterion(fake_outputs, torch.zeros_like(fake_outputs))  # Label fake images as 0

        discriminator_loss = (real_loss + fake_loss) / 2
        discriminator_loss.backward()
        discriminator_optimizer.step()

        generator_optimizer.zero_grad()

        fake_outputs = discriminator(fake_images)
        generator_loss = criterion(fake_outputs, torch.ones_like(fake_outputs))  # Label fake images as 1

        generator_loss.backward()
        generator_optimizer.step()

        if (i+1) % 200 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Discriminator Loss: {discriminator_loss.item()}, Generator Loss: {generator_loss.item()}')

        writer.add_scalar('Discriminator Loss', discriminator_loss.item(), epoch * len(train_loader) + i)
        writer.add_scalar('Generator Loss', generator_loss.item(), epoch * len(train_loader) + i)

    n_samples = 10
    generated_noise = generate_noise(n_samples, input_size)
    generated_images = generator(generated_noise).detach().numpy()
    reshaped_images = generated_images.reshape(-1, 1, 28, 28)
    writer.add_images('Generated Images', reshaped_images, global_step=epoch)
    
writer.close()

In [None]:
fig, axes = plt.subplots(1, n_samples, figsize=(20, 2))
for i, ax in enumerate(axes):
    ax.imshow(generated_images[i].reshape(28, 28), cmap='gray')
    ax.axis('off')
plt.show()