In [61]:
from __future__ import print_function
import tensorflow as tf
from tensorflow import keras
import numpy as np
import cv2
import random
import os
import matplotlib.pyplot as plt
from IPython import display
import time

In [44]:
# data augmention(preprocess)
class data_aug:
    def resize(self, input_image, real_image, height, width):
        input_image = tf.image.resize(input_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        real_image = tf.image.resize(real_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

        return input_image, real_image

    def random_crop(self, input_image, real_image):
        stack_image = tf.stack([input_image, real_image], axis=0)
        cropped_image = tf.image.random_crop(stack_image, size=[2, 256, 256, 3])

        return cropped_image[0], cropped_image[1]

    @tf.function()
    def random_jitter(self, input_image, real_image):
        input_image, real_image = self.resize(input_image, real_image, 286, 286)
        input_image, real_image = self.random_crop(input_image, real_image)

        if tf.random.uniform(()) > 0.5:
            input_image = tf.image.flip_left_right(input_image)
            real_image = tf.image.flip_left_right(real_image)

        return input_image, real_image

In [101]:
# dataset
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

def load(image_file):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_png(image, channels=3)
    image = tf.cast(image, tf.float32)
    image = tf.image.resize(image, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    return image

def load_image(image_file):
    image = load(image_file)

    return image

train_input_dataset_path = os.listdir('data/train/')
train_input_dataset_path = [os.path.join(os.path.abspath('data/train/'), p) for p in train_input_dataset_path]

test_input_dataset_path = os.listdir('data/test/')
test_input_dataset_path = [os.path.join(os.path.abspath('data/test/'), p) for p in test_input_dataset_path]

train_aug_dataset_path = os.listdir('data/aug/train')
train_aug_dataset_path = [os.path.join(os.path.abspath('data/aug/train'), p) for p in train_aug_dataset_path]

test_aug_dataset_path = os.listdir('data/aug/test')
test_aug_dataset_path = [os.path.join(os.path.abspath('data/aug/test'), p) for p in test_aug_dataset_path]

train_input_dataset = tf.data.Dataset.from_tensor_slices(train_input_dataset_path)
train_input_dataset = train_input_dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)

test_input_dataset = tf.data.Dataset.from_tensor_slices(test_input_dataset_path)
test_input_dataset = test_input_dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)

train_aug_dataset = tf.data.Dataset.from_tensor_slices(train_aug_dataset_path)
train_aug_dataset = train_aug_dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)

test_aug_dataset = tf.data.Dataset.from_tensor_slices(test_aug_dataset_path)
test_aug_dataset = test_aug_dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)

train_dataset = tf.data.Dataset.zip((train_input_dataset, train_aug_dataset))
test_dataset = tf.data.Dataset.zip((test_input_dataset, test_aug_dataset))

In [45]:
# network module
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0, 0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding="same", kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())

    result.add(tf.keras.layers.LeakyReLU())

    return result

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0, 0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding="same", kernel_initializer=initializer, use_bias=False))
    result.add(tf.keras.layers.BatchNormalization())

    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))

    result.add(tf.keras.layers.ReLU())

    return result

In [175]:
class Generator(tf.keras.Model):
    def __init__(self):
        super(Generator, self).__init__()
        self.down_stack = [
            downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
            downsample(128, 4), # (bs, 64, 64, 128)
            downsample(256, 4), # (bs, 32, 32, 256)
            downsample(512, 4), # (bs, 16, 16, 512)
            downsample(512, 4), # (bs, 8, 8, 512)
            downsample(512, 4), # (bs, 4, 4, 512)
            downsample(512, 4), # (bs, 2, 2, 512)
            downsample(512, 4), # (bs, 1, 1, 512)
        ]

        self.up_stack = [
            upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
            upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
            upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
            upsample(512, 4), # (bs, 16, 16, 1024
            upsample(256, 4), # (bs, 32, 32, 512)
            upsample(128, 4), # (bs, 64, 64, 256)
            upsample(64, 4), # (bs, 128, 128, 128)
        ]

        self.last = tf.keras.layers.Conv2DTranspose(3, 4, strides=2,padding='same', kernel_initializer=tf.random_normal_initializer(0, 0.02), activation='tanh') # (bs, 256, 256, 3)

    def call(self, inputs, training=True):
        x = inputs

        skips = []
        for down in self.down_stack:
            x = down(x)
            skips.append(x)
        skips = reversed(skips[:-1])

        for up, skip in zip(self.up_stack, skips):
            x = up(x)
            x = tf.keras.layers.Concatenate()([x, skip])

        output = self.last(x)

        return output

