# A Pytorch Implementation of SinGAN

In [None]:
import os
from typing import Optional, Tuple, List
import torch
from torch import nn
import torch.nn.functional as F
import math
from tqdm import tqdm
import matplotlib.pyplot as plt

## Basic Building Block

In [2]:
class ConvBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stride: int,
                 padding: int) -> None:
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2, True)
    
    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        return self.relu(self.norm(self.conv(x)))

## Discriminator

In [11]:
class Discriminator(nn.Module):
    def __init__(self,
                 in_channels: int = 3,
                 out_channels: int = 32,
                 kernel_size: int = 3,
                 padding: int = 1,
                 num_layers: int = 3) -> None:
        super(Discriminator, self).__init__()
        self.head = ConvBlock(in_channels, out_channels, kernel_size, 1, padding)
        self.body = nn.Sequential(
            *[ConvBlock(out_channels, out_channels, kernel_size, 1, padding)
            for _ in range(num_layers)]
        )
        self.tail = ConvBlock(out_channels, 1, kernel_size, 1, padding)
        # self.pad = nn.ZeroPad2d(5)
    
    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        # x = self.pad(x)
        x = self.head(x)
        x = self.body(x)
        return self.tail(x)

## Generator

In [4]:
class Generator(nn.Module):
    def __init__(self,
                 in_channels: int = 3,
                 out_channels: int = 32,
                 kernel_size: int = 3,
                 padding: int = 1,
                 num_layers: int = 3) -> None:
        super(Generator, self).__init__()
        self.head = ConvBlock(in_channels, out_channels, kernel_size, 1, padding)
        self.body = nn.Sequential(
            *[ConvBlock(out_channels, out_channels, kernel_size, 1, padding)
            for _ in range(num_layers)]
        )
        self.tail = nn.Sequential(
            ConvBlock(out_channels, in_channels, kernel_size, 1, padding),
            nn.Tanh()
        )
        # self.pad = nn.ZeroPad2d(5)
    
    def forward(self,
                noise: torch.Tensor,
                x: torch.Tensor) -> torch.Tensor:
        # x = self.pad(x)
        return x + self.tail(self.body(self.head(x + noise)))

## Read & Write Image

In [42]:
def read_image(path: str) -> torch.Tensor:
    image = plt.imread(path)
    image = torch.from_numpy(image.copy()).float() / 255.0
    image = image.permute(2, 0, 1)
    return (image - 0.5) * 2

def write_image(path: str,
                image: torch.Tensor) -> None:
    image = (image + 1) / 2
    image = image.clamp(0, 1)
    plt.imsave(path, image.permute(1, 2, 0).cpu().numpy())

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = '../data/angkorwat.jpg'
image = read_image(path).to(device)
max_size, min_size = 250, 25

## Create Image Pyramid

In [None]:
scale_factor = 3 / 4
number_of_scales = int(math.log(min_size / max_size, scale_factor)) # 8
number_of_scales

8

In [8]:
def adjust_scales2image(image: torch.Tensor) -> torch.Tensor:
    # resize the long side of the image to max_size
    scale = min(max_size / max([image.shape[2], image.shape[3]]), 1)
    real = F.interpolate(image, scale_factor=scale, mode='bilinear', align_corners=True)
    return real
real = adjust_scales2image(image.unsqueeze(0))
print(real.shape)

torch.Size([1, 3, 166, 250])


In [9]:
def create_reals_pyramid(real: torch.Tensor,
                         scale_factor: float,
                         stop_scale: int = 7):
    reals = []
    for i in range(stop_scale):
        scale = math.pow(scale_factor, ((stop_scale - 1) / math.log(stop_scale)) * math.log(stop_scale - i) + 1)
        curr_real = F.interpolate(real, scale_factor=scale, mode='bilinear', align_corners=True)
        reals.append(curr_real)
    reals.append(real)
    return reals

