## README.md

In [None]:
%%writefile README.md
Implementation of [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661).

## input.py

In [None]:
%%writefile vanilla_gan_module/trainer/input.py
import tensorflow as tf


def preprocess_image(image, params):
    """Preprocess image tensor.

    Args:
        image: tensor, input image with shape
            [cur_batch_size, height, width, depth].
        params: dict, user passed parameters.

    Returns:
        Preprocessed image tensor with shape
            [cur_batch_size, height, width, depth].
    """
    # Convert from [0, 255] -> [-1.0, 1.0] floats.
    image = tf.cast(x=image, dtype=tf.float32) * (2. / 255) - 1.0

    return image


def decode_example(protos, params):
    """Decodes TFRecord file into tensors.

    Given protobufs, decode into image and label tensors.

    Args:
        protos: protobufs from TFRecord file.
        params: dict, user passed parameters.

    Returns:
        Image and label tensors.
    """
    # Create feature schema map for protos.
    features = {
        "image_raw": tf.io.FixedLenFeature(shape=[], dtype=tf.string),
        "label": tf.io.FixedLenFeature(shape=[], dtype=tf.int64)
    }

    # Parse features from tf.Example.
    parsed_features = tf.io.parse_single_example(
        serialized=protos, features=features
    )

    # Convert from a scalar string tensor (whose single string has
    # length height * width * depth) to a uint8 tensor with shape
    # [height * width * depth].
    image = tf.io.decode_raw(
        input_bytes=parsed_features["image_raw"], out_type=tf.uint8
    )

    # Reshape flattened image back into normal dimensions.
    image = tf.reshape(
        tensor=image,
        shape=[params["height"], params["width"], params["depth"]]
    )

    # Preprocess image.
    image = preprocess_image(image=image, params=params)

    # Convert label from a scalar uint8 tensor to an int32 scalar.
    label = tf.cast(x=parsed_features["label"], dtype=tf.int32)

    return {"image": image}, label


def read_dataset(filename, batch_size, params, training):
    """Reads TF Record data using tf.data, doing necessary preprocessing.

    Given filename, mode, batch size, and other parameters, read TF Record
    dataset using Dataset API, apply necessary preprocessing, and return an
    input function to the Estimator API.

    Args:
        filename: str, file pattern that to read into our tf.data dataset.
        batch_size: int, number of examples per batch.
        params: dict, dictionary of user passed parameters.
        training: bool, if training or not.

    Returns:
        An input function.
    """
    def _input_fn():
        """Wrapper input function used by Estimator API to get data tensors.

        Returns:
            Batched dataset object of dictionary of feature tensors and label
                tensor.
        """
        # Create list of files that match pattern.
        file_list = tf.data.Dataset.list_files(file_pattern=filename)

        # Create dataset from file list.
        dataset = tf.data.TFRecordDataset(
            filenames=file_list,
            num_parallel_reads=(
                tf.contrib.data.AUTOTUNE
                if params["input_fn_autotune"]
                else None
            )
        )

        # Shuffle and repeat if training with fused op.
        if training:
            dataset = dataset.apply(
                tf.data.experimental.shuffle_and_repeat(
                    buffer_size=50 * batch_size,
                    count=None  # indefinitely
                )
            )

        # Decode CSV file into a features dictionary of tensors, then batch.
        dataset = dataset.apply(
            tf.data.experimental.map_and_batch(
                map_func=lambda x: decode_example(
                    protos=x,
                    params=params
                ),
                batch_size=batch_size,
                num_parallel_calls=(
                    tf.contrib.data.AUTOTUNE
                    if params["input_fn_autotune"]
                    else None
                )
            )
        )

        # Prefetch data to improve latency.
        dataset = dataset.prefetch(
            buffer_size=(
                tf.data.experimental.AUTOTUNE
                if params["input_fn_autotune"]
                else 1
            )
        )

        return dataset
    return _input_fn


## generators.py

In [None]:
%%writefile vanilla_gan_module/trainer/generators.py
import tensorflow as tf


