In [None]:
!export XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda-11.8/
!export CUDA_DIR="/usr/local/cuda-11.8/"
!export TF_GPU_ALLOCATOR=cuda_malloc_async

In [None]:
import tensorflow as tf
import tensorflow.keras as keras


gpus = tf.config.list_physical_devices("GPU")
print(f"gpus={gpus}")

from tensorflow.keras.datasets import mnist
from tensorflow.keras import layers
from tensorflow.keras.layers import Dense, Input, Flatten,\
                                    Reshape, LeakyReLU as LR,\
                                    Activation, Dropout
from tensorflow.keras.models import Model, Sequential
from matplotlib import pyplot as plt

import tensorflow_datasets as tfds
import os

In [None]:
from IPython import display # If using IPython, Colab or Jupyter
import numpy as np
import tensorflow_addons as tfa
import datetime
import random

In [None]:
!wget https://drive.google.com/uc?export=download&id=0B7EVK8r0v71pZjFTYXZWM3FlRnM

In [None]:
ds_img = tfds.load("coil100", split=['train'])

IMG_H = 128
IMG_W = 128
IMG_CHANNELS = 3

EXPERIMENT_VERSION = "v2"

In [None]:
images = []
def gen():
    for x in ds_img:
        for y in x:
            yield y["image"]
        
output_signature=tf.TensorSpec(shape=(IMG_H, IMG_W, IMG_CHANNELS), dtype=tf.uint8)
dataset = tf.data.Dataset.from_generator(gen, output_signature=output_signature)

In [None]:
def normalize(image):
    a = (tf.cast(image, tf.float32) - 127.5) 
    return a / 127.5

def denormalize(image):
    # return image
    return tf.clip_by_value(tf.cast(tf.cast(image, tf.float32) * 127.5 + 127.5, np.float32), 0.0, 255.0)

In [None]:
# mnist = tf.keras.datasets.fashion_mnist


# (x_train, y_train), (x_test, y_test) = mnist.load_data()
# x_train = x_train/255.0
# x_test = x_test/255.0

TEST_SAMPLES = 25

x_test = dataset.take(TEST_SAMPLES)
x_train = dataset.skip(TEST_SAMPLES)

In [None]:
x_test_samples = np.array([y for y in list(x_test.take(4))])

In [None]:
def print_validation(fn=lambda x:x, save=False, path="./"):
    random.seed(10)
    rows = 4
    cols = 2

    rand = zip(x_test_samples, fn(x_test_samples))
    n = np.vstack((x_test_samples,fn(x_test_samples)))
    print((cols, rows))

    plt.figure(figsize=(rows * 2, cols * 2))
    for i in range(n.shape[0]):
        x = n[i,:,:,:]
        plt.subplot(cols, rows, i+1)
        plt.imshow(tf.cast(x, np.uint8))
        plt.axis('off')
    plt.subplots_adjust(wspace = 0, hspace = 0.5)
    if save:
        plt.savefig(path)
    plt.show()
    
print_validation(lambda x: denormalize(normalize(x)))

## Experiments

In [None]:
import json

test_results_file = "./testresults.txt"

class Experiment:
    def __init__(self):
        self.encoder = ()
        self.decoder = ()
        self.use_denorm = False
        self.all_attention = False
        self.use_batch_norm = False
        
    def toJSON(self):
        return json.dumps(self, default=lambda o: o.__dict__, 
            sort_keys=True, indent=4)
    
    def set_auto_encoder_filters(self, encoder, decoder):
        self.encoder = encoder
        self.decoder = decoder
        return self
        
    def with_denorm(self, use_denorm):
        self.use_denorm = use_denorm
        return self
        
    def with_attention(self, all_a=True):
        self.all_attention = all_a
        return self
        
    def with_batch_norm(self, use_batch=False):
        self.use_batch_norm = use_batch
        return self
    
    def get_name(self):
        return "k".join(map(lambda x:str(x), self.encoder)) + "-" + "k".join(map(lambda x:str(x), self.decoder)) \
            + ("_idenorm" if self.use_denorm else "_inorm") + \
            ("_attn" if self.all_attention else "") + \
            ("_batch_norm" if self.use_batch_norm else "_spectral")
    
    def __repr__(self):
        return str(self.toJSON())
    
    def __str__(self):
        return str(self.toJSON())
    
    def create_model(self):
        enc_conf = self.encoder
        dec_conf = self.decoder

        use_attention = self.all_attention
        use_batch_norm = self.use_batch_norm

        return AutoEncoder(enc_conf, dec_conf, use_attention, use_batch_norm)
    
    def save(self):
        f = open(test_results_file, "a")
        f.write(self.get_name() + "\n")
        f.close()
  

