In [28]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
import torchvision.datasets as datasets
from model import Discriminator, Generator
import tqdm

from PIL import Image, ImageDraw, ImageFont
import numpy as np

In [29]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lr = 3e-4
z_dim = 64
image_dim = 28*28
batch_size = 32
epochs = 50

In [30]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)

In [31]:
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = T.Compose(
    [T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]
)

In [32]:
dataset = datasets.MNIST(root="./dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [33]:
optimizer_disc = torch.optim.Adam(disc.parameters(), lr=lr)
optimizer_gen = torch.optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

In [36]:
def create_image_with_labels(fake_images, real_images):
    fake_images = fake_images.cpu().detach().numpy()
    real_images = real_images.cpu().detach().numpy()

    image_width = 640
    image_height = 256
    background_color = (0, 0, 0)
    image = Image.new('RGB', (image_width, image_height), background_color)

    draw = ImageDraw.Draw(image)

    margin = 10
    image_size = 28

    for i in range(len(fake_images)):
        generated_image = fake_images[i][0] * 255
        generated_image = Image.fromarray(np.uint8(generated_image), 'L')
        x = i % 8 * (image_size + margin)
        y = i // 8 * (image_size + margin)
        image.paste(generated_image, (x, y))

    for i in range(len(real_images)):
        real_image = real_images[i][0] * 255
        real_image = Image.fromarray(np.uint8(real_image), 'L')
        x = (i % 8 + 8) * (image_size + margin)
        y = i // 8 * (image_size + margin)
        image.paste(real_image, (x, y))

    return image

In [37]:
for epoch in tqdm.tqdm(range(epochs)):
    for (real, _) in loader:
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)

        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        lossD = (lossD_real + lossD_fake)/2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        optimizer_disc.step()


        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        optimizer_gen.step()

    print(f'Epoch: {epoch}, Loss Disc: {lossD:.4f}, Loss Gen: {lossG:.4f}')

    with torch.no_grad():
        fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
        real = real.reshape(-1, 1, 28, 28)

        image = create_image_with_labels(fake, real)
        image.show()


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch: 0, Loss Disc: 0.0244, Loss Gen: 4.8797


  2%|▏         | 1/50 [00:32<26:18, 32.21s/it]


KeyboardInterrupt: 