In [16]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import random
from PIL import Image
import math
import os
import tensorflow_datasets as tfds

# Constants

In [17]:
IMG_SIZE = 256
LATENT_SIZE = 512
BATCH_SIZE = 12

LAYERS = int(math.log2(IMG_SIZE) - 1)
MIX_PROB = 0.9
CHA = 48

INITIALIZER = 'he_normal'

# Utility Functions

In [18]:
def noise(num):
    return np.random.normal(0.0, 1.0, size = [num, LATENT_SIZE]).astype('float32')

def get_noise(num):
    return [noise(num)] * LAYERS

def get_mixed_noise(num):
    rand = int(random.random() * LAYERS)
    p1 = [noise(num)] * rand
    p2 = [noise(num)] * (LAYERS - rand)
    return p1 + [] + p2

def img_dim(size):
    return np.random.uniform(0.0, 1.0, size = [size, IMG_SIZE, IMG_SIZE, 1]).astype('float32')

def pixel_norm(x, epsilon = 1e-7):
    mean = tf.keras.backend.mean(x, axis=[1, 2], keepdims=True)
    std = tf.keras.backend.std(x, axis=[1, 2], keepdims=True) + epsilon
    return (x - mean) / std

# Loss Functions

In [19]:
def gradient_loss(sample, output, weights):
    grad = tf.keras.backend.gradients(output, sample)[0]
    grad_sq = tf.keras.backend.square(grad)
    grad_loss = tf.keras.backend.sum(grad_sq, axis=np.arange(1, len(grad_sq.shape)))
    return tf.keras.backend.mean(grad_loss * weights)

def wasserstein_loss(y_true, y_pred):
    return tf.keras.backend.mean(y_true * y_pred)

# Custom Layers

In [20]:
def fade_in(alpha, a, b):
    alpha = tf.reshape(alpha, [-1, 1, 1, 1])
    alpha = tf.clip(alpha - b, 0, 1)
    return a * alpha + ((1.0 - alpha) * (1.0 - a))

def AdaIN(input_shapes):
    y = pixel_norm(input_shapes[0])
    #shape = [-1, 1, 1, y.shape[-1]]
    scale = tf.reshape(input_shapes[1], (-1, 1, 1, y.shape[-1])) + 1.0
    bias = tf.reshape(input_shapes[2], (-1, 1, 1, y.shape[-1]))
    return y * scale + bias 

def fit(x):
    h = x[1].shape[1]
    w = x[1].shape[2]
    return x[0][:, :h, :w, :]

# Blocks

In [21]:
def get_gen_block(input_tensor, style, inoise, filters, up_sample = True):
    if up_sample:
        block = tf.keras.layers.UpSampling2D()(input_tensor)
    else:
        block = tf.keras.layers.Activation('linear')(input_tensor)

    beta = tf.keras.layers.Dense(filters)(style)
    delta = tf.keras.layers.Lambda(fit)([inoise, block])
    delta = tf.keras.layers.Dense(filters, kernel_initializer='zeros')(delta)
    gamma = tf.keras.layers.Dense(filters)(style)

    block = tf.keras.layers.Conv2D(filters=filters, kernel_size=3, padding='same', \
        kernel_initializer='he_normal')(block)
    block = tf.keras.layers.add([block, delta])
    block = tf.keras.layers.Lambda(AdaIN)([block, gamma, beta])

    return tf.keras.layers.LeakyReLU(0.2)(block)

def get_desc_block(input_tensor, fil, pool = True):
    block = tf.keras.layers.Conv2D(filters=fil, kernel_size=3, \
        padding='same', kernel_initializer='he_normal')(input_tensor)
    block = tf.keras.layers.LeakyReLU(0.2)(block)

    if pool:
        block = tf.keras.layers.AveragePooling2D()(block)

    return block

# GAN

