In [None]:
import tensorflow.keras as keras
import tensorflow as tf

from tensorflow.keras.datasets import mnist
from tensorflow.keras import layers
from tensorflow.keras.layers import Dense, Input, Flatten,\
                                    Reshape, LeakyReLU as LR,\
                                    Activation, Dropout
from tensorflow.keras.models import Model, Sequential
from matplotlib import pyplot as plt
from IPython import display # If using IPython, Colab or Jupyter
import numpy as np
import tensorflow_addons as tfa
import datetime
import random

In [None]:
mnist = tf.keras.datasets.fashion_mnist


(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train/255.0
x_test = x_test/255.0

In [None]:
def print_validation(fn):
    random.seed(10)
    rows = 3
    cols = 3
    indices = np.random.randint(0, 10000, 9)
    rand = x_test[indices].reshape((cols, rows, 1, 28, 28, 1))
    rand_y = y_test[indices].reshape((cols, rows, 1))

    fig, axs = plt.subplots(rows, cols)
    for i in range(rows):
        for j in range(cols):
            predicted = fn(rand[i,j])
            axs[i,j].set_title(rand_y[i,j][0])
            axs[i,j].imshow(predicted, cmap = "gray")
            axs[i, j].axis("off")

    plt.subplots_adjust(wspace = 0, hspace = 0.5)
    plt.show()

print_validation(lambda x: x[0])

In [None]:
# Plot image data from x_train
plt.imshow(x_train[0], cmap = "gray")
plt.show()

In [None]:
LATENT_SIZE = 32

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy()
opt = tf.keras.optimizers.Adam()

def get_conv(filters):
    conv = tf.keras.Sequential()
    conv.add(tfa.layers.SpectralNormalization(
        layers.Conv2D(filters, 3, padding="same", use_bias=False)
    ))
    conv.add(layers.LeakyReLU())
    return conv
    
def get_encoder():
  encoder = tf.keras.Sequential(name="encoder")
  # encoder.add(layers.GaussianDropout(0.2))
  encoder.add(get_conv(16))
  encoder.add(layers.MaxPooling2D(pool_size = (2, 2), padding='same'))
  encoder.add(get_conv(64))
  encoder.add(layers.MaxPooling2D(pool_size = (2, 2), padding='same'))
  encoder.add(get_conv(128))
  return encoder

def get_decoder():
  decoder = tf.keras.Sequential(name="decoder")
  decoder.add(layers.UpSampling2D((2, 2), interpolation='bilinear'))
  decoder.add(get_conv(64))
  decoder.add(layers.UpSampling2D((2, 2), interpolation='bilinear'))
  decoder.add(get_conv(16))
  decoder.add(layers.Conv2D(1, 3, padding='same', activation='tanh'))
  return decoder

class AutoEncoder(tf.keras.Model):
  def __init__(self):
    super(AutoEncoder, self).__init__()
    
  def build(self, input_shape):
    self.encoder = get_encoder()
    self.decoder = get_decoder()
    
    self.encoder.build(input_shape=input_shape)
    
    sh = self.encoder.output_shape
    
    self.flatten = layers.Flatten()
    self.seq1 = layers.Dense(1024)
    print(sh)
    self.reshape = layers.Reshape([*sh[1:]])
    
    self.last = layers.Conv2D(1, 3, padding='same', activation='tanh')
    self.inputs_dropout = layers.Dropout(0.2)

    
  def call(self, inputs):
    x = self.inputs_dropout(inputs)
    x = self.encoder(x)
    x = self.flatten(x)
    x = self.reshape(x)
    x = self.decoder(x)
    
    # x_inputs = self.conv_input(inputs)
    
    # x = self.add([x * self.gamma, x_inputs])
    # x = self.attention(x)
    x = self.last(x)
    return x

In [None]:
y = tf.expand_dims(x_train[0], axis=0)
y = tf.expand_dims(y, axis=3)

print(y.shape)

enc = get_encoder()(y)
print(enc.shape)

dec = get_decoder()(enc)

print(dec.shape)

In [None]:
generated_images = []

def train_get():
   for x, y in zip(x_train, y_train):
        x = tf.expand_dims(x, axis=2)
        generated_images.append(y)
        yield x

def test_get():
   for x in x_test:
        x = tf.expand_dims(x, axis=2)
        yield x

output_signature=tf.TensorSpec(shape=(28, 28, 1), dtype=tf.float32)

BATCH_SIZE = 25

ds_x_train = tf.data.Dataset.from_generator(train_get, output_signature=output_signature).map(lambda x: (x,x)).batch(BATCH_SIZE)

ds_x_test = tf.data.Dataset.from_generator(test_get, output_signature=output_signature).map(lambda x: (x,x)).take(10).batch(BATCH_SIZE)

In [None]:
class MyModel(tf.keras.Model):
  def __init__(self, autoencoder):
    super(MyModel, self).__init__()
    self.autoencoder = autoencoder

  def call(self, inputs, training=False):
    return self.autoencoder(inputs, training=training)

  # def validation_step(self, images):
  #   pass

  # def train_step(self, images):
  #   with tf.GradientTape() as auto_tape:
  #     generated = self.autoencoder(images)
  #     loss = cross_entropy(images, generated)
  #   gradients = auto_tape.gradient(loss, self.autoencoder.trainable_variables)
  #   opt.apply_gradients(zip(gradients, self.autoencoder.trainable_variables))

  #   return {"loss": loss}

current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = f'logs/minst/{current_time}'
tboard_callback = tf.keras.callbacks.TensorBoard(log_dir = train_log_dir,
  write_graph=True,
  histogram_freq = 1,
  update_freq="batch"
  )

EPOCHS = 60
autoencoder = AutoEncoder()
model = MyModel(autoencoder)


class SkMetrics(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.validation_loss = []   

    def on_batch_end(self, batch, logs={}):
      originals = []
      predicted = []
      def expand_and_predict(x):
        result = autoencoder(x, training=False)
        originals.append(x)
        predicted.append(result)
        return result[0]
      if batch % 299 == 0 and batch > 0:
        display.clear_output()
        print_validation(expand_and_predict)

        loss = cross_entropy(originals, predicted)
        tf.summary.scalar('validation_loss', loss)
        print(f"Validation Loss: {loss}")
        
class CustomMSE(keras.losses.Loss):
    def __init__(self, name="custom_mse"):
        super().__init__(name=name)

    def call(self, y_true, y_pred):
        mse = tf.math.reduce_mean(tf.square(y_true - y_pred))
        return mse


# model.compile(run_eagerly=True)
model.compile(loss=keras.losses.MeanSquaredError(), optimizer=opt, metrics=["mse"])
history = model.fit(ds_x_train.repeat(), epochs=55, steps_per_epoch=500, validation_data=ds_x_test, callbacks=[SkMetrics(), tboard_callback])


In [None]:
len(generated_images), 5*5*2