In [1]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from typing import Tuple, List, Optional
import math
import copy
import tqdm
import matplotlib.pyplot as plt

# The Basic Building Block

In [2]:
class ConvBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 padding: int,
                 batch_norm: bool = False) -> None:
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, 1, padding)
        self.norm = nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity()
        self.relu = nn.LeakyReLU(0.05, inplace=True)

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        return self.relu(self.norm(self.conv(x)))

## Discriminator

In [3]:
class Discriminator(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 padding: int,
                 num_layer: int) -> None:
        super(Discriminator, self).__init__()
        self.head = ConvBlock(in_channels, out_channels, kernel_size, padding)
        self.body = nn.Sequential(
            *[ConvBlock(out_channels, out_channels, kernel_size, padding)
            for _ in range(num_layer)]
        )
        self.tail = nn.Conv2d(out_channels, 1, kernel_size, padding)

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        x = self.head(x)
        x = self.body(x)
        return self.tail(x)

## Generator

In [4]:
class Generator(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 padding: int,
                 num_layers: int,
                 batch_norm: bool = False) -> None:
        super(Generator, self).__init__()
        self.head = ConvBlock(in_channels, out_channels, kernel_size, padding, batch_norm)
        self.body = nn.ModuleList()
        self.body.append(nn.Sequential(
            *[ConvBlock(out_channels, out_channels, kernel_size, padding, batch_norm)
            for _ in range(num_layers)]
        ))
        self.tail = nn.Sequential(
            nn.Conv2d(out_channels, in_channels, kernel_size, padding=padding),
            nn.Tanh()
        )
    
    def init_next_stage(self):
        self.body.append(copy.deepcopy(self.body[-1]))
    
    def forward(self,
                noise: List[torch.Tensor],
                real_shapes: List[torch.Size],
                noise_amp: List[float]) -> torch.Tensor:
        x = self.head(noise[0])
        x = self.body[0](x)
        for i in range(1, len(self.body)):
            x = F.interpolate(x, size=real_shapes[i], mode='bilinear', align_corners=True)
            x = x + self.body[i](x + noise[i] * noise_amp[i])
        return self.tail(x)

## Read & Write Image

In [5]:
def read_image(path: str) -> torch.Tensor:
    image = plt.imread(path)
    image = torch.from_numpy(image.copy()).float() / 255.0
    image = image.permute(2, 0, 1)
    return (image - 0.5) * 2

def write_image(path: str,
                image: torch.Tensor) -> None:
    image = (image + 1) / 2
    plt.imsave(path, image.permute(1, 2, 0).cpu().numpy(), vmin=0, vmax=1)

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = '../data/angkorwat.jpg'
image = read_image(path).to(device)
max_size, min_size = 250, 25
stages = 8 # stop_scale = stages - 1 = 7

In [7]:
def adjust_scales2image(image: torch.Tensor,
                        stop_scale: int = 7) -> Tuple[torch.Tensor, float]:
    # resize the long side of the image to max_size
    scale = min(max_size / max([image.shape[2], image.shape[3]]), 1)
    real = F.interpolate(image, scale_factor=scale, mode='bilinear', align_corners=True)
    # calculate the scaling factor for resizing the short side to min_size
    scale_factor = math.pow(min_size / (min(real.shape[2], real.shape[3])), 1 / stop_scale)
    return real, scale_factor
real, scale_factor = adjust_scales2image(image.unsqueeze(0), stages - 1)
real.shape, scale_factor

(torch.Size([1, 3, 166, 250]), 0.763040197031911)

In [8]:
def create_reals_pyramid(real, scale_factor: float, stop_scale: int = 7):
    reals = []
    for i in range(stop_scale):
        scale = math.pow(scale_factor,((stop_scale-1)/math.log(stop_scale))*math.log(stop_scale-i)+1)
        curr_real = F.interpolate(real, scale_factor=scale, mode='bilinear', align_corners=True)
        reals.append(curr_real)
    reals.append(real)
    return reals

reals = create_reals_pyramid(real, scale_factor, stages - 1)
for r in reals:
    print(r.shape)

