# Import Libraries

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import layers, models, optimizers, callbacks
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.utils import plot_model

from IPython import display

# Load Data

In [None]:
(train_data, train_labels), (test_data, test_labels) = fashion_mnist.load_data()

In [None]:
train_data = np.where(train_data < (0.33 * 256), 0, 1)
train_data = train_data.astype(np.float32)

In [None]:
test_data = np.where(test_data < (0.33 * 256), 0, 1)
test_data = test_data.astype(np.float32)

# Model

In [None]:
NUM_CLASSES = 10
BATCH_SIZE = 128
EPOCHS = 50
LEARNING_RATE = 0.0005
IMAGE_SIZE = 28

In [None]:
dist = tfp.distributions.PixelCNN(
    image_shape=(32, 32, 1),
    num_resnet=1,
    num_hierarchies=2,
    num_filters=32,
    num_logistic_mix=5,
    dropout_p=0.3,
)
input_layer = layers.Input(shape=(32, 32, 1))
log_prob = dist.log_prob(input_layer)
model = models.Model(inputs=input_layer, outputs=log_prob)
model.add_loss(-tf.reduce_mean(log_prob))
model.compile(optimizer=optimizers.Adam(learning_rate=LEARNING_RATE))

In [None]:
model.summary()

In [None]:
plot_model(model, show_layer_names=True, show_shapes=True, expand_nested=True)

# Train

In [None]:
tensorboard_callback = callbacks.TensorBoard(log_dir='./logs')

In [None]:
history = model.fit(
    train_data, train_data,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=[test_data, test_data],
    callbacks=[tensorboard_callback]
)

# Results

In [None]:
history_df = pd.DataFrame(history.history)
history_df.head()

In [None]:
plt.figure()
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(["train", "valid"])
plt.title("Loss Curve")
plt.show()

# Generate

In [None]:
def generate_and_plot_images(batch=8):
    pixels = np.zeros(shape=(batch,) + (model.input_shape)[1:])
    batch, rows, cols, channels = pixels.shape

    for row in range(rows):
        for col in range(cols):
            for channel in range(channels):
                probs = model.predict(pixels, verbose=0)[:, row, col, channel]
                probs /= np.sum(probs, axis=-1, keepdims=True)
                r = np.random.choice(len(probs[0]), size=batch, p=probs[0])
                pixels[:, row, col, channel] = r / 4

    fig, ax = plt.subplots(1, batch, figsize=(12, 12))
    for i, pixel in enumerate(pixels):
        x = np.squeeze(pixel, -1)
        x = np.stack((x, x, x), axis=2) * 255
        x = np.clip(x, 0, 255).astype('uint8')
        ax[i].imshow(x, cmap='gray')
        ax[i].axis('off')

    plt.show()

In [None]:
generate_and_plot_images(batch=8)