In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from typing import Tuple, List, Optional
import math
import copy
import tqdm
import matplotlib.pyplot as plt

In [28]:
class ConvBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 padding: int,
                 batch_norm: bool = False) -> None:
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, 1, padding)
        self.norm = nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity()
        self.relu = nn.LeakyReLU(0.05, inplace=True)

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        return self.relu(self.norm(self.conv(x)))

In [51]:
class Discriminator(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 padding: int,
                 num_layer: int) -> None:
        super(Discriminator, self).__init__()
        self.head = ConvBlock(in_channels, out_channels, kernel_size, padding)
        self.body = nn.Sequential(
            *[ConvBlock(out_channels, out_channels, kernel_size, padding)
            for _ in range(num_layer)]
        )
        self.tail = nn.Conv2d(out_channels, 1, kernel_size, padding)

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        x = self.head(x)
        x = self.body(x)
        return self.tail(x)

In [None]:
class Generator(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 padding: int,
                 num_layers: int,
                 batch_norm: bool = False) -> None:
        super(Generator, self).__init__()
        self.head = ConvBlock(in_channels, out_channels, kernel_size, padding, batch_norm)
        self.body = nn.ModuleList()
        self.body.append(nn.Sequential(
            *[ConvBlock(out_channels, out_channels, kernel_size, padding, batch_norm)
            for _ in range(num_layers)]
        ))
        self.tail = nn.Sequential(
            nn.Conv2d(out_channels, in_channels, kernel_size, padding=padding),
            nn.Tanh()
        )
    
    def init_next_stage(self):
        self.body.append(copy.deepcopy(self.body[-1]))
    
    def forward(self,
                noise: List[torch.Tensor],
                real_shapes: List[torch.Size],
                noise_amp: List[float]) -> torch.Tensor:
        x = self.head(noise[0])
        x = self.body[0](x)
        for i in range(1, len(self.body)):
            x = F.interpolate(x, size=real_shapes[i], mode='bilinear', align_corners=True)
            x = x + self.body[i](x + noise[i] * noise_amp[i])
        return self.tail(x)

In [20]:
def read_image(path: str) -> torch.Tensor:
    image = plt.imread(path)
    image = torch.from_numpy(image).float() / 255.0
    image = image.permute(2, 0, 1)
    return (image - 0.5) * 2

def write_image(path: str,
                image: torch.Tensor) -> None:
    image = (image + 1) / 2
    plt.imsave(path, image.permute(1, 2, 0).cpu().numpy(), vmin=0, vmax=1)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = '../data/angkorwat.jpg'
image = read_image(path).to(device)
max_size, min_size = 250, 25
stages = 8 # stop_scale = stages - 1 = 7

In [33]:
def adjust_scales2image(image: torch.Tensor,
                        stop_scale: int = 7) -> Tuple[torch.Tensor, float]:
    # resize the long side of the image to max_size
    scale = min(max_size / max([image.shape[2], image.shape[3]]), 1)
    real = F.interpolate(image, scale_factor=scale, mode='bilinear', align_corners=True)
    # calculate the scaling factor for resizing the short side to min_size
    scale_factor = math.pow(min_size / (min(real.shape[2], real.shape[3])), 1 / stop_scale)
    return real, scale_factor
real, scale_factor = adjust_scales2image(image.unsqueeze(0))
real.shape, scale_factor

(torch.Size([1, 3, 166, 250]), 0.763040197031911)

In [34]:
def create_reals_pyramid(real, scale_factor: float, stop_scale: int = 7):
    reals = []
    for i in range(stop_scale):
        scale = math.pow(scale_factor,((stop_scale-1)/math.log(stop_scale))*math.log(stop_scale-i)+1)
        curr_real = F.interpolate(real, scale_factor=scale, mode='bilinear', align_corners=True)
        reals.append(curr_real)
    reals.append(real)
    return reals

reals = create_reals_pyramid(real, scale_factor)
for r in reals:
    print(r.shape)