class Generator(object):
    """Generator that takes latent vector input and outputs image.

    Fields:
        name: str, name of `Generator`.
        model: instance of generator `Model`.
    """
    def __init__(
            self,
            input_shape,
            kernel_regularizer,
            bias_regularizer,
            name,
            params):
        """Instantiates and builds generator network.

        Args:
            input_shape: tuple, shape of latent vector input of shape
                [batch_size, latent_size].
            kernel_regularizer: `l1_l2_regularizer` object, regularizar for
                kernel variables.
            bias_regularizer: `l1_l2_regularizer` object, regularizar for bias
                variables.
            name: str, name of generator.
            params: dict, user passed parameters.
        """
        # Set name of generator.
        self.name = name

        # Instantiate generator `Model`.
        self.model = self._define_generator(
            input_shape, kernel_regularizer, bias_regularizer, params
        )

    def _define_generator(
            self, input_shape, kernel_regularizer, bias_regularizer, params):
        """Defines generator network.

        Args:
            input_shape: tuple, shape of latent vector input of shape
                [batch_size, latent_size].
            kernel_regularizer: `l1_l2_regularizer` object, regularizar for
                kernel variables.
            bias_regularizer: `l1_l2_regularizer` object, regularizar for bias
                variables.
            params: dict, user passed parameters.

        Returns:
            Instance of `Model` object.
        """
        # Create the input layer to our DNN.
        # shape = (batch_size, latent_size)
        inputs = tf.keras.Input(
            shape=input_shape, name="{}_inputs".format(self.name)
        )
        network = inputs

        # Dictionary containing possible final activations.
        final_activation_set = {"sigmoid", "relu", "tanh"}

        # Add hidden layers with given number of units/neurons per layer.
        for i, units in enumerate(params["generator_hidden_units"]):
            # shape = (batch_size, generator_hidden_units[i])
            network = tf.keras.layers.Dense(
                units=units,
                activation=None,
                kernel_regularizer=kernel_regularizer,
                bias_regularizer=bias_regularizer,
                name="{}_layers_dense_{}".format(self.name, i)
            )(inputs=network)

            network = tf.keras.layers.LeakyReLU(
                alpha=params["generator_leaky_relu_alpha"],
                name="{}_leaky_relu_{}".format(self.name, i)
            )(inputs=network)

        # Final linear layer for outputs.
        # shape = (batch_size, height * width * depth)
        generated_outputs = tf.keras.layers.Dense(
            units=params["height"] * params["width"] * params["depth"],
            activation=(
                params["generator_final_activation"].lower()
                if params["generator_final_activation"].lower()
                in final_activation_set
                else None
            ),
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            name="{}_layers_dense_generated_outputs".format(self.name)
        )(inputs=network)

        # Define model.
        model = tf.keras.Model(
            inputs=inputs, outputs=generated_outputs, name=self.name
        )

        return model

    def get_model(self):
        """Returns generator's `Model` object.

        Returns:
            Generator's `Model` object.
        """
        return self.model

    def get_generator_loss(
        self,
        global_batch_size,
        fake_logits,
        params,
        global_step,
        summary_file_writer
    ):
        """Gets generator loss.

        Args:
            global_batch_size: int, global batch size for distribution.
            fake_logits: tensor, shape of
                [batch_size, 1].
            params: dict, user passed parameters.
            global_step: int, current global step for training.
            summary_file_writer: summary file writer.

        Returns:
            Tensor of generator's total loss of shape [].
        """
        # Calculate base generator loss.
        generator_loss = tf.nn.compute_average_loss(
            per_example_loss=tf.keras.losses.BinaryCrossentropy(
                from_logits=True,
                reduction=tf.keras.losses.Reduction.NONE
            )(
                y_true=tf.ones_like(input=fake_logits), y_pred=fake_logits
            ),
            global_batch_size=global_batch_size
        )

        # Get regularization losses.
        generator_reg_loss = tf.nn.scale_regularization_loss(
            regularization_loss=sum(self.model.losses)
        )

        # Combine losses for total losses.
        generator_total_loss = tf.math.add(
            x=generator_loss,
            y=generator_reg_loss,
            name="generator_total_loss"
        )

        # Add summaries for TensorBoard.
        with summary_file_writer.as_default():
            with tf.summary.record_if(
                global_step % params["save_summary_steps"] == 0
            ):
                tf.summary.scalar(
                    name="losses/generator_loss",
                    data=generator_loss,
                    step=global_step
                )
                tf.summary.scalar(
                    name="losses/generator_reg_loss",
                    data=generator_reg_loss,
                    step=global_step
                )
                tf.summary.scalar(
                    name="optimized_losses/generator_total_loss",
                    data=generator_total_loss,
                    step=global_step
                )
                summary_file_writer.flush()

        return generator_total_loss


## discriminators.py

In [None]:
%%writefile vanilla_gan_module/trainer/discriminators.py
import tensorflow as tf


class Discriminator(object):
    """Discriminator that takes image input and outputs logits.

    Fields:
        name: str, name of `Discriminator`.
        model: instance of discriminator `Model`.
    """
    def __init__(
            self,
            input_shape,
            kernel_regularizer,
            bias_regularizer,
            name,
            params):
        """Instantiates and builds discriminator network.

        Args:
            input_shape: tuple, shape of image vector input of shape
                [batch_size, height * width * depth].
            kernel_regularizer: `l1_l2_regularizer` object, regularizar for
                kernel variables.
            bias_regularizer: `l1_l2_regularizer` object, regularizar for bias
                variables.
            name: str, name of discriminator.
            params: dict, user passed parameters.
        """
        # Set name of discriminator.
        self.name = name

        # Regularizer for kernel weights.
        self.kernel_regularizer = kernel_regularizer

        # Regularizer for bias weights.
        self.bias_regularizer = bias_regularizer

        # Instantiate discriminator `Model`.
        self.model = self._define_discriminator(
            input_shape, kernel_regularizer, bias_regularizer, params
        )

    def _define_discriminator(
            self, input_shape, kernel_regularizer, bias_regularizer, params):
        """Defines discriminator network.

        Args:
            input_shape: tuple, shape of image vector input of shape
                [batch_size, height * width * depth].
            kernel_regularizer: `l1_l2_regularizer` object, regularizar for
                kernel variables.
            bias_regularizer: `l1_l2_regularizer` object, regularizar for bias
                variables.
            params: dict, user passed parameters.

        Returns:
            Instance of `Model` object.
        """
        # Create the input layer to our DNN.
        # shape = (batch_size, height * width * depth)
        inputs = tf.keras.Input(
            shape=input_shape,
            name="{}_inputs".format(self.name)
        )
        network = inputs

        # Add hidden layers with given number of units/neurons per layer.
        for i, units in enumerate(params["discriminator_hidden_units"]):
            # shape = (batch_size, discriminator_hidden_units[i])
            network = tf.keras.layers.Dense(
                units=units,
                activation=None,
                kernel_regularizer=kernel_regularizer,
                bias_regularizer=bias_regularizer,
                name="{}_layers_dense_{}".format(self.name, i)
            )(inputs=network)

            network = tf.keras.layers.LeakyReLU(
                alpha=params["discriminator_leaky_relu_alpha"],
                name="{}_leaky_relu_{}".format(self.name, i)
            )(inputs=network)

        # Final linear layer for logits.
        # shape = (batch_size, 1)
        logits = tf.keras.layers.Dense(
            units=1,
            activation=None,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            name="{}_layers_dense_logits".format(self.name)
        )(inputs=network)

        # Define model.
        model = tf.keras.Model(
            inputs=inputs, outputs=logits, name=self.name
        )

        return model

    def get_model(self):
        """Returns discriminator's `Model` object.

        Returns:
            Discriminator's `Model` object.
        """
        return self.model

    def get_discriminator_loss(
        self,
        global_batch_size,
        fake_logits,
        real_logits,
        params,
        global_step,
        summary_file_writer
    ):
        """Gets discriminator loss.

        Args:
            global_batch_size: int, global batch size for distribution.
            fake_logits: tensor, shape of
                [batch_size, 1].
            real_logits: tensor, shape of
                [batch_size, 1].
            params: dict, user passed parameters.
            global_step: int, current global step for training.
            summary_file_writer: summary file writer.

        Returns:
            Tensor of discriminator's total loss of shape [].
        """
        # Calculate base discriminator loss.
        discriminator_real_loss = tf.nn.compute_average_loss(
            per_example_loss=tf.keras.losses.BinaryCrossentropy(
                from_logits=True,
                label_smoothing=params["label_smoothing"],
                reduction=tf.keras.losses.Reduction.NONE
            )(
                y_true=tf.ones_like(input=real_logits), y_pred=real_logits
            ),
            global_batch_size=global_batch_size
        )

        discriminator_fake_loss = tf.nn.compute_average_loss(
            per_example_loss=tf.keras.losses.BinaryCrossentropy(
                from_logits=True,
                reduction=tf.keras.losses.Reduction.NONE
            )(
                y_true=tf.zeros_like(input=fake_logits), y_pred=fake_logits
            ),
            global_batch_size=global_batch_size
        )

        discriminator_loss = tf.add(
            x=discriminator_real_loss,
            y=discriminator_fake_loss,
            name="discriminator_loss"
        )

        # Get regularization losses.
        discriminator_reg_loss = tf.nn.scale_regularization_loss(
            regularization_loss=sum(self.model.losses)
        )

        # Combine losses for total losses.
        discriminator_total_loss = tf.math.add(
            x=discriminator_loss,
            y=discriminator_reg_loss,
            name="discriminator_total_loss"
        )

        # Add summaries for TensorBoard.
        with summary_file_writer.as_default():
            with tf.summary.record_if(
                global_step % params["save_summary_steps"] == 0
            ):
                tf.summary.scalar(
                    name="losses/discriminator_real_loss",
                    data=discriminator_real_loss,
                    step=global_step
                )
                tf.summary.scalar(
                    name="losses/discriminator_fake_loss",
                    data=discriminator_fake_loss,
                    step=global_step
                )
                tf.summary.scalar(
                    name="losses/discriminator_loss",
                    data=discriminator_loss,
                    step=global_step
                )
                tf.summary.scalar(
                    name="losses/discriminator_reg_loss",
                    data=discriminator_reg_loss,
                    step=global_step
                )
                tf.summary.scalar(
                    name="optimized_losses/discriminator_total_loss",
                    data=discriminator_total_loss,
                    step=global_step
                )
                summary_file_writer.flush()

        return discriminator_total_loss