all_filters = [
    ([32,64,128,256,512], [256,128,64,32]),
    ([16,32,64,128,512], [256,64,32,16]), # best
    ([4,8,32,128,512], [128,32,8,4])
]
experiments = []
if os.path.exists(test_results_file):
    os.remove(test_results_file)
for encf, decf in all_filters:
    for use_den in [True, False]: # both equally good
        for with_att in [True, False]: # true best
            for with_batch_norm in [True, False]: # false best
                experiments.append(Experiment()
                    .set_auto_encoder_filters(encf, decf)
                    .with_denorm(use_den)
                    .with_attention(with_att)
                    .with_batch_norm(with_batch_norm))

list(map(lambda x: x.get_name(), experiments))

In [None]:
LATENT_SIZE = 32

In [None]:
class AttentionLayer(tf.keras.layers.Layer):
  def __init__(self, filters, use_batch_norm=False, activation=False):
    super(AttentionLayer, self).__init__()
    self.filters = filters
    self.activation = activation
    self.use_batch_norm = use_batch_norm
    
  def get_conv_t(self, filters, strides, use_batch_norm=False):
    conv = tf.keras.Sequential()

    if use_batch_norm:
        conv.add(layers.Conv2DTranspose(filters, 3, strides=strides, padding="same", use_bias=False))
        conv.add(layers.BatchNormalization(momentum=0.3))
    else:
        conv.add(tfa.layers.SpectralNormalization(
            layers.Conv2DTranspose(filters, 3, strides=strides, padding="same", use_bias=False)
        ))
        
    conv.add(layers.LeakyReLU())
    return conv
    
    
  def build(self, input_shape):
    self.in_shape = input_shape
    b, hi, wi, c = input_shape
    self.downscale1 = layers.Conv2D(tf.cast(c*4, tf.int32), 4, strides = 4)
    self.downscale2 = layers.Conv2D(tf.cast(c*2, tf.int32), 2, strides = 2)
    self.fc = self._conv(self.filters)
    self.gc = self._conv(self.filters)
    self.hc = self._conv(self.filters)
    self.xc = self._conv(self.filters)
#     self.upscale1 = layers.Conv2DTranspose(self.filters, 3, strides=4, padding="same", use_bias=False)
#     self.upscale2 = layers.Conv2DTranspose(self.filters, 3, strides=2, padding="same", use_bias=False)
    self.upscale1 = self.get_conv_t(self.filters, 4, self.use_batch_norm)
    self.upscale2 = self.get_conv_t(self.filters, 2, self.use_batch_norm)

    self.gamma = tf.Variable([1.], name="gamma")
    self.act = layers.LeakyReLU()
    if self.use_batch_norm:
        self.norm = layers.BatchNormalization(momentum=0.3)
#     self.downsized_dims = tf.cast(hi/6, tf.int32), tf.cast(wi/6, tf.int32), c*6
    
  def _conv(self, filters, kernel=1, strides=1):
    return layers.Conv2D(filters, 1, strides=strides, padding="same", use_bias=False)
    
  def call(self, inputs):
    ch = self.filters
    xdsh = self.in_shape
    b = xdsh[0]
    hi = xdsh[1]
    wi = xdsh[2]
    c = xdsh[3]
        
    xd = inputs
    xd = self.downscale1(xd)
    xd = self.downscale2(xd)

    xdsh = tf.shape(xd)
    b = xdsh[0]
    hi = xdsh[1]
    wi = xdsh[2]
    c = xdsh[3]

    f = self.fc(xd) # [bs, h, w, c']
    g = self.gc(xd) # [bs, h, w, c']
    h = self.hc(xd) # [bs, h, w, c]
    inputs = self.xc(inputs)
    
    f = tf.reshape(f, [-1, hi*wi, ch])
    g = tf.reshape(g, [-1, hi*wi, ch])
    h = tf.reshape(h, [-1, hi*wi, ch])

    s = tf.matmul(g, f, transpose_b=True) # # [bs, N, N]

    beta = tf.nn.softmax(s)  # attention map

    o = tf.matmul(beta, h) # [bs, N, C]
    
    o = tf.reshape(o, shape=[b, hi, wi, tf.cast(ch, tf.int32)])
    o = self.upscale1(o)
    o = self.upscale2(o)
    

    x = self.gamma * o + inputs

    if self.use_batch_norm:
        x = self.norm(x)

    if self.activation:
        x = self.act(x)
        
    return x

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy()