torch.Size([1, 3, 25, 37])
torch.Size([1, 3, 28, 42])
torch.Size([1, 3, 33, 49])
torch.Size([1, 3, 39, 60])
torch.Size([1, 3, 50, 76])
torch.Size([1, 3, 71, 107])
torch.Size([1, 3, 126, 190])
torch.Size([1, 3, 166, 250])


In [36]:
torch.randn(1, 1), torch.randn(1), torch.randn(1).item()

(tensor([[0.6290]]), tensor([0.1069]), 0.14044876396656036)

In [63]:
def calc_gradient_penalty(discriminator: Discriminator,
                          real_data: torch.Tensor,
                          fake_data: torch.Tensor,
                          lamb: float = 0.1,
                          device: str = 'cpu') -> torch.Tensor:
    # alpha = torch.rand(1, 1)
    # alpha = alpha.expand(real_data.size())
    alpha = torch.rand(1).item()
    interpolates = alpha * real_data + (1 - alpha) * fake_data
    interpolates.requires_grad_(True)
    disc_interpolates = discriminator(interpolates)
    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                    grad_outputs=torch.ones_like(disc_interpolates, device=device),
                                    create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lamb
    return gradient_penalty

In [90]:
generator = Generator(3, 64, 3, 1, 3, False).to(device)
discriminator = Discriminator(3, 64, 3, 1, 3).to(device)
fixed_noise = []
noise_amp = []


d_iter = g_iter = 3
max_epochs = 2000
lr = 0.0005
lr_scale = 0.1
beta1 = 0.5
gamma = 0.1
lamb = 0.1
alpha = 10
reals_shapes = [real.shape[2:] for real in reals]
train_depth = 3

In [91]:
for stage in range(stages):
    if stage:
        # discriminator.load_state_dict(torch.load(f'output/discriminator/netD.pth'))
        generator.init_next_stage()
    
    real = reals[stage]
    z_opt = reals[0] if stage == 0 else torch.randn(1, 64, *reals_shapes[stage], device=device)
    fixed_noise.append(z_opt)
    
    optim_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

    for block in generator.body[:-train_depth]:
        for param in block.parameters():
            param.requires_grad = False
    
    parameter_list = [{"params": block.parameters(), "lr": lr * (lr_scale ** (len(generator.body[-train_depth:]) - 1 - idx))}
            for idx, block in enumerate(generator.body[-train_depth:])]

    # add parameters of head and tail to training
    if stage - train_depth < 0:
        parameter_list += [{"params": generator.head.parameters(), "lr": lr * (lr_scale ** stage)}]
    parameter_list += [{"params": generator.tail.parameters(), "lr": lr}]
    optim_g = torch.optim.Adam(parameter_list, lr=lr, betas=(beta1, 0.999))

    # define learning rate schedules
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optim_d, milestones=[0.8 * max_epochs], gamma=gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optim_g, milestones=[0.8 * max_epochs], gamma=gamma)

    # calculate noise_amp
    if stage == 0:
        noise_amp.append(1)
    else:
        noise_amp.append(0)
        with torch.no_grad():
            z_reconstruction = generator(fixed_noise, reals_shapes, noise_amp)
        rec_loss = F.mse_loss(z_reconstruction, real)
        noise_amp[-1] = 0.1 * torch.sqrt(rec_loss)

    for epoch in tqdm.tqdm(range(1, max_epochs + 1)):
        noise = [torch.randn(*z.shape, device=device) for z in fixed_noise]

        # update discriminator: minimize D(G(z)) - D(x)
        for j in range(d_iter):
            output = discriminator(real)
            errD_real = -output.mean()
            
            if j == d_iter - 1:
                fake = generator(noise, reals_shapes, noise_amp)
            else:
                with torch.no_grad():
                    fake = generator(noise, reals_shapes, noise_amp)
            
            output = discriminator(fake.detach())
            errD_fake = output.mean()
            
            gradient_penalty = calc_gradient_penalty(discriminator, real, fake.detach(), lamb, device)
            errD_total = errD_real + errD_fake + gradient_penalty
            optim_d.zero_grad()
            errD_total.backward()
            optim_d.step()
        
        # update generator: minimize -D(G(z))
        output = discriminator(fake)
        errG = -output.mean()
        
        rec = generator(fixed_noise, reals_shapes, noise_amp)
        rec_loss = alpha * F.mse_loss(rec, real)
        
        errG_total = errG + rec_loss
        optim_g.zero_grad()
        errG_total.backward()
        for _ in range(g_iter):
            optim_g.step()
        
        if epoch % 250 == 0:
            print(f'Epoch: {epoch: 04d}, D_real Loss: {-errD_real.item():.5f}, D_fake Loss: {errD_fake.item():.5f}, '
                  f'D_gradient_penalty: {gradient_penalty.item():.5f}, G_fake Loss: {errG.item():.5f}, G_rec Loss: {rec_loss.item():.5f}')
        if epoch % 500 == 0:
            write_image(f'output/fake_sample_{epoch:04d}.jpg', fake.detach().squeeze(0))
            write_image(f'output/reconstruction_{epoch: 04d}.jpg', rec.detach().squeeze(0))
        
        schedulerD.step()
        schedulerG.step()

 13%|█▎        | 255/2000 [00:09<01:02, 28.04it/s]