## train_and_eval.py

In [None]:
%%writefile vanilla_gan_module/trainer/train_and_eval.py
import tensorflow as tf


def generator_loss_phase(
    global_batch_size,
    generator,
    discriminator,
    params,
    global_step,
    summary_file_writer,
    mode,
    training
):
    """Gets fake logits and loss for generator.

    Args:
        global_batch_size: int, global batch size for distribution.
        generator: instance of `Generator`.
        discriminator: instance of `Discriminator`.
        params: dict, user passed parameters.
        global_step: int, current global step for training.
        summary_file_writer: summary file writer.
        training: bool, if in training mode.

    Returns:
        Fake logits of shape [batch_size, 1] and generator loss.
    """
    batch_size = (
        params["train_batch_size"]
        if mode == "TRAIN"
        else params["eval_batch_size"]
    )

    # Create random noise latent vector for each batch example.
    Z = tf.random.normal(
        shape=[batch_size, params["latent_size"]],
        mean=0.0,
        stddev=1.0,
        dtype=tf.float32
    )

    # Get generated image from generator network from gaussian noise.
    fake_images = generator.get_model()(inputs=Z, training=training)

    if mode == "TRAIN":
        # Add summaries for TensorBoard.
        with summary_file_writer.as_default():
            with tf.summary.record_if(
                global_step % params["save_summary_steps"] == 0
            ):
                tf.summary.image(
                    name="fake_images",
                    data=tf.reshape(
                        tensor=fake_images,
                        shape=[
                            -1,
                            params["height"],
                            params["width"],
                            params["depth"]
                        ]
                    ),
                    step=global_step,
                    max_outputs=5,
                )
                summary_file_writer.flush()

    # Get fake logits from discriminator using generator's output image.
    fake_logits = discriminator.get_model()(
        inputs=fake_images, training=False
    )

    # Get generator total loss.
    generator_total_loss = generator.get_generator_loss(
        global_batch_size=global_batch_size,
        fake_logits=fake_logits,
        params=params,
        global_step=global_step,
        summary_file_writer=summary_file_writer
    )

    return fake_logits, generator_total_loss


def discriminator_loss_phase(
    global_batch_size,
    discriminator,
    real_images,
    fake_logits,
    params,
    global_step,
    summary_file_writer,
    mode,
    training
):
    """Gets real logits and loss for discriminator.

    Args:
        global_batch_size: int, global batch size for distribution.
        discriminator: instance of `Discriminator`.
        real_images: tensor, real images of shape
            [batch_size, height * width * depth].
        fake_logits: tensor, discriminator logits of fake images of shape
            [batch_size, 1].
        params: dict, user passed parameters.
        global_step: int, current global step for training.
        summary_file_writer: summary file writer.
        training: bool, if in training mode.

    Returns:
        Real logits and discriminator loss.
    """
    # Get real logits from discriminator using real image.
    real_logits = discriminator.get_model()(
        inputs=real_images, training=training
    )

    # Get discriminator total loss.
    discriminator_total_loss = discriminator.get_discriminator_loss(
        global_batch_size=global_batch_size,
        fake_logits=fake_logits,
        real_logits=real_logits,
        params=params,
        global_step=global_step,
        summary_file_writer=summary_file_writer
    )

    return real_logits, discriminator_total_loss


## train.py

In [None]:
%%writefile vanilla_gan_module/trainer/train.py
import tensorflow as tf

from . import train_and_eval