torch.Size([1, 3, 25, 37])
torch.Size([1, 3, 28, 42])
torch.Size([1, 3, 33, 49])
torch.Size([1, 3, 39, 60])
torch.Size([1, 3, 50, 76])
torch.Size([1, 3, 71, 107])
torch.Size([1, 3, 126, 190])
torch.Size([1, 3, 166, 250])


In [9]:
def calc_gradient_penalty(discriminator: Discriminator,
                          real_data: torch.Tensor,
                          fake_data: torch.Tensor,
                          lamb: float = 0.1,
                          device: str = 'cpu') -> torch.Tensor:
    alpha = torch.rand(1).item()
    interpolates = alpha * real_data + (1 - alpha) * fake_data
    interpolates.requires_grad_(True)
    disc_interpolates = discriminator(interpolates)
    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                    grad_outputs=torch.ones_like(disc_interpolates, device=device),
                                    create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lamb
    return gradient_penalty

## Train

In [17]:
n_features = 64
generator = Generator(3, n_features, 3, 1, 3, False).to(device)
discriminator = Discriminator(3, n_features, 3, 1, 3).to(device)
fixed_noise = []
noise_amp = []

d_iter = g_iter = 3
max_epochs = 2000
lr = 0.0005
lr_scale = 0.1
beta1 = 0.5
gamma = 0.1
lamb = 0.1
alpha = 10
reals_shapes = [real.shape[2:] for real in reals]
train_depth = 3

In [15]:
output_dir = 'output'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
for i in range(stages):
    p = os.path.join(output_dir, f'stage_{i:02d}')
    if not os.path.exists(p):
        os.mkdir(p)

In [None]:
for stage in range(stages):
    print(f'Stage {stage + 1}: ')
    if stage:
        generator.init_next_stage()
    
    real = reals[stage]
    z_opt = reals[0] if stage == 0 else torch.randn(1, n_features, *reals_shapes[stage], device=device)
    fixed_noise.append(z_opt)
    
    optim_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

    for block in generator.body[:-train_depth]:
        for param in block.parameters():
            param.requires_grad = False
    
    parameter_list = [{"params": block.parameters(), "lr": lr * (lr_scale ** (len(generator.body[-train_depth:]) - 1 - idx))}
            for idx, block in enumerate(generator.body[-train_depth:])]

    # add parameters of head and tail to training
    if stage - train_depth < 0:
        parameter_list += [{"params": generator.head.parameters(), "lr": lr * (lr_scale ** stage)}]
    parameter_list += [{"params": generator.tail.parameters(), "lr": lr}]
    optim_g = torch.optim.Adam(parameter_list, lr=lr, betas=(beta1, 0.999))

    # define learning rate schedules
    scheduler_d = torch.optim.lr_scheduler.MultiStepLR(optimizer=optim_d, milestones=[0.8 * max_epochs], gamma=gamma)
    scheduler_g = torch.optim.lr_scheduler.MultiStepLR(optimizer=optim_g, milestones=[0.8 * max_epochs], gamma=gamma)

    # calculate noise_amp
    if stage == 0:
        noise_amp.append(1)
    else:
        noise_amp.append(0)
        with torch.no_grad():
            z_reconstruction = generator(fixed_noise, reals_shapes, noise_amp)
        rec_loss = F.mse_loss(z_reconstruction, real)
        noise_amp[-1] = 0.1 * torch.sqrt(rec_loss)

    for epoch in tqdm.tqdm(range(1, max_epochs + 1)):
        noise = [torch.randn(*z.shape, device=device) for z in fixed_noise]

        # update discriminator: minimize D(G(z)) - D(x)
        for j in range(d_iter):
            output = discriminator(real)
            errD_real = -output.mean()
            
            if j == d_iter - 1:
                fake = generator(noise, reals_shapes, noise_amp)
            else:
                with torch.no_grad():
                    fake = generator(noise, reals_shapes, noise_amp)
            
            output = discriminator(fake.detach())
            errD_fake = output.mean()
            
            gradient_penalty = calc_gradient_penalty(discriminator, real, fake.detach(), lamb, device)
            errD_total = errD_real + errD_fake + gradient_penalty
            optim_d.zero_grad()
            errD_total.backward()
            optim_d.step()
        
        # update generator: minimize -D(G(z))
        output = discriminator(fake)
        errG = -output.mean()
        
        rec = generator(fixed_noise, reals_shapes, noise_amp)
        rec_loss = alpha * F.mse_loss(rec, real)
        
        errG_total = errG + rec_loss
        optim_g.zero_grad()
        errG_total.backward()
        for _ in range(g_iter):
            optim_g.step()
        
        scheduler_d.step()
        scheduler_g.step()

        if epoch % 250 == 0:
            print(f'Epoch: {epoch: 04d}, D_real Loss: {errD_real.item():.5f}, D_fake Loss: {errD_fake.item():.5f}, '
                  f'D_gradient_penalty: {gradient_penalty.item():.5f}, G_fake Loss: {errG.item():.5f}, G_rec Loss: {rec_loss.item():.5f}')
        if epoch % 500 == 0:
            write_image(f'{output_dir}/stage_{stage:02d}/fake_sample_{epoch:04d}.jpg', fake.detach().squeeze(0))
            write_image(f'{output_dir}/stage_{stage:02d}/reconstruction_{epoch: 04d}.jpg', rec.detach().squeeze(0))