Epoch:  250, D_real Loss: 4.30068, D_fake Loss: 1.59975, D_gradient_penalty: 0.53338, G_fake Loss: -2.47047, G_rec Loss: 0.10092


 25%|██▌       | 503/2000 [00:18<00:53, 28.05it/s]

Epoch:  500, D_real Loss: 3.07716, D_fake Loss: 0.60868, D_gradient_penalty: 0.51457, G_fake Loss: -1.20070, G_rec Loss: 0.07324


 38%|███▊      | 756/2000 [00:27<00:43, 28.73it/s]

Epoch:  750, D_real Loss: 2.29107, D_fake Loss: -0.12926, D_gradient_penalty: 1.10258, G_fake Loss: 0.10690, G_rec Loss: 0.03019


 50%|█████     | 1004/2000 [00:36<00:34, 28.79it/s]

Epoch:  1000, D_real Loss: 2.24582, D_fake Loss: -0.16952, D_gradient_penalty: 0.96145, G_fake Loss: 0.24900, G_rec Loss: 0.02169


 63%|██████▎   | 1253/2000 [00:45<00:26, 27.73it/s]

Epoch:  1250, D_real Loss: 2.22249, D_fake Loss: -1.27685, D_gradient_penalty: 0.96704, G_fake Loss: 1.55424, G_rec Loss: 0.02218


 75%|███████▌  | 1505/2000 [00:54<00:17, 28.42it/s]

Epoch:  1500, D_real Loss: 1.82246, D_fake Loss: -1.62922, D_gradient_penalty: 1.85099, G_fake Loss: 1.96205, G_rec Loss: 0.01462


 88%|████████▊ | 1755/2000 [01:03<00:08, 28.23it/s]

Epoch:  1750, D_real Loss: 0.79192, D_fake Loss: -1.99299, D_gradient_penalty: 1.04691, G_fake Loss: 1.98249, G_rec Loss: 0.00689


100%|██████████| 2000/2000 [01:11<00:00, 27.90it/s]
  x = F.upsample(x, size=real_shapes[i], mode='bilinear', align_corners=True)


Epoch:  2000, D_real Loss: 0.82904, D_fake Loss: -2.38501, D_gradient_penalty: 1.69481, G_fake Loss: 2.44950, G_rec Loss: 0.00591


 13%|█▎        | 252/2000 [00:10<01:13, 23.94it/s]

Epoch:  250, D_real Loss: 0.14297, D_fake Loss: -1.00803, D_gradient_penalty: 0.30931, G_fake Loss: 0.87406, G_rec Loss: 0.01556


 25%|██▌       | 504/2000 [00:21<01:01, 24.33it/s]

Epoch:  500, D_real Loss: 0.68886, D_fake Loss: -0.63819, D_gradient_penalty: 0.31413, G_fake Loss: 0.74943, G_rec Loss: 0.00819


 38%|███▊      | 754/2000 [00:31<00:50, 24.50it/s]

