In [1]:
import os
import math
import torch
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import utils as vutils
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder

In [2]:
num_workers = os.cpu_count()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_dir = "/home/pervinco/Datasets/celeba_hq_256"
save_dir = "./runs/PGGAN"

lr = 0.0001
beta1 = 0.0
beta2 = 0.99
batch_sizes = [256, 256, 128, 64, 16, 4] ## img_size : 4, 8, 16, 32, 64, 128

nc = 3
nz = 256
ndf = 256
ngf = 256

min_img_size = 4
max_img_size = 128
gp_coeff = 10

progressive_epochs = [100] * len(batch_sizes)

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [3]:
## Equlized Learning Rate Conv2d
class EqualConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
        super().__init__()
        """
        학습률을 균등화하기 위해 가중치에 곱해지는 Scaling Factor.
        계산 공식은 He 초기화 방식을 변형한 것으로, 가중치의 분산을 조절하여 특정 레이어를 통과할 때의 학습률을 균일하게 유지한다.
        """
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5

        self.bias = self.conv.bias
        self.conv.bias = None

        # initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

In [4]:
## Pixel Norm
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.eps = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.eps)

In [5]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm=True):
        super().__init__()
        self.use_pixelnorm = use_pixelnorm
        self.conv1 = EqualConv2d(in_channels, out_channels)
        self.conv2 = EqualConv2d(out_channels, out_channels)
        self.lrelu = nn.LeakyReLU(0.2)
        self.pixel_norm = PixelNorm()

    def forward(self, x):
        x = self.conv1(x)
        x = self.lrelu(x)
        if self.use_pixelnorm:
            x = self.pixel_norm(x)
        
        x = self.conv2(x)
        x = self.lrelu(x)
        if self.use_pixelnorm:
            x = self.pixel_norm(x)

        return x

In [6]:
factors = [1, 1, 1, 1, 1/2, 1/4, 1/8, 1/16, 1/32]

class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super().__init__()

        self.block1 = nn.Sequential(PixelNorm(),
                                    nn.ConvTranspose2d(nz, ngf, kernel_size=4, stride=1, padding=0),
                                    nn.LeakyReLU(0.2),
                                    EqualConv2d(ngf, ngf, kernel_size=3, stride=1, padding=1),
                                    nn.LeakyReLU(0.2),
                                    PixelNorm())
        self.toRGB1 = EqualConv2d(ngf, nc, kernel_size=1, stride=1, padding=0) ## 생성된 특성 맵을 이미지의 RGB 채널로 변환

        ## 모델이 점진적으로 성장하면서 추가될 블록들과 해당 블록들로부터의 RGB 변환 레이어를 저장.
        self.prog_blocks, self.toRGBs = (nn.ModuleList([]), nn.ModuleList([self.toRGB1]))
        
        ## 마지막 요소를 제외한 모든 요소를 순회
        for i in range(len(factors) - 1):
            in_channels = int(ngf * factors[i]) ## 128, 128, 128, 128, 64, 32, 16, 8
            out_channels = int(ngf * factors[i + 1]) ## 128, 128, 128, 64, 32, 16, 8, 4
            self.prog_blocks.append(ConvBlock(in_channels, out_channels))
            self.toRGBs.append(EqualConv2d(out_channels, nc, kernel_size=1, stride=1, padding=0))

    def fade_in(self, alpha, upscaled, generated):
        ## alpha 값을 사용하여 이전 단계의 출력(upscaled)과 현재 단계의 출력(generated)을 선형적으로 혼합
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
    
    def forward(self, x, alpha, steps):
        out = self.block1(x)

        ## steps == 0일 경우, 초기 단계의 출력만을 사용하여 결과를 반환.
        if steps == 0:
            return self.toRGB1(out)
        
        ## steps > 0일 경우, 각 단계마다 F.interpolate를 사용하여 출력을 업스케일링하고, 
        ## self.prog_blocks의 해당하는 블록으로 처리
        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[step](upscaled)

        final_upscaled = self.toRGBs[steps-1](upscaled)
        final_output = self.toRGBs[steps](out)

        return self.fade_in(alpha, final_upscaled, final_output)
    