def get_variables_and_gradients(
    loss,
    network,
    gradient_tape,
    params,
    scope
):
    """Gets variables and gradients from model wrt. loss.

    Args:
        loss: tensor, shape of [].
        network: instance of network; either `Generator` or `Discriminator`.
        gradient_tape: instance of `GradientTape`.
        params: dict, user passed parameters.
        scope: str, the name of the network of interest.

    Returns:
        Lists of network's variables and gradients.
    """
    # Get trainable variables.
    variables = network.get_model().trainable_variables

    # Get gradients from gradient tape.
    gradients = gradient_tape.gradient(
        target=loss, sources=variables
    )

    # Clip gradients.
    if params["{}_clip_gradients".format(scope)]:
        gradients, _ = tf.clip_by_global_norm(
            t_list=gradients,
            clip_norm=params["{}_clip_gradients".format(scope)],
            name="{}_clip_by_global_norm_gradients".format(scope)
        )

    # Add variable names back in for identification.
    gradients = [
        tf.identity(
            input=g,
            name="{}_{}_gradients".format(scope, v.name[:-2])
        )
        if tf.is_tensor(x=g) else g
        for g, v in zip(gradients, variables)
    ]

    return variables, gradients


def get_generator_loss_variables_and_gradients(
    global_batch_size,
    generator,
    discriminator,
    global_step,
    summary_file_writer,
    params
):
    """Gets generator's loss, variables, and gradients.

    Args:
        global_batch_size: int, global batch size for distribution.
        generator: instance of `Generator`.
        discriminator: instance of `Discriminator`.
        params: dict, user passed parameters.
        global_step: int, current global step for training.
        summary_file_writer: summary file writer.

    Returns:
        Generator's loss, variables, and gradients.
    """
    with tf.GradientTape() as generator_tape:
        # Get generator loss.
        _, generator_loss = train_and_eval.generator_loss_phase(
            global_batch_size,
            generator,
            discriminator,
            params,
            global_step,
            summary_file_writer,
            mode="TRAIN",
            training=True
        )

    # Get variables and gradients from generator wrt. loss.
    variables, gradients = get_variables_and_gradients(
        loss=generator_loss,
        network=generator,
        gradient_tape=generator_tape,
        params=params,
        scope="generator"
    )

    return generator_loss, variables, gradients


def get_discriminator_loss_variables_and_gradients(
    global_batch_size,
    real_images,
    generator,
    discriminator,
    global_step,
    summary_file_writer,
    params
):
    """Gets discriminator's loss, variables, and gradients.

    Args:
        global_batch_size: int, global batch size for distribution.
        real_images: tensor, real images of shape
            [batch_size, height * width * depth].
        generator: instance of `Generator`.
        discriminator: instance of `Discriminator`.
        params: dict, user passed parameters.
        global_step: int, current global step for training.
        summary_file_writer: summary file writer.

    Returns:
        Discriminator's loss, variables, and gradients.
    """
    with tf.GradientTape() as discriminator_tape:
        # Get fake logits from generator.
        fake_logits, _ = train_and_eval.generator_loss_phase(
            global_batch_size,
            generator,
            discriminator,
            params,
            global_step,
            summary_file_writer,
            mode="TRAIN",
            training=False
        )

        # Get discriminator loss.
        _, discriminator_loss = train_and_eval.discriminator_loss_phase(
            global_batch_size,
            discriminator,
            real_images,
            fake_logits,
            params,
            global_step,
            summary_file_writer,
            mode="TRAIN",
            training=True
        )

    # Get variables and gradients from discriminator wrt. loss.
    variables, gradients = get_variables_and_gradients(
        loss=discriminator_loss,
        network=discriminator,
        gradient_tape=discriminator_tape,
        params=params,
        scope="discriminator"
    )

    return discriminator_loss, variables, gradients


def create_variable_and_gradient_histogram_summaries(
    variables,
    gradients,
    params,
    global_step,
    summary_file_writer,
    scope
):
    """Creates variable and gradient histogram summaries.

    Args:
        variables: list, network's trainable variables.
        gradients: list, gradients of networks trainable variables wrt. loss.
        params: dict, user passed parameters.
        global_step: int, current global step for training.
        summary_file_writer: summary file writer.
        scope: str, the name of the network of interest.
    """
    # Add summaries for TensorBoard.
    with summary_file_writer.as_default():
        with tf.summary.record_if(
            global_step % params["save_summary_steps"] == 0
        ):
            for v, g in zip(variables, gradients):
                tf.summary.histogram(
                    name="{}_variables/{}".format(scope, v.name[:-2]),
                    data=v,
                    step=global_step
                )
                if tf.is_tensor(x=g):
                    tf.summary.histogram(
                        name="{}_gradients/{}".format(scope, v.name[:-2]),
                        data=g,
                        step=global_step
                    )
            summary_file_writer.flush()


def get_select_loss_variables_and_gradients(
    global_batch_size,
    real_images,
    generator,
    discriminator,
    global_step,
    summary_file_writer,
    params,
    scope
):
    """Gets selected network's loss, variables, and gradients.

    Args:
        global_batch_size: int, global batch size for distribution.
        real_images: tensor, real images of shape
            [batch_size, height * width * depth].
        generator: instance of `Generator`.
        discriminator: instance of `Discriminator`.
        params: dict, user passed parameters.
        global_step: int, current global step for training.
        summary_file_writer: summary file writer.
        scope: str, the name of the network of interest.

    Returns:
        Selected network's loss, variables, and gradients.
    """
    with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
        # Get fake logits from generator.
        fake_logits, generator_loss = train_and_eval.generator_loss_phase(
            global_batch_size,
            generator,
            discriminator,
            params,
            global_step,
            summary_file_writer,
            mode="TRAIN",
            training=(scope == "generator")
        )

        # Get discriminator loss.
        _, discriminator_loss = train_and_eval.discriminator_loss_phase(
            global_batch_size,
            discriminator,
            real_images,
            fake_logits,
            params,
            global_step,
            summary_file_writer,
            mode="TRAIN",
            training=(scope == "discriminator")
        )

    # Create empty dicts to hold loss, variables, gradients.
    loss_dict = {}
    vars_dict = {}
    grads_dict = {}

    # Loop over generator and discriminator.
    for (loss, network, gradient_tape, scope) in zip(
        [generator_loss, discriminator_loss],
        [generator, discriminator],
        [gen_tape, dis_tape],
        ["generator", "discriminator"]
    ):
        # Get variables and gradients from generator wrt. loss.
        variables, gradients = get_variables_and_gradients(
            loss, network, gradient_tape, params, scope
        )

        # Add loss, variables, and gradients to dictionaries.
        loss_dict[scope] = loss
        vars_dict[scope] = variables
        grads_dict[scope] = gradients

        # Create variable and gradient histogram summaries.
        create_variable_and_gradient_histogram_summaries(
            variables,
            gradients,
            params,
            global_step,
            summary_file_writer,
            scope
        )

    return loss_dict[scope], vars_dict[scope], grads_dict[scope]