Epoch:  750, D_real Loss: 1.16750, D_fake Loss: -0.29832, D_gradient_penalty: 0.36517, G_fake Loss: 0.26229, G_rec Loss: 0.00906


 50%|█████     | 1003/2000 [00:42<00:41, 24.16it/s]

Epoch:  1000, D_real Loss: 1.16454, D_fake Loss: -0.25319, D_gradient_penalty: 0.82521, G_fake Loss: 0.34075, G_rec Loss: 0.00822


 63%|██████▎   | 1255/2000 [00:52<00:29, 24.87it/s]

Epoch:  1250, D_real Loss: 0.28956, D_fake Loss: -1.07511, D_gradient_penalty: 0.45249, G_fake Loss: 1.44491, G_rec Loss: 0.00506


 75%|███████▌  | 1504/2000 [01:02<00:21, 23.25it/s]

Epoch:  1500, D_real Loss: 0.06413, D_fake Loss: -1.18476, D_gradient_penalty: 0.69806, G_fake Loss: 1.28617, G_rec Loss: 0.00405


 88%|████████▊ | 1753/2000 [01:12<00:10, 23.74it/s]

Epoch:  1750, D_real Loss: -0.00922, D_fake Loss: -1.31877, D_gradient_penalty: 0.26815, G_fake Loss: 1.29585, G_rec Loss: 0.00155


100%|██████████| 2000/2000 [01:23<00:00, 23.94it/s]


Epoch:  2000, D_real Loss: -0.17788, D_fake Loss: -1.45792, D_gradient_penalty: 0.51880, G_fake Loss: 1.50247, G_rec Loss: 0.00138


 13%|█▎        | 253/2000 [00:12<01:17, 22.65it/s]

Epoch:  250, D_real Loss: -0.92732, D_fake Loss: -1.74971, D_gradient_penalty: 0.37569, G_fake Loss: 2.09932, G_rec Loss: 0.01524


 25%|██▌       | 502/2000 [00:23<01:07, 22.08it/s]

Epoch:  500, D_real Loss: -0.89243, D_fake Loss: -1.65159, D_gradient_penalty: 0.36462, G_fake Loss: 1.69437, G_rec Loss: 0.00791


 38%|███▊      | 752/2000 [00:34<00:55, 22.69it/s]

Epoch:  750, D_real Loss: -1.54297, D_fake Loss: -2.38167, D_gradient_penalty: 0.46503, G_fake Loss: 2.22033, G_rec Loss: 0.00690


 50%|█████     | 1004/2000 [00:46<00:45, 21.84it/s]

Epoch:  1000, D_real Loss: -0.92100, D_fake Loss: -1.70708, D_gradient_penalty: 0.14511, G_fake Loss: 1.54187, G_rec Loss: 0.01222


 63%|██████▎   | 1254/2000 [00:58<00:32, 23.15it/s]

Epoch:  1250, D_real Loss: -0.74416, D_fake Loss: -1.53110, D_gradient_penalty: 0.51536, G_fake Loss: 1.73433, G_rec Loss: 0.00543


 75%|███████▌  | 1503/2000 [01:09<00:21, 22.73it/s]

Epoch:  1500, D_real Loss: -1.06551, D_fake Loss: -1.90034, D_gradient_penalty: 0.28693, G_fake Loss: 2.15633, G_rec Loss: 0.00719


 88%|████████▊ | 1754/2000 [01:20<00:10, 23.47it/s]

Epoch:  1750, D_real Loss: -1.17838, D_fake Loss: -1.90666, D_gradient_penalty: 0.14707, G_fake Loss: 1.88710, G_rec Loss: 0.00157


100%|██████████| 2000/2000 [01:31<00:00, 21.97it/s]


Epoch:  2000, D_real Loss: -1.31049, D_fake Loss: -2.04711, D_gradient_penalty: 0.25747, G_fake Loss: 2.02764, G_rec Loss: 0.00147


 13%|█▎        | 254/2000 [00:11<01:19, 21.91it/s]