In [22]:
class GAN(object):
    def __init__(self, steps = 1, learn_rate = 1e-4, decay = 1e-5):
        self.desc = None
        self.gen = None
        self.style = None

        self.g_model = None

        self.L_Rate = learn_rate
        self.steps = steps
        self.beta = 0.99

        self.discriminator()
        self.generator()

        # self.g_model = tf.keras.models.model_from_json(self.gen.to_json())
        # self.g_model.set_weights(self.gen.get_weights())
        
        # self.s_model = tf.keras.models.model_from_json(self.style.to_json())
        # self.s_model.set_weights(self.style.get_weights())

    def discriminator(self):
        if self.desc:
            return self.desc
        
        input_tensor = tf.keras.layers.Input(shape = [IMG_SIZE, IMG_SIZE, 3])

        #x = tf.keras.layers.Input(shape = [IMG_SIZE, IMG_SIZE, 3])
        x = get_desc_block(input_tensor, 1*CHA)
        x = get_desc_block(x, 2*CHA)
        x = get_desc_block(x, 3*CHA)
        x = get_desc_block(x, 4*CHA)
        x = get_desc_block(x, 6*CHA)
        x = get_desc_block(x, 8*CHA)
        x = get_desc_block(x, 16*CHA, False)
        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(16*CHA, kernel_initializer='he_normal')(x)
        x = tf.keras.layers.LeakyReLU(0.2)(x)
        x = tf.keras.layers.Dense(1, kernel_initializer='he_normal')(x)

        self.desc = tf.keras.models.Model(inputs=input_tensor, outputs=x)

        return self.desc

    def generator(self):
        if self.gen:
            return self.gen
        
        # Style Mapping
        self.style = tf.keras.Sequential()
        self.style.add(tf.keras.layers.Dense(512, input_shape=[LATENT_SIZE]))
        self.style.add(tf.keras.layers.LeakyReLU(0.2))
        self.style.add(tf.keras.layers.Dense(512))
        self.style.add(tf.keras.layers.LeakyReLU(0.2))
        self.style.add(tf.keras.layers.Dense(512))
        self.style.add(tf.keras.layers.LeakyReLU(0.2))
        self.style.add(tf.keras.layers.Dense(512))
        self.style.add(tf.keras.layers.LeakyReLU(0.2))  
            # tf.keras.layers.Dense(512, input_shape=[LATENT_SIZE]),
            # tf.keras.layers.LeakyReLU(0.2),
            # tf.keras.layers.Dense(512),
            # tf.keras.layers.LeakyReLU(0.2),
            # tf.keras.layers.Dense(512),
            # tf.keras.layers.LeakyReLU(0.2),
            # tf.keras.layers.Dense(512),
            # tf.keras.layers.LeakyReLU(0.2)

        # Actual Generator
        input_style = []

        for i in range(LAYERS):
            input_style.append(tf.keras.Input([LATENT_SIZE]))

        input_noise = tf.keras.layers.Input([IMG_SIZE, IMG_SIZE, 1])

        x = tf.keras.layers.Lambda(lambda x: x[:, :128])(input_style[0])
        x = tf.keras.layers.Dense(4*4*4*CHA, activation='relu', kernel_initializer='he_normal')(x)
        x = tf.keras.layers.Reshape([4, 4, 4*CHA])(x)
        x = get_gen_block(x, input_style[0], input_noise, 16*CHA, up_sample=False)
        x = get_gen_block(x, input_style[1], input_noise, 8*CHA)
        x = get_gen_block(x, input_style[2], input_noise, 6*CHA)
        x = get_gen_block(x, input_style[3], input_noise, 4*CHA)
        x = get_gen_block(x, input_style[4], input_noise, 3*CHA)
        x = get_gen_block(x, input_style[5], input_noise, 2*CHA)
        x = get_gen_block(x, input_style[6], input_noise, 1*CHA)
        x = tf.keras.layers.Conv2D(filters=3, kernel_size=1, padding='same', kernel_initializer='he_normal')(x)

        self.gen = tf.keras.models.Model(inputs = input_style + [input_noise], outputs = x)

        return self.gen
    
    def gen_model(self):
        input_style = []
        style = []

        for i in range(LAYERS):
            input_style.append(tf.keras.layers.Input([LATENT_SIZE]))
            style.append(self.style(input_style[-1]))

        input_noise = tf.keras.layers.Input([IMG_SIZE, IMG_SIZE, 1])

        x = self.gen(style+[input_noise])
        self.g_model = tf.keras.models.Model(inputs = input_style + [input_noise], outputs = x)

        return self.g_model


#  Optimisers

In [23]:
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0, beta_2=0.9)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=4*1e-4, beta_1=0, beta_2=0.9)

# Style GAN