def train_network(variables, gradients, optimizer):
    """Trains network variables using gradients with optimizer.

    Args:
        variables: list, network's trainable variables.
        gradients: list, gradients of networks trainable variables wrt. loss.
        optimizer: instance of `Optimizer`.
    """
    # Zip together gradients and variables.
    grads_and_vars = zip(gradients, variables)

    # Applying gradients to variables using optimizer.
    optimizer.apply_gradients(grads_and_vars=grads_and_vars)


def train_discriminator(
    global_batch_size,
    features,
    generator,
    discriminator,
    discriminator_optimizer,
    params,
    global_step,
    summary_file_writer
):
    """Trains discriminator network.

    Args:
        global_batch_size: int, global batch size for distribution.
        features: dict, feature tensors from input function.
        generator: instance of `Generator`.
        discriminator: instance of `Discriminator`.
        discriminator_optimizer: instance of `Optimizer`, discriminator's
            optimizer.
        params: dict, user passed parameters.
        global_step: int, current global step for training.
        summary_file_writer: summary file writer.

    Returns:
        Discriminator loss tensor.
    """
    # Extract real images from features dictionary.
    real_images = tf.reshape(
        tensor=features["image"],
        shape=[-1, params["height"] * params["width"] * params["depth"]]
    )

    # Get gradients for training by running inputs through networks.
    if global_step % params["save_summary_steps"] == 0:
        # More computation, but needed for ALL histogram summaries.
        loss, variables, gradients = (
            get_select_loss_variables_and_gradients(
                global_batch_size,
                real_images,
                generator,
                discriminator,
                global_step,
                summary_file_writer,
                params,
                scope="discriminator"
            )
        )
    else:
        # More efficient computation.
        loss, variables, gradients = (
            get_discriminator_loss_variables_and_gradients(
                global_batch_size,
                real_images,
                generator,
                discriminator,
                global_step,
                summary_file_writer,
                params
            )
        )

    # Train discriminator network.
    train_network(variables, gradients, optimizer=discriminator_optimizer)

    return loss


def train_generator(
    global_batch_size,
    features,
    generator,
    discriminator,
    generator_optimizer,
    params,
    global_step,
    summary_file_writer
):
    """Trains generator network.

    Args:
        global_batch_size: int, global batch size for distribution.
        features: dict, feature tensors from input function.
        generator: instance of `Generator`.
        discriminator: instance of `Discriminator`.
        generator_optimizer: instance of `Optimizer`, generator's
            optimizer.
        params: dict, user passed parameters.
        global_step: int, current global step for training.
        summary_file_writer: summary file writer.

    Returns:
        Generator loss tensor.
    """
    # Get gradients for training by running inputs through networks.
    if global_step % params["save_summary_steps"] == 0:
        # Extract real images from features dictionary.
        real_images = tf.reshape(
            tensor=features["image"],
            shape=[-1, params["height"] * params["width"] * params["depth"]]
        )

        # More computation, but needed for ALL histogram summaries.
        loss, variables, gradients = (
            get_select_loss_variables_and_gradients(
                global_batch_size,
                real_images,
                generator,
                discriminator,
                global_step,
                summary_file_writer,
                params,
                scope="generator"
            )
        )
    else:
        # More efficient computation.
        loss, variables, gradients = (
            get_generator_loss_variables_and_gradients(
                global_batch_size,
                generator,
                discriminator,
                global_step,
                summary_file_writer,
                params
            )
        )

    # Train generator network.
    train_network(variables, gradients, optimizer=generator_optimizer)

    return loss


def train_step(
    global_batch_size,
    features,
    network_dict,
    optimizer_dict,
    params,
    global_step,
    summary_file_writer
):
    """Perform one train step.

    Args:
        global_batch_size: int, global batch size for distribution.
        features: dict, feature tensors from input function.
        network_dict: dict, dictionary of network objects.
        optimizer_dict: dict, dictionary of optimizer objects.
        params: dict, user passed parameters.
        global_step: int, current global step for training.
        summary_file_writer: summary file writer.

    Returns:
        Loss tensor for chosen network.
    """
    # Determine if it is time to train generator or discriminator.
    cycle_step = global_step % (
        params["discriminator_train_steps"] + params["generator_train_steps"]
    )

    # Conditionally choose to train generator or discriminator subgraph.
    if cycle_step < params["discriminator_train_steps"]:
        loss = train_discriminator(
            global_batch_size=global_batch_size,
            features=features,
            generator=network_dict["generator"],
            discriminator=network_dict["discriminator"],
            discriminator_optimizer=optimizer_dict["discriminator"],
            params=params,
            global_step=global_step,
            summary_file_writer=summary_file_writer
        )
    else:
        loss = train_generator(
            global_batch_size=global_batch_size,
            features=features,
            generator=network_dict["generator"],
            discriminator=network_dict["discriminator"],
            generator_optimizer=optimizer_dict["generator"],
            params=params,
            global_step=global_step,
            summary_file_writer=summary_file_writer
        )

    return loss