Stage 1: 


 13%|█▎        | 256/2000 [00:08<00:54, 31.99it/s]

Epoch:  250, D_real Loss: 2.44567, D_fake Loss: 0.46688, D_gradient_penalty: 0.85142, G_fake Loss: -0.56948, G_rec Loss: 0.08614


 25%|██▌       | 504/2000 [00:16<00:48, 30.97it/s]

Epoch:  500, D_real Loss: 2.75209, D_fake Loss: 0.02575, D_gradient_penalty: 1.26078, G_fake Loss: 0.00961, G_rec Loss: 0.06195


 38%|███▊      | 754/2000 [00:24<00:41, 29.90it/s]

Epoch:  750, D_real Loss: 2.86804, D_fake Loss: 0.61244, D_gradient_penalty: 0.73418, G_fake Loss: -0.81290, G_rec Loss: 0.08331


 50%|█████     | 1004/2000 [00:33<00:35, 28.42it/s]

Epoch:  1000, D_real Loss: 1.59142, D_fake Loss: -1.31235, D_gradient_penalty: 1.58353, G_fake Loss: 1.79641, G_rec Loss: 0.03709


 63%|██████▎   | 1253/2000 [00:42<00:26, 27.70it/s]

Epoch:  1250, D_real Loss: 4.15951, D_fake Loss: 0.32691, D_gradient_penalty: 1.64002, G_fake Loss: -0.45096, G_rec Loss: 0.01530


 75%|███████▌  | 1504/2000 [00:51<00:18, 27.50it/s]

Epoch:  1500, D_real Loss: 2.81833, D_fake Loss: -1.24190, D_gradient_penalty: 1.38232, G_fake Loss: 0.77452, G_rec Loss: 0.02890


 88%|████████▊ | 1754/2000 [01:00<00:08, 28.09it/s]

Epoch:  1750, D_real Loss: 1.90450, D_fake Loss: -1.06204, D_gradient_penalty: 1.59214, G_fake Loss: 1.10019, G_rec Loss: 0.00815


100%|██████████| 2000/2000 [01:09<00:00, 28.71it/s]


Epoch:  2000, D_real Loss: 2.19303, D_fake Loss: -0.71114, D_gradient_penalty: 1.23367, G_fake Loss: 0.77997, G_rec Loss: 0.00664
Stage 2: 


 13%|█▎        | 254/2000 [00:10<01:13, 23.77it/s]

Epoch:  250, D_real Loss: -0.05511, D_fake Loss: -1.23056, D_gradient_penalty: 0.29881, G_fake Loss: 1.04575, G_rec Loss: 0.01736


 25%|██▌       | 503/2000 [00:20<01:02, 24.02it/s]

