In [206]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from model import Discriminator, Generator, initialize_weights
import tqdm

from PIL import Image
import numpy as np

In [207]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lr = 2e-4
batch_size = 64
image_size = 64
channels = 3
z_dim = 100
epochs = 50
features_disc = 64
features_gen = 64

In [208]:
transforms = T.Compose(
    [
        T.Resize((image_size, image_size)),
        T.ToTensor(),
        T.Normalize([0.5 for _ in range(channels)], [0.5 for _ in range(channels)]),
    ]
)

In [209]:
dataset = ImageFolder(root='./dataset/', transform=transforms)
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [210]:
gen = Generator(z_dim, channels, features_gen).to(device)
disc = Discriminator(channels, features_disc).to(device)
initialize_weights(disc)
initialize_weights(gen)

optimizer_gen = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_disc = torch.optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(64, z_dim, 1, 1).to(device)

In [211]:
def create_image_with_labels(fake_images, real_images):
    fake_images = fake_images.cpu().detach().numpy()
    real_images = real_images.cpu().detach().numpy()
    margin = 10
    image_size = 64

    image_width = 32*image_size+31*margin
    image_height = 4*image_size+3*margin
    background_color = (255, 255, 255)
    image = Image.new('RGB', (image_width, image_height), background_color)

    x = 0
    y = 0
    for i, img in enumerate(fake_images):
        img = img.transpose(1, 2, 0)
        img = (img * [0.5 for _ in range(channels)] + [0.5 for _ in range(channels)]) * 255
        img = Image.fromarray(np.uint8(img), 'RGB')
        image.paste(img, (x, y))
        x += image_size + margin
        if i == 31 or i == len(fake_images) - 1:
            x = 0
            y += image_size + margin

    for i, img in enumerate(real_images):
        img = img.transpose(1, 2, 0)
        img = (img * [0.5 for _ in range(channels)] + [0.5 for _ in range(channels)]) * 255
        img = Image.fromarray(np.uint8(img), 'RGB')
        image.paste(img, (x, y))
        x += image_size + margin
        if i == 31 or i == len(real_images) - 1:
            x = 0
            y += image_size + margin

    return image

In [213]:
for epoch in tqdm.tqdm(range(5, epochs)):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        noise = torch.randn(batch_size, z_dim, 1, 1).to(device)
        fake = gen(noise)
        
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        lossD = (lossD_real + lossD_fake)/2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        optimizer_disc.step()


        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        optimizer_gen.step()

        if batch_idx % 5 == 0:
            with torch.no_grad():
                fake = gen(fixed_noise)

                image = create_image_with_labels(fake, real)
                image.save(f'./training_progress/epoch_{epoch}_step{batch_idx}.png')

    print(f'Epoch: {epoch}, Loss Disc: {lossD:.4f}, Loss Gen: {lossG:.4f}')

    


  0%|          | 0/45 [00:00<?, ?it/s]

  2%|▏         | 1/45 [42:13<30:57:46, 2533.32s/it]

Epoch: 5, Loss Disc: 0.2912, Loss Gen: 2.4877


  4%|▍         | 2/45 [1:14:14<25:57:33, 2173.34s/it]

Epoch: 6, Loss Disc: 0.1862, Loss Gen: 3.2658


  7%|▋         | 3/45 [1:43:51<23:14:35, 1992.27s/it]

Epoch: 7, Loss Disc: 0.2609, Loss Gen: 7.9467


  9%|▉         | 4/45 [2:13:48<21:48:39, 1915.11s/it]

Epoch: 8, Loss Disc: 0.1916, Loss Gen: 5.2240


 11%|█         | 5/45 [2:43:40<20:47:12, 1870.81s/it]

Epoch: 9, Loss Disc: 0.0991, Loss Gen: 3.9490


 13%|█▎        | 6/45 [3:13:36<19:59:32, 1845.45s/it]

Epoch: 10, Loss Disc: 0.1402, Loss Gen: 3.3888


 16%|█▌        | 7/45 [3:43:34<19:18:51, 1829.77s/it]

Epoch: 11, Loss Disc: 0.4456, Loss Gen: 4.7723


 18%|█▊        | 8/45 [4:13:24<18:40:39, 1817.27s/it]

Epoch: 12, Loss Disc: 0.3101, Loss Gen: 3.0354


 20%|██        | 9/45 [4:43:18<18:05:57, 1809.94s/it]

Epoch: 13, Loss Disc: 0.0807, Loss Gen: 4.8435


 22%|██▏       | 10/45 [5:15:22<17:56:22, 1845.22s/it]

Epoch: 14, Loss Disc: 0.0533, Loss Gen: 5.6265


 24%|██▍       | 11/45 [5:49:26<18:00:01, 1905.91s/it]

Epoch: 15, Loss Disc: 0.1155, Loss Gen: 2.7302


 27%|██▋       | 12/45 [6:19:26<17:10:37, 1873.85s/it]

Epoch: 16, Loss Disc: 0.0580, Loss Gen: 4.6782


 29%|██▉       | 13/45 [6:49:20<16:26:22, 1849.44s/it]

Epoch: 17, Loss Disc: 0.0623, Loss Gen: 4.3781


 31%|███       | 14/45 [7:19:07<15:45:51, 1830.70s/it]

Epoch: 18, Loss Disc: 0.1111, Loss Gen: 4.3342


 33%|███▎      | 15/45 [7:49:01<15:09:48, 1819.61s/it]

Epoch: 19, Loss Disc: 0.0817, Loss Gen: 7.6012


 36%|███▌      | 16/45 [8:18:56<14:35:58, 1812.35s/it]

Epoch: 20, Loss Disc: 0.1696, Loss Gen: 3.1513


 38%|███▊      | 17/45 [8:48:51<14:03:18, 1807.10s/it]

Epoch: 21, Loss Disc: 0.2136, Loss Gen: 2.8440


 40%|████      | 18/45 [9:18:49<13:31:53, 1804.21s/it]

Epoch: 22, Loss Disc: 0.0332, Loss Gen: 5.2779


 42%|████▏     | 19/45 [9:48:51<13:01:30, 1803.50s/it]

Epoch: 23, Loss Disc: 0.0343, Loss Gen: 5.5882


 44%|████▍     | 20/45 [10:18:54<12:31:29, 1803.56s/it]

Epoch: 24, Loss Disc: 0.1160, Loss Gen: 3.7822


 47%|████▋     | 21/45 [10:50:15<12:10:40, 1826.67s/it]

Epoch: 25, Loss Disc: 0.0655, Loss Gen: 3.7766


 49%|████▉     | 22/45 [11:25:23<12:12:33, 1911.00s/it]

Epoch: 26, Loss Disc: 0.0631, Loss Gen: 2.7536


 51%|█████     | 23/45 [11:55:38<11:30:10, 1882.30s/it]

Epoch: 27, Loss Disc: 0.0000, Loss Gen: 50.0591


 53%|█████▎    | 24/45 [12:32:03<11:30:38, 1973.27s/it]

Epoch: 28, Loss Disc: 0.0000, Loss Gen: 49.7825


 53%|█████▎    | 24/45 [12:35:20<11:00:55, 1888.36s/it]


KeyboardInterrupt: 