In [2]:
import tensorflow as tf
import numpy as np
import matplotlib as plt
import random
from PIL import Image
import tensorflow_datasets as tfds

# Constants

In [3]:
IMG_SIZE = 256
LATENT_SIZE = 512
BATCH_SIZE = 12

LAYERS = int(np.log2(IMG_SIZE) - 1)
MIX_PROB = 0.9
CHA = 48

INITIALIZER = tf.keras.initializers.he_normal

# Utility Functions

In [3]:
def noise(num):
    return np.random.normal(0.0, 1.0, size = [num, LATENT_SIZE]).astype('float32')

def get_noise(num):
    return list(noise(num)) * LAYERS

def get_mixed_noise(num):
    rand = int(random() - LAYERS)
    p1 = list(noise(num)) * rand
    p2 = list(noise(num)) * (LAYERS - rand)
    return p1 + p2

def img_dim(size):
    return np.random.uniform(0.0, 1.0, size = [size, IMG_SIZE, IMG_SIZE, 1]).astype('float32')

def pixel_norm(x, epsilon = 1e-8):
    mean = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
    std = tf.keras.backend.std(x, axis=[1, 2], keepdims=True) + epsilon
    return (x - mean) / std

# Loss Functions

In [4]:
def gradient_loss(sample, output, weights):
    grad = tf.keras.backend.gradients(output, sample, weights)
    grad_sq = tf.keras.backend.square(grad)
    grad_loss = tf.keras.backend.sum(grad_sq, axis=np.arrange(1, len(grad_sq.shape)))
    return tf.keras.mean(grad_loss * weights)

def wasserstein_loss(y_true, y_pred):
    return tf.reduce_mean(y_true * y_pred)

# Custom Layers

In [5]:
def fade_in(alpha, a, b):
    alpha = tf.reshape(alpha, [-1, 1, 1, 1])
    alpha = tf.clip(alpha - b, 0, 1)
    return a * alpha + ((1.0 - alpha) * (1.0 - a))

def AdaIN(input_shapes):
    y = pixel_norm(input_shapes[0])
    #shape = [-1, 1, 1, y.shape[-1]]
    scale = tf.reshape(input_shapes[1], (-1, 1, 1, y.shape[-1])) + 1.0
    bias = tf.reshape(input_shapes[2], (-1, 1, 1, y.shape[-1]))
    return y * scale + bias 

def fit(x):
    h = x[1].shape[1]
    w = x[1].shape[2]
    return x[0][:, :h, :w, :]

# Blocks

In [None]:
def get_gen_block(input_tensor, style, inoise, filters, up_sample = True):
    if up_sample:
        block = tf.keras.layers.UpSampling2D()(input_tensor)
    else:
        block = tf.keras.layers.Activation('linear')(input_tensor)

    beta = tf.keras.layers.Dense(filters)(style)
    delta = tf.keras.layers.Lambda(fit)([inoise, block])
    delta = tf.keras.layers.Dense(filters, kernel_initializer='zeros')(delta)
    gamma = tf.keras.layers.Dense(filters)(style)

    block = tf.keras.layers.Conv2D(filters=filter, kernel_size=3, padding='same', \
        kernel_initializer=INITIALIZER)(block)
    block = tf.keras.layers.add([block, delta])
    block = tf.keras.layers.Lambda(AdaIN)([block, gamma, beta])

    return tf.keras.layers.LeakyReLU(0.2)(block)

def get_desc_block(input_tensor, filters, pool = True):
    block = tf.keras.layers.Conv2D(filters=filters, kernel_size=3, \
        padding='same', kernel_initializer=INITIALIZER)(input_tensor)
    block = tf.keras.layers.LeakyReLU(0.2)(block)

    if pool:
        block = tf.keras.layers.AveragePooling2D()(block)

    return block

# GAN