## vanilla_gan.py

In [None]:
%%writefile vanilla_gan_module/trainer/vanilla_gan.py
import tensorflow as tf

from . import discriminators
from . import generators


def instantiate_network_objects(params):
    """Instantiates generator and discriminator with parameters.

    Args:
        params: dict, user passed parameters.

    Returns:
        Dictionary of instance of `Generator` and instance of `Discriminator`.
    """
    # Instantiate generator.
    generator = generators.Generator(
        input_shape=(params["latent_size"]),
        kernel_regularizer=tf.keras.regularizers.l1_l2(
            l1=params["generator_l1_regularization_scale"],
            l2=params["generator_l2_regularization_scale"]
        ),
        bias_regularizer=None,
        name="generator",
        params=params
    )

    # Instantiate discriminator.
    discriminator = discriminators.Discriminator(
        input_shape=(
            params["height"] * params["width"] * params["depth"]
        ),
        kernel_regularizer=tf.keras.regularizers.l1_l2(
            l1=params["discriminator_l1_regularization_scale"],
            l2=params["discriminator_l2_regularization_scale"]
        ),
        bias_regularizer=None,
        name="discriminator",
        params=params
    )

    return {"generator": generator, "discriminator": discriminator}


def instantiate_optimizer(params, scope):
    """Instantiates optimizer with parameters.

    Args:
        params: dict, user passed parameters.
        scope: str, the name of the network of interest.

    Returns:
        Instance of `Optimizer`.
    """
    # Create optimizer map.
    optimizers = {
        "Adadelta": tf.keras.optimizers.Adadelta,
        "Adagrad": tf.keras.optimizers.Adagrad,
        "Adam": tf.keras.optimizers.Adam,
        "Adamax": tf.keras.optimizers.Adamax,
        "Ftrl": tf.keras.optimizers.Ftrl,
        "Nadam": tf.keras.optimizers.Nadam,
        "RMSprop": tf.keras.optimizers.RMSprop,
        "SGD": tf.keras.optimizers.SGD
    }

    # Get optimizer and instantiate it.
    if params["{}_optimizer".format(scope)] == "Adam":
        optimizer = optimizers[params["{}_optimizer".format(scope)]](
            learning_rate=params["{}_learning_rate".format(scope)],
            beta_1=params["{}_adam_beta1".format(scope)],
            beta_2=params["{}_adam_beta2".format(scope)],
            epsilon=params["{}_adam_epsilon".format(scope)],
            name="{}_{}_optimizer".format(
                scope, params["{}_optimizer".format(scope)].lower()
            )
        )
    else:
        optimizer = optimizers[params["{}_optimizer".format(scope)]](
            learning_rate=params["{}_learning_rate".format(scope)],
            name="{}_{}_optimizer".format(
                scope, params["{}_optimizer".format(scope)].lower()
            )
        )

    return optimizer


def vanilla_gan_model(params):
    """Vanilla GAN custom Estimator model function.

    Args:
        params: dict, user passed parameters.

    Returns:
        Dictionary of network objects, dictionary of models objects, and
            dictionary of optimizer objects.
    """
    # Instantiate generator and discriminator objects.
    network_dict = instantiate_network_objects(params)

    # Instantiate generator optimizer.
    generator_optimizer = instantiate_optimizer(params, scope="generator")

    # Instantiate discriminator optimizer.
    discriminator_optimizer = instantiate_optimizer(
        params, scope="discriminator"
    )

    return (
        network_dict,
        {
            "generator": generator_optimizer,
            "discriminator": discriminator_optimizer
        }
    )


## model.py

In [None]:
%%writefile vanilla_gan_module/trainer/model.py
import datetime
import os
import tensorflow as tf

from . import input
from . import vanilla_gan
from . import train


def distributed_train_step(
    strategy,
    global_batch_size,
    features,
    network_dict,
    optimizer_dict,
    params,
    global_step,
    summary_file_writer
):
    """Perform one distributed train step.

    Args:
        strategy: instance of `tf.distribute.Strategy`.
        global_batch_size: int, global batch size for distribution.
        features: dict, feature tensors from input function.
        network_dict: dict, dictionary of network objects.
        optimizer_dict: dict, dictionary of optimizer objects.
        params: dict, user passed parameters.
        global_step: int, current global step for training.
        summary_file_writer: summary file writer.

    Returns:
        Reduced loss tensor for chosen network across replicas.
    """
    if params["tf_version"] > 2.1:
        per_replica_losses = strategy.run(
            fn=train.train_step,
            kwargs={
                "global_batch_size": global_batch_size,
                "features": features,
                "network_dict": network_dict,
                "optimizer_dict": optimizer_dict,
                "params": params,
                "global_step": global_step,
                "summary_file_writer": summary_file_writer
            }
        )
    else:
        per_replica_losses = strategy.experimental_run_v2(
            fn=train.train_step,
            kwargs={
                "global_batch_size": global_batch_size,
                "features": features,
                "network_dict": network_dict,
                "optimizer_dict": optimizer_dict,
                "params": params,
                "global_step": global_step,
                "summary_file_writer": summary_file_writer
            }
        )

    return strategy.reduce(
        reduce_op=tf.distribute.ReduceOp.SUM,
        value=per_replica_losses,
        axis=None
    )