In [24]:
class StyleGAN(object):
    def __init__(self, steps = 1, learn_rate = 1e-4, decay = 1e-5):
        self.GAN = GAN(steps = steps, learn_rate = learn_rate, decay = decay)
        
        self.generator = self.GAN.gen_model()
        self.discriminiator = self.GAN.discriminator()
        
        self.weight = np.array([10] * BATCH_SIZE).astype('float32')

    def train(self, train_set):
        # Randomly train alternating styles
        if random.random() < MIX_PROB:
            style = get_mixed_noise(BATCH_SIZE)
        else:
            style = get_noise(BATCH_SIZE)

        
        d_loss, g_loss, div = self.train_step(train_set.astype('float32'), style, img_dim(12), self.weight)

        new_weight = 5/(np.array(div) + 1e-7)
        self.weight = self.weight[0] * 0.9 + 0.1 * new_weight
        self.weight = np.clip([self.weight] * BATCH_SIZE, 0.01, 10000.0).astype('float32')

        # Print progress after models after 100 steps
        if self.GAN.steps%1 == 0:
            print("\n==============================")
            print("Epoch: ", self.GAN.steps)
            print("Discriminator Loss: ", d_loss)
            print("Generator Loss: ", g_loss)
            print("==============================\n")

            #Save images in /Generated-img after every 500 epochs
            if self.GAN.steps%500 == 0:
                self.save_image(self.GAN.steps/500)

        self.GAN.steps += 1
        if self.GAN.steps < 2:
            print(self.GAN.steps)

    @tf.function
    def train_step(self, images, style, noise, weight):
        # =======================DEBUG===============
        print(images.shape)
        with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
            generated_img = self.GAN.g_model(style + [noise], training=True)
            real_output = self.GAN.desc(images, training=True)
            generated_output = self.GAN.desc(generated_img, training=True)

            generator_loss = tf.keras.backend.mean(generated_output)
            divergence = tf.keras.backend.mean(tf.keras.backend.relu(1+real_output) \
                + tf.keras.backend.relu(1-generated_output))
            discriminator_loss = divergence + gradient_loss(images, real_output, weight)

        gradients_of_generator = g_tape.gradient(generator_loss, self.GAN.g_model.trainable_variables)
        gradients_of_discriminator = d_tape.gradient(discriminator_loss, \
            self.GAN.desc.trainable_variables)

        generator_optimizer.apply_gradients(zip(gradients_of_generator, \
            self.GAN.g_model.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, \
            self.GAN.desc.trainable_variables))
        
        return discriminator_loss, generator_loss, divergence

    def save_image(self, image_num):
        noise1 = get_noise(64)
        noise2 = img_dim(64)

        # g_model = self.GAN.gen_model()
        generated_images = self.GAN.g_model.predict(noise1 + [noise2], batch_size = BATCH_SIZE)

        result = []

        # for i in range(0, 64, 8):
        #     result.append(np.concatenate(generated_images[i:i+8], axis=1))
        result.append(np.concatenate(generated_images[0:1], axis = 1))
        x = np.concatenate(result, axis = 0)
        x = np.clip(x, 0.0, 1.0)

        images = Image.fromarray(np.uint8(x*255))
        #images = Image.fromarray(np.array(x))

        #np.save("Generated_img/img-"+str(image_num)+".png", images)
        images.save("Generated_img/img-"+str(image_num)+".png")
        if image_num == 0:
            plt.imshow(images)
             

# Data Processing

## Convert OASIS Brain data to .npy
Converts the OASIS Brain .png images to .npy arrays for training efficiency 

In [25]:
# Data directory for OASIS Brain 
def convert_to_npy(dir_path):
    segment_length = (1024 ** 3) // (IMG_SIZE*IMG_SIZE*3)

    file_names = []

    for dirpath, dirnames, filenames in os.walk(dir_path):
        for filename in filenames:
            file_names.append(os.path.join(dirpath, filename))

    np.random.shuffle(file_names)
    
    segment = []
    ctr = 0
    kn = 0
    for fname in file_names:
        img = Image.open(fname).convert("RGB").resize((IMG_SIZE, IMG_SIZE), Image.BILINEAR)
        img = np.array(img, dtype='uint8')
        segment.append(img)
        kn += 1

        if kn >= segment_length:
            np.save("OASIS-Brain-npy/image-" + str(ctr) + ".npy", np.array(segment))
            segment = []
            kn = 0
            ctr += 1

    np.save("OASIS-Brain-npy/image-" + str(ctr) + ".npy", np.array(segment))

If data nor loaded as .npy run the next block

In [26]:

# Uncomment this block to run code to convert image folder to .npy folders
# path to original image directory will go to 'data_dir'

# data_dir = "keras_png_slices_data/keras_png_slices_train"
# convert_to_npy(data_dir)

If folder with .npy exists run next

In [27]:
def load_data(data_path):
    segment = []
    images = []
    for dirpath, dirnames, filenames in os.walk(data_path):
        for filename in filenames:
            segment.append(os.path.join(dirpath, filename))
    
    index = random.randint(0, len(segment) - 1)
    images = np.load(segment[index])
    return images

images = load_data('OASIS-Brain-npy')
print(str(len(images)) + " images.")

1184 images.


# Get Training Images
### Training to be done on number of images = BATCH_SIZE 

In [28]:
def get_training_batch(images, update):
    if update > images.shape[0]:
        images = load_data('OASIS-Brain-npy')

    #randomly select #BATCH_SIZE numbers
    print(images.shape)
    indeces = np.random.randint(0, images.shape[0] - 1, BATCH_SIZE)
    train_set = []
    for i in indeces:
        train_set.append(images[i])

    return np.array(train_set).astype('float32') / 255.0
# image_indices = np.random.randint(0, train.shape[0] - 1, [4])
# real_images = train[image_indices]

# for i in range(4):
#     plt.figure(i)
#     plt.imshow(real_images[i])

# plt.show()

# Train Model

In [29]:
model = StyleGAN()
model.save_image(0)
update = 0
while model.GAN.steps <= 1000001:
    train_set = get_training_batch(images, update)
    # ========= DEBUG ==========
    #print(train_set.shape)
    update += BATCH_SIZE
    model.train(train_set)

ResourceExhaustedError:  OOM when allocating tensor with shape[12,144,128,128] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node model_5/model_4/conv2d_27/Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
 [Op:__inference_predict_function_6283]

Function call stack:
predict_function
