In [None]:
import os
import math
import torch
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import utils as vutils
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder

In [None]:
num_workers = os.cpu_count()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_dir = "/home/pervinco/Datasets/CelebA"
save_dir = "./runs/StyleGAN"

lr = 0.0001
batch_sizes = [256, 256, 128, 64, 32, 16]

nc = 3
nz = 512
nw = 512
ndf = 512
ngf = 512

min_img_size = 4
max_img_size = 128
gp_coeff = 10
num_steps = int(math.log2(max_img_size / 4)) + 1

progressive_epochs = [30] * len(batch_sizes)

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [None]:
def get_dataset(img_size):
    dataset = datasets.ImageFolder(root=data_dir,
                                    transform=transforms.Compose([
                                        transforms.Resize(img_size),
                                        transforms.CenterCrop(img_size),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                    ]))
    dataloader = DataLoader(dataset, batch_size=batch_sizes[int(math.log2(img_size / 4))], shuffle=True, num_workers=num_workers)

    return dataset, dataloader

In [None]:
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.eps = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.eps)
    

class EqualConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5

        self.bias = self.conv.bias
        self.conv.bias = None

        # initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)


class EqualLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.scale = (2 / in_features) ** 0.5
        self.bias = self.linear.bias
        self.linear.bias = None

        nn.init.normal_(self.linear.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.linear(x * self.scale) + self.bias

In [None]:
## Mapping Network
class MappingNetwork(nn.Module):
    def __init__(self, nz, nw):
        super().__init__()
        self.mapping = nn.Sequential(PixelNorm(),
                                     EqualLinear(nz, nw),
                                     nn.ReLU(),
                                     EqualLinear(nw, nw),
                                     nn.ReLU(),
                                     EqualLinear(nw, nw),
                                     nn.ReLU(),
                                     EqualLinear(nw, nw),
                                     nn.ReLU(),
                                     EqualLinear(nw, nw),
                                     nn.ReLU(),
                                     EqualLinear(nw, nw),
                                     nn.ReLU(),
                                     EqualLinear(nw, nw),
                                     nn.ReLU(),
                                     EqualLinear(nw, nw))
    
    def forward(self,x):
        return self.mapping(x)
    

## Adaptive Instance Normalization
class AdaIN(nn.Module):
    def __init__(self, channels, nw):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale = EqualLinear(nw, channels)
        self.style_bias = EqualLinear(nw, channels)

    def forward(self, x, w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias  = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        
        return style_scale * x + style_bias
    

class NoiseInjection(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1,channels,1,1))

    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device = x.device)
        return x + self.weight + noise

In [None]:
class StyleGeneratorBlock(nn.Module):
    def __init__(self, in_channels, out_channels, nw):
        super().__init__()
        self.conv1 = EqualConv2d(in_channels, out_channels)
        self.conv2 = EqualConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        
        self.adain1 = AdaIN(out_channels, nw)
        self.adain2 = AdaIN(out_channels, nw)

        self.noise_inject1 = NoiseInjection(out_channels)
        self.noise_inject2 = NoiseInjection(out_channels)

    def forward(self, x, w):
        x = self.conv1(x)
        x = self.noise_inject1(x)
        x = self.leaky(x)
        x = self.adain1(x, w)

        x = self.conv2(x)
        x = self.noise_inject2(x)
        x = self.leaky(x)
        x = self.adain2(x, w)

        return x
    

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = EqualConv2d(in_channels, out_channels)
        self.conv2 = EqualConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.leaky(x)
        x = self.conv2(x)
        x = self.leaky(x)

        return x

In [None]:
factors = [1, 1, 1, 1/2, 1/4, 1/8, 1/16, 1/32]

