In [167]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
import torchvision.datasets as datasets
from model import Discriminator, Generator
import tqdm

from PIL import Image
import numpy as np

In [168]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lr = 3e-4
z_dim = 64
image_dim = 28*28
batch_size = 32
epochs = 50

In [169]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)

In [170]:
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = T.Compose(
    [T.ToTensor(), T.Normalize((0.5,), (0.5,))]
)

In [171]:
dataset = datasets.MNIST(root="./dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [173]:
optimizer_disc = torch.optim.Adam(disc.parameters(), lr=lr)
optimizer_gen = torch.optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

In [174]:
def create_image_with_labels(fake_images, real_images):
    fake_images = fake_images.cpu().detach().numpy()
    real_images = real_images.cpu().detach().numpy()

    margin = 10
    image_size = 28

    image_width = 16*image_size+15*margin
    image_height = 4*image_size+3*margin
    background_color = (255, 255, 255)
    image = Image.new('RGB', (image_width, image_height), background_color)

    x = 0
    y = 0
    for i, img in enumerate(fake_images):
        img = img[0] * 255
        img = Image.fromarray(np.uint8(img), 'L')
        image.paste(img, (x, y))
        x += image_size + margin
        if i == 15 or i == len(fake_images)-1:
            x = 0
            y += image_size + margin

    for i, img in enumerate(real_images):
        img = img[0] * 255
        img = Image.fromarray(np.uint8(img), 'L')
        image.paste(img, (x, y))
        x += image_size + margin
        if i == 15 or i == len(real_images)-1:
            x = 0
            y += image_size + margin

    return image

In [None]:
for epoch in tqdm.tqdm(range(epochs)):
    for (real, _) in loader:
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)

        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        lossD = (lossD_real + lossD_fake)/2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        optimizer_disc.step()


        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        optimizer_gen.step()

    print(f'Epoch: {epoch}, Loss Disc: {lossD:.4f}, Loss Gen: {lossG:.4f}')

    with torch.no_grad():
        fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
        real = real.reshape(-1, 1, 28, 28)

        image = create_image_with_labels(fake, real)
        image.save(f'./training_progress/epoch_{epoch}.png')