def train_and_evaluate(args):
    """Trains and evaluates Keras model.

    Args:
        args: dict, user passed parameters.
    """
    # If the list of devices is not specified in the
    # `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
    strategy = tf.distribute.MirroredStrategy()
    print("Number of devices = {}".format(strategy.num_replicas_in_sync))

    # Get input datasets. Batch size will be split evenly between replicas.
    train_dataset = input.read_dataset(
        filename=args["train_file_pattern"],
        batch_size=args["train_batch_size"] * strategy.num_replicas_in_sync,
        params=args,
        training=True
    )()

    eval_dataset = input.read_dataset(
        filename=args["eval_file_pattern"],
        batch_size=args["eval_batch_size"] * strategy.num_replicas_in_sync,
        params=args,
        training=False
    )()
    if args["eval_steps"]:
        eval_dataset = eval_dataset.take(count=args["eval_steps"])

    with strategy.scope():
        # Create distributed datasets.
        train_dist_dataset = strategy.experimental_distribute_dataset(
            dataset=train_dataset
        )
        eval_dist_dataset = strategy.experimental_distribute_dataset(
            dataset=eval_dataset
        )

        # Create iterators of distributed datasets.
        train_dist_iter = iter(train_dist_dataset)
        eval_dist_iter = iter(eval_dist_dataset)

        steps_per_epoch = args["train_dataset_length"] // args["train_batch_size"]

        # Instantiate model objects.
        network_dict, optimizer_dict = vanilla_gan.vanilla_gan_model(params=args)

        # Create checkpoint instance.
        checkpoint_dir = os.path.join(args["output_dir"], "checkpoints")
        checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
        checkpoint = tf.train.Checkpoint(
            generator_model=network_dict["generator"].get_model(),
            discriminator_model=network_dict["discriminator"].get_model(),
            generator_optimizer=optimizer_dict["generator"],
            discriminator_optimizer=optimizer_dict["discriminator"]
        )

        # Create checkpoint manager.
        checkpoint_manager = tf.train.CheckpointManager(
            checkpoint=checkpoint,
            directory=checkpoint_dir,
            max_to_keep=args["keep_checkpoint_max"]
        )

        # Restore any prior checkpoints.
        status = checkpoint.restore(
            save_path=checkpoint_manager.latest_checkpoint
        )

        # Create summary file writer.
        summary_file_writer = tf.summary.create_file_writer(
            logdir=os.path.join(args["output_dir"], "summaries"),
            name="summary_file_writer"
        )

        # Loop over datasets to perform training.
        global_step = 0
        for epoch in range(args["num_epochs"]):
            for epoch_step in range(steps_per_epoch):
                features, labels = next(train_dist_iter)

                loss = distributed_train_step(
                    strategy=strategy,
                    global_batch_size=(
                        args["train_batch_size"] * strategy.num_replicas_in_sync
                    ),
                    features=features,
                    network_dict=network_dict,
                    optimizer_dict=optimizer_dict,
                    params=args,
                    global_step=global_step,
                    summary_file_writer=summary_file_writer
                )

                if global_step % args["log_step_count_steps"] == 0:
                    print(
                        "epoch = {}, global_step = {}, epoch_step = {}, loss = {}".format(
                            epoch, global_step, epoch_step, loss
                        )
                    )
                global_step += 1

            # Checkpoint model every so many steps.
            if global_step % args["save_checkpoints_steps"] == 0:
                checkpoint_manager.save(checkpoint_number=global_step)

        # Write final checkpoint.
        checkpoint_manager.save(checkpoint_number=global_step)

        # Export SavedModel for serving.
        export_path = os.path.join(
            args["output_dir"],
            "export",
            datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        )

        # Signature will be serving_default.
        tf.saved_model.save(
            obj=network_dict["generator"].get_model(), export_dir=export_path
        )


## task.py

In [None]:
%%writefile vanilla_gan_module/trainer/task.py
import argparse
import json
import os

from . import model


def convert_string_to_bool(string):
    """Converts string to bool.
    Args:
        string: str, string to convert.
    Returns:
        Boolean conversion of string.
    """
    return False if string.lower() == "false" else True


def convert_string_to_none_or_float(string):
    """Converts string to None or float.

    Args:
        string: str, string to convert.

    Returns:
        None or float conversion of string.
    """
    return None if string.lower() == "none" else float(string)


def convert_string_to_none_or_int(string):
    """Converts string to None or int.

    Args:
        string: str, string to convert.

    Returns:
        None or int conversion of string.
    """
    return None if string.lower() == "none" else int(string)