reals = create_reals_pyramid(real, scale_factor, number_of_scales)
for r in reals:
    print(r.shape)

torch.Size([1, 3, 16, 25])
torch.Size([1, 3, 18, 28])
torch.Size([1, 3, 21, 33])
torch.Size([1, 3, 26, 39])
torch.Size([1, 3, 32, 48])
torch.Size([1, 3, 42, 64])
torch.Size([1, 3, 63, 95])
torch.Size([1, 3, 124, 187])
torch.Size([1, 3, 166, 250])


## Train

In [43]:
n_channels = image.size(0)
n_features = 32
generators = [Generator(in_channels=n_channels, out_channels=n_features * pow(2, i // 4)).to(device) for i in range(number_of_scales + 1)]
discriminators = [Discriminator(in_channels=n_channels, out_channels=n_features * pow(2, i // 4)).to(device) for i in range(number_of_scales + 1)]

lr = 5e-4
beta1 = 0.5
gamma = 0
max_epochs = 2000
d_iter = g_iter = 3
alpha = 10.0
lamb = 0.1

reals_shapes = [r.shape[2:] for r in reals]
fixed_noise = []
noise_amp = []

In [16]:
output_dir = 'output'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
for i in range(number_of_scales + 1):
    p = f'{output_dir}/stage_{i:02d}'
    if not os.path.exists(p):
        os.mkdir(p)

In [None]:
def generate(generators: List[Generator],
             noise: torch.Tensor,
             reals_shapes: List[torch.Size],
             noise_amp: List[float]) -> torch.Tensor:
    x = generators[0](noise[0], 0)
    for i in range(1, len(noise)):
        x = F.interpolate(x, size=reals_shapes[i], mode='bilinear', align_corners=True)
        x = generators[i](noise[i] * noise_amp[i], x)
    return x

In [15]:
def calc_gradient_penalty(discriminator: Discriminator,
                          real_data: torch.Tensor,
                          fake_data: torch.Tensor,
                          lamb: float = 0.1,
                          device: str = 'cpu') -> torch.Tensor:
    alpha = torch.rand(1).item()
    interpolates = alpha * real_data + (1 - alpha) * fake_data
    interpolates.requires_grad_(True)
    disc_interpolates = discriminator(interpolates)
    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                    grad_outputs=torch.ones_like(disc_interpolates, device=device),
                                    create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lamb
    return gradient_penalty

In [44]:
for stage, (generator, discriminator, real) in enumerate(zip(generators, discriminators, reals)):
    print(f'Stage: {stage}')
    if stage % 4:
        generator.load_state_dict(torch.load(f'{output_dir}/stage_{stage - 1:02d}/generator.pth'))
        discriminator.load_state_dict(torch.load(f'{output_dir}/stage_{stage - 1:02d}/discriminator.pth'))
    
    z_opt = torch.randn(1, n_channels, *reals_shapes[stage], device=device) if stage == 0 else torch.zeros_like(real, device=device)
    fixed_noise.append(z_opt)
    
    optim_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
    optim_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
    scheduler_g = torch.optim.lr_scheduler.MultiStepLR(optim_g, milestones=[1600], gamma=gamma)
    scheduler_d = torch.optim.lr_scheduler.MultiStepLR(optim_d, milestones=[1600], gamma=gamma)
    
    # calculate noise_amp
    if stage == 0:
        noise_amp.append(1)
    else:
        noise_amp.append(0)
        with torch.no_grad():
            z_resconstruction = generate(generators, fixed_noise, reals_shapes, noise_amp)
        rec_loss = F.mse_loss(z_resconstruction, real)
        noise_amp[-1] = 0.1 * torch.sqrt(rec_loss)
    
    for epoch in tqdm(range(1, max_epochs + 1)):
        noise = [torch.randn(*z.shape, device=device) for z in fixed_noise]
        # update Discriminator: minimize D(G(z)) - D(x)
        for j in range(d_iter):
            # train with real
            output = discriminator(real)
            errD_real = -output.mean()
            
            # train with fake
            if j == d_iter - 1:
                fake = generate(generators, noise, reals_shapes, noise_amp)
            else:
                with torch.no_grad():
                    fake = generate(generators, noise, reals_shapes, noise_amp)
            
            output = discriminator(fake.detach())
            errD_fake = output.mean()

            gradient_penalty = calc_gradient_penalty(discriminator, real, fake.detach(), lamb, device)
            errD_total = errD_real + errD_fake + gradient_penalty
            optim_d.zero_grad()
            errD_total.backward()
            optim_d.step()

        # update Generator: minimize -D(G(z))
        output = discriminator(fake)
        errG = -output.mean()
        
        rec = generate(generators, fixed_noise, reals_shapes, noise_amp)
        rec_loss = alpha * F.mse_loss(rec, real)
        
        errG_total = errG + rec_loss
        optim_g.zero_grad()
        errG_total.backward()
        for _ in range(g_iter):
            optim_g.step()
    
        scheduler_d.step()
        scheduler_g.step()
        
        if epoch % 250 == 0:
            print(f'Epoch: {epoch: 04d}, D_real Loss: {errD_real.item():.5f}, D_fake Loss: {errD_fake.item():.5f}, '
                  f'D_gradient_penalty: {gradient_penalty.item():.5f}, G_fake Loss: {errG.item():.5f}, G_rec Loss: {rec_loss.item():.5f}')
        if epoch % 500 == 0:
            write_image(f'{output_dir}/stage_{stage:02d}/fake_sample_{epoch:04d}.jpg', fake.detach().squeeze(0))
            write_image(f'{output_dir}/stage_{stage:02d}/reconstruction_{epoch: 04d}.jpg', rec.detach().squeeze(0))
        
    torch.save(generator.state_dict(), f'{output_dir}/stage_{stage:02d}/generator.pth') 
    torch.save(discriminator.state_dict(), f'{output_dir}/stage_{stage:02d}/discriminator.pth') 

Stage: 1


 13%|█▎        | 251/2000 [00:18<02:34, 11.33it/s]

Epoch:  250, D_real Loss: -0.35454, D_fake Loss: -0.00789, D_gradient_penalty: 0.03603, G_fake Loss: 0.01947, G_rec Loss: 0.75577


 25%|██▌       | 501/2000 [00:36<01:42, 14.69it/s]

Epoch:  500, D_real Loss: -0.46850, D_fake Loss: -0.05045, D_gradient_penalty: 0.02514, G_fake Loss: 0.05241, G_rec Loss: 0.41613


 38%|███▊      | 751/2000 [00:53<01:27, 14.27it/s]

Epoch:  750, D_real Loss: -0.58495, D_fake Loss: -0.08097, D_gradient_penalty: 0.02522, G_fake Loss: 0.08205, G_rec Loss: 0.28100


 50%|█████     | 1001/2000 [01:11<01:08, 14.61it/s]

Epoch:  1000, D_real Loss: -0.67685, D_fake Loss: -0.06833, D_gradient_penalty: 0.04871, G_fake Loss: 0.09109, G_rec Loss: 0.18328


 63%|██████▎   | 1251/2000 [01:28<00:53, 13.94it/s]

Epoch:  1250, D_real Loss: -0.74573, D_fake Loss: -0.09865, D_gradient_penalty: 0.06340, G_fake Loss: 0.09765, G_rec Loss: 0.11170


 75%|███████▌  | 1501/2000 [01:46<00:36, 13.71it/s]

Epoch:  1500, D_real Loss: -0.83251, D_fake Loss: -0.10217, D_gradient_penalty: 0.13576, G_fake Loss: 0.10436, G_rec Loss: 0.06854


 88%|████████▊ | 1753/2000 [02:04<00:17, 14.30it/s]

Epoch:  1750, D_real Loss: -0.80384, D_fake Loss: 0.05765, D_gradient_penalty: 0.13207, G_fake Loss: -0.05765, G_rec Loss: 0.05366


Epoch:  2000, D_real Loss: -0.80384, D_fake Loss: 0.03466, D_gradient_penalty: 0.03333, G_fake Loss: -0.03466, G_rec Loss: 0.05366
Stage: 2


 13%|█▎        | 251/2000 [00:20<02:20, 12.42it/s]

Epoch:  250, D_real Loss: -0.74305, D_fake Loss: 0.02438, D_gradient_penalty: 0.09634, G_fake Loss: 0.00428, G_rec Loss: 0.24348


 25%|██▌       | 501/2000 [00:41<02:01, 12.35it/s]

Epoch:  500, D_real Loss: -0.97770, D_fake Loss: 0.05356, D_gradient_penalty: 0.28576, G_fake Loss: -0.07769, G_rec Loss: 0.15981


 38%|███▊      | 752/2000 [01:05<01:48, 11.51it/s]

Epoch:  750, D_real Loss: -1.15930, D_fake Loss: 0.45711, D_gradient_penalty: 0.23302, G_fake Loss: -0.49199, G_rec Loss: 0.09253


 50%|█████     | 1002/2000 [01:26<01:24, 11.85it/s]

Epoch:  1000, D_real Loss: -1.34448, D_fake Loss: 0.45296, D_gradient_penalty: 0.25581, G_fake Loss: -0.43211, G_rec Loss: 0.05826


 63%|██████▎   | 1252/2000 [01:50<01:08, 10.92it/s]

Epoch:  1250, D_real Loss: -1.49675, D_fake Loss: 0.49318, D_gradient_penalty: 0.21609, G_fake Loss: -0.43017, G_rec Loss: 0.05259


 75%|███████▌  | 1502/2000 [02:10<00:43, 11.54it/s]

Epoch:  1500, D_real Loss: -1.35601, D_fake Loss: 0.41611, D_gradient_penalty: 0.35138, G_fake Loss: -0.44818, G_rec Loss: 0.02535


 88%|████████▊ | 1752/2000 [02:33<00:22, 10.85it/s]

Epoch:  1750, D_real Loss: -1.32077, D_fake Loss: 0.41017, D_gradient_penalty: 0.03765, G_fake Loss: -0.41017, G_rec Loss: 0.02652


100%|██████████| 2000/2000 [02:54<00:00, 11.49it/s]


Epoch:  2000, D_real Loss: -1.32077, D_fake Loss: 0.38950, D_gradient_penalty: 0.28682, G_fake Loss: -0.38950, G_rec Loss: 0.02652
Stage: 3


 13%|█▎        | 252/2000 [00:23<02:42, 10.77it/s]

Epoch:  250, D_real Loss: -1.61934, D_fake Loss: 0.91130, D_gradient_penalty: 0.19254, G_fake Loss: -0.85243, G_rec Loss: 0.06991


 25%|██▌       | 502/2000 [00:46<02:29, 10.03it/s]

Epoch:  500, D_real Loss: -1.74720, D_fake Loss: 0.74702, D_gradient_penalty: 0.41534, G_fake Loss: -0.93724, G_rec Loss: 0.04864


 38%|███▊      | 752/2000 [01:09<01:55, 10.83it/s]

Epoch:  750, D_real Loss: -1.82018, D_fake Loss: 0.79794, D_gradient_penalty: 0.46054, G_fake Loss: -0.93563, G_rec Loss: 0.03873


 50%|█████     | 1001/2000 [01:34<01:48,  9.18it/s]

Epoch:  1000, D_real Loss: -1.96369, D_fake Loss: 1.04347, D_gradient_penalty: 0.43946, G_fake Loss: -0.99450, G_rec Loss: 0.03182


 63%|██████▎   | 1252/2000 [01:58<01:10, 10.61it/s]

Epoch:  1250, D_real Loss: -1.88122, D_fake Loss: 1.02467, D_gradient_penalty: 0.18763, G_fake Loss: -0.99340, G_rec Loss: 0.03202


 75%|███████▌  | 1501/2000 [02:21<00:44, 11.33it/s]

Epoch:  1500, D_real Loss: -2.19472, D_fake Loss: 1.13258, D_gradient_penalty: 0.57804, G_fake Loss: -1.06039, G_rec Loss: 0.02985


 88%|████████▊ | 1750/2000 [02:43<00:23, 10.85it/s]

Epoch:  1750, D_real Loss: -1.92807, D_fake Loss: 1.02481, D_gradient_penalty: 0.13701, G_fake Loss: -1.02481, G_rec Loss: 0.02614


100%|██████████| 2000/2000 [03:06<00:00, 10.72it/s]


Epoch:  2000, D_real Loss: -1.92807, D_fake Loss: 1.03414, D_gradient_penalty: 0.10341, G_fake Loss: -1.03414, G_rec Loss: 0.02614
Stage: 4


 13%|█▎        | 251/2000 [00:26<02:58,  9.77it/s]

Epoch:  250, D_real Loss: -1.80329, D_fake Loss: 1.14909, D_gradient_penalty: 0.13410, G_fake Loss: -1.12343, G_rec Loss: 0.06016


 25%|██▌       | 501/2000 [00:50<02:24, 10.39it/s]

Epoch:  500, D_real Loss: -1.92572, D_fake Loss: 1.38973, D_gradient_penalty: 0.11890, G_fake Loss: -1.35065, G_rec Loss: 0.03052


 38%|███▊      | 750/2000 [01:14<01:54, 10.92it/s]

Epoch:  750, D_real Loss: -2.07318, D_fake Loss: 1.25755, D_gradient_penalty: 0.18300, G_fake Loss: -1.26813, G_rec Loss: 0.03016


 50%|█████     | 1002/2000 [01:38<01:34, 10.52it/s]

Epoch:  1000, D_real Loss: -2.29094, D_fake Loss: 1.62480, D_gradient_penalty: 0.47350, G_fake Loss: -1.67360, G_rec Loss: 0.02721


 63%|██████▎   | 1251/2000 [02:02<01:09, 10.77it/s]

Epoch:  1250, D_real Loss: -2.20676, D_fake Loss: 1.47601, D_gradient_penalty: 0.25413, G_fake Loss: -1.44509, G_rec Loss: 0.03298


 75%|███████▌  | 1501/2000 [02:26<00:46, 10.72it/s]

Epoch:  1500, D_real Loss: -2.32524, D_fake Loss: 1.55970, D_gradient_penalty: 0.23557, G_fake Loss: -1.51258, G_rec Loss: 0.02301


 88%|████████▊ | 1751/2000 [02:50<00:23, 10.49it/s]

Epoch:  1750, D_real Loss: -2.34039, D_fake Loss: 1.69292, D_gradient_penalty: 0.21651, G_fake Loss: -1.69292, G_rec Loss: 0.01903


100%|██████████| 2000/2000 [03:14<00:00, 10.31it/s]


Epoch:  2000, D_real Loss: -2.34039, D_fake Loss: 1.69789, D_gradient_penalty: 0.21445, G_fake Loss: -1.69789, G_rec Loss: 0.01903
Stage: 5


 13%|█▎        | 251/2000 [00:29<03:38,  8.00it/s]

Epoch:  250, D_real Loss: -0.37958, D_fake Loss: 0.11301, D_gradient_penalty: 0.03307, G_fake Loss: -0.05891, G_rec Loss: 0.10790


 25%|██▌       | 501/2000 [00:57<02:46,  8.98it/s]

Epoch:  500, D_real Loss: -0.35355, D_fake Loss: 0.02508, D_gradient_penalty: 0.10204, G_fake Loss: -0.04857, G_rec Loss: 0.07885


 38%|███▊      | 750/2000 [01:25<02:21,  8.85it/s]

Epoch:  750, D_real Loss: -0.46630, D_fake Loss: 0.01146, D_gradient_penalty: 0.02680, G_fake Loss: -0.01346, G_rec Loss: 0.07359


 50%|█████     | 1001/2000 [01:53<01:47,  9.28it/s]

Epoch:  1000, D_real Loss: -0.57009, D_fake Loss: 0.01436, D_gradient_penalty: 0.51541, G_fake Loss: -0.36278, G_rec Loss: 0.05833


 63%|██████▎   | 1251/2000 [02:22<01:24,  8.89it/s]

Epoch:  1250, D_real Loss: -0.58732, D_fake Loss: 0.08173, D_gradient_penalty: 0.04215, G_fake Loss: -0.03305, G_rec Loss: 0.02891


 75%|███████▌  | 1500/2000 [02:50<00:56,  8.78it/s]

Epoch:  1500, D_real Loss: -0.60001, D_fake Loss: 0.22703, D_gradient_penalty: 0.14011, G_fake Loss: -0.20054, G_rec Loss: 0.02973


 88%|████████▊ | 1751/2000 [03:18<00:29,  8.44it/s]

Epoch:  1750, D_real Loss: -0.62135, D_fake Loss: 0.15581, D_gradient_penalty: 0.25599, G_fake Loss: -0.15581, G_rec Loss: 0.02733


100%|██████████| 2000/2000 [03:46<00:00,  8.81it/s]


Epoch:  2000, D_real Loss: -0.62135, D_fake Loss: 0.09652, D_gradient_penalty: 0.05036, G_fake Loss: -0.09652, G_rec Loss: 0.02733
Stage: 6


 13%|█▎        | 251/2000 [00:32<03:40,  7.92it/s]

Epoch:  250, D_real Loss: -0.66984, D_fake Loss: 0.19255, D_gradient_penalty: 0.21450, G_fake Loss: -0.09560, G_rec Loss: 0.03351


 25%|██▌       | 501/2000 [01:03<03:08,  7.94it/s]

Epoch:  500, D_real Loss: -0.67049, D_fake Loss: 0.34283, D_gradient_penalty: 0.13448, G_fake Loss: -0.29948, G_rec Loss: 0.03146


 38%|███▊      | 751/2000 [01:36<02:49,  7.35it/s]

Epoch:  750, D_real Loss: -0.64064, D_fake Loss: 0.15110, D_gradient_penalty: 0.28162, G_fake Loss: -0.16520, G_rec Loss: 0.02516


 50%|█████     | 1001/2000 [02:08<02:08,  7.77it/s]

Epoch:  1000, D_real Loss: -0.59545, D_fake Loss: 0.14166, D_gradient_penalty: 0.10594, G_fake Loss: -0.12615, G_rec Loss: 0.02223


 63%|██████▎   | 1251/2000 [02:40<01:33,  8.02it/s]

Epoch:  1250, D_real Loss: -0.70046, D_fake Loss: 0.17158, D_gradient_penalty: 0.09426, G_fake Loss: -0.13598, G_rec Loss: 0.02452


 75%|███████▌  | 1501/2000 [03:14<01:15,  6.59it/s]

Epoch:  1500, D_real Loss: -0.68671, D_fake Loss: 0.19141, D_gradient_penalty: 0.12063, G_fake Loss: -0.19356, G_rec Loss: 0.02602


 88%|████████▊ | 1751/2000 [03:51<00:33,  7.39it/s]

Epoch:  1750, D_real Loss: -0.65566, D_fake Loss: 0.14357, D_gradient_penalty: 0.28112, G_fake Loss: -0.14357, G_rec Loss: 0.01977


100%|██████████| 2000/2000 [04:27<00:00,  7.46it/s]


Epoch:  2000, D_real Loss: -0.65566, D_fake Loss: 0.25836, D_gradient_penalty: 0.17660, G_fake Loss: -0.25836, G_rec Loss: 0.01977
Stage: 7


 13%|█▎        | 251/2000 [00:44<05:53,  4.95it/s]

Epoch:  250, D_real Loss: -0.57367, D_fake Loss: 0.32762, D_gradient_penalty: 0.11325, G_fake Loss: -0.31711, G_rec Loss: 0.04325


 25%|██▌       | 501/2000 [01:35<05:06,  4.89it/s]

Epoch:  500, D_real Loss: -0.54818, D_fake Loss: 0.25007, D_gradient_penalty: 0.08716, G_fake Loss: -0.23657, G_rec Loss: 0.04133


 38%|███▊      | 751/2000 [02:25<04:07,  5.05it/s]

Epoch:  750, D_real Loss: -0.57083, D_fake Loss: 0.31313, D_gradient_penalty: 0.08302, G_fake Loss: -0.32294, G_rec Loss: 0.02531


 50%|█████     | 1001/2000 [03:15<03:20,  4.97it/s]

Epoch:  1000, D_real Loss: -0.56506, D_fake Loss: 0.26697, D_gradient_penalty: 0.05219, G_fake Loss: -0.28587, G_rec Loss: 0.02490


 63%|██████▎   | 1251/2000 [04:05<02:28,  5.06it/s]

Epoch:  1250, D_real Loss: -0.57869, D_fake Loss: 0.19598, D_gradient_penalty: 0.05473, G_fake Loss: -0.21939, G_rec Loss: 0.02177


 75%|███████▌  | 1501/2000 [04:56<01:37,  5.14it/s]

Epoch:  1500, D_real Loss: -0.68354, D_fake Loss: 0.28075, D_gradient_penalty: 0.40066, G_fake Loss: -0.29244, G_rec Loss: 0.02149


 88%|████████▊ | 1751/2000 [05:46<00:49,  5.05it/s]

Epoch:  1750, D_real Loss: -0.60589, D_fake Loss: 0.28790, D_gradient_penalty: 0.11084, G_fake Loss: -0.28790, G_rec Loss: 0.01874


100%|██████████| 2000/2000 [06:36<00:00,  5.04it/s]


Epoch:  2000, D_real Loss: -0.60589, D_fake Loss: 0.32613, D_gradient_penalty: 0.12847, G_fake Loss: -0.32613, G_rec Loss: 0.01874
Stage: 8


 12%|█▎        | 250/2000 [02:33<18:56,  1.54it/s]

Epoch:  250, D_real Loss: -0.67396, D_fake Loss: 0.58432, D_gradient_penalty: 0.02934, G_fake Loss: -0.55786, G_rec Loss: 0.07427


 25%|██▌       | 500/2000 [05:13<16:07,  1.55it/s]

Epoch:  500, D_real Loss: -0.70938, D_fake Loss: 0.64394, D_gradient_penalty: 0.02676, G_fake Loss: -0.54856, G_rec Loss: 0.06233


 38%|███▊      | 750/2000 [07:52<13:27,  1.55it/s]

Epoch:  750, D_real Loss: -0.65437, D_fake Loss: 0.46471, D_gradient_penalty: 0.04832, G_fake Loss: -0.43551, G_rec Loss: 0.04626


 50%|█████     | 1000/2000 [09:44<07:27,  2.23it/s]

Epoch:  1000, D_real Loss: -0.48524, D_fake Loss: 0.35449, D_gradient_penalty: 0.03561, G_fake Loss: -0.39312, G_rec Loss: 0.04899


 62%|██████▎   | 1250/2000 [11:36<05:36,  2.23it/s]

Epoch:  1250, D_real Loss: -0.44606, D_fake Loss: 0.27870, D_gradient_penalty: 0.04456, G_fake Loss: -0.29611, G_rec Loss: 0.03450


 75%|███████▌  | 1500/2000 [13:29<03:50,  2.17it/s]

Epoch:  1500, D_real Loss: -0.43263, D_fake Loss: 0.28255, D_gradient_penalty: 0.02926, G_fake Loss: -0.25685, G_rec Loss: 0.03686


 88%|████████▊ | 1750/2000 [15:22<01:53,  2.21it/s]

Epoch:  1750, D_real Loss: -0.41622, D_fake Loss: 0.23536, D_gradient_penalty: 0.08515, G_fake Loss: -0.23536, G_rec Loss: 0.03992


100%|██████████| 2000/2000 [17:14<00:00,  1.93it/s]


Epoch:  2000, D_real Loss: -0.41622, D_fake Loss: 0.21590, D_gradient_penalty: 0.08326, G_fake Loss: -0.21590, G_rec Loss: 0.03992
Stage: 9


 12%|█▎        | 250/2000 [02:39<18:57,  1.54it/s]

Epoch:  250, D_real Loss: -0.27900, D_fake Loss: 0.15933, D_gradient_penalty: 0.02075, G_fake Loss: -0.15119, G_rec Loss: 0.53077


 25%|██▌       | 500/2000 [05:19<16:19,  1.53it/s]

Epoch:  500, D_real Loss: -0.19396, D_fake Loss: 0.12852, D_gradient_penalty: 0.05486, G_fake Loss: -0.10647, G_rec Loss: 0.23686


 38%|███▊      | 750/2000 [07:59<13:30,  1.54it/s]

Epoch:  750, D_real Loss: -0.12511, D_fake Loss: 0.06457, D_gradient_penalty: 0.02396, G_fake Loss: -0.06784, G_rec Loss: 0.16524


 50%|█████     | 1000/2000 [11:10<15:33,  1.07it/s]

Epoch:  1000, D_real Loss: -0.20677, D_fake Loss: 0.18214, D_gradient_penalty: 0.03651, G_fake Loss: -0.06822, G_rec Loss: 0.11628


 62%|██████▎   | 1250/2000 [14:57<11:38,  1.07it/s]

Epoch:  1250, D_real Loss: -0.13226, D_fake Loss: 0.07994, D_gradient_penalty: 0.02349, G_fake Loss: -0.03083, G_rec Loss: 0.09855


 75%|███████▌  | 1500/2000 [18:43<07:41,  1.08it/s]

Epoch:  1500, D_real Loss: -0.06785, D_fake Loss: 0.02166, D_gradient_penalty: 0.03237, G_fake Loss: -0.03312, G_rec Loss: 0.09761


 88%|████████▊ | 1750/2000 [22:29<03:50,  1.08it/s]

Epoch:  1750, D_real Loss: -0.07582, D_fake Loss: 0.01724, D_gradient_penalty: 0.01702, G_fake Loss: -0.01724, G_rec Loss: 0.09935


100%|██████████| 2000/2000 [25:38<00:00,  1.30it/s]

Epoch:  2000, D_real Loss: -0.07582, D_fake Loss: 0.02769, D_gradient_penalty: 0.02390, G_fake Loss: -0.02769, G_rec Loss: 0.09935





In [45]:
@torch.no_grad
def generate_samples(path: str,
                     n: int = 25) -> None:
    if not os.path.exists(path):
        os.mkdir(path)
    for idx in range(n):
        noise = [torch.randn(*z.shape, device=device) for z in fixed_noise]
        sample = generate(generators, noise, reals_shapes, noise_amp)
        write_image(f'{path}/gen_sample_{idx + 1:02d}.jpg', sample.squeeze(0))

In [46]:
generate_samples(f'{output_dir}/gen')

## Reference
1. [Official implementation](https://github.com/tamarott/SinGAN)
2. https://github.com/tohinz/ConSinGAN