class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super().__init__()
        self.lrelu = nn.LeakyReLU(0.2)
        self.prog_blocks, self.fromRGBs = nn.ModuleList([]), nn.ModuleList([])

        ## 높은 해상도에서 낮은 해상도로 이동하며 판별자 블록을 초기화.
        for i in range(len(factors) - 1, 0, -1):
            in_channels = int(ndf * factors[i])
            out_channels = int(ndf * factors[i-1])
            self.prog_blocks.append(ConvBlock(in_channels, out_channels, use_pixelnorm=True))
            self.fromRGBs.append(EqualConv2d(nc, in_channels, kernel_size=1, stride=1, padding=0))

        self.fromRGB1 = EqualConv2d(nc, ndf, kernel_size=1, stride=1, padding=0)
        self.fromRGBs.append(self.fromRGB1)

        ## 평균 풀링, stride=2이므로 크기를 절반으로 줄인다.
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.final_block = nn.Sequential(EqualConv2d(ndf + 1, ndf, kernel_size=3, padding=1),
                                         nn.LeakyReLU(0.2),
                                         EqualConv2d(ndf, ndf, kernel_size=4, padding=0, stride=1),
                                         nn.LeakyReLU(0.2),
                                         EqualConv2d(ndf, 1, kernel_size=1, padding=0, stride=1))
    
    def fade_in(self, alpha, downscaled, out):
        ## alpha 값을 사용하여 이전 해상도의 이미지(downscaled)와 현재 처리된 이미지(out)를 혼합.
        return alpha * out + (1-alpha) * downscaled
    
    def minibatch_discrimination(self, x):
        batch_stat = (torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3]))

        return torch.cat([x, batch_stat], dim=1)
    
    def forward(self, x, alpha, steps):
        ## 현재 처리해야 할 해상도의 단계를 결정
        current_step = len(self.prog_blocks) - steps

        out = self.fromRGBs[current_step](x)
        out = self.lrelu(out)

        if steps == 0:
            out = self.minibatch_discrimination(out)
            return self.final_block(out).view(out.shape[0], -1)
        
        downscaled = self.lrelu(self.fromRGBs[current_step+1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[current_step](out))

        out = self.fade_in(alpha, downscaled, out)

        for step in range(current_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_discrimination(out)

        return self.final_block(out).view(out.shape[0], -1)

In [7]:
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    bs, c, h, w = real.shape
    beta = torch.rand((bs,1,1,1)).repeat(1,c, h,w).to(device)
    
    interpolated_images = real * beta + fake.detach() * (1-beta)
    interpolated_images.requires_grad_(True)
    
    ## Calculate critic scores
    mixed_scores = critic(interpolated_images,alpha,train_step)
    
    ## Take the gradient of the scores with respect to the image
    gradient = torch.autograd.grad(
        inputs = interpolated_images,
        outputs = mixed_scores,
        grad_outputs = torch.ones_like(mixed_scores),
        create_graph = True,
        retain_graph = True
    )[0]
    
    gradient = gradient.view(gradient.shape[0],-1)
    gradient_norm = gradient.norm(2,dim=1)
    penalty = torch.mean((gradient_norm - 1)**2)
    
    return penalty

In [8]:
def get_dataset(img_size, batch_size):
    dataset = datasets.ImageFolder(root=data_dir,
                                    transform=transforms.Compose([
                                        transforms.Resize(img_size),
                                        transforms.CenterCrop(img_size),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                    ]))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    return dataset, dataloader


def save_fake_images(epoch, G, fixed_noise, alpha, step, num_images=64):
    with torch.no_grad():  # 그래디언트 계산을 하지 않음
        fake_images = G(fixed_noise, alpha, step) * 0.5 + 0.5
        fake_images = fake_images.detach().cpu()
        img_size = fake_images.size(-1)
        
    fig = plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title(f"Fake Images at Epoch {epoch}")
    plt.imshow(np.transpose(vutils.make_grid(fake_images[:num_images], padding=2, normalize=True), (1, 2, 0)))
    plt.savefig(f"{save_dir}/Gep_{epoch}_{img_size}x{img_size}.png")  # 이미지 파일로 저장
    plt.close(fig)

In [9]:
G = Generator(nz, ngf, nc).to(device)
D = Discriminator(nc, ndf).to(device)

g_optimizer = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2))
d_optimizer = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta1, beta2))

global_epochs = 0
step = int(math.log2(min_img_size / 4))
fixed_noise = torch.randn(64, nz, 1, 1).to(device)