Epoch:  500, D_real Loss: 1.14827, D_fake Loss: -0.62732, D_gradient_penalty: 1.37147, G_fake Loss: 1.22368, G_rec Loss: 0.01113


 38%|███▊      | 755/2000 [00:31<00:51, 24.20it/s]

Epoch:  750, D_real Loss: 0.32888, D_fake Loss: -1.20341, D_gradient_penalty: 0.33360, G_fake Loss: 1.17005, G_rec Loss: 0.00818


 50%|█████     | 1004/2000 [00:41<00:40, 24.42it/s]

Epoch:  1000, D_real Loss: 0.10090, D_fake Loss: -1.41096, D_gradient_penalty: 0.67366, G_fake Loss: 1.43826, G_rec Loss: 0.00678


 63%|██████▎   | 1253/2000 [00:51<00:29, 24.93it/s]

Epoch:  1250, D_real Loss: -0.34205, D_fake Loss: -1.66898, D_gradient_penalty: 0.54977, G_fake Loss: 1.95535, G_rec Loss: 0.00788


 75%|███████▌  | 1505/2000 [01:02<00:20, 24.58it/s]

Epoch:  1500, D_real Loss: -0.05581, D_fake Loss: -1.48170, D_gradient_penalty: 0.32130, G_fake Loss: 1.46038, G_rec Loss: 0.00684


 88%|████████▊ | 1754/2000 [01:12<00:10, 24.30it/s]

Epoch:  1750, D_real Loss: -0.30878, D_fake Loss: -1.64869, D_gradient_penalty: 0.33847, G_fake Loss: 1.62825, G_rec Loss: 0.00159


100%|██████████| 2000/2000 [01:22<00:00, 24.10it/s]


Epoch:  2000, D_real Loss: -0.23347, D_fake Loss: -1.58599, D_gradient_penalty: 0.90047, G_fake Loss: 1.62404, G_rec Loss: 0.00154
Stage 3: 


 13%|█▎        | 252/2000 [00:11<01:24, 20.66it/s]

Epoch:  250, D_real Loss: -1.40649, D_fake Loss: -2.15758, D_gradient_penalty: 0.37187, G_fake Loss: 2.46801, G_rec Loss: 0.01127


 25%|██▌       | 503/2000 [00:23<01:09, 21.68it/s]

Epoch:  500, D_real Loss: -1.05309, D_fake Loss: -2.01897, D_gradient_penalty: 0.27398, G_fake Loss: 2.45107, G_rec Loss: 0.01339


 38%|███▊      | 752/2000 [00:35<00:58, 21.24it/s]

Epoch:  750, D_real Loss: -0.59272, D_fake Loss: -1.47411, D_gradient_penalty: 0.26643, G_fake Loss: 1.60294, G_rec Loss: 0.01314


 50%|█████     | 1004/2000 [00:46<00:44, 22.36it/s]

Epoch:  1000, D_real Loss: -1.21572, D_fake Loss: -2.07044, D_gradient_penalty: 0.22169, G_fake Loss: 1.80216, G_rec Loss: 0.00593


 63%|██████▎   | 1254/2000 [00:58<00:36, 20.62it/s]

Epoch:  1250, D_real Loss: -1.08078, D_fake Loss: -1.89109, D_gradient_penalty: 0.35677, G_fake Loss: 2.01704, G_rec Loss: 0.01242


 75%|███████▌  | 1504/2000 [01:10<00:22, 22.40it/s]

Epoch:  1500, D_real Loss: -0.79370, D_fake Loss: -1.56003, D_gradient_penalty: 0.11377, G_fake Loss: 1.42615, G_rec Loss: 0.00616


 88%|████████▊ | 1754/2000 [01:21<00:10, 22.47it/s]

Epoch:  1750, D_real Loss: -1.44992, D_fake Loss: -2.28231, D_gradient_penalty: 0.38001, G_fake Loss: 2.30963, G_rec Loss: 0.00174


100%|██████████| 2000/2000 [01:33<00:00, 21.42it/s]


