In [None]:
import torch
from torch import nn


class LayerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-12, affine=True):
        super(LayerNorm, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps

        if self.affine:
            self.gamma = nn.Parameter(torch.ones(num_features))
            self.beta = nn.Parameter(torch.zeros(num_features))

    def forward(self, x):
        shape = [-1] + [1] * (x.dim() - 1)
        mean = x.view(x.size(0), -1).mean(1).view(*shape)
        std = x.view(x.size(0), -1).std(1).view(*shape)

        y = (x - mean) / (std + self.eps)
        if self.affine:
            shape = [1, -1] + [1] * (x.dim() - 2)
            y = self.gamma.view(*shape) * y + self.beta.view(*shape)
            
        return y

In [None]:
import torch.nn as nn
from torchvision import models


# VGG architecter, used for the perceptual loss using a pretrained VGG network
class VGG19(nn.Module):
    def __init__(self, requires_grad=False):
        super().__init__()
        vgg_pretrained_features = models.vgg19(pretrained=True).features
        self.slice1 = nn.Sequential()
        self.slice2 = nn.Sequential()
        self.slice3 = nn.Sequential()
        self.slice4 = nn.Sequential()
        self.slice5 = nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h_input = X
        h_relu1 = self.slice1(h_input)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_input, h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out

In [None]:
import torch
from torch import nn

Norm = LayerNorm

class ResBlock(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(ResBlock, self).__init__()
        self.conv = nn.Sequential(nn.InstanceNorm2d(in_dim, affine=True),
                                  nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                  nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=1, padding=1),
                                  nn.InstanceNorm2d(in_dim, affine=True),
                                  nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                  nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=1, padding=1),
                                  nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
        
        self.short_cut = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
                                       nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=1, padding=0))
        
    def forward(self, x):
        out = self.conv(x) + self.short_cut(x)
        return out


class Conv2DLReLU(nn.Module):
    def __init__(self, inc, outc, kernel_size=3, stride=1, padding=0, negative_slope=0.2):
        super().__init__()
        self.conv = nn.Conv2d(inc, outc, kernel_size, stride, padding)
        self.ln = Norm(outc)
        self.llr = nn.LeakyReLU(negative_slope=negative_slope, inplace=True)

    def forward(self, x):
        return self.llr(self.ln(self.conv(x)))


class Conv2DInstLReLU(nn.Module):
    def __init__(self, inc, outc, kernel_size=3, stride=1, padding=0, negative_slope=0.2, is_inst=True):
        super().__init__()
        self.is_inst = is_inst
        self.conv = nn.Conv2d(inc, outc, kernel_size, stride, padding)
        self.inst = nn.InstanceNorm2d(outc, affine=True)
        self.llr = nn.LeakyReLU(negative_slope=negative_slope, inplace=True)

    def forward(self, x):
        if self.is_inst:
            return self.llr(self.inst(self.conv(x)))
        else:
            return self.llr(self.conv(x))


class Conv2DTransposeLReLU(nn.Module):
    def __init__(self, inc, outc, bilinear=True):
        super().__init__()
        if bilinear:
            self.deconv = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.deconv = nn.ConvTranspose2d(inc, outc, kernel_size=2, stride=2, padding=0)
        self.ln = Norm(outc)
        self.llr = nn.LeakyReLU(inplace=True)

    def forward(self, x):
        return self.llr(self.ln(self.deconv(x)))


class SwishMod(nn.Module):
    def __init__(self, inc, outc):
        super().__init__()
        self.conv = nn.Conv2d(inc, outc, 3, 1, 1)
        self.ln = Norm(outc)

    def forward(self, x):
        _x = torch.sigmoid(self.ln(self.conv(x)))
        return x.mul(_x)


class SwishGatedBlock(nn.Module):
    def __init__(self, inc, outc, cat=False, conv1x1=True, dropout=False):
        super().__init__()
        self.conv1x1 = conv1x1

        if conv1x1:
            self.conv0 = Conv2DLReLU(inc, outc, padding=1)
            inc = outc
            self.conv1 = Conv2DLReLU(inc, outc, padding=1)
        else:
            self.conv1 = Conv2DLReLU(inc, outc, padding=1)
        self.conv2 = Conv2DLReLU(outc, outc, padding=1)

        self.pooling = nn.MaxPool2d(2)
        if cat:
            self.deconv1 = Conv2DTransposeLReLU(outc, outc)
            self.deconv2 = Conv2DTransposeLReLU(inc, outc)
            self.swish_mod = SwishMod(outc, outc)
        else:
            self.swish_mod = SwishMod(inc, inc)

    def forward(self, inputs, cat=None):
        if self.conv1x1:
            inputs = self.conv0(inputs)
        x = self.conv1(inputs)
        x = self.conv2(x)

        if cat is None:
            # downsampling
            sgb_op = self.pooling(x)
            swish = self.pooling(inputs)
            swish = self.swish_mod(swish)
            concat = [sgb_op, swish]
        else:
            sgb_op = self.deconv1(x)
            swish = self.deconv2(inputs)
            swish = self.swish_mod(swish)
            concat = [sgb_op, swish, cat]

        return torch.cat(concat, dim=1), x