# 프로그레시브 학습 단계를 위한 루프
for n_epochs in progressive_epochs[step:]:
    alpha = 0.00001
    img_size = 4*2**step
    b = int(math.log2(img_size / 4))
    batch_size = batch_sizes[b]
    dataset, dataloader = get_dataset(img_size=4*2**step, batch_size=batch_size)
    print(f"Image size:{4*2**step} | Current batch size : {batch_size} | Current step:{step}", end="")
    for epoch in range(n_epochs):
        for idx, (images, _) in enumerate(tqdm(dataloader, desc="Train", leave=False)):
            bs = images.size(0)
            real_images = images.to(device)
            z = torch.randn(bs, nz, 1, 1).to(device)

            # Discriminator 학습
            fake_images = G(z, alpha, step)
            d_real_loss = D(real_images, alpha, step).mean()
            d_fake_loss = D(fake_images.detach(), alpha, step).mean()
            gp = gradient_penalty(D, real_images, fake_images, alpha, step, device=device)
            d_loss = -(d_real_loss - d_fake_loss) + gp_coeff * gp + 0.001 * (d_real_loss ** 2)

            D.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            # Generator 학습
            fake_images = G(z, alpha, step)
            g_loss = -D(fake_images, alpha, step).mean()
            
            G.zero_grad()
            g_loss.backward()
            g_optimizer.step()
            
            alpha += (bs / len(dataset)) * (1 / progressive_epochs[step]) * 2
            alpha = min(alpha, 1)

        if epoch > 0 and epoch % 10 == 0:
            print(f"\tEpoch [{epoch}/{n_epochs}] Global Epoch:{global_epochs} D Loss : {d_loss.item():.4f} G Loss : {g_loss.item():.4f}")
            save_fake_images(global_epochs, G, fixed_noise, alpha, step)

        global_epochs += 1  # 전체 학습에서의 epoch 수 업데이트
    step += 1  # 다음 프로그레시브 학습 단계로 이동

torch.save(G.state_dict(), f'{save_dir}/G.ckpt')
torch.save(D.state_dict(), f'{save_dir}/D.ckpt')

Image size:4 | Current batch size : 256 | Current step:0

                                                         

	Epoch [10/100] Global Epoch:10 D Loss : 0.2203 G Loss : -0.3637


                                                         

	Epoch [20/100] Global Epoch:20 D Loss : 0.1675 G Loss : -0.1965


                                                         

	Epoch [30/100] Global Epoch:30 D Loss : 0.1559 G Loss : -0.1437


                                                         

	Epoch [40/100] Global Epoch:40 D Loss : 0.1535 G Loss : -0.1075


                                                         

	Epoch [50/100] Global Epoch:50 D Loss : 0.1687 G Loss : 0.0538


                                                         

	Epoch [60/100] Global Epoch:60 D Loss : 0.1040 G Loss : -0.0514


                                                         

	Epoch [70/100] Global Epoch:70 D Loss : 0.0807 G Loss : -0.0011


                                                         

	Epoch [80/100] Global Epoch:80 D Loss : 0.0728 G Loss : 0.0352


                                                         

	Epoch [90/100] Global Epoch:90 D Loss : 0.0868 G Loss : 0.0326


                                                         

Image size:8 | Current batch size : 256 | Current step:1

                                                        

	Epoch [10/100] Global Epoch:110 D Loss : 0.0833 G Loss : 0.1911


                                                        

	Epoch [20/100] Global Epoch:120 D Loss : 0.0899 G Loss : 0.0300


                                                        

	Epoch [30/100] Global Epoch:130 D Loss : 0.0158 G Loss : -0.2955


                                                        

	Epoch [40/100] Global Epoch:140 D Loss : 0.0334 G Loss : -0.0191


                                                        

	Epoch [50/100] Global Epoch:150 D Loss : -0.0590 G Loss : -0.0517


                                                        

	Epoch [60/100] Global Epoch:160 D Loss : 0.0716 G Loss : -0.0058


                                                        

	Epoch [70/100] Global Epoch:170 D Loss : 0.0409 G Loss : -0.0309


                                                        

	Epoch [80/100] Global Epoch:180 D Loss : 0.0238 G Loss : -0.0771


                                                        

	Epoch [90/100] Global Epoch:190 D Loss : 0.0231 G Loss : 0.0581


                                                        

Image size:16 | Current batch size : 128 | Current step:2

                                                          

	Epoch [10/100] Global Epoch:210 D Loss : -0.3068 G Loss : 1.5067


                                                          

	Epoch [20/100] Global Epoch:220 D Loss : -0.2104 G Loss : 0.2007


                                                          

	Epoch [30/100] Global Epoch:230 D Loss : -0.1912 G Loss : 0.4754


                                                          

	Epoch [40/100] Global Epoch:240 D Loss : -0.2531 G Loss : -0.0912


                                                          

	Epoch [50/100] Global Epoch:250 D Loss : -0.2930 G Loss : 0.4004


                                                          

	Epoch [60/100] Global Epoch:260 D Loss : -0.1231 G Loss : 0.4452


                                                          

	Epoch [70/100] Global Epoch:270 D Loss : -0.1049 G Loss : 0.3537


                                                          

	Epoch [80/100] Global Epoch:280 D Loss : -0.1768 G Loss : -0.2964


                                                          

	Epoch [90/100] Global Epoch:290 D Loss : -0.0682 G Loss : 0.2274


                                                          

Image size:32 | Current batch size : 64 | Current step:3

                                                          

	Epoch [10/100] Global Epoch:310 D Loss : 0.0921 G Loss : 1.3615


                                                          

	Epoch [20/100] Global Epoch:320 D Loss : -0.1076 G Loss : 0.2649


                                                          

KeyboardInterrupt: 