Epoch:  2000, D_real Loss: -1.34912, D_fake Loss: -2.05222, D_gradient_penalty: 0.33126, G_fake Loss: 2.07791, G_rec Loss: 0.00144
Stage 4: 


 13%|█▎        | 253/2000 [00:12<01:22, 21.13it/s]

Epoch:  250, D_real Loss: -1.48609, D_fake Loss: -1.91113, D_gradient_penalty: 0.09517, G_fake Loss: 1.74799, G_rec Loss: 0.01569


 25%|██▌       | 502/2000 [00:26<01:38, 15.17it/s]

Epoch:  500, D_real Loss: -0.95946, D_fake Loss: -1.46425, D_gradient_penalty: 0.13991, G_fake Loss: 1.56166, G_rec Loss: 0.00976


 38%|███▊      | 752/2000 [00:40<01:12, 17.24it/s]

Epoch:  750, D_real Loss: -1.20979, D_fake Loss: -1.68394, D_gradient_penalty: 0.11328, G_fake Loss: 2.00337, G_rec Loss: 0.00940


 50%|█████     | 1003/2000 [00:56<00:58, 17.09it/s]

Epoch:  1000, D_real Loss: -1.09386, D_fake Loss: -1.61909, D_gradient_penalty: 0.28515, G_fake Loss: 2.02674, G_rec Loss: 0.01358


 63%|██████▎   | 1251/2000 [01:10<00:38, 19.40it/s]

Epoch:  1250, D_real Loss: -0.94832, D_fake Loss: -1.44289, D_gradient_penalty: 0.24103, G_fake Loss: 1.45445, G_rec Loss: 0.00560


 75%|███████▌  | 1503/2000 [01:24<00:25, 19.70it/s]

Epoch:  1500, D_real Loss: -1.09367, D_fake Loss: -1.57583, D_gradient_penalty: 0.23874, G_fake Loss: 1.67756, G_rec Loss: 0.00552


 88%|████████▊ | 1754/2000 [01:39<00:12, 19.42it/s]

Epoch:  1750, D_real Loss: -1.25229, D_fake Loss: -1.78719, D_gradient_penalty: 0.14328, G_fake Loss: 1.82131, G_rec Loss: 0.00165


100%|██████████| 2000/2000 [01:53<00:00, 17.62it/s]


Epoch:  2000, D_real Loss: -1.22323, D_fake Loss: -1.80057, D_gradient_penalty: 0.15170, G_fake Loss: 1.80609, G_rec Loss: 0.00152
Stage 5: 


 13%|█▎        | 252/2000 [00:23<02:50, 10.27it/s]

Epoch:  250, D_real Loss: -0.67684, D_fake Loss: -1.03667, D_gradient_penalty: 0.18146, G_fake Loss: 1.41466, G_rec Loss: 0.01381


 25%|██▌       | 501/2000 [00:48<02:26, 10.23it/s]

Epoch:  500, D_real Loss: -0.79434, D_fake Loss: -1.10427, D_gradient_penalty: 0.11103, G_fake Loss: 1.12742, G_rec Loss: 0.01182


 38%|███▊      | 752/2000 [01:12<01:59, 10.48it/s]

Epoch:  750, D_real Loss: -1.14750, D_fake Loss: -1.44815, D_gradient_penalty: 0.11696, G_fake Loss: 1.57776, G_rec Loss: 0.01144


 50%|█████     | 1001/2000 [01:37<01:37, 10.21it/s]

Epoch:  1000, D_real Loss: -0.96119, D_fake Loss: -1.27049, D_gradient_penalty: 0.13658, G_fake Loss: 1.21695, G_rec Loss: 0.00916


 63%|██████▎   | 1251/2000 [02:01<01:13, 10.15it/s]

Epoch:  1250, D_real Loss: -1.09957, D_fake Loss: -1.41732, D_gradient_penalty: 0.15502, G_fake Loss: 1.53140, G_rec Loss: 0.00879


 75%|███████▌  | 1501/2000 [02:26<00:48, 10.31it/s]

Epoch:  1500, D_real Loss: -1.43687, D_fake Loss: -1.78272, D_gradient_penalty: 0.21507, G_fake Loss: 1.93118, G_rec Loss: 0.00662


 88%|████████▊ | 1751/2000 [02:50<00:24, 10.25it/s]