In [196]:
LAMDA = 100
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

    # mean absolute error
    l1_loss = tf.reduce_mean(tf.abs(target-gen_output))
    total_gen_loss = gan_loss + (LAMDA * l1_loss)

    return total_gen_loss, gan_loss, l1_loss

In [193]:
# Discriminator
class Discriminator(tf.keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.down_stack = [downsample(64, 4, False), downsample(128, 4), downsample(256, 4)]
        self.zero_pad1 = tf.keras.layers.ZeroPadding2D()
        self.zero_pad2 = tf.keras.layers.ZeroPadding2D()
        self.conv = tf.keras.layers.Conv2D(512, 4, strides=1, kernel_initializer=tf.random_normal_initializer(0, 0.02), use_bias=False)
        self.batchnorm = tf.keras.layers.BatchNormalization()
        self.leaky_relu = tf.keras.layers.LeakyReLU()
        self.last = tf.keras.layers.Conv2D(1, 4, strides=1, kernel_initializer=tf.random_normal_initializer(0, 0.02))

    def call(self, inputs):
        # inputs = input + target(concatenated)
        inp, tar = inputs
        x = tf.concat([inp, tar], 0)
        # x = tf.reshape(x, [-1, 286, 286, 6])
        for down in self.down_stack:
            x = down(x)
        x = self.conv(self.zero_pad1(x))
        x = self.zero_pad2(self.leaky_relu(self.batchnorm(x)))
        output = self.last(x)

        return output

In [185]:
inputs = [0, 1]

0


In [199]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
    total_disc_loss = real_loss + generated_loss

    return total_disc_loss

In [50]:
import datetime
log_dir="logs/"

summary_writer = tf.summary.create_file_writer(
    log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [206]:
# training
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

generator = Generator()
discriminator = Discriminator()

@tf.function
def train_step(input_image, target, epoch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)
        print("OK")
        # disc_input1 = tf.keras.layers.Concatenate()[input_image, target]
        # disc_input1 = tf.keras.layers.Concatenate()[input_image, gen_output]

        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)
        print(disc_real_output.shape)
        print(disc_generated_output.shape)
        print("OK")
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss, generator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))

    with summary_writer.as_default():
        tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
        tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
        tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
        tf.summary.scalar('disc_loss', disc_loss, step=epoch)

In [167]:
def normalize(input_image, real_image):
    input_image = (input_image / 127.5) - 1
    real_image = (real_image / 127.5) - 1

    return input_image, real_image

def resize(input_image, real_image, height, width):
    input_image = tf.image.resize(input_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    return input_image, real_image

def random_crop(input_image, real_image):
    stack_image = tf.stack([input_image, real_image], axis=0)
    print(stack_image)
    cropped_image = tf.image.random_crop(stack_image, size=[2, -1, 256, 256, 3])

    return cropped_image[0], cropped_image[1]

@tf.function
def random_jitter(input_image, real_image):
    input_image, real_image = resize(input_image, real_image, 286, 286)
    input_image, real_image = random_crop(input_image, real_image)

    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)

    input_image, real_image = normalize(input_image, real_image)
    return input_image, real_image

@tf.function
def aug_image(input_image, real_image):
    input_image, real_image = random_jitter(input_image, real_image)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image

In [182]:
def fit(train_ds, epochs, test_ds):
    for epoch in range(epochs):
        start = time.time()

        display.clear_output(wait=True)

        # for example_input, example_target in test_ds.take(1)

        for n, (input_image, target) in enumerate(train_ds.batch(8).shuffle(8)):
            print('.', end='')
            if (n+1) % 100 == 0:
                print()
            input_image, target = aug_image(input_image, target)
            train_step(input_image, target, epoch)
        print()

        # saving (checkpoint) the model every 20 epochs
        if (epoch + 1) % 20 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)

        print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                        time.time()-start))
    checkpoint.save(file_prefix = checkpoint_prefix)