In [None]:
import random
import torch
import numpy as np

from torch.autograd import Variable
import torch.backends.cudnn as cudnn


# Source from "https://github.com/ultralytics/yolov5/blob/master/utils/torch_utils.py"
def init_torch_seeds(seed: int = 0):
    r""" Sets the seed for generating random numbers. Returns a
    Args:
        seed (int): The desired seed.
    """
    # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
    if seed == 0:  # slower, more reproducible
        cudnn.deterministic = True
        cudnn.benchmark = False
    else:  # faster, less reproducible
        cudnn.deterministic = False
        cudnn.benchmark = True

    print("Initialize random seed.")
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv2d") != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)


'''
    < var >
    Convert tensor to Variable
'''
def var(tensor, requires_grad=True):
    if torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor
    var = Variable(tensor.type(dtype), requires_grad=requires_grad)
    return var

'''
    < make_img >
    Generate images

    * Parameters
    dloader : Data loader for test data set
    G : Generator
    z : random_z(size = (N, img_num, z_dim))
        N : test img number / img_num : Number of images that you want to generate with one test img / z_dim : 8
    img_num : Number of images that you want to generate with one test img
'''
def make_img(dloader, G, z, img_num=5, img_size=128):
    if torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor
        
    dloader = iter(dloader)
    img, _ = dloader.next()

    N = img.size(0)    
    img = var(img.type(dtype))
    result_img = torch.FloatTensor(N * (img_num + 1), 3, img_size, img_size).type(dtype)

    for i in range(N):
        # original image to the leftmost
        result_img[i * (img_num + 1)] = img[i].data

        # Insert generated images to the next of the original image
        for j in range(img_num):
            img_ = img[i].unsqueeze(dim=0)
            z_ = z[i, j, :].unsqueeze(dim=0)
            
            out_img = G(img_, z_)
            result_img[i * (img_num + 1) + j + 1] = out_img.data

    # [-1, 1] -> [0, 1]
    result_img = result_img / 2 + 0.5
    return result_img


'''
    < make_interpolation >
    Make linear interpolated latent code.
    
    * Parameters
    n : Input images number
    img_num : Generated images number per one input image
    z_dim : Dimension of latent code. Basically 8.
'''
def make_interpolation(n=200, img_num=9, z_dim=8):
    if torch.cuda.is_available() is True:
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor

    # Make interpolated z
    step = 1 / (img_num - 1)
    alpha = torch.from_numpy(np.arange(0, 1, step))
    interpolated_z = torch.FloatTensor(n, img_num, z_dim).type(dtype)

    for i in range(n):
        first_z = torch.randn(1, z_dim)
        last_z = torch.randn(1, z_dim)
        
        for j in range(img_num - 1):
            interpolated_z[i, j] = (1 - alpha[j]) * first_z + alpha[j] * last_z
        interpolated_z[i, img_num-1] = last_z
    
    return interpolated_z


'''
    < make_z >
    Make latent code
    
    * Parameters
    n : Input images number
    img_num : Generated images number per one input image
    z_dim : Dimension of latent code. Basically 8.
    sample_type : random or interpolation
'''
def make_z(n, img_num, z_dim=8, sample_type='random'):
    if sample_type == 'random':
        z = var(torch.randn(n, img_num, 8))
    elif sample_type == 'interpolation':
        z = var(make_interpolation(n=n, img_num=img_num, z_dim=z_dim))
    
    return z

In [None]:
import torch
from torch import nn