Epoch:  1750, D_real Loss: -1.27595, D_fake Loss: -1.57334, D_gradient_penalty: 0.09225, G_fake Loss: 1.56579, G_rec Loss: 0.00219


100%|██████████| 2000/2000 [03:14<00:00, 10.27it/s]


Epoch:  2000, D_real Loss: -1.42027, D_fake Loss: -1.75672, D_gradient_penalty: 0.17511, G_fake Loss: 1.75872, G_rec Loss: 0.00214
Stage 6: 


 13%|█▎        | 252/2000 [00:27<03:03,  9.51it/s]

Epoch:  250, D_real Loss: -0.65974, D_fake Loss: -0.82493, D_gradient_penalty: 0.03876, G_fake Loss: 0.92181, G_rec Loss: 0.02755


 25%|██▌       | 502/2000 [00:55<02:38,  9.48it/s]

Epoch:  500, D_real Loss: -0.75458, D_fake Loss: -0.99343, D_gradient_penalty: 0.08772, G_fake Loss: 1.05342, G_rec Loss: 0.01333


 38%|███▊      | 750/2000 [01:27<03:56,  5.28it/s]

Epoch:  750, D_real Loss: -0.94803, D_fake Loss: -1.15053, D_gradient_penalty: 0.05750, G_fake Loss: 1.01569, G_rec Loss: 0.01493


 50%|█████     | 1000/2000 [02:07<03:09,  5.27it/s]

Epoch:  1000, D_real Loss: -0.76108, D_fake Loss: -0.94716, D_gradient_penalty: 0.05530, G_fake Loss: 0.93049, G_rec Loss: 0.01088


 62%|██████▎   | 1250/2000 [02:47<02:20,  5.33it/s]

Epoch:  1250, D_real Loss: -1.05848, D_fake Loss: -1.26707, D_gradient_penalty: 0.06313, G_fake Loss: 1.15767, G_rec Loss: 0.00743


 75%|███████▌  | 1500/2000 [03:27<01:35,  5.26it/s]

Epoch:  1500, D_real Loss: -0.73408, D_fake Loss: -0.94272, D_gradient_penalty: 0.06070, G_fake Loss: 0.99504, G_rec Loss: 0.00824


 88%|████████▊ | 1750/2000 [04:07<00:47,  5.31it/s]

Epoch:  1750, D_real Loss: -1.09379, D_fake Loss: -1.28863, D_gradient_penalty: 0.10569, G_fake Loss: 1.30025, G_rec Loss: 0.00289


100%|██████████| 2000/2000 [04:47<00:00,  6.96it/s]


Epoch:  2000, D_real Loss: -1.12294, D_fake Loss: -1.32191, D_gradient_penalty: 0.06054, G_fake Loss: 1.33551, G_rec Loss: 0.00278
Stage 7: 


 12%|█▎        | 250/2000 [02:15<17:35,  1.66it/s]

Epoch:  250, D_real Loss: -0.54732, D_fake Loss: -0.61189, D_gradient_penalty: 0.02131, G_fake Loss: 0.58320, G_rec Loss: 0.02213


 25%|██▌       | 500/2000 [04:31<15:08,  1.65it/s]

Epoch:  500, D_real Loss: -0.44545, D_fake Loss: -0.54005, D_gradient_penalty: 0.02609, G_fake Loss: 0.59339, G_rec Loss: 0.01569


 38%|███▊      | 750/2000 [06:27<08:44,  2.38it/s]

Epoch:  750, D_real Loss: -0.61420, D_fake Loss: -0.71358, D_gradient_penalty: 0.03070, G_fake Loss: 0.54102, G_rec Loss: 0.01379


 50%|█████     | 1000/2000 [08:12<07:01,  2.37it/s]

Epoch:  1000, D_real Loss: -0.64543, D_fake Loss: -0.75620, D_gradient_penalty: 0.03348, G_fake Loss: 0.67703, G_rec Loss: 0.01654


 62%|██████▎   | 1250/2000 [09:47<05:17,  2.36it/s]

