<a href="https://colab.research.google.com/github/purohik/notebooks/blob/main/GANs/Basic%20GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Initial setup

In [None]:
import tensorflow.compat.v1 as tf
!pip install tensorflow-gan
import tensorflow_gan as tfgan
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline
tf.logging.set_verbosity(tf.logging.ERROR)

# Input pipeline

In [None]:
def input_fn(mode, params):
    assert 'batch_size' in params
    assert 'noise_dims' in params
    bs = params['batch_size']
    nd = params['noise_dims']
    split = 'train' if mode == tf.estimator.ModeKeys.TRAIN else 'test'
    shuffle = (mode == tf.estimator.ModeKeys.TRAIN)
    just_noise = (mode == tf.estimator.ModeKeys.PREDICT)

    noise_ds = (tf.data.Dataset.from_tensors(0).repeat().map(
        lambda _: tf.random.normal([bs, nd])
    ))

    if just_noise:
        return noise_ds
    
    def _preprocess(element):
        images = (tf.cast(element['image'], tf.float32) - 127.5) / 127.5
        return images
    
    images_ds = (tfds.load('mnist:3.*.*', split=split)
                .map(_preprocess)
                .cache()
                .repeat())
    
    if shuffle:
        images_ds = images_ds.shuffle(
            buffer_size=10000, reshuffle_each_iteration=True
        )
    images_ds = (images_ds.batch(bs, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE))

    return tf.data.Dataset.zip((noise_ds, images_ds))

# Download the data and sanity check the inputs

In [None]:
params = {'batch_size': 100, 'noise_dims': 64}
with tf.Graph().as_default():
    ds = input_fn(tf.estimator.ModeKeys.TRAIN, params)
    numpy_imgs = next(iter(tfds.as_numpy(ds)))[1]

img_grid = tfgan.eval.python_image_grid(numpy_imgs, grid_shape=(10, 10))
plt.axis('off')
plt.imshow(np.squeeze(img_grid))
plt.show()