In [1]:
%matplotlib inline

In [12]:
import torch
import numpy as np
from torch import nn
from torch.nn.init import kaiming_normal,calculate_gain,xavier_normal
import torchvision

In [3]:
class PixelNormalization(nn.Module):
    def __init__(self, epsilon=1e-8):
        super(PixelNormalization, self).__init__()
        self.epsilon = epsilon

    def forward(self, input):
        x = input * torch.rsqrt(
            torch.mean(input**2, dim=1, keepdim=True) + self.epsilon
        )
        return x

    def __repr__(self):
        return self.__class__.__name__ + "(epsilon = %s)" % self.epsilon

In [4]:
class MiniBatchNormalization(nn.Module):
    def __init__(self):
        super(MiniBatchNormalization, self).__init__()

    def forward(self, input):
        N, C, H, W = input.size()
        std = torch.std(input, unbiased=False, dim=0)
        val = torch.mean(std)
        return torch.cat([input, torch.full((N, 1, H, W), val)], dim=1)

In [5]:
class EquilizedConv2d(nn.Module):
    def __init__(self, c_in, c_out, k_size, stride, pad, bias=False, gain=2, **kwargs):
        super(EquilizedConv2d, self).__init__()
        if not isinstance(k_size, tuple):
            k_size = (k_size, k_size)
        self.conv = nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=bias, **kwargs)
        nn.init.kaiming_normal_(self.conv.weight)
        nn.init.zeros_(self.conv.bias)

        f_in = torch.prod(torch.tensor([c_in, *k_size]))
        self.scale = (2 / f_in) ** 0.5

    def forward(self, input):
        return self.conv(input.mul(self.scale))

In [6]:
class EquilizedDeconv2d(nn.Module):
    def __init__(self, c_in, c_out, k_size, stride, pad, bias=False, gain=2, **kwargs):
        super(EquilizedDeconv2d, self).__init__()
        if not isinstance(k_size, tuple):
            k_size = (k_size, k_size)
        self.deconv = nn.ConvTranspose2d(
            c_in, c_out, k_size, stride, pad, bias=bias, **kwargs
        )
        nn.init.kaiming_normal_(self.deconv.weight)
        nn.init.zeros_(self.deconv.bias)

        f_in = torch.prod(torch.tensor([c_in, *k_size]))
        self.scale = (2 / f_in) ** 0.5

    def forward(self, input):
        return self.deconv(input.mul(self.scale))

In [7]:
class G_Block(nn.Module):
    def __init__(self, c_in, c_out, initial_block=False, relu_slope=0.2):
        super(G_Block, self).__init__()
        self.init = None
        if initial_block:
            self.init = nn.Sequential(
                EquilizedDeconv2d(
                    c_in, c_out, k_size=(4, 4), stride=(1, 1), pad=(1, 1)
                ),
            )
        else:
            self.init = nn.Sequential(
                nn.Upsample(scale_factor=2.0, mode="nearest"),
                EquilizedConv2d(c_in, c_out, k_size=(3, 3), stride=(1, 1), pad=(1, 1)),
            )
        self.main = nn.Sequential(
            nn.LeakyReLU(relu_slope),
            PixelNormalization(),
            EquilizedConv2d(c_out, c_out, k_size=(3, 3), stride=(1, 1), pad=(1, 1)),
            nn.LeakyReLU(relu_slope),
            PixelNormalization(),
        )

    def forward(self, x):
        x = self.init(x)
        x = self.main(x)
        return x

In [8]:
class D_Block(nn.Module):
    def __init__(self, c_in, c_out, last_block=False, relu_slope=0.2):
        super(D_Block, self).__init__()
        if last:
            layers = [
                MiniBatchNormalization(),
                EquilizedConv2d(
                    c_in + 1, c_out, k_size=(3, 3), stride=(1, 1), pad=(1, 1)
                ),
                nn.LeakyReLU(relu_slope),
                EquilizedConv2d(c_out, c_out, k_size=(4, 4), stride=(1, 1)),
                nn.LeakyReLU(relu_slope),
                EquilizedConv2d(c_out, 1, k_size=(1, 1), stride=(1, 1)),
            ]
        else:
            layers = [
                EquilizedConv2d(c_in, c_out, k_size=(3, 3), stride=(1, 1), pad=(1, 1)),
                nn.LeakyReLU(relu_slope),
                EquilizedConv2d(c_out, c_out, k_size=(3, 3), stride=(1, 1), pad=(1, 1)),
                nn.LeakyReLU(relu_slope),
                nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2)),
            ]
        self.main = nn.Sequential(*layers)

    def forward(self, x):
        x = self.main(x)
        return x