Epoch:  1250, D_real Loss: -0.29921, D_fake Loss: -0.39663, D_gradient_penalty: 0.02792, G_fake Loss: 0.37323, G_rec Loss: 0.01067


 75%|███████▌  | 1500/2000 [11:49<05:03,  1.65it/s]

Epoch:  1500, D_real Loss: -0.54265, D_fake Loss: -0.66608, D_gradient_penalty: 0.03810, G_fake Loss: 0.87646, G_rec Loss: 0.01243


 88%|████████▊ | 1750/2000 [13:53<01:44,  2.38it/s]

Epoch:  1750, D_real Loss: -0.51955, D_fake Loss: -0.63011, D_gradient_penalty: 0.05240, G_fake Loss: 0.64116, G_rec Loss: 0.00547


100%|██████████| 2000/2000 [15:40<00:00,  2.13it/s]


Epoch:  2000, D_real Loss: -0.50154, D_fake Loss: -0.62625, D_gradient_penalty: 0.03503, G_fake Loss: 0.60726, G_rec Loss: 0.00537
Stage 8: 


 12%|█▎        | 250/2000 [02:31<21:01,  1.39it/s]

Epoch:  250, D_real Loss: -0.47773, D_fake Loss: -0.54092, D_gradient_penalty: 0.02207, G_fake Loss: 0.50163, G_rec Loss: 0.03117


 25%|██▌       | 500/2000 [05:12<18:28,  1.35it/s]

Epoch:  500, D_real Loss: -0.55819, D_fake Loss: -0.66361, D_gradient_penalty: 0.03191, G_fake Loss: 0.65489, G_rec Loss: 0.01896


 38%|███▊      | 750/2000 [07:54<15:04,  1.38it/s]

Epoch:  750, D_real Loss: -0.57348, D_fake Loss: -0.68891, D_gradient_penalty: 0.03146, G_fake Loss: 0.62952, G_rec Loss: 0.01778


 50%|█████     | 1000/2000 [10:06<08:19,  2.00it/s]

Epoch:  1000, D_real Loss: -0.47685, D_fake Loss: -0.59481, D_gradient_penalty: 0.03636, G_fake Loss: 0.54993, G_rec Loss: 0.01750


 62%|██████▎   | 1250/2000 [12:04<06:16,  1.99it/s]

Epoch:  1250, D_real Loss: -0.45301, D_fake Loss: -0.58700, D_gradient_penalty: 0.03989, G_fake Loss: 0.64609, G_rec Loss: 0.01500


 75%|███████▌  | 1500/2000 [13:58<04:14,  1.96it/s]

Epoch:  1500, D_real Loss: -0.75755, D_fake Loss: -0.87975, D_gradient_penalty: 0.05053, G_fake Loss: 1.07052, G_rec Loss: 0.01689


 88%|████████▊ | 1750/2000 [15:52<02:06,  1.97it/s]

Epoch:  1750, D_real Loss: -0.80386, D_fake Loss: -0.93552, D_gradient_penalty: 0.03609, G_fake Loss: 0.93714, G_rec Loss: 0.00825


100%|██████████| 2000/2000 [17:47<00:00,  1.87it/s]

Epoch:  2000, D_real Loss: -0.91262, D_fake Loss: -1.03618, D_gradient_penalty: 0.03719, G_fake Loss: 1.04677, G_rec Loss: 0.00810





In [19]:
torch.save(generator.state_dict(), f"{output_dir}/generator.pth")
torch.save(discriminator.state_dict(), f"{output_dir}/discriminator.pth")

In [None]:
@torch.no_grad
def generate_samples(path: str,
                     n: int = 25) -> None:
    if not os.path.exists(path):
        os.mkdir(path)
    for idx in range(n):
        noise = [torch.randn(*z.shape, device=device) for z in fixed_noise]
        sample = generator(noise, reals_shapes, noise_amp)
        write_image(f'{path}/gen_sample_{idx + 1:02d}.jpg', sample.squeeze(0))

In [21]:
generate_samples(f'{output_dir}/gen')