In [None]:
import torch
from torch import nn
from discriminator import Discriminator 
from generator import Generator

import math
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms

In [None]:
torch.manual_seed(111)

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

In [None]:
train_set = torchvision.datasets.MNIST(
    root=".", train=True, download=True, transform=transform
)

In [None]:
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)

In [None]:
real_samples, mnist_labels = next(iter(train_loader))
for i in range(16):
    ax = plt.subplot(4, 4, i + 1)
    plt.imshow(real_samples[i].reshape(28, 28), cmap="gray_r")
    plt.xticks([])
    plt.yticks([])

In [36]:
dsc = Discriminator()

In [37]:
gen = Generator()

In [None]:
lr = 0.0001
epochs = 50
loss_function = nn.BCELoss()
optimizer_discriminator = torch.optim.Adam(dsc.parameters(), lr=lr)
optimizer_gen = torch.optim.Adam(gen.parameters(), lr=lr)

In [None]:
for epoch in range(epochs):
    for n, (real_samples, mnist_labels) in enumerate(train_loader):
        # Data for training the discriminator
        real_samples = real_samples
        real_samples_labels = torch.ones((batch_size, 1))
        latent_space_samples = torch.randn((batch_size, 100))
        generated_samples = gen(latent_space_samples)
        generated_samples_labels = torch.zeros((batch_size, 1))
        all_samples = torch.cat((real_samples, generated_samples))
        all_samples_labels = torch.cat(
            (real_samples_labels, generated_samples_labels)
        )

        # Training the discriminator
        dsc.zero_grad()
        output_dsc = dsc(all_samples)
        loss_dsc = loss_function(
            output_dsc, all_samples_labels
        )
        loss_dsc.backward()
        optimizer_discriminator.step()

        # Data for training the generator
        latent_space_samples = torch.randn((batch_size, 100))

        # Training the generator
        gen.zero_grad()
        generated_samples = gen(latent_space_samples)
        output_dsc_generated = dsc(generated_samples)
        loss_gen = loss_function(
            output_dsc_generated, real_samples_labels
        )
        loss_gen.backward()
        optimizer_gen.step()

        # Show loss
        if n == batch_size - 1:
            print(f"Epoch: {epoch} Loss D.: {loss_dsc}")
            print(f"Epoch: {epoch} Loss G.: {loss_gen}")

latent_space_samples = torch.randn(batch_size, 100).to(device=device)
generated_samples = gen(latent_space_samples)
for i in range(16):
    ax = plt.subplot(4, 4, i + 1)
    plt.imshow(generated_samples[i].reshape(28, 28), cmap="gray_r")
    plt.xticks([])
    plt.yticks([])

Epoch: 0 Loss D.: 0.03395676612854004
Epoch: 0 Loss G.: 3.8259243965148926
Epoch: 1 Loss D.: 0.01941712014377117
Epoch: 1 Loss G.: 5.783373832702637
Epoch: 2 Loss D.: 0.06617913395166397
Epoch: 2 Loss G.: 6.18718957901001