Epoch:  250, D_real Loss: -1.07494, D_fake Loss: -1.57599, D_gradient_penalty: 0.19962, G_fake Loss: 1.76766, G_rec Loss: 0.01013


 25%|██▌       | 503/2000 [00:23<01:08, 21.93it/s]

Epoch:  500, D_real Loss: -1.74285, D_fake Loss: -2.25621, D_gradient_penalty: 0.14770, G_fake Loss: 1.83469, G_rec Loss: 0.00941


 38%|███▊      | 754/2000 [00:34<00:55, 22.43it/s]

Epoch:  750, D_real Loss: -1.26367, D_fake Loss: -1.83781, D_gradient_penalty: 0.22544, G_fake Loss: 2.08677, G_rec Loss: 0.01341


 50%|█████     | 1003/2000 [00:46<00:47, 21.02it/s]

Epoch:  1000, D_real Loss: -1.78488, D_fake Loss: -2.24749, D_gradient_penalty: 0.12941, G_fake Loss: 2.02706, G_rec Loss: 0.00814


 63%|██████▎   | 1252/2000 [00:57<00:35, 20.85it/s]

Epoch:  1250, D_real Loss: -1.34330, D_fake Loss: -1.92231, D_gradient_penalty: 0.31016, G_fake Loss: 2.22601, G_rec Loss: 0.01329


 75%|███████▌  | 1501/2000 [01:09<00:24, 20.68it/s]

Epoch:  1500, D_real Loss: -1.21301, D_fake Loss: -1.86882, D_gradient_penalty: 0.16213, G_fake Loss: 1.88064, G_rec Loss: 0.00841


 88%|████████▊ | 1753/2000 [01:22<00:11, 20.73it/s]

Epoch:  1750, D_real Loss: -1.52167, D_fake Loss: -2.00709, D_gradient_penalty: 0.25926, G_fake Loss: 2.00587, G_rec Loss: 0.00191


100%|██████████| 2000/2000 [01:34<00:00, 21.27it/s]


Epoch:  2000, D_real Loss: -1.65972, D_fake Loss: -2.25191, D_gradient_penalty: 0.11451, G_fake Loss: 2.23133, G_rec Loss: 0.00177


 12%|█▎        | 250/2000 [00:36<04:57,  5.87it/s]

Epoch:  250, D_real Loss: -1.31871, D_fake Loss: -1.69631, D_gradient_penalty: 0.09456, G_fake Loss: 1.54201, G_rec Loss: 0.01617


 25%|██▌       | 502/2000 [01:12<03:27,  7.24it/s]

Epoch:  500, D_real Loss: -1.44710, D_fake Loss: -1.79144, D_gradient_penalty: 0.07967, G_fake Loss: 1.66353, G_rec Loss: 0.01037


 38%|███▊      | 752/2000 [01:48<02:51,  7.27it/s]

Epoch:  750, D_real Loss: -1.35011, D_fake Loss: -1.69682, D_gradient_penalty: 0.11735, G_fake Loss: 1.58980, G_rec Loss: 0.00872


 50%|█████     | 1000/2000 [02:24<02:48,  5.93it/s]

Epoch:  1000, D_real Loss: -1.50733, D_fake Loss: -1.80243, D_gradient_penalty: 0.14229, G_fake Loss: 2.06407, G_rec Loss: 0.00878


 62%|██████▎   | 1250/2000 [03:00<02:06,  5.91it/s]

Epoch:  1250, D_real Loss: -1.59128, D_fake Loss: -1.90309, D_gradient_penalty: 0.12444, G_fake Loss: 1.53958, G_rec Loss: 0.00843


 75%|███████▌  | 1500/2000 [03:36<01:26,  5.78it/s]

Epoch:  1500, D_real Loss: -1.25129, D_fake Loss: -1.59892, D_gradient_penalty: 0.11444, G_fake Loss: 1.61913, G_rec Loss: 0.00749


 88%|████████▊ | 1750/2000 [04:12<00:41,  5.98it/s]