def get_conv(filters, use_attention=False, use_batch_norm=False):
    conv = tf.keras.Sequential()
    if use_attention:
        conv.add(AttentionLayer(filters, use_batch_norm=use_batch_norm))
    else:
        if use_batch_norm:
            conv.add(layers.Conv2D(filters, 3, padding="same", use_bias=False))
            conv.add(layers.BatchNormalization(momentum=0.3))
        else:
            conv.add(tfa.layers.SpectralNormalization(
                layers.Conv2D(filters, 3, padding="same", use_bias=False)
            ))
        
    conv.add(layers.LeakyReLU())
    return conv

def get_conv_t(filters, use_batch_norm=False):
    conv = tf.keras.Sequential()

    if use_batch_norm:
        conv.add(layers.Conv2DTranspose(filters, 3, strides=2, padding="same", use_bias=False))
        conv.add(layers.BatchNormalization(momentum=0.3))
    else:
        conv.add(tfa.layers.SpectralNormalization(
            layers.Conv2DTranspose(filters, 3, strides=2, padding="same", use_bias=False)
        ))
        
    conv.add(layers.LeakyReLU())
    return conv
    
def get_encoder(filters=[16,64,128,256,512], use_attention=False, use_batch_norm=False):
    encoder = tf.keras.Sequential(name="encoder")
    # encoder.add(layers.GaussianDropout(0.2))
    encoder.add(get_conv(filters[0], use_batch_norm))
    for f in filters[1:]:
        encoder.add(get_conv(f, use_batch_norm=use_batch_norm))
        encoder.add(layers.MaxPooling2D(pool_size = (2, 2), padding='same'))
  
    encoder.add(get_conv(filters[-1], use_attention, use_batch_norm))
    return encoder

def get_decoder(filters=[256,128, 64,16], use_attention=False, use_batch_norm=False):
    decoder = tf.keras.Sequential(name="decoder")
    
    for f in filters:
        decoder.add(get_conv_t(f, use_batch_norm=use_batch_norm))
    
    if use_attention:
        decoder.add(AttentionLayer(filters[IMG_CHANNELS]))
    decoder.add(layers.Conv2D(IMG_CHANNELS, 3, padding='same', activation='tanh'))
    return decoder

class AutoEncoder(tf.keras.Model):
  def __init__(self, encoder_conf, decoder_conf, use_attention, use_batch_norm=False):
    super(AutoEncoder, self).__init__()
    self.encoder_conf = encoder_conf
    self.decoder_conf = decoder_conf
    self.use_attention = use_attention
    self.use_batch_norm = use_batch_norm
    
  def build(self, input_shape):
    self.encoder = get_encoder(self.encoder_conf, use_attention=self.use_attention, use_batch_norm=self.use_batch_norm)
    self.decoder = get_decoder(self.decoder_conf, use_attention=self.use_attention, use_batch_norm=self.use_batch_norm)
    
    self.encoder.build(input_shape=input_shape)
    
    sh = self.encoder.output_shape
    
    self.flatten = layers.Flatten()
    self.seq1 = layers.Dense(1024)
    self.reshape = layers.Reshape([*sh[1:]])
    
    self.last = layers.Conv2D(IMG_CHANNELS, 3, padding='same', activation='tanh')
    self.inputs_dropout = layers.Dropout(0.2)

    
  def call(self, inputs):
    x = self.inputs_dropout(inputs)
    x = self.encoder(x)
    x = self.flatten(x)
    x = self.reshape(x)
    x = self.decoder(x)
    
    # x_inputs = self.conv_input(inputs)
    
    # x = self.add([x * self.gamma, x_inputs])
    # x = self.attention(x)
    x = self.last(x)
    return x

In [None]:
y = next(iter(x_train.map(normalize).batch(1)))

print(y.shape)

enc = get_encoder([16,64,128,256,512])(y)
print(enc.shape)

dec = get_decoder([256,128,64,14])(enc)

print(dec.shape)

assert y.shape == dec.shape

In [None]:
# generated_images = []