class Latent_Discriminator(nn.Module):
    def __init__(self, z_dim=8, n_filters=64, negative_slope=0.2):
        super().__init__()
        # Discriminator with latent space z # (N, 8)
        self.dz = nn.Sequential(
            nn.Linear(z_dim, n_filters, bias=True),
            nn.LayerNorm(n_filters),
            nn.LeakyReLU(negative_slope, inplace=True),
            nn.Linear(n_filters, n_filters, bias=True),
            nn.LayerNorm(n_filters),
            nn.LeakyReLU(negative_slope, inplace=True),
            nn.Linear(n_filters, n_filters, bias=True),
            nn.LayerNorm(n_filters),
            nn.LeakyReLU(negative_slope, inplace=True),
            nn.Linear(n_filters, n_filters, bias=True),
            nn.LayerNorm(n_filters),
            nn.LeakyReLU(negative_slope, inplace=True),
            nn.Linear(n_filters, 1, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        out = self.dz(x)
        return out 


class Encoder(nn.Module):
    def __init__(self, in_nc=3, dim=64, z_dim=8):
        super(Encoder, self).__init__()
        self.z_dim = z_dim
        # n, 3, 128, 128 -> n, 256, 1, 1
        self.encode = nn.Sequential(
            nn.Conv2d(in_nc, dim // 2, kernel_size=7, stride=2, padding=3),
            ResBlock(dim // 2, dim),
            ResBlock(dim, dim * 2),
            ResBlock(dim * 2, dim * 4),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.AdaptiveAvgPool2d(output_size=1)
        )
        # n, 256, 1, 1 -> n, 256 -> n, 8
        self.fc_mu = nn.Linear(dim * 4, z_dim)
        # n, 512, 1, 1 -> n, 512 -> n, 8
        self.fc_logvar = nn.Linear(dim * 4, z_dim)
        
    def forward(self, x):
        # n, 3, 128, 128 -> n, 256, 1, 1
        encode = self.encode(x)
        # n, 256, 1, 1 -> n, 256
        encode = torch.flatten(encode, start_dim=1)
        # get mu n, 256 -> n, 8
        mu = self.fc_mu(encode)
        # get logvar n, 256 -> n, 8
        log_var = self.fc_logvar(encode)
        # use reparameter trick
        encoded_z = self.reparemeterize(mu, log_var, self.z_dim)
        return (mu, log_var, encoded_z)
        
    # define reparameter tricks for latent space z
    def reparemeterize(self, mu, log_variance, z_dim):
        std = torch.exp(log_variance / 2)
        random_z = var(torch.randn(1, z_dim))
        encoded_z = (random_z * std) + mu
        return encoded_z


class Generator(nn.Module):
    def __init__(self, z_dim=8):
        super().__init__()

        down_in_channels = [3 + z_dim, 75, 256, 384, 512, 640]
        down_out_channels = [64, 128, 192, 256, 320, 384]
        up_in_channels = [768, 1152, 960, 768, 576, 384]
        up_out_channels = [384, 320, 256, 192, 128, 64]

        self.down0 = SwishGatedBlock(down_in_channels[0], down_out_channels[0], conv1x1=False)
        self.down1 = SwishGatedBlock(down_in_channels[1], down_out_channels[1])
        self.down2 = SwishGatedBlock(down_in_channels[2], down_out_channels[2])
        self.down3 = SwishGatedBlock(down_in_channels[3], down_out_channels[3])
        self.down4 = SwishGatedBlock(down_in_channels[4], down_out_channels[4])
        self.down5 = SwishGatedBlock(down_in_channels[5], down_out_channels[5])

        self.swishmod0 = SwishMod(down_out_channels[0], down_out_channels[0])
        self.swishmod1 = SwishMod(down_out_channels[1], down_out_channels[1])
        self.swishmod2 = SwishMod(down_out_channels[2], down_out_channels[2])
        self.swishmod3 = SwishMod(down_out_channels[3], down_out_channels[3])
        self.swishmod4 = SwishMod(down_out_channels[4], down_out_channels[4])
        self.swishmod5 = SwishMod(down_out_channels[5], down_out_channels[5])

        self.up0 = SwishGatedBlock(up_in_channels[0], up_out_channels[0], cat=True)
        self.up1 = SwishGatedBlock(up_in_channels[1], up_out_channels[1], cat=True)
        self.up2 = SwishGatedBlock(up_in_channels[2], up_out_channels[2], cat=True)
        self.up3 = SwishGatedBlock(up_in_channels[3], up_out_channels[3], cat=True)
        self.up4 = SwishGatedBlock(up_in_channels[4], up_out_channels[4], cat=True)
        self.up5 = SwishGatedBlock(up_in_channels[5], up_out_channels[5], cat=True)

        self.out = nn.Sequential(
            Conv2DLReLU(down_out_channels[0] * 3, down_out_channels[0], kernel_size=1),
            Conv2DLReLU(down_out_channels[0], down_out_channels[0], kernel_size=3, padding=1),
            Conv2DLReLU(down_out_channels[0], down_out_channels[0], kernel_size=3, padding=1),
            nn.Conv2d(down_out_channels[0], 3, kernel_size=1),
            nn.Tanh()
        )
        
    def forward(self, x, z):
        # print('Encoder')
        # z : (N, z_dim) -> (N, z_dim, 1, 1) -> (N, z_dim, H, W)
        # x_with_z : (N, 3 + z_dim, H, W)
        z = z.unsqueeze(dim=2).unsqueeze(dim=3)
        z = z.expand(z.size(0), z.size(1), x.size(2), x.size(3))
        x_with_z = torch.cat([x, z], dim=1)

        # [B, 1, 128, 128] -> [B, 65, 64, 64] + [2, 64, 128, 128]
        inputs, conv0 = self.down0(x_with_z)
        # [B, 65, 64, 64] -> [B, 256, 32, 32] + [2, 128, 64, 64]
        inputs, conv1 = self.down1(inputs) 
        # [B, 256, 32, 32] -> [B, 384, 16, 16] + [2, 192, 32, 32]
        inputs, conv2 = self.down2(inputs)
        # [B, 384, 16, 16] -> [B, 512, 8, 8] + [2, 256, 16, 16]
        inputs, conv3 = self.down3(inputs)
        # [B, 512, 8, 8] -> [B, 640, 4, 4] + [2, 320, 8, 8]
        inputs, conv4 = self.down4(inputs)
        # [B, 640, 4, 4] -> [B, 768, 2, 2] + [2, 384, 4, 4]
        inputs, conv5 = self.down5(inputs)

        # print('SwishMod')
        # [2, 64, 128, 128]
        conv0 = self.swishmod0(conv0)
        # [2, 128, 64, 64]
        conv1 = self.swishmod1(conv1)
        # [2, 192, 32, 32]
        conv2 = self.swishmod2(conv2)
        # [2, 256, 16, 16]
        conv3 = self.swishmod3(conv3)
        # [2, 320, 8, 8]
        conv4 = self.swishmod4(conv4)
        # [2, 384, 4, 4]
        conv5 = self.swishmod5(conv5)

        # print('Decoder')
        # [B, 768, 2, 2] -> [B, 1152, 4, 4]
        inputs, _ = self.up0(inputs, cat=conv5)
        # [B, 1152, 4, 4] -> [B, 960, 8, 8]
        inputs, _ = self.up1(inputs, cat=conv4)
        # [B, 960, 8, 8] -> [B, 768, 16, 16]
        inputs, _ = self.up2(inputs, cat=conv3)
        # [B, 768, 16, 16] -> [B, 576, 32, 32]
        inputs, _ = self.up3(inputs, cat=conv2)
        # [B, 576, 32, 32] -> [B, 384, 64, 64]
        inputs, _ = self.up4(inputs, cat=conv1)
        # [B, 384, 64, 64] -> [B, 192, 128, 128]
        inputs, _ = self.up5(inputs, cat=conv0)
        # [B, 192, 128, 128] -> [B, 3, 128, 128]
        out = self.out(inputs)
        return out


class Discriminator(nn.Module):
    def __init__(self, n_filters=64):
        super().__init__()   

        self.out = nn.Sequential(
            Conv2DInstLReLU(inc=6, outc=n_filters, kernel_size=4, stride=2, padding=1, is_inst=False),
            Conv2DInstLReLU(inc=n_filters, outc=n_filters * 2, kernel_size=4, stride=2, padding=1),
            Conv2DInstLReLU(inc=n_filters * 2, outc=n_filters * 4, kernel_size=4, stride=2, padding=1),
            Conv2DInstLReLU(inc=n_filters * 4, outc=n_filters * 8, kernel_size=4, stride=2, padding=1),
            nn.Conv2d(n_filters * 8, 1, kernel_size=4, stride=1, padding=0),
        )

    def forward(self, x):
        out = self.out(x)
        return out


In [None]:
import torch
from torch import nn


# MSELoss for LSGAN
def MSELoss(score, target=1):
    if target == 1:
        label = var(torch.ones(score.size()).fill_(0.95), requires_grad=False)
    elif target == 0:
        label = var(torch.ones(score.size()).fill_(0.05), requires_grad=False)
    
    criterion = nn.MSELoss()
    loss = criterion(score, label)
    
    return loss


# BCELoss for Latent code z
def BCELoss(score, target=1):
    if target == 1:
        label = var(torch.ones(score.size()).fill_(0.95), requires_grad=False)
    elif target == 0:
        label = var(torch.ones(score.size()).fill_(0.05), requires_grad=False)

    criterion = nn.MSELoss()
    loss = criterion(score, label)
    
    return loss


# Perceptual loss that uses a pretrained VGG network
class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        self.vgg = VGG19().cuda()
        self.criterion = nn.L1Loss()
        self.weights = [0.88, 0.79, 0.63, 0.51, 0.39, 1.07]
        # self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]

    def forward(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        for i in range(len(x_vgg)):
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
        return loss

# KL Divergence loss used in VAE with an image encoder
class KLDLoss(nn.Module):
    def forward(self, mu, logvar):
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())


# wgan_loss
def WGANLoss(pred, real_or_not=True):
    if real_or_not:
        return - torch.mean(pred)
    else:
        return torch.mean(pred)


def Calculate_gradient_penalty(model, real_images, fake_images, device, constant=1.0, lamb=10.0):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake data
    alpha = torch.randn((real_images.size(0), 1, 1, 1), device=device)
    # Get random interpolation between real and fake data
    interpolates = (alpha * real_images + ((1 - alpha) * fake_images)).requires_grad_(True)

    model_interpolates = model(interpolates)
    grad_outputs = torch.ones(model_interpolates.size(), device=device, requires_grad=False)

    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=model_interpolates,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lamb
    return gradient_penalty

In [None]:
import torch
from torch.utils.data import Dataset
import torchvision.transforms as Transforms

import os
from PIL import Image

class BlackWhite2Color(Dataset):
    def __init__(self, root, transform, mode='train'):
        self.root = root
        self.transform = transform
        self.mode = mode
        
        data_dir = os.path.join(root, mode)
        self.file_list = os.listdir(data_dir)
        
    def __len__(self):
        return len(self.file_list)
        
    def __getitem__(self, idx):
        img_path = os.path.join(self.root, self.mode, self.file_list[idx])
        img = Image.open(img_path)
        # W, H = img.size[0], img.size[1]
        
        data_l = img.convert('L') # 1 dimension
        data = [data_l, data_l, data_l]
        data = Image.merge("RGB", data)   # 3 dimension
        ground_truth = img
        
        data = self.transform(data)
        ground_truth = self.transform(ground_truth)
        
        return (data, ground_truth)


class Sketch2Color(Dataset):
    def __init__(self, root, transform, mode='train'):
        self.root = root
        self.transform = transform
        self.mode = mode
        
        data_dir = os.path.join(root, mode)
        self.file_list = os.listdir(data_dir)
        
    def __len__(self):
        return len(self.file_list)
        
    def __getitem__(self, idx):
        img_path = os.path.join(self.root, self.mode, self.file_list[idx])
        img = Image.open(img_path).convert('RGB')
        W, H = img.size[0], img.size[1]

        data = img.crop((int(W / 2), 0, W, H))
        ground_truth = img.crop((0, 0, int(W / 2), H))
        
        data = self.transform(data)
        ground_truth = self.transform(ground_truth)
        
        return (data, ground_truth)


def data_loader(root, batch_size=1, shuffle=True, img_size=128, mode='train', dstname='sketch'):    
    transform = Transforms.Compose([Transforms.Resize((img_size, img_size)),
                                    Transforms.ToTensor(),
                                    Transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                                         std=(0.5, 0.5, 0.5))
                                   ])
    if dstname == 'sketch':
        dset = Sketch2Color(root, transform, mode=mode)
    else:
        dset = BlackWhite2Color(root, transform, mode=mode)
    
    if batch_size == 'all':
        batch_size = len(dset)
        
    dloader = torch.utils.data.DataLoader(dset,
                                          batch_size=batch_size,
                                          shuffle=shuffle,
                                          num_workers=0,
                                          drop_last=True)
    dlen = len(dset)
    
    return dloader, dlen


In [None]:
import torch
from torch import nn
from tqdm import tqdm
import torch.optim as optim
import torchvision
from tensorboardX import SummaryWriter
import os
import time
import datetime

# for reproductionary
init_torch_seeds(seed=1)

class Solver():
    def __init__(self, root='dataset/anime_faces', dstname='sketch', result_dir='result', weight_dir='weight', load_weight=False,
                 batch_size=1, test_size=10, test_img_num=5, img_size=128, num_epoch=100, save_every=1000,
                 g_lr=0.0002, d_lr=0.0001, beta_1=0.5, beta_2=0.999, lambda_kl=0.01, lambda_img=10, lambda_z=0.5, \
                     z_dim=8, logdir=None):
        
        # Data type(Can use GPU or not?)
        self.dtype = torch.cuda.FloatTensor
        if torch.cuda.is_available() is False:
            self.dtype = torch.FloatTensor

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Data loader for training
        self.dloader, dlen = data_loader(root=root, batch_size=batch_size, shuffle=True, 
                                         img_size=img_size, mode='train', dstname=dstname)
        print('training dataset length:', dlen)
        # Data loader for test
        self.t_dloader, _ = data_loader(root=root, batch_size=test_size, shuffle=False, 
                                        img_size=img_size, mode='test', dstname=dstname)

        # Models
        # Di is discriminator for image
        # Dz is discriminator for latent code
        # G is generator for input image and latent code z
        # Z is encoder for information difference
        self.Di = Discriminator().type(self.dtype)
        self.Di.apply(weights_init)
        self.G = Generator().type(self.dtype)
        self.G.apply(weights_init)
        self.E = Encoder().type(self.dtype)
        self.E.apply(weights_init)

        # Optimizers
        self.optim_Di = optim.Adam(self.Di.parameters(), lr=d_lr, betas=(beta_1, beta_2))
        self.optim_G = optim.Adam(self.G.parameters(), lr=g_lr, betas=(beta_1, beta_2))
        self.optim_E = optim.Adam(self.E.parameters(), lr=g_lr, betas=(beta_1, beta_2))

        # fixed random_z for test
        self.fixed_z = var(torch.randn(test_size, test_img_num, z_dim))

        # losses
        self.bce_loss = BCELoss
        self.recon_x_loss = VGGLoss()
        self.recon_z_loss = nn.L1Loss()
        self.mse_loss = MSELoss
        self.kl_loss = KLDLoss()

        # Some hyperparameters
        self.z_dim = z_dim
        self.lambda_img = lambda_img
        self.lambda_kl = lambda_kl
        self.lambda_z = lambda_z

        self.writer = SummaryWriter(logdir)

        # Extra things
        self.result_dir = result_dir
        self.weight_dir = weight_dir
        self.load_weight = load_weight
        self.test_img_num = test_img_num
        self.img_size = img_size
        self.start_epoch = 0
        self.num_epoch = num_epoch
        self.save_every = save_every
    
    '''
        < show_model >
        Print model architectures
    '''
    def show_model(self):
        print('================================ Discriminator for image =====================================')
        print(self.Di)
        print('==========================================================================================\n\n')
        print('================================= Generator ==================================================')
        print(self.G)
        print('==========================================================================================\n\n')
        print('================================= Encoder ==================================================')
        print(self.E)
        print('==========================================================================================\n\n')
        
    '''
        < set_train_phase >
        Set training phase
    '''
    def set_train_phase(self):
        self.Di.train()
        self.G.train()
        self.E.train()
    
    '''
        < load_checkpoint >
        If you want to continue to train, load pretrained weight from checkpoint
    '''
    def load_checkpoint(self, checkpoint):
        print('Load model')
        self.Di.load_state_dict(checkpoint['discriminator_image_state_dict'])
        self.G.load_state_dict(checkpoint['generator_state_dict'])
        self.E.load_state_dict(checkpoint['encoder_state_dict'])
        self.optim_Di.load_state_dict(checkpoint['optim_di'])
        self.optim_G.load_state_dict(checkpoint['optim_g'])
        self.optim_E.load_state_dict(checkpoint['optim_e'])
        self.start_epoch = checkpoint['epoch']
        
    '''
        < save_checkpoint >
        Save checkpoint
    '''
    def save_checkpoint(self, state, file_name):
        print('saving check_point')
        torch.save(state, file_name)
    
    '''
        < all_zero_grad >
        Set all optimizers' grad to zero 
    '''
    def all_zero_grad(self):
        self.optim_Di.zero_grad()
        self.optim_G.zero_grad()
        self.optim_E.zero_grad()

    '''
        < train >
        Train the D_image, D_latnet, G and E 
    '''
    def train(self):
        if self.load_weight is True:
            # checkpoint = torch.load(os.path.join(self.weight_dir, '14-G.pkl'))
            checkpoint = torch.load('../input/weightnolatent/checkpoint_49_epoch.pkl')
            self.load_checkpoint(checkpoint)
        
        self.set_train_phase()
        self.show_model()

        print('====================     Training    Start... =====================')
        for epoch in range(self.start_epoch, self.num_epoch):
            start_time = time.time()

            for iters, (img, ground_truth) in tqdm(enumerate(self.dloader)):
                # img : (1, 3, 128, 128) of domain A / ground_truth : (1, 3, 128, 128) of domain B
                img, ground_truth = var(img), var(ground_truth)
          
                # seperate data for image and z latent space
                data = {'img' : img[0].unsqueeze(dim=0), 'ground_truth' : ground_truth[0].unsqueeze(dim=0)}

                ''' ----------------------------- 1. Train D ----------------------------- '''
                # encoded latent vector
                _, _, z_hat = self.E(data['ground_truth'])
                # generate fake image 
                x_tilde = self.G(data['img'], z_hat)

                # random latent vector
                z = var(torch.randn(1, self.z_dim))
                # generate fake image 
                x_hat = self.G(data['img'], z)
                # encoded latent vector
                z_tilde, _, _ = self.E(x_hat)

                # get scores and loss
                real_pair = torch.cat([data['img'], data['ground_truth']], dim=1)
                fake_pair_tilde = torch.cat([data['img'], x_tilde], dim=1)
                fake_pair_hat = torch.cat([data['img'], x_hat], dim=1)

                real_d = self.Di(real_pair)
                fake_d_tidle = self.Di(fake_pair_tilde.detach())
                fake_d_hat = self.Di(fake_pair_hat.detach())


                loss_images = (self.mse_loss(real_d, target=1) * 2 + self.mse_loss(fake_d_tidle, target=0) + \
                                    self.mse_loss(fake_d_hat, target=0)) / 4
                
                d_loss = loss_images

                self.writer.add_scalars('d_losses', {'images_loss': loss_images}, epoch)

                # Update D
                self.all_zero_grad()
                d_loss.backward()
                self.optim_Di.step()
                

                ''' ----------------------------- 2. Train G & E ----------------------------- '''
                # encoded latent vector
                mu, log_variance, z_hat = self.E(data['ground_truth'])
                # generate fake image 
                x_tilde = self.G(data['img'], z_hat)

                # random latent vector
                z = var(torch.randn(1, self.z_dim))
                # generate fake image 
                x_hat = self.G(data['img'], z)
                # encoded latent vector
                z_tilde, _, _ = self.E(x_hat)

                # get scores and loss
                fake_pair_tilde = torch.cat([data['img'], x_tilde], dim=1)
                fake_pair_hat = torch.cat([data['img'], x_hat], dim=1)

                fake_d_tidle = self.Di(fake_pair_tilde)
                fake_d_hat = self.Di(fake_pair_hat)

                loss_images = (self.mse_loss(fake_d_tidle, target=1) +  self.mse_loss(fake_d_hat, target=1)) / 2
                g_loss = loss_images
                
                loss_x_recon = self.recon_x_loss(x_tilde, data['ground_truth']) * self.lambda_img
                loss_z_recon = self.recon_z_loss(z, z_tilde) * self.lambda_z
                loss_kl = self.lambda_kl * self.kl_loss(mu, log_variance)

                eg_loss = g_loss + loss_x_recon + loss_z_recon + loss_kl

                self.all_zero_grad()
                eg_loss.backward()
                self.optim_E.step()
                self.optim_G.step()

                self.writer.add_scalars('eg_losses', {'images_loss': loss_images, 'x_recon_loss': loss_x_recon, \
                    'z_recon_loss': loss_z_recon, 'kl_loss': loss_kl,}, epoch)

                log_file = open('log.txt', 'w')
                log_file.write(str(epoch))
                
                # Print error and save intermediate result image and weight
                if iters % self.save_every == 0:
                    et = time.time() - start_time
                    et = str(datetime.timedelta(seconds=et))[:-7]
                    print('[Elapsed : %s / Epoch : %d / Iters : %d] => D_loss : %f / G_loss : %f / KL_div : %f / img_recon_loss : %f / z_recon_loss : %f'\
                          %(et, epoch, iters, d_loss.item(), g_loss.item(), loss_kl.item(), loss_x_recon.item(), loss_z_recon.item()))

                   
                    # Save intermediate result image
                    if os.path.exists(self.result_dir) is False:
                        os.makedirs(self.result_dir)
                   
                    self.G.eval()
                    with torch.no_grad():
                        result_img = make_img(self.t_dloader, self.G, self.fixed_z, 
                                                img_num=self.test_img_num, img_size=self.img_size)

                    img_name = '{epoch}_{iters}.png'.format(epoch=epoch, iters=iters)
                    img_path = os.path.join(self.result_dir, img_name)

                    torchvision.utils.save_image(result_img, img_path, nrow=self.test_img_num + 1)

                    # Save intermediate weight
                    if os.path.exists(self.weight_dir) is False:
                        os.makedirs(self.weight_dir)
                    
            
            # Save weight at the end of every epoch
            if (epoch + 1) % 5 == 0:
                # self.save_weight(epoch=epoch)
                checkpoint = {
                    "generator_state_dict": self.G.state_dict(),
                    "discriminator_image_state_dict": self.Di.state_dict(),
                    "encoder_state_dict": self.E.state_dict(),
                    "optim_g": self.optim_G.state_dict(),
                    "optim_di": self.optim_Di.state_dict(),
                    "optim_g": self.optim_G.state_dict(),
                    "optim_e": self.optim_E.state_dict(),
                    "epoch": epoch
                    }
                path_checkpoint = os.path.join(self.weight_dir, "checkpoint_{}_epoch.pkl".format(epoch))
                self.save_checkpoint(checkpoint, path_checkpoint)

In [None]:
import torch
import argparse
import os

def main(args):
    solver = Solver(root = args.root,
                    dstname= args.dstname,
                    result_dir = args.result_dir,
                    weight_dir = args.weight_dir,
                    load_weight = args.load_weight,
                    batch_size = args.batch_size,
                    test_size = args.test_size,
                    test_img_num = args.test_img_num,
                    img_size = args.img_size,
                    num_epoch = args.num_epoch,
                    save_every = args.save_every,
                    g_lr = args.g_lr,
                    d_lr = args.d_lr,
                    beta_1 = args.beta_1,
                    beta_2 = args.beta_2,
                    lambda_kl = args.lambda_kl,
                    lambda_img = args.lambda_img,
                    lambda_z = args.lambda_z,
                    z_dim = args.z_dim)
                    
    solver.train()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--root', type=str, default='../input/animeface', 
                        help='Data location')
    parser.add_argument('--result_dir', type=str, default='./', 
                        help='Result images location')
    parser.add_argument('--dstname', type=str, default='anime', 
                        help='Choosed dataset name(sketch2color/black2color)')
    parser.add_argument('--weight_dir', type=str, default='./', 
                        help='Weight location')
    parser.add_argument('--batch_size', type=int, default=1, 
                        help='Training batch size')
    parser.add_argument('--test_size', type=int, default=8, 
                        help='Test batch size')
    parser.add_argument('--test_img_num', type=int, default=8, 
                        help='How many images do you want to generate?')
    parser.add_argument('--img_size', type=int, default=128, 
                        help='Image size')
    parser.add_argument('--g_lr', type=float, default=0.0001,
                        help='Learning rate')
    parser.add_argument('--d_lr', type=float, default=0.0002,
                        help='Discriminator Learning rate')
    parser.add_argument('--beta_1', type=float, default=0.5, 
                        help='Beta1 for Adam')
    parser.add_argument('--beta_2', type=float, default=0.999, 
                        help='Beta2 for Adam')
    parser.add_argument('--lambda_kl', type=float, default=1e-2, 
                        help='Lambda for KL Divergence')
    parser.add_argument('--lambda_img', type=float, default=10, 
                        help='Lambda for image reconstruction')
    parser.add_argument('--lambda_z', type=float, default=0.5, 
                        help='Lambda for z reconstruction')
    parser.add_argument('--z_dim', type=int, default=8, 
                        help='Dimension of z')
    parser.add_argument('--num_epoch', type=int, default=100, 
                        help='Number of epoch')
    parser.add_argument('--save_every', type=int, default=1000, 
                        help='How often do you want to see the result?')
    parser.add_argument('--load_weight', type=bool, default=True,
                        help='Load weight or not')

    args = parser.parse_args([])
    main(args)

In [None]:
import torch
import torchvision

import os
import numpy as np
import argparse


def make_img_split(dloader, G, z, img_num=5):
    if torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor
        
    dloader = iter(dloader)
    img, img_real = dloader.next()

    N = img.size(0)    
    img = var(img.type(dtype))
    img_real = var(img_real.type(dtype))

    for i in range(N):
        # generate img_num images per a domain B image
        real_img = img_real[i].data / 2 + 0.5
        img_name_ = '{idx}.png'.format(idx=str(i + 1))
        img_path = os.path.join(args.result_dir, "ground_truth")
        if os.path.exists(img_path) is False:
            os.makedirs(img_path)
   
        real_img_path = os.path.join(img_path, img_name_)
        torchvision.utils.save_image(real_img, real_img_path)

        # Insert generated images to the next of the original image
        for j in range(img_num):
            img_ = img[i].unsqueeze(dim=0)
            z_ = z[i, j, :].unsqueeze(dim=0)
            
            out_img = G(img_, z_)
            out_img = out_img.data / 2 + 0.5
            img_name = '{idx}.png'.format(idx=str(i * img_num + j + 1))
            img_path = os.path.join(args.result_dir, "generated")
            if os.path.exists(img_path) is False:
                os.makedirs(img_path)
            gen_img_path = os.path.join(img_path, img_name)
            torchvision.utils.save_image(out_img, gen_img_path)


def main(args):    
    dloader, dlen = data_loader(root=args.root, batch_size=100, shuffle=False, 
                                img_size=128, mode='test', dstname='anime')

    if torch.cuda.is_available() is True:
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor
        
    if args.epochs is not None:
        weight_name = 'checkpoint_{epoch}_epoch.pkl'.format(epoch=args.epochs)
    else:
        weight_name = 'checkpoint_1_epoch.pkl'
        
    checkpoint = torch.load(os.path.join(args.weight_dir, weight_name))
    G = Generator(z_dim=8).type(dtype)
    G.load_state_dict(checkpoint['generator_state_dict'])
    G.eval()
    
    if os.path.exists(args.result_dir) is False:
        os.makedirs(args.result_dir)

    # Make latent code and images
    z = make_z(n=dlen, img_num=args.img_num, z_dim=8, sample_type=args.sample_type)

    make_img_split(dloader, G, z, img_num=args.img_num)  

    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--sample_type', type=str, default='random',
                        help='Type of sampling : \'random\' or \'interpolation\'') 
    parser.add_argument('--root', type=str, default='../input/anime-test', 
                        help='Data location')
    parser.add_argument('--result_dir', type=str, default='./',
                        help='Ouput images location')
    parser.add_argument('--weight_dir', type=str, default='../input/weightlatent',
                        help='Trained weight location of generator. pkl file location')
    parser.add_argument('--img_num', type=int, default=8,
                        help='Generated images number per one input image')
    parser.add_argument('--epochs', type=int, default=64,
                        help='Epoch that you want to see the result. If it is None, the most recent epoch')

    args = parser.parse_args([])
    main(args)

In [None]:
!zip -r file.zip ./generated

In [None]:
%cd /kaggle/working

In [None]:
from IPython.display import FileLink

FileLink(r'./generated')