In [141]:
%load_ext tensorboard
%tensorboard --logdir {log_dir}
EPOCHS = 100

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Launching TensorBoard...

In [207]:
fit(train_dataset, EPOCHS, test_dataset)

.

OK
(8, 30, 30, 1)
(8, 30, 30, 1)
OK


ValueError: in converted code:

    <ipython-input-206-5c1d50a496b4>:28 train_step  *
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
    /home/naoki/anaconda3/envs/tensorflow-new/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:444 apply_gradients
        kwargs={"name": name})
    /home/naoki/anaconda3/envs/tensorflow-new/lib/python3.7/site-packages/tensorflow_core/python/distribute/distribute_lib.py:1949 merge_call
        return self._merge_call(merge_fn, args, kwargs)
    /home/naoki/anaconda3/envs/tensorflow-new/lib/python3.7/site-packages/tensorflow_core/python/distribute/distribute_lib.py:1956 _merge_call
        return merge_fn(self._strategy, *args, **kwargs)
    /home/naoki/anaconda3/envs/tensorflow-new/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:488 _distributed_apply
        var, apply_grad_to_update_var, args=(grad,), group=False))
    /home/naoki/anaconda3/envs/tensorflow-new/lib/python3.7/site-packages/tensorflow_core/python/distribute/distribute_lib.py:1543 update
        return self._update(var, fn, args, kwargs, group)
    /home/naoki/anaconda3/envs/tensorflow-new/lib/python3.7/site-packages/tensorflow_core/python/distribute/distribute_lib.py:2174 _update
        return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
    /home/naoki/anaconda3/envs/tensorflow-new/lib/python3.7/site-packages/tensorflow_core/python/distribute/distribute_lib.py:2180 _update_non_slot
        result = fn(*args, **kwargs)
    /home/naoki/anaconda3/envs/tensorflow-new/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:470 apply_grad_to_update_var
        update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
    /home/naoki/anaconda3/envs/tensorflow-new/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/adam.py:207 _resource_apply_dense
        use_locking=self._use_locking)
    /home/naoki/anaconda3/envs/tensorflow-new/lib/python3.7/site-packages/tensorflow_core/python/training/gen_training_ops.py:1436 resource_apply_adam
        use_nesterov=use_nesterov, name=name)
    /home/naoki/anaconda3/envs/tensorflow-new/lib/python3.7/site-packages/tensorflow_core/python/framework/op_def_library.py:742 _apply_op_helper
        attrs=attr_protos, op_def=op_def)
    /home/naoki/anaconda3/envs/tensorflow-new/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py:595 _create_op_internal
        compute_device)
    /home/naoki/anaconda3/envs/tensorflow-new/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:3322 _create_op_internal
        op_def=op_def)
    /home/naoki/anaconda3/envs/tensorflow-new/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:1786 __init__
        control_input_ops)
    /home/naoki/anaconda3/envs/tensorflow-new/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:1622 _create_c_op
        raise ValueError(str(e))

    ValueError: Dimension 3 in both shapes must be equal, but are 1 and 512. Shapes are [4,4,512,1] and [4,4,512,512]. for 'Adam_1/Adam/update_10/ResourceApplyAdam' (op: 'ResourceApplyAdam') with input shapes: [], [], [], [], [], [], [], [], [], [4,4,512,512].


In [41]:
%matplotlib inline

plt.figure(figsize=(10, 8))
# for input, real in train_dataset.batch(1):
#     input_image = input.numpy().reshape((286, 286, 3))
#     cv2.imwrite("/home/naoki/input.png", input_image)
#     input_image = plt.imread("/home/naoki/input.png")
#     plt.subplot(1, 2, 1)
#     plt.imshow(input_image)

#     real_image = real.numpy().reshape((286, 286, 3))
#     cv2.imwrite("/home/naoki/real.png", real_image)
#     real_image = plt.imread("/home/naoki/real.png")
#     plt.subplot(1, 2, 2)
#     plt.imshow(real_image)
#     first_input_batch = next(iter(train_dataset))


<Figure size 720x576 with 0 Axes>

<Figure size 720x576 with 0 Axes>