In [17]:
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.layers import Layer

from glob import glob

import random
from functools import partial

import numpy as np
import matplotlib.pyplot as plt

In [2]:
PATH = 'c:/Users/admin/Documents/Datasets/celeba_hq/'
BUFFER_SIZE = 200
IMAGE_RESOLUTION = 256

In [3]:
def load(res, image_file):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)
    image = tf.image.resize(image, [res, res],
                            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    image = tf.cast(image, tf.float32)
    image = (image/127.5)-1
    return image

In [6]:
train_datasets = {}
train_images = glob(PATH + 'train/**/*.jpg')
random.shuffle(train_images)
train_dataset_list = tf.data.Dataset.from_tensor_slices(train_images)

In [8]:
n_workers = tf.data.experimental.AUTOTUNE

In [9]:
BATCH_SIZE = {2: 16, 3: 16, 4: 16, 5: 16, 6: 16, 7: 8, 8: 4, 9: 4, 10: 4}
TRAIN_STEP_RATIO = {k: BATCH_SIZE[2]/v for k, v in BATCH_SIZE.items()}

In [11]:
TRAIN_STEP_RATIO

{2: 1.0, 3: 1.0, 4: 1.0, 5: 1.0, 6: 1.0, 7: 2.0, 8: 4.0, 9: 4.0, 10: 4.0}

In [13]:
for log2_res in range(2, int(np.log2(IMAGE_RESOLUTION)) + 1):
    res = 2**log2_res
    temp = train_dataset_list.map(partial(load, res), num_parallel_calls=n_workers)
    
    temp = temp.shuffle(BUFFER_SIZE).batch(BATCH_SIZE[log2_res], drop_remainder=True).repeat()
    train_datasets[log2_res] = temp

In [14]:
train_datasets

{2: <RepeatDataset element_spec=TensorSpec(shape=(16, 4, 4, None), dtype=tf.float32, name=None)>,
 3: <RepeatDataset element_spec=TensorSpec(shape=(16, 8, 8, None), dtype=tf.float32, name=None)>,
 4: <RepeatDataset element_spec=TensorSpec(shape=(16, 16, 16, None), dtype=tf.float32, name=None)>,
 5: <RepeatDataset element_spec=TensorSpec(shape=(16, 32, 32, None), dtype=tf.float32, name=None)>,
 6: <RepeatDataset element_spec=TensorSpec(shape=(16, 64, 64, None), dtype=tf.float32, name=None)>,
 7: <RepeatDataset element_spec=TensorSpec(shape=(8, 128, 128, None), dtype=tf.float32, name=None)>,
 8: <RepeatDataset element_spec=TensorSpec(shape=(4, 256, 256, None), dtype=tf.float32, name=None)>}

In [16]:
def plot_images(images, log2_res, fname=''):
    scales = {2:0.5, 
              3:1, 
              4:2, 
              5:3,
              6:4,
              7:5,
              8:6, 
              9:7,
              10:8}
    scale = scales[log2_res]
    
    grid_col = min(12, int(12//scale))
    grid_row = images.shape[0]//grid_col
    grid_row = min(2, grid_row)
    
    f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*scale, grid_row*scale))
    
    for row in range(grid_row):
        ax = axarr if grid_row==1 else axarr[row]
        for col in range(grid_col):
            ax[col].imshow(images[row*grid_col+col])
            ax[col].axis('off')
    plt.show()
    if fname:
        print('image name', fname)
        f.savefig(fname)

In [18]:
class PixelNorm(Layer):
    def __init__(self, epsilon=1e-8):
        super(PixelNorm, self).__init__()
        self.epsilon = epsilon
        
    def call(self, input_tensor):
        return input_tensor / tf.math.sqrt(tf.reduce_mean(input_tensor**2, axis=-1, keepdims=True) + self.epsilon)

In [20]:
class MinibatchStd(Layer):
    def __init__(self, group_size=4, epsilon=1e-8):
        super(MinibatchStd, self).__init__()
        self.epsilon = epsilon
        self.group_size = group_size
        
    def call(self, input_tensor):
        n, h, w, c = input_tensor.shape
        x = tf.reshape(input_tensor, [self.group_size, -1, h, w, c])
        group_mean, group_var = tf.nn.moments(x, axes=(0), keepdims=False)
        group_std = tf.sqrt(group_var + self.epsilon)
        avg_std = tf.reduce_mean(group_std, axis=[1,2,3], keepdims=True)
        x = tf.tile(avg_std, [self.group_size, h, w, 1])
        
        return tf.concat([input_tensor, x], axis=-1)

In [21]:
class FadeIn(Layer):
    @tf.function
    def call(self, input_alpha, a, b):
        alpha = tf.reduce_mean(input_alpha)
        y = alpha * a + ( 1. - alpha) * b
        return y

In [22]:
def wasserstein_loss(y_true, y_pred):
    return -tf.reduce_mean(y_true * y_pred)

In [None]:
class Conv2D(Layer):
    def __init__(self, out_channels, kernel=3, gain=2, **kwargs):
        super(Conv2D, self).__init__(kwargs)
        self.kernel = kernel
        self.out_channels = out_channels
        self.gain = gain
        self.pad = kernel!=1
        
    def build(self, input_shape):
        self.in_channels = input_shape[-1]
        initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=1.)
        self.w = self.add_weight(shape=[self.kernel,
                                        self.kernel,
                                        self.in_channels,
                                        self.out_channels],
                                 initializer=initializer,
                                 trainable=True, name='kernel')
        self.b = self.add_weight(shape=(self.out_channels,),
                                 initializer='zeros',
                                 trainable=True, name='bias')
        fan_in = self.kernel*self.kernel*self.in_channels
        self.scale = tf.sqrt(self.gain/fan_in)
        
    def call(self, inputs):
        if self.pad:
            x = tf.pad(inputs, [[0,0],[1,1],[1,1],[0,0]], mode='REFLECT')
        else:
            x = inputs
        output = tf.nn.conv2d(x, self.scale*self.w, stri