class Generator(nn.Module):
    def __init__(self, nz, nw, ngf, nc):
        super().__init__()
        self.const_input = nn.Parameter(torch.ones(1, ngf, 4, 4))
        self.map = MappingNetwork(nz, nw)

        self.initial_adain1 = AdaIN(ngf, nw)
        self.initial_adain2 = AdaIN(ngf, nw)

        self.initial_noise1 = NoiseInjection(ngf)
        self.initial_noise2 = NoiseInjection(ngf)
        self.initial_conv = nn.Conv2d(ngf, ngf, kernel_size=3, stride=1, padding=1)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)

        self.initial_rgb = EqualConv2d(ngf, nc, kernel_size = 1, stride=1, padding=0)
        self.prog_blocks, self.toRGBs = (nn.ModuleList([]), nn.ModuleList([self.initial_rgb]))

        for i in range(len(factors)-1):
            conv_in_c  = int(ngf * factors[i])
            conv_out_c = int(ngf * factors[i+1])
            self.prog_blocks.append(StyleGeneratorBlock(conv_in_c, conv_out_c, nw))
            self.toRGBs.append(EqualConv2d(conv_out_c, nc, kernel_size = 1, stride=1, padding=0))
        

    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha * generated + (1-alpha ) * upscaled)


    def forward(self, noise, alpha, steps):
        w = self.map(noise)
        x = self.initial_adain1(self.initial_noise1(self.const_input),w)
        x = self.initial_conv(x)
        out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)

        if steps == 0:
            return self.initial_rgb(x)
        
        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode = 'bilinear')
            out = self.prog_blocks[step](upscaled,w)

        final_upscaled = self.toRGBs[steps-1](upscaled)
        final_out = self.toRGBs[steps](out)

        return self.fade_in(alpha, final_upscaled, final_out)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super().__init__()
        self.prog_blocks, self.fromRGBs = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(ndf * factors[i])
            conv_out = int(ndf * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out))
            self.fromRGBs.append(EqualConv2d(nc, conv_in, kernel_size=1, stride=1, padding=0))


        self.initial_rgb = EqualConv2d(nc, ndf, kernel_size=1, stride=1, padding=0)
        self.fromRGBs.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.final_block = nn.Sequential(EqualConv2d(ndf + 1, ndf, kernel_size=3, padding=1),
                                         nn.LeakyReLU(0.2),
                                         EqualConv2d(ndf, ndf, kernel_size=4, padding=0, stride=1),
                                         nn.LeakyReLU(0.2),
                                         EqualConv2d(ndf, 1, kernel_size=1, padding=0, stride=1))

    def fade_in(self, alpha, downscaled, out):

        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3]))

        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        cur_step = len(self.prog_blocks) - steps
        out = self.leaky(self.fromRGBs[cur_step](x))

        if steps == 0:
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        downscaled = self.leaky(self.fromRGBs[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        
        return self.final_block(out).view(out.shape[0], -1)

In [None]:
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    bs, c, h, w = real.shape
    beta = torch.rand((bs,1,1,1)).repeat(1,c, h,w).to(device)
    
    interpolated_images = real * beta + fake.detach() * (1-beta)
    interpolated_images.requires_grad_(True)
    
    ## Calculate critic scores
    mixed_scores = critic(interpolated_images,alpha,train_step)
    
    ## Take the gradient of the scores with respect to the image
    gradient = torch.autograd.grad(
        inputs = interpolated_images,
        outputs = mixed_scores,
        grad_outputs = torch.ones_like(mixed_scores),
        create_graph = True,
        retain_graph = True
    )[0]
    
    gradient = gradient.view(gradient.shape[0],-1)
    gradient_norm = gradient.norm(2,dim=1)
    penalty = torch.mean((gradient_norm - 1)**2)
    
    return penalty

In [None]:
def save_fake_images(epoch, G, fixed_noise, alpha, step, save_dir, num_images=64):
    with torch.no_grad():  # 그래디언트 계산을 하지 않음
        fake_images = G(fixed_noise, alpha, step).detach().cpu()
        # 이미지의 값 범위를 [0, 1]로 조정
        fake_images = (fake_images * 0.5) + 0.5

    # 이미지 그리드 생성
    fig = plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title(f"Fake Images at Epoch {epoch}")
    grid = vutils.make_grid(fake_images[:num_images], padding=2, normalize=False)
    plt.imshow(np.transpose(grid, (1, 2, 0)))

    # 이미지 저장
    plt.savefig(f"{save_dir}/Epoch_{epoch}_Fake.png")
    plt.close(fig)

In [None]:
G = Generator(nz, nw, ngf, nc).to(device)
D = Discriminator(nc, ndf).to(device)

g_optimizer = torch.optim.Adam([{'params': [param for name, param in G.named_parameters() if 'map' not in name]},
                                {'params': G.map.parameters(), 'lr': 1e-5}], lr=lr, betas =(0.5, 0.99))
d_optimizer = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.99))

global_epochs = 0
step = int(math.log2(min_img_size / 4))
fixed_noise = torch.randn(64, nz).to(device)

# 프로그레시브 학습 단계를 위한 루프
for n_epochs in progressive_epochs[step:]:
    alpha = 0.00001
    dataset, dataloader = get_dataset(img_size=4*2**step)
    print(f"Image size:{4*2**step} | Current step:{step}")
    for epoch in range(n_epochs):
        for idx, (images, _) in enumerate(tqdm(dataloader, desc="Train", leave=False)):
            bs = images.size(0)
            real_images = images.to(device)
            z = torch.randn(bs, nz).to(device)

            # Discriminator 학습
            fake_images = G(z, alpha, step)
            d_real_loss = D(real_images, alpha, step).mean()
            d_fake_loss = D(fake_images.detach(), alpha, step).mean()
            gp = gradient_penalty(D, real_images, fake_images, alpha, step, device=device)
            d_loss = -(d_real_loss - d_fake_loss) + gp_coeff * gp + 0.001 * (d_real_loss ** 2)

            D.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            # Generator 학습
            fake_images = G(z, alpha, step)
            g_loss = -D(fake_images, alpha, step).mean()
            
            G.zero_grad()
            g_loss.backward()
            g_optimizer.step()
            
            alpha += (bs / len(dataset)) * (1 / progressive_epochs[step]) * 2
            alpha = min(alpha, 1)

        # Epoch의 끝에서 fake 이미지 저장
        save_fake_images(global_epochs, G, fixed_noise, alpha, step)
        print(f"Epoch [{epoch+1}/{n_epochs}] Global Epoch:{global_epochs} D Loss : {d_loss.item():.4f} G Loss : {g_loss.item():.4f}")
        global_epochs += 1  # 전체 학습에서의 epoch 수 업데이트

    step += 1  # 다음 프로그레시브 학습 단계로 이동

print("Training finished")