In [2]:
class GAN(object):
    def __init__(self, steps = 1, learn_rate = 1e-4, decay = 1e-5):
        self.desc = None
        self.gen = None
        self.style = None

        self.L_Rate = learn_rate
        self.steps = steps
        self.beta = 0.99

        self.discriminator()
        self.generator()

        self.g_model = tf.keras.models.model_from_json(self.gen.to_json())
        self.g_model.set_weights(self.gen.get_weights())
        
        self.s_model = tf.keras.models.model_from_yaml(self.style.to_yaml())
        self.s_model.set_weights(self.style.get_weights())

    def discriminator(self):
        if self.desc:
            return self.desc
        
        input_tensor = tf.keras.layers.Input(shape = [IMG_SIZE, IMG_SIZE, 3])

        x = input_tensor
        x = get_desc_block(x, 1*CHA)
        x = get_desc_block(x, 2*CHA)
        x = get_desc_block(x, 3*CHA)
        x = get_desc_block(x, 4*CHA)
        x = get_desc_block(x, 4*CHA)
        x = get_desc_block(x, 6*CHA)
        x = get_desc_block(x, 8*CHA)
        x = get_desc_block(x, 16*CHA, False)
        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(16*CHA, kernel_initializer=INITIALIZER)(x)
        x = tf.keras.layers.LeakyReLU(0.2)(x)
        x = tf.keras.layers.Dense(1, kernel_initializer=INITIALIZER)(x)

        self.desc = tf.keras.models.Model(inputs=input_tensor, outputs=x)

        return self.desc

    def generator(self):
        if self.gen:
            return self.gen
        
        # Style Mapping
        self.style = tf.keras.models.Sequential(
            [
                tf.keras.layers.Dense(512, input_shape=[LATENT_SIZE]),
                tf.keras.layers.LeakyReLU(0.2),
                tf.keras.layers.Dense(512),
                tf.keras.layers.LeakyReLU(0.2),
                tf.keras.layers.Dense(512),
                tf.keras.layers.LeakyReLU(0.2),
                tf.keras.layers.Dense(512),
                tf.keras.layers.LeakyReLU(0.2)
            ]
        )

        # Actual Generator
        input_style = []

        for i in range(LAYERS):
            input_style.append(tf.keras.layers.Input([512]))

        input_noise = tf.keras.layers.Input([IMG_SIZE, IMG_SIZE, 1])

        x = tf.keras.layers.Lambda(lambda x: x[:, :128])(input_style[0])
        x = tf.keras.layers.Dense(4*4*4*CHA, activation=tf.nn.relu, kernel_initializer=INITIALIZER)(x)
        x = tf.keras.layers.Reshape([4, 4, 4*CHA])(x)
        x = get_gen_block(x, input_style[0], input_noise, 16*CHA, up_sample=False)
        x = get_gen_block(x, input_style[1], input_noise, 8*CHA)
        x = get_gen_block(x, input_style[2], input_noise, 6*CHA)
        x = get_gen_block(x, input_style[3], input_noise, 4*CHA)
        x = get_gen_block(x, input_style[4], input_noise, 3*CHA)
        x = get_gen_block(x, input_style[5], input_noise, 2*CHA)
        x = get_gen_block(x, input_style[6], input_noise, 1*CHA)
        x = tf.keras.layers.Conv2D(filters=3, kernel_size=1, padding='same', kernel_initializer=INITIALIZER)(x)

        self.gen = tf.keras.models.Model(inputs = input_style + input_style, outputs = x)

        return self.gen
    
    def gen_model(self):
        input_style = []
        style = []

        for i in range(LAYERS):
            input_style.append(tf.keras.layers.Input([LATENT_SIZE]))
            style.append(self.style(input_style[-1]))

        input_noise = tf.keras.layers.Input([IMG_SIZE, IMG_SIZE, 1])

        x = self.gen(style+[input_noise])

        return tf.keras.models.Model(inputs = input_style + [input_noise], outputs = x)


#  Optimisers

In [None]:
generator_optimizer = tf.optimizers.Adam(learning_rate=1e-4, beta_1=0, beta_2=0.9)
discriminator_optimizer = tf.optimisers.Adam(learning_rate=4*1e-4, beta_1=0, beta_2=0.9)

# Style GAN

In [None]:
"!!!!!!!!!!!!!! # NEED WORK # !!!!!!!!!!!!!!!!!!"
class StyleGAN(object):
    def __init__(self, data_dir, steps = 1, learn_rate = 1e-4, decay = 1e-5):
        self.GAN = GAN(steps = steps, learn_rate = learn_rate, decay = decay)
        
        self.generator = self.GAN.gen_model()
        self.discriminiator = self.GAN.discriminator()

        self.data = data_dir
        
        self.weight = np.array([10] * BATCH_SIZE).astype('float32')

    def train(self):
        # Randomly train alternating styles
        if random.random() < MIX_PROB:
            style = get_mixed_noise(BATCH_SIZE)
        else:
            style = get_noise(BATCH_SIZE)

        img = self.data
        
        d_loss, g_loss, div = self.train_step(img, style, img_dim(BATCH_SIZE), self.weight)

        new_weight = 5/(np.array(div) + 1e-7)
        self.weight = self.gp_weight[0] * 0.9 + 0.1 * new_weight
        self.weight = np.clip([self.weight] * BATCH_SIZE, 0.01, 10000.0).astype('float32')

        # Print progress after models after 100 steps
        if self.GAN.steps%100 == 0:
            print("\n==============================")
            print("Epoch: ", self.GAN.steps)
            print("Discriminator Loss: ", d_loss)
            print("Generator Loss: ", g_loss)
            print("==============================\n")

            #Save images in /Generated-img after every 500 epochs
            if self.GAN.steps%500 == 0:
                self.save_image(self.GAN.steps/500)

        self.GAN.steps += 1

    def train_step(self, images, style, noise, weight):
        with tf.GradientTape() as g_tape, tf.GradientTape as d_tape:
            generated_img = self.generator(style + [noise], training=True)
            real_output = self.discriminiator(images, training=True)
            generated_output = self.discriminiator(generated_img, training=True)

            generator_loss = tf.keras.backend.mean(generated_output)
            divergence = tf.keras.backend.mean(tf.nn.relu(1+real_output) \
                + tf.nn.relu(1-generated_output))
            discriminator_loss = divergence + gradient_loss(images, real_output, weight)

        gradients_of_generator = g_tape.gradient(generator_loss, self.generator.trainable_variables)
        gradients_of_discriminator = d_tape.gradient(discriminator_loss, \
            self.discriminiator.trainable_variables)

        generator_optimizer.apply_gradients(zip(gradients_of_generator, \
            self.generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, \
            self.discriminiator.trainable_variables))
        
        return discriminator_loss, generator_loss, divergence

    def save_image(self, image_num):
        noise1 = get_noise(64)
        noise2 = img_dim(64)

        g_model = self.GAN.gen_model()
        generated_images = g_model.predict(noise1 + [noise2])

        result = []

        for i in range(0, 64, 8):
            result.append(np.concatenate(generated_images[i:i+8], axis=1))

        x = np.concatenate(result, axis = 0)
        x = np.clip(x, 0.0, 1.0)

        images = Image.fromarray(np.unit8(x*255))

        images.save("Generated_img/img-"+str(image_num)+".png")
             

# Data Processing

# Run Training

In [None]:
model = StyleGAN(data)
model.save_image(0)

while model.GAN.steps <= 1000001:
    model.train()