Epoch:  1750, D_real Loss: -1.56559, D_fake Loss: -1.96301, D_gradient_penalty: 0.08582, G_fake Loss: 1.94526, G_rec Loss: 0.00236


100%|██████████| 2000/2000 [04:49<00:00,  6.92it/s]


Epoch:  2000, D_real Loss: -1.70752, D_fake Loss: -2.02349, D_gradient_penalty: 0.10173, G_fake Loss: 2.02005, G_rec Loss: 0.00222


 12%|█▎        | 250/2000 [00:41<05:39,  5.16it/s]

Epoch:  250, D_real Loss: -0.68473, D_fake Loss: -0.86191, D_gradient_penalty: 0.03913, G_fake Loss: 0.78401, G_rec Loss: 0.02434


 25%|██▌       | 500/2000 [01:22<04:57,  5.05it/s]

Epoch:  500, D_real Loss: -1.31948, D_fake Loss: -1.50996, D_gradient_penalty: 0.05763, G_fake Loss: 1.33810, G_rec Loss: 0.01222


 38%|███▊      | 752/2000 [01:55<02:09,  9.65it/s]

Epoch:  750, D_real Loss: -1.13813, D_fake Loss: -1.31609, D_gradient_penalty: 0.06968, G_fake Loss: 1.21373, G_rec Loss: 0.01476


 50%|█████     | 1002/2000 [02:22<01:46,  9.36it/s]

Epoch:  1000, D_real Loss: -1.27105, D_fake Loss: -1.48525, D_gradient_penalty: 0.04881, G_fake Loss: 1.44137, G_rec Loss: 0.00899


 63%|██████▎   | 1252/2000 [02:49<01:17,  9.64it/s]

Epoch:  1250, D_real Loss: -1.17490, D_fake Loss: -1.34729, D_gradient_penalty: 0.07490, G_fake Loss: 1.27939, G_rec Loss: 0.01294


 75%|███████▌  | 1502/2000 [03:17<00:51,  9.61it/s]

Epoch:  1500, D_real Loss: -1.20657, D_fake Loss: -1.41316, D_gradient_penalty: 0.06459, G_fake Loss: 1.40172, G_rec Loss: 0.00867


 88%|████████▊ | 1752/2000 [03:44<00:26,  9.41it/s]

Epoch:  1750, D_real Loss: -1.05113, D_fake Loss: -1.25689, D_gradient_penalty: 0.05925, G_fake Loss: 1.25493, G_rec Loss: 0.00314


100%|██████████| 2000/2000 [04:11<00:00,  7.95it/s]


Epoch:  2000, D_real Loss: -1.16550, D_fake Loss: -1.36689, D_gradient_penalty: 0.10838, G_fake Loss: 1.39329, G_rec Loss: 0.00305


 12%|█▎        | 250/2000 [01:36<12:15,  2.38it/s]

Epoch:  250, D_real Loss: -0.45376, D_fake Loss: -0.52796, D_gradient_penalty: 0.02185, G_fake Loss: 0.60786, G_rec Loss: 0.02508


 25%|██▌       | 500/2000 [03:11<10:32,  2.37it/s]

Epoch:  500, D_real Loss: -0.98195, D_fake Loss: -1.08206, D_gradient_penalty: 0.02786, G_fake Loss: 1.13394, G_rec Loss: 0.01516


 38%|███▊      | 750/2000 [04:47<08:47,  2.37it/s]

Epoch:  750, D_real Loss: -0.83811, D_fake Loss: -0.94755, D_gradient_penalty: 0.03686, G_fake Loss: 1.02432, G_rec Loss: 0.02484


 50%|█████     | 1000/2000 [06:22<07:03,  2.36it/s]

Epoch:  1000, D_real Loss: -1.03377, D_fake Loss: -1.14765, D_gradient_penalty: 0.03285, G_fake Loss: 1.14775, G_rec Loss: 0.01300


 62%|██████▎   | 1250/2000 [07:58<05:17,  2.36it/s]