In [11]:
class ToRGB(nn.Module):
    def __init__(self, c_in):
        super(ToRGB, self).__init__()
        self.conv = EquilizedConv2d(c_in, 3, k_size=(1, 1), stride=(1, 1))

    def forward(self, x):
        x = self.conv(x)
        return x


class FromRGB(nn.Module):
    def __init__(self, c_out):
        super(FromRGB, self).__init__()
        self.conv = EquilizedConv2d(3, c_out, k_size=(1, 1), stride=(1, 1))

    def forward(self, x):
        x = self.conv(x)
        return x

In [16]:
class Generator(nn.Module):
    def __init__(self, latent_in, out_size):
        super(Generator, self).__init__()
        self.depth = 1  # どの深さのレイヤまで使用するか
        self.alpha = 1  # 成長時に混ぜ合わせる比率
        self.fade = 0
        self.upsample = nn.Upsample(scale_factor=2.0, mode="nearest")
        self.current_net = nn.ModuleList(
            [G_Block(latent_in, latent_in, initial_block=True)]
        )
        self.toRGBs = nn.ModuleList([ToRGB(latent_in)])
        # 成長予定のレイヤ追加
        for d in range(2, int(np.log2(out_size))):
            # (4*4),8*8,16*16,32*32までは512chで生成
            # https://research.nvidia.com/publication/2018-04_progressive-growing-gans-improved-quality-stability-and-all
            if d < 5:
                c_in, c_out = 512, 512
            else:
                c_in, c_out = int(512 / 2 ** (d - 5)), int(512 / 2 ** (d - 4))
            self.current_net.append(G_Block(c_in, c_out))
            self.toRGBs.append(ToRGB(c_out))

    def forward(self, x):
        # 深さの1つ後ろまで進める
        for block in self.current_net[: self.depth - 1]:
            x = block(x)
        out = self.current_net[self.depth](x)
        x_rgb = self.toRGBs[self.depth](out)
        # 成長中なら
        if self.alpha < 1:
            x_old = self.upsample(x)
            old_rgb = self.toRGBs[self.depth - 1](x_old)
            x_rgb = (1 - self.alpha) * old_rgb + self.alpha * x_rgb

            self.alpha += self.fade
        return x_rgb

    def grow_net(self, num_iters):
        self.fade = 1 / num_iters
        self.alpha = 1 / num_iters

        self.depth += 1

In [None]:
class Discriminator(nn.Module):
    def __init__(self, latent_in, out_size):
        self.depth = 1  # どの深さのレイヤまで使用するか
        self.alpha = 1  # 成長時に混ぜ合わせる比率
        self.fade = 0
        self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2))
        self.current_net = nn.ModuleList(
            [D_Block(latent_in, latent_in, last_block=True)]
        )
        self.fromRGBs = nn.ModuleList([FromRGB(latent_in)])
        for d in range(2, int(np.log2(out_size))):
            # 32*32,16*16,8*8,(4*4)までは512chに戻す
            # https://research.nvidia.com/publication/2018-04_progressive-growing-gans-improved-quality-stability-and-all
            if d < 5:
                c_in, c_out = 512, 512
            else:
                c_in, c_out = int(512 / 2 ** (d - 5)), int(512 / 2 ** (d - 4))
            self.current_net.append(D_Block(c_in, c_out))
            self.fromRGBs.append(FromRGB(c_in))

    def forward(self, x_rgb):
        # 一層進める
        x = self.fromRGBs[self.depth - 1](x_rgb)
        x = self.current_net[self.depth - 1](x)
        # 成長中なら
        if self.alpha < 1:
            x_rgb = self.downsample(x_rgb)
            x_old = self.fromRGBs[self.depth - 2](x_rgb)
            x = (1 - self.alpha) * x_old + self.alpha * x

            self.alpha += self.fade
        for block in reversed(current_net[: self.depth - 1]):
            x = block(x)
        return x

    def grow_net(self, num_iters):
        self.fade = 1 / num_iters
        self.alpha = 1 / num_iters

        self.depth += 1