def convert_string_to_list_of_ints(string, sep):
    """Converts string to list of ints.

    Args:
        string: str, string to convert.
        sep: str, separator string.

    Returns:
        List of ints conversion of string.
    """
    return [int(x) for x in string.split(sep)]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # File arguments.
    parser.add_argument(
        "--train_file_pattern",
        help="GCS location to read training data.",
        required=True
    )
    parser.add_argument(
        "--eval_file_pattern",
        help="GCS location to read evaluation data.",
        required=True
    )
    parser.add_argument(
        "--output_dir",
        help="GCS location to write checkpoints and export models.",
        required=True
    )
    parser.add_argument(
        "--job-dir",
        help="This model ignores this field, but it is required by gcloud.",
        default="junk"
    )

    # Training parameters.
    parser.add_argument(
        "--tf_version",
        help="Version of TensorFlow",
        type=float,
        default=2.2
    )
    parser.add_argument(
        "--num_epochs",
        help="Number of epochs to train for.",
        type=int,
        default=100
    )
    parser.add_argument(
        "--train_dataset_length",
        help="Number of examples in one epoch of training set",
        type=int,
        default=100
    )
    parser.add_argument(
        "--train_batch_size",
        help="Number of examples in training batch.",
        type=int,
        default=32
    )
    parser.add_argument(
        "--log_step_count_steps",
        help="How many steps to train before writing steps and loss to log.",
        type=int,
        default=100
    )
    parser.add_argument(
        "--save_summary_steps",
        help="How many steps to train before saving a summary.",
        type=int,
        default=100
    )
    parser.add_argument(
        "--save_checkpoints_steps",
        help="How many steps to train before saving a checkpoint.",
        type=int,
        default=100
    )
    parser.add_argument(
        "--keep_checkpoint_max",
        help="Max number of checkpoints to keep.",
        type=int,
        default=100
    )
    parser.add_argument(
        "--input_fn_autotune",
        help="Whether to autotune input function performance.",
        type=str,
        default="True"
    )

    # Eval parameters.
    parser.add_argument(
        "--eval_batch_size",
        help="Number of examples in evaluation batch.",
        type=int,
        default=32
    )
    parser.add_argument(
        "--eval_steps",
        help="Number of steps to evaluate for.",
        type=str,
        default="None"
    )

    # Image parameters.
    parser.add_argument(
        "--height",
        help="Height of image.",
        type=int,
        default=32
    )
    parser.add_argument(
        "--width",
        help="Width of image.",
        type=int,
        default=32
    )
    parser.add_argument(
        "--depth",
        help="Depth of image.",
        type=int,
        default=3
    )

    # Generator parameters.
    parser.add_argument(
        "--latent_size",
        help="The latent size of the noise vector.",
        type=int,
        default=3
    )
    parser.add_argument(
        "--generator_hidden_units",
        help="Hidden layer sizes to use for generator.",
        type=str,
        default="2,4,8"
    )
    parser.add_argument(
        "--generator_leaky_relu_alpha",
        help="The amount of leakyness of generator's leaky relus.",
        type=float,
        default=0.2
    )
    parser.add_argument(
        "--generator_final_activation",
        help="The final activation function of generator.",
        type=str,
        default="None"
    )
    parser.add_argument(
        "--generator_l1_regularization_scale",
        help="Scale factor for L1 regularization for generator.",
        type=float,
        default=0.0
    )
    parser.add_argument(
        "--generator_l2_regularization_scale",
        help="Scale factor for L2 regularization for generator.",
        type=float,
        default=0.0
    )
    parser.add_argument(
        "--generator_optimizer",
        help="Name of optimizer to use for generator.",
        type=str,
        default="Adam"
    )
    parser.add_argument(
        "--generator_learning_rate",
        help="How quickly we train our model by scaling the gradient for generator.",
        type=float,
        default=0.001
    )
    parser.add_argument(
        "--generator_adam_beta1",
        help="Adam optimizer's beta1 hyperparameter for first moment.",
        type=float,
        default=0.9
    )
    parser.add_argument(
        "--generator_adam_beta2",
        help="Adam optimizer's beta2 hyperparameter for second moment.",
        type=float,
        default=0.999
    )
    parser.add_argument(
        "--generator_adam_epsilon",
        help="Adam optimizer's epsilon hyperparameter for numerical stability.",
        type=float,
        default=1e-8
    )
    parser.add_argument(
        "--generator_clip_gradients",
        help="Global clipping to prevent gradient norm to exceed this value for generator.",
        type=str,
        default="None"
    )
    parser.add_argument(
        "--generator_train_steps",
        help="Number of steps to train generator for per cycle.",
        type=int,
        default=100
    )

    # Discriminator parameters.
    parser.add_argument(
        "--discriminator_hidden_units",
        help="Hidden layer sizes to use for discriminator.",
        type=str,
        default="2,4,8"
    )
    parser.add_argument(
        "--discriminator_leaky_relu_alpha",
        help="The amount of leakyness of discriminator's leaky relus.",
        type=float,
        default=0.2
    )
    parser.add_argument(
        "--discriminator_l1_regularization_scale",
        help="Scale factor for L1 regularization for discriminator.",
        type=float,
        default=0.0
    )
    parser.add_argument(
        "--discriminator_l2_regularization_scale",
        help="Scale factor for L2 regularization for discriminator.",
        type=float,
        default=0.0
    )
    parser.add_argument(
        "--discriminator_optimizer",
        help="Name of optimizer to use for discriminator.",
        type=str,
        default="Adam"
    )
    parser.add_argument(
        "--discriminator_learning_rate",
        help="How quickly we train our model by scaling the gradient for discriminator.",
        type=float,
        default=0.001
    )
    parser.add_argument(
        "--discriminator_adam_beta1",
        help="Adam optimizer's beta1 hyperparameter for first moment.",
        type=float,
        default=0.9
    )
    parser.add_argument(
        "--discriminator_adam_beta2",
        help="Adam optimizer's beta2 hyperparameter for second moment.",
        type=float,
        default=0.999
    )
    parser.add_argument(
        "--discriminator_adam_epsilon",
        help="Adam optimizer's epsilon hyperparameter for numerical stability.",
        type=float,
        default=1e-8
    )
    parser.add_argument(
        "--discriminator_clip_gradients",
        help="Global clipping to prevent gradient norm to exceed this value for discriminator.",
        type=str,
        default="None"
    )
    parser.add_argument(
        "--discriminator_train_steps",
        help="Number of steps to train discriminator for per cycle.",
        type=int,
        default=100
    )
    parser.add_argument(
        "--label_smoothing",
        help="Multiplier when making real labels instead of all ones.",
        type=float,
        default=0.9
    )

    # Parse all arguments.
    args = parser.parse_args()
    arguments = args.__dict__

    # Unused args provided by service.
    arguments.pop("job_dir", None)
    arguments.pop("job-dir", None)

    # Fix input_fn_autotune.
    arguments["input_fn_autotune"] = convert_string_to_bool(
        string=arguments["input_fn_autotune"]
    )

    # Fix eval steps.
    arguments["eval_steps"] = convert_string_to_none_or_int(
        string=arguments["eval_steps"])

    # Fix hidden_units.
    arguments["generator_hidden_units"] = convert_string_to_list_of_ints(
        string=arguments["generator_hidden_units"], sep=","
    )

    arguments["discriminator_hidden_units"] = convert_string_to_list_of_ints(
        string=arguments["discriminator_hidden_units"], sep=","
    )

    # Fix clip_gradients.
    arguments["generator_clip_gradients"] = convert_string_to_none_or_float(
        string=arguments["generator_clip_gradients"]
    )

    arguments["discriminator_clip_gradients"] = convert_string_to_none_or_float(
        string=arguments["discriminator_clip_gradients"]
    )

    # Append trial_id to path if we are doing hptuning.
    # This code can be removed if you are not using hyperparameter tuning.
    arguments["output_dir"] = os.path.join(
        arguments["output_dir"],
        json.loads(
            os.environ.get(
                "TF_CONFIG", "{}"
            )
        ).get("task", {}).get("trial", ""))

    # Run the training job.
    model.train_and_evaluate(arguments)