Epoch:  1250, D_real Loss: -0.97679, D_fake Loss: -1.08590, D_gradient_penalty: 0.03033, G_fake Loss: 1.11043, G_rec Loss: 0.01693


 75%|███████▌  | 1500/2000 [09:34<03:34,  2.33it/s]

Epoch:  1500, D_real Loss: -0.88732, D_fake Loss: -1.01657, D_gradient_penalty: 0.04183, G_fake Loss: 1.14144, G_rec Loss: 0.01578


 88%|████████▊ | 1750/2000 [11:10<01:45,  2.37it/s]

Epoch:  1750, D_real Loss: -0.87843, D_fake Loss: -1.00853, D_gradient_penalty: 0.03616, G_fake Loss: 1.01423, G_rec Loss: 0.00581


100%|██████████| 2000/2000 [12:46<00:00,  2.61it/s]


Epoch:  2000, D_real Loss: -1.09222, D_fake Loss: -1.21551, D_gradient_penalty: 0.06237, G_fake Loss: 1.23614, G_rec Loss: 0.00574


 12%|█▎        | 250/2000 [02:38<20:50,  1.40it/s]

Epoch:  250, D_real Loss: -0.60581, D_fake Loss: -0.69395, D_gradient_penalty: 0.03352, G_fake Loss: 1.01242, G_rec Loss: 0.02757


 25%|██▌       | 500/2000 [05:18<17:58,  1.39it/s]

Epoch:  500, D_real Loss: -0.91214, D_fake Loss: -1.01908, D_gradient_penalty: 0.03314, G_fake Loss: 1.13364, G_rec Loss: 0.02359


 38%|███▊      | 750/2000 [07:59<14:53,  1.40it/s]

Epoch:  750, D_real Loss: -0.96116, D_fake Loss: -1.09413, D_gradient_penalty: 0.04179, G_fake Loss: 1.10706, G_rec Loss: 0.02115


 50%|█████     | 1000/2000 [10:39<11:55,  1.40it/s]

Epoch:  1000, D_real Loss: -1.05566, D_fake Loss: -1.17584, D_gradient_penalty: 0.02910, G_fake Loss: 1.13134, G_rec Loss: 0.01832


 62%|██████▎   | 1250/2000 [13:20<08:56,  1.40it/s]

Epoch:  1250, D_real Loss: -1.22309, D_fake Loss: -1.34485, D_gradient_penalty: 0.03416, G_fake Loss: 1.32614, G_rec Loss: 0.01612


 75%|███████▌  | 1500/2000 [15:31<04:10,  2.00it/s]

Epoch:  1500, D_real Loss: -1.36419, D_fake Loss: -1.50761, D_gradient_penalty: 0.04406, G_fake Loss: 1.57757, G_rec Loss: 0.01631


 88%|████████▊ | 1750/2000 [17:25<02:06,  1.97it/s]

Epoch:  1750, D_real Loss: -1.39273, D_fake Loss: -1.52153, D_gradient_penalty: 0.05710, G_fake Loss: 1.53567, G_rec Loss: 0.00886


100%|██████████| 2000/2000 [19:19<00:00,  1.73it/s]

Epoch:  2000, D_real Loss: -1.41666, D_fake Loss: -1.55512, D_gradient_penalty: 0.06552, G_fake Loss: 1.57962, G_rec Loss: 0.00875





In [86]:
F.interpolate(torch.randn(1, 3, 25, 37), size=(28, 42), mode='bilinear', align_corners=True).shape

torch.Size([1, 3, 28, 42])

In [78]:
len(noise)

3

In [79]:
noise[0].shape, noise[1].shape, noise[2].shape

(torch.Size([1, 3, 25, 37]),
 torch.Size([1, 3, 28, 42]),
 torch.Size([1, 3, 25, 37]))

In [75]:
len(fixed_noise)

2

In [76]:
fixed_noise[0].shape, fixed_noise[1].shape

(torch.Size([1, 3, 25, 37]), torch.Size([1, 3, 28, 42]))

In [71]:
noise_amp

[1]