# def train_get():
#    for x, y in zip(x_train, y_train):
#         x = tf.expand_dims(x, axis=2)
#         generated_images.append(y)
#         yield x

# def test_get():
#    for x in x_test:
#         x = tf.expand_dims(x, axis=2)
#         yield x

# output_signature=tf.TensorSpec(shape=(28, 28, 1), dtype=tf.float32)

BATCH_SIZE = 25

def get_train_ds():
    return x_train.map(normalize, num_parallel_calls=tf.data.AUTOTUNE).map(lambda x: (x,x)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
#     return tf.data.Dataset.from_generator(train_get, output_signature=output_signature).map(lambda x: (x,x)).batch(BATCH_SIZE)

def get_test_ds():
    return x_test.map(normalize, num_parallel_calls=tf.data.AUTOTUNE).map(lambda x: (x,x)).batch(5).prefetch(tf.data.AUTOTUNE)
#     return tf.data.Dataset.from_generator(test_get, output_signature=output_signature).map(lambda x: (x,x)).take(100).batch(BATCH_SIZE)

In [None]:
# %tensorboard --logdir logs/minst/

## Metrics

In [None]:
class SSIM(tf.keras.metrics.Metric):

  def __init__(self, name='ssim', **kwargs):
    super(SSIM, self).__init__(name=name, **kwargs)
    self.ssim = self.add_weight(name='ssim', initializer='zeros')
    self.ep = 0.0000001

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = denormalize(y_true)
    y_pred = denormalize(y_pred)
    same = tf.math.reduce_sum(tf.image.ssim(y_true, y_true, 255.0, filter_size=3)) + self.ep
    values = (self.ep + tf.math.reduce_sum(tf.image.ssim(y_true, y_pred, 255.0, filter_size=3))) / same
    
    values = tf.cast(values, self.dtype)
    if sample_weight is not None:
        sample_weight = tf.cast(sample_weight, self.dtype)
        sample_weight = tf.broadcast_to(sample_weight, values.shape)
        values = tf.multiply(values, sample_weight)
    self.ssim.assign(tf.reduce_sum(values))

  def result(self):
    return self.ssim

class SSIM_Multiscale(tf.keras.metrics.Metric):

  def __init__(self, name='ssim_ms', **kwargs):
    super(SSIM_Multiscale, self).__init__(name=name, **kwargs)
    self.ssim_ms = self.add_weight(name='ssim_ms', initializer='zeros')
    self.self_ssim_ms = self.add_weight(name='self_ssim_ms', initializer='zeros')
    self.ep = 0.0000001

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = denormalize(y_true)
    y_pred = denormalize(y_pred)
    same = tf.math.reduce_sum(tf.image.ssim_multiscale(y_true, y_true, 255.0, filter_size=3)) + self.ep
    values = (self.ep + tf.math.reduce_sum(tf.image.ssim_multiscale(y_true, y_pred, 255.0, filter_size=4))) / same

    values = tf.cast(values, self.dtype)
    if sample_weight is not None:
        sample_weight = tf.cast(sample_weight, self.dtype)
        sample_weight = tf.broadcast_to(sample_weight, values.shape)
        values = tf.multiply(values, sample_weight)
    self.ssim_ms.assign(values)
    self.self_ssim_ms.assign(same)

  def result(self):
    return self.ssim_ms

class TOP_SSIM_Multiscale(tf.keras.metrics.Metric):

  def __init__(self, name='self_ssim_ms', **kwargs):
    super(TOP_SSIM_Multiscale, self).__init__(name=name, **kwargs)
    self.self_ssim_ms = self.add_weight(name='self_ssim_ms', initializer='zeros')

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = denormalize(y_true)
    values = tf.math.reduce_sum(tf.image.ssim_multiscale(y_true, y_true, 255.0, filter_size=3))

    values = tf.cast(values, self.dtype)
    if sample_weight is not None:
        sample_weight = tf.cast(sample_weight, self.dtype)
        sample_weight = tf.broadcast_to(sample_weight, values.shape)
        values = tf.multiply(values, sample_weight)
    self.self_ssim_ms.assign(values)

  def result(self):
    return self.self_ssim_ms


class DenormalizedMSE(tf.keras.metrics.Metric):

  def __init__(self, name='denormalized_mse', **kwargs):
    super(DenormalizedMSE, self).__init__(name=name, **kwargs)
    self.denormalized_mse = self.add_weight(name='denormalized_mse', initializer='zeros')

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = denormalize(y_true)
    y_pred = denormalize(y_pred)
    self.mse = keras.losses.MeanSquaredError()

    self.denormalized_mse.assign(self.mse(y_true, y_pred))

  def result(self):
    return self.denormalized_mse

In [None]:
import shutil

for conf in experiments:    
    model_name = f'coli100_{EXPERIMENT_VERSION}_{conf.get_name()}'
    train_log_dir = f'logs/minst/{model_name}'
    if os.path.exists(train_log_dir):
        shutil.rmtree(train_log_dir, ignore_errors=True)


## Train

In [None]:
# class MyModel(tf.keras.Model):
#   def __init__(self, autoencoder):
#     super(MyModel, self).__init__()
#     self.autoencoder = autoencoder

#   def call(self, inputs, training=False):
#     return self.autoencoder(inputs, training=training)

  # def validation_step(self, images):
  #   pass

  # def train_step(self, images):
  #   with tf.GradientTape() as auto_tape:
  #     generated = self.autoencoder(images)
  #     loss = cross_entropy(images, generated)
  #   gradients = auto_tape.gradient(loss, self.autoencoder.trainable_variables)
  #   opt.apply_gradients(zip(gradients, self.autoencoder.trainable_variables))

  #   return {"loss": loss}
    



for conf in experiments:    
    model_name = f'coli100_{EXPERIMENT_VERSION}_{conf.get_name()}'
    train_log_dir = f'logs/minst/{model_name}'

    
    tboard_callback = tf.keras.callbacks.TensorBoard(log_dir = train_log_dir,
      write_graph=True,
      histogram_freq = 1,
      update_freq="batch"
      )

    
    
    EPOCHS = 60
    
    denorm = conf.use_denorm

    model = conf.create_model()


    class SkMetrics(keras.callbacks.Callback):
        def on_train_begin(self, logs={}):
            tf.summary.text("description", str(conf))
            self.validation_loss = []
            self.epoch_n = 0
#             self.batch_ssim = []
#             self.batch_ssim_ms = []
#             self.batch_denormalized_mse = []
            
        def on_epoch_end(self, epoch, logs):
            self.epoch_n = epoch
#             tf.summary.scalar('ssim', logs["ssim"])
#             tf.summary.scalar('ssim_ms', logs["ssim_ms"])
#             tf.summary.scalar('denormalized_mse', logs["denormalized_mse"])

        def on_batch_end(self, batch, logs={}):
          def expand_and_predict(x):
            result = model(x, training=False)
            return denormalize(result)
          if batch % 99 == 0 and batch > 0:
            display.clear_output(wait=True)
            p = f'./minst_output/{model_name}/'
            isExist = os.path.exists(p)
            if not isExist:
                os.makedirs(p)
            full_path = f'./minst_output/{model_name}/{self.epoch_n}_{batch}'
            print_validation(expand_and_predict, save=True, path=full_path)
            print(conf.get_name())


    class CustomMSE(keras.losses.Loss):
        def __init__(self, denorm=True, name="custom_mse"):
            super().__init__(name=name)
            self.mse = keras.losses.MeanSquaredError()
            self.denorm = denorm

        def call(self, y_true, y_pred):
            if denorm:
                return self.mse(denormalize(y_true), denormalize(y_pred))
            else:
                return self.mse(y_true, y_pred)

    opt = tf.keras.optimizers.Adam(learning_rate=0.0001)
    # model.compile(run_eagerly=True)
    model.compile(loss=CustomMSE(denorm=denorm), optimizer=opt,
                  metrics=["mse", DenormalizedMSE(), SSIM_Multiscale(), SSIM()])
#                   run_eagerly=True
#                  )
    print(conf.get_name())
    history = model.fit(get_train_ds().repeat(), epochs=60, steps_per_epoch=100, validation_data=get_test_ds(), callbacks=[SkMetrics(), tboard_callback])
    conf.save()


In [None]:
len(generated_images), 5*5*2

In [None]:
import imageio
import pathlib

for conf in experiments:    
    model_name = f'./minst_output/coli100_{EXPERIMENT_VERSION}_{conf.get_name()}'
    data_dir = pathlib.Path(model_name)
    pictures = list(data_dir.glob('*.png'))
    pictures.sort()
    with imageio.get_writer(f'{model_name}.gif', mode='I') as writer:
        for filename in pictures:
            image = imageio.imread(filename)
            writer.append_data(image)