## README.md

In [None]:
%%writefile README.md
Implementation of [Wasserstein GAN](https://arxiv.org/abs/1701.07875).

## print_object.py

In [None]:
%%writefile wgan_module/trainer/print_object.py
def print_obj(function_name, object_name, object_value):
    """Prints enclosing function, object name, and object value.

    Args:
        function_name: str, name of function.
        object_name: str, name of object.
        object_value: object, value of passed object.
    """
#     pass
    print("{}: {} = {}".format(function_name, object_name, object_value))


## image_utils.py

In [None]:
%%writefile wgan_module/trainer/image_utils.py
import tensorflow as tf

from .print_object import print_obj


def preprocess_image(image, params):
    """Preprocess image tensor.

    Args:
        image: tensor, input image with shape
            [cur_batch_size, height, width, depth].
        params: dict, user passed parameters.

    Returns:
        Preprocessed image tensor with shape
            [cur_batch_size, height, width, depth].
    """
    func_name = "preprocess_image"
    # Convert from [0, 255] -> [-1.0, 1.0] floats.
    image = tf.cast(x=image, dtype=tf.float32) * (2. / 255) - 1.0
    print_obj(func_name, "image", image)

    return image


def resize_fake_images(fake_images, params):
    """Resizes fake images to match real image sizes.

    Args:
        fake_images: tensor, fake images from generator.
        params: dict, user passed parameters.

    Returns:
        Resized image tensor.
    """
    func_name = "resize_real_image"
    print_obj("\n" + func_name, "fake_images", fake_images)

    # Resize fake images to match real image sizes.
    resized_fake_images = tf.image.resize(
        images=fake_images,
        size=[params["height"], params["width"]],
        method="nearest",
        name="resized_fake_images"
    )
    print_obj(func_name, "resized_fake_images", resized_fake_images)

    return resized_fake_images


## input.py

In [None]:
%%writefile wgan_module/trainer/input.py
import tensorflow as tf

from . import image_utils
from .print_object import print_obj


def decode_example(protos, params):
    """Decodes TFRecord file into tensors.

    Given protobufs, decode into image and label tensors.

    Args:
        protos: protobufs from TFRecord file.
        params: dict, user passed parameters.

    Returns:
        Image and label tensors.
    """
    func_name = "decode_example"
    # Create feature schema map for protos.
    features = {
        "image_raw": tf.FixedLenFeature(shape=[], dtype=tf.string),
        "label": tf.FixedLenFeature(shape=[], dtype=tf.int64)
    }

    # Parse features from tf.Example.
    parsed_features = tf.parse_single_example(
        serialized=protos, features=features
    )
    print_obj("\n" + func_name, "features", features)

    # Convert from a scalar string tensor (whose single string has
    # length height * width * depth) to a uint8 tensor with shape
    # [height * width * depth].
    image = tf.decode_raw(
        input_bytes=parsed_features["image_raw"], out_type=tf.uint8
    )
    print_obj(func_name, "image", image)

    # Reshape flattened image back into normal dimensions.
    image = tf.reshape(
        tensor=image,
        shape=[params["height"], params["width"], params["depth"]]
    )
    print_obj(func_name, "image", image)

    # Preprocess image.
    image = image_utils.preprocess_image(image=image, params=params)
    print_obj(func_name, "image", image)

    # Convert label from a scalar uint8 tensor to an int32 scalar.
    label = tf.cast(x=parsed_features["label"], dtype=tf.int32)
    print_obj(func_name, "label", label)

    return {"image": image}, label


def read_dataset(filename, mode, batch_size, params):
    """Reads TF Record data using tf.data, doing necessary preprocessing.

    Given filename, mode, batch size, and other parameters, read TF Record
    dataset using Dataset API, apply necessary preprocessing, and return an
    input function to the Estimator API.

    Args:
        filename: str, file pattern that to read into our tf.data dataset.
        mode: The estimator ModeKeys. Can be TRAIN or EVAL.
        batch_size: int, number of examples per batch.
        params: dict, dictionary of user passed parameters.

    Returns:
        An input function.
    """
    def _input_fn():
        """Wrapper input function used by Estimator API to get data tensors.

        Returns:
            Batched dataset object of dictionary of feature tensors and label
                tensor.
        """
        # Create list of files that match pattern.
        file_list = tf.gfile.Glob(filename=filename)

        # Create dataset from file list.
        if params["input_fn_autotune"]:
            dataset = tf.data.TFRecordDataset(
                filenames=file_list,
                num_parallel_reads=tf.contrib.data.AUTOTUNE
            )
        else:
            dataset = tf.data.TFRecordDataset(filenames=file_list)

        # Shuffle and repeat if training with fused op.
        if mode == tf.estimator.ModeKeys.TRAIN:
            dataset = dataset.apply(
                tf.contrib.data.shuffle_and_repeat(
                    buffer_size=50 * batch_size,
                    count=None  # indefinitely
                )
            )

        # Decode CSV file into a features dictionary of tensors, then batch.
        if params["input_fn_autotune"]:
            dataset = dataset.apply(
                tf.contrib.data.map_and_batch(
                    map_func=lambda x: decode_example(
                        protos=x,
                        params=params
                    ),
                    batch_size=batch_size,
                    num_parallel_calls=tf.contrib.data.AUTOTUNE
                )
            )
        else:
            dataset = dataset.apply(
                tf.contrib.data.map_and_batch(
                    map_func=lambda x: decode_example(
                        protos=x,
                        params=params
                    ),
                    batch_size=batch_size
                )
            )

        # Prefetch data to improve latency.
        if params["input_fn_autotune"]:
            dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
        else:
            dataset = dataset.prefetch(buffer_size=1)

        # Create a iterator, then get batch of features from example queue.
        batched_dataset = dataset.make_one_shot_iterator().get_next()

        return batched_dataset
    return _input_fn


## generator.py

In [None]:
%%writefile wgan_module/trainer/generator.py
import tensorflow as tf

from .print_object import print_obj


class Generator(object):
    """Generator that takes latent vector input and outputs image.
    Fields:
        name: str, name of `Generator`.
        kernel_regularizer: `l1_l2_regularizer` object, regularizar for kernel
            variables.
        bias_regularizer: `l1_l2_regularizer` object, regularizar for bias
            variables.
    """
    def __init__(self, kernel_regularizer, bias_regularizer, name):
        """Instantiates and builds generator network.
        Args:
            kernel_regularizer: `l1_l2_regularizer` object, regularizar for
                kernel variables.
            bias_regularizer: `l1_l2_regularizer` object, regularizar for bias
                variables.
            name: str, name of generator.
        """
        # Set name of generator.
        self.name = name

        # Regularizer for kernel weights.
        self.kernel_regularizer = kernel_regularizer

        # Regularizer for bias weights.
        self.bias_regularizer = bias_regularizer

    def get_fake_images(self, Z, mode, params):
        """Creates generator network and returns generated images.

        Args:
            Z: tensor, latent vectors of shape [cur_batch_size, latent_size].
            mode: tf.estimator.ModeKeys with values of either TRAIN, EVAL, or
                PREDICT.
            params: dict, user passed parameters.

        Returns:
            Generated image tensor of shape
                [cur_batch_size, height, width, depth].
        """
        func_name = "get_fake_images"
        print_obj("\n" + func_name, "Z", Z)

        # Dictionary containing possible final activations.
        final_activation_dict = {
            "sigmoid": tf.nn.sigmoid, "relu": tf.nn.relu, "tanh": tf.nn.tanh
        }

        with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
            # Project latent vectors.
            projection_height = params["generator_projection_dims"][0]
            projection_width = params["generator_projection_dims"][1]
            projection_depth = params["generator_projection_dims"][2]

            # shape = (
            #     cur_batch_size,
            #     projection_height * projection_width * projection_depth
            # )
            projection = tf.layers.dense(
                inputs=Z,
                units=projection_height * projection_width * projection_depth,
                activation=None,
                kernel_regularizer=self.kernel_regularizer,
                bias_regularizer=self.bias_regularizer,
                name="projection_dense_layer"
            )
            print_obj(func_name, "projection", projection)

            projection_leaky_relu = tf.nn.leaky_relu(
                features=projection,
                alpha=params["generator_leaky_relu_alpha"],
                name="projection_leaky_relu"
            )
            print_obj(
                func_name, "projection_leaky_relu", projection_leaky_relu
            )

            # Add batch normalization to keep the inputs from blowing up.
            # shape = (
            #     cur_batch_size,
            #     projection_height * projection_width * projection_depth
            # )
            projection_batch_norm = tf.layers.batch_normalization(
                inputs=projection_leaky_relu,
                training=(mode == tf.estimator.ModeKeys.TRAIN),
                name="projection_batch_norm"
            )
            print_obj(
                func_name, "projection_batch_norm", projection_batch_norm
            )

            # Reshape projection into "image".
            # shape = (
            #     cur_batch_size,
            #     projection_height,
            #     projection_width,
            #     projection_depth
            # )
            network = tf.reshape(
                tensor=projection_batch_norm,
                shape=[
                    -1, projection_height, projection_width, projection_depth
                ],
                name="projection_reshaped"
            )
            print_obj(func_name, "network", network)

            # Iteratively build upsampling layers.
            for i in range(len(params["generator_num_filters"])):
                # Add conv transpose layers with given params per layer.
                # shape = (
                #     cur_batch_size,
                #     generator_kernel_sizes[i - 1] * generator_strides[i],
                #     generator_kernel_sizes[i - 1] * generator_strides[i],
                #     generator_num_filters[i]
                # )
                network = tf.layers.conv2d_transpose(
                    inputs=network,
                    filters=params["generator_num_filters"][i],
                    kernel_size=params["generator_kernel_sizes"][i],
                    strides=params["generator_strides"][i],
                    padding="same",
                    activation=None,
                    kernel_regularizer=self.kernel_regularizer,
                    bias_regularizer=self.bias_regularizer,
                    name="layers_conv2d_tranpose_{}".format(i)
                )
                print_obj(func_name, "network", network)

                network = tf.nn.leaky_relu(
                    features=network,
                    alpha=params["generator_leaky_relu_alpha"],
                    name="leaky_relu_{}".format(i)
                )
                print_obj(func_name, "network", network)

                # Add batch normalization to keep the inputs from blowing up.
                network = tf.layers.batch_normalization(
                    inputs=network,
                    training=(mode == tf.estimator.ModeKeys.TRAIN),
                    name="layers_batch_norm_{}".format(i)
                )
                print_obj(func_name, "network", network)

            # Final conv2d transpose layer for image output.
            # shape = (cur_batch_size, height, width, depth)
            fake_images = tf.layers.conv2d_transpose(
                inputs=network,
                filters=params["generator_final_num_filters"],
                kernel_size=params["generator_final_kernel_size"],
                strides=params["generator_final_stride"],
                padding="same",
                activation=final_activation_dict.get(
                    params["generator_final_activation"].lower(), None
                ),
                kernel_regularizer=self.kernel_regularizer,
                bias_regularizer=self.bias_regularizer,
                name="layers_conv2d_tranpose_fake_images"
            )
            print_obj(func_name, "fake_images", fake_images)

        return fake_images

    def get_generator_loss(self, fake_logits):
        """Gets generator loss.

        Args:
            fake_logits: tensor, shape of
                [cur_batch_size, 1].

        Returns:
            Tensor of generator's total loss of shape [].
        """
        func_name = "get_generator_loss"
        # Calculate base generator loss.
        generator_loss = -tf.reduce_mean(
            input_tensor=fake_logits,
            name="generator_loss"
        )
        print_obj("\n" + func_name, "generator_loss", generator_loss)

        # Get regularization losses.
        generator_reg_loss = tf.losses.get_regularization_loss(
            scope="generator",
            name="generator_reg_loss"
        )
        print_obj(func_name, "generator_reg_loss", generator_reg_loss)

        # Combine losses for total losses.
        generator_total_loss = tf.math.add(
            x=generator_loss,
            y=generator_reg_loss,
            name="generator_total_loss"
        )
        print_obj(func_name, "generator_total_loss", generator_total_loss)

        # Add summaries for TensorBoard.
        tf.summary.scalar(
            name="generator_loss", tensor=generator_loss, family="losses"
        )
        tf.summary.scalar(
            name="generator_reg_loss",
            tensor=generator_reg_loss,
            family="losses"
        )
        tf.summary.scalar(
            name="generator_total_loss",
            tensor=generator_total_loss,
            family="total_losses"
        )

        return generator_total_loss


## critic.py

In [None]:
%%writefile wgan_module/trainer/critic.py
import tensorflow as tf

from .print_object import print_obj


class Critic(object):
    """Critic that takes image input and outputs logits.
    Fields:
        name: str, name of `Critic`.
        kernel_regularizer: `l1_l2_regularizer` object, regularizar for kernel
            variables.
        bias_regularizer: `l1_l2_regularizer` object, regularizar for bias
            variables.
    """
    def __init__(self, kernel_regularizer, bias_regularizer, name):
        """Instantiates and builds critic network.
        Args:
            kernel_regularizer: `l1_l2_regularizer` object, regularizar for
                kernel variables.
            bias_regularizer: `l1_l2_regularizer` object, regularizar for bias
                variables.
            name: str, name of critic.
        """
        # Set name of critic.
        self.name = name

        # Regularizer for kernel weights.
        self.kernel_regularizer = kernel_regularizer

        # Regularizer for bias weights.
        self.bias_regularizer = bias_regularizer

    def get_critic_logits(self, X, params):
        """Creates critic network and returns logits.

        Args:
            X: tensor, image tensors of shape
                [cur_batch_size, height, width, depth].
            params: dict, user passed parameters.

        Returns:
            Logits tensor of shape [cur_batch_size, 1].
        """
        func_name = "get_critic_logits"
        # Create the input layer to our CNN.
        # shape = (cur_batch_size, height * width * depth)
        network = X
        print_obj("\n" + func_name, "network", network)

        with tf.variable_scope("critic", reuse=tf.AUTO_REUSE):
            # Iteratively build downsampling layers.
            for i in range(len(params["critic_num_filters"])):
                # Add convolutional layers with given params per layer.
                # shape = (
                #     cur_batch_size,
                #     critic_kernel_sizes[i - 1] / critic_strides[i],
                #     critic_kernel_sizes[i - 1] / critic_strides[i],
                #     critic_num_filters[i]
                # )
                network = tf.layers.conv2d(
                    inputs=network,
                    filters=params["critic_num_filters"][i],
                    kernel_size=params["critic_kernel_sizes"][i],
                    strides=params["critic_strides"][i],
                    padding="same",
                    activation=None,
                    kernel_regularizer=self.kernel_regularizer,
                    bias_regularizer=self.bias_regularizer,
                    name="layers_conv2d_{}".format(i)
                )
                print_obj(func_name, "network", network)

                network = tf.nn.leaky_relu(
                    features=network,
                    alpha=params["critic_leaky_relu_alpha"],
                    name="leaky_relu_{}".format(i)
                )
                print_obj(func_name, "network", network)

                # Add some dropout for better regularization and stability.
                network = tf.layers.dropout(
                    inputs=network,
                    rate=params["critic_dropout_rates"][i],
                    name="layers_dropout_{}".format(i)
                )
                print_obj(func_name, "network", network)

            # Flatten network output.
            # shape = (
            #     cur_batch_size,
            #     (critic_kernel_sizes[-2] / critic_strides[-1]) ** 2 * critic_num_filters[-1]
            # )
            network_flat = tf.layers.Flatten()(inputs=network)
            print_obj(func_name, "network_flat", network_flat)

            # Final linear layer for logits.
            # shape = (cur_batch_size, 1)
            logits = tf.layers.dense(
                inputs=network_flat,
                units=1,
                activation=None,
                kernel_regularizer=self.kernel_regularizer,
                bias_regularizer=self.bias_regularizer,
                name="layers_dense_logits"
            )
            print_obj(func_name, "logits", logits)

        return logits

    def get_critic_loss(self, fake_logits, real_logits):
        """Gets critic loss.

        Args:
            fake_logits: tensor, shape of [cur_batch_size, 1].
            real_logits: tensor, shape of [cur_batch_size, 1].

        Returns:
            Tensor of critic's total loss of shape [].
        """
        func_name = "get_critic_loss"
        # Calculate base critic loss.
        critic_real_loss = tf.reduce_mean(
            input_tensor=real_logits, name="critic_real_loss"
        )
        print_obj("\n" + func_name, "critic_real_loss", critic_real_loss)

        critic_fake_loss = tf.reduce_mean(
            input_tensor=fake_logits, name="critic_fake_loss"
        )
        print_obj(
            func_name, "critic_fake_loss", critic_fake_loss
        )

        critic_loss = tf.subtract(
            x=critic_fake_loss, y=critic_real_loss, name="critic_loss"
        )
        print_obj(func_name, "critic_loss", critic_loss)

        # Get regularization losses.
        critic_reg_loss = tf.losses.get_regularization_loss(
            scope="critic", name="critic_reg_loss"
        )
        print_obj(func_name, "critic_reg_loss", critic_reg_loss)

        # Combine losses for total losses.
        critic_total_loss = tf.math.add(
            x=critic_loss, y=critic_reg_loss, name="critic_total_loss"
        )
        print_obj(func_name, "critic_total_loss", critic_total_loss)

        # Add summaries for TensorBoard.
        tf.summary.scalar(
            name="critic_real_loss", tensor=critic_real_loss, family="losses"
        )
        tf.summary.scalar(
            name="critic_fake_loss", tensor=critic_fake_loss, family="losses"
        )
        tf.summary.scalar(
            name="critic_loss", tensor=critic_loss, family="losses"
        )
        tf.summary.scalar(
            name="critic_reg_loss", tensor=critic_reg_loss, family="losses"
        )
        tf.summary.scalar(
            name="critic_total_loss",
            tensor=critic_total_loss,
            family="total_losses"
        )

        return critic_total_loss


## train_and_eval.py

In [None]:
%%writefile wgan_module/trainer/train_and_eval.py
import tensorflow as tf

from . import image_utils
from .print_object import print_obj


def get_logits_and_losses(features, generator, critic, mode, params):
    """Gets logits and losses for both train and eval modes.

    Args:
        features: dict, feature tensors from serving input function.
        generator: instance of generator.`Generator`.
        critic: instance of critic.`Critic`.
        mode: tf.estimator.ModeKeys with values of either TRAIN or EVAL.
        params: dict, user passed parameters.

    Returns:
        Real and fake logits and generator and critic losses.
    """
    func_name = "get_logits_and_losses"
    # Extract real images from features dictionary.
    real_images = features["image"]
    print_obj("\n" + func_name, "real_images", real_images)

    # Get dynamic batch size in case of partial batch.
    cur_batch_size = tf.shape(
        input=real_images,
        out_type=tf.int32,
        name="{}_cur_batch_size".format(func_name)
    )[0]

    # Create random noise latent vector for each batch example.
    Z = tf.random.normal(
        shape=[cur_batch_size, params["latent_size"]],
        mean=0.0,
        stddev=1.0,
        dtype=tf.float32
    )
    print_obj(func_name, "Z", Z)

    # Get generated image from generator network from gaussian noise.
    print("\nCall generator with Z = {}.".format(Z))
    fake_images = generator.get_fake_images(Z=Z, mode=mode, params=params)

    # Resize fake images to match real image sizes.
    fake_images = image_utils.resize_fake_images(fake_images, params)
    print_obj(func_name, "fake_images", fake_images)

    # Add summaries for TensorBoard.
    tf.summary.image(
        name="fake_images",
        tensor=tf.reshape(
            tensor=fake_images,
            shape=[-1, params["height"], params["width"], params["depth"]]
        ),
        max_outputs=5,
    )

    # Get fake logits from critic using generator's output image.
    print("\nCall critic with fake_images = {}.".format(fake_images))
    fake_logits = critic.get_critic_logits(
        X=fake_images, params=params
    )

    # Get real logits from critic using real image.
    print(
        "\nCall critic with real_images = {}.".format(real_images)
    )
    real_logits = critic.get_critic_logits(
        X=real_images, params=params
    )

    # Get generator total loss.
    generator_total_loss = generator.get_generator_loss(
        fake_logits=fake_logits
    )

    # Get critic total loss.
    critic_total_loss = critic.get_critic_loss(
        fake_logits=fake_logits, real_logits=real_logits
    )

    return (real_logits,
            fake_logits,
            generator_total_loss,
            critic_total_loss)


## train.py

In [None]:
%%writefile wgan_module/trainer/train.py
import tensorflow as tf

from .print_object import print_obj


def get_variables_and_gradients(loss, scope):
    """Gets variables and their gradients wrt. loss.
    Args:
        loss: tensor, shape of [].
        scope: str, the network's name to find its variables to train.
    Returns:
        Lists of variables and their gradients.
    """
    func_name = "get_variables_and_gradients"
    # Get trainable variables.
    variables = tf.trainable_variables(scope=scope)
    print_obj("\n{}_{}".format(func_name, scope), "variables", variables)

    # Get gradients.
    gradients = tf.gradients(
        ys=loss,
        xs=variables,
        name="{}_gradients".format(scope)
    )
    print_obj("\n{}_{}".format(func_name, scope), "gradients", gradients)

    # Add variable names back in for identification.
    gradients = [
        tf.identity(
            input=g,
            name="{}_{}_gradients".format(func_name, v.name[:-2])
        )
        if tf.is_tensor(x=g) else g
        for g, v in zip(gradients, variables)
    ]
    print_obj("\n{}_{}".format(func_name, scope), "gradients", gradients)

    return variables, gradients


def create_variable_and_gradient_histogram_summaries(loss_dict, params):
    """Creates variable and gradient histogram summaries.
    Args:
        loss_dict: dict, keys are scopes and values are scalar loss tensors
            for each network kind.
        params: dict, user passed parameters.
    """
    for scope, loss in loss_dict.items():
        # Get variables and their gradients wrt. loss.
        variables, gradients = get_variables_and_gradients(loss, scope)

        # Add summaries for TensorBoard.
        for g, v in zip(gradients, variables):
            tf.summary.histogram(
                name="{}".format(v.name[:-2]),
                values=v,
                family="{}_variables".format(scope)
            )
            if tf.is_tensor(x=g):
                tf.summary.histogram(
                    name="{}".format(v.name[:-2]),
                    values=g,
                    family="{}_gradients".format(scope)
                )


def train_network(loss, global_step, params, scope):
    """Trains network and returns loss and train op.

    Args:
        loss: tensor, shape of [].
        global_step: tensor, the current training step or batch in the
            training loop.
        params: dict, user passed parameters.
        scope: str, the variables that to train.

    Returns:
        Loss tensor and training op.
    """
    func_name = "train_network"
    print_obj("\n" + func_name, "scope", scope)
    # Create optimizer map.
    optimizers = {
        "Adam": tf.train.AdamOptimizer,
        "Adadelta": tf.train.AdadeltaOptimizer,
        "AdagradDA": tf.train.AdagradDAOptimizer,
        "Adagrad": tf.train.AdagradOptimizer,
        "Ftrl": tf.train.FtrlOptimizer,
        "GradientDescent": tf.train.GradientDescentOptimizer,
        "Momentum": tf.train.MomentumOptimizer,
        "ProximalAdagrad": tf.train.ProximalAdagradOptimizer,
        "ProximalGradientDescent": tf.train.ProximalGradientDescentOptimizer,
        "RMSProp": tf.train.RMSPropOptimizer
    }

    # Get optimizer and instantiate it.
    if params["{}_optimizer".format(scope)] == "Adam":
        optimizer = optimizers[params["{}_optimizer".format(scope)]](
            learning_rate=params["{}_learning_rate".format(scope)],
            beta1=params["{}_adam_beta1".format(scope)],
            beta2=params["{}_adam_beta2".format(scope)],
            epsilon=params["{}_adam_epsilon".format(scope)],
            name="{}_{}_optimizer".format(
                scope, params["{}_optimizer".format(scope)].lower()
            )
        )
    elif params["{}_optimizer".format(scope)] == "RMSProp":
        optimizer = optimizers[params["{}_optimizer".format(scope)]](
            learning_rate=params["{}_learning_rate".format(scope)],
            decay=params["{}_rmsprop_decay".format(scope)],
            momentum=params["{}_rmsprop_momentum".format(scope)],
            epsilon=params["{}_rmsprop_epsilon".format(scope)],
            name="{}_{}_optimizer".format(
                scope, params["{}_optimizer".format(scope)].lower()
            )
        )
    else:
        optimizer = optimizers[params["{}_optimizer".format(scope)]](
            learning_rate=params["{}_learning_rate".format(scope)],
            name="{}_{}_optimizer".format(
                scope, params["{}_optimizer".format(scope)].lower()
            )
        )
    print_obj("{}_{}".format(func_name, scope), "optimizer", optimizer)

    # Get gradients.
    gradients = tf.gradients(
        ys=loss,
        xs=tf.trainable_variables(scope=scope),
        name="{}_gradients".format(scope)
    )
    print_obj("\n{}_{}".format(func_name, scope), "gradients", gradients)

    # Clip gradients.
    if params["{}_clip_gradients".format(scope)]:
        gradients, _ = tf.clip_by_global_norm(
            t_list=gradients,
            clip_norm=params["{}_clip_gradients".format(scope)],
            name="{}_clip_by_global_norm_gradients".format(scope)
        )
        print_obj("\n{}_{}".format(func_name, scope), "gradients", gradients)

    # Zip back together gradients and variables.
    grads_and_vars = zip(gradients, tf.trainable_variables(scope=scope))
    print_obj(
        "{}_{}".format(func_name, scope), "grads_and_vars", grads_and_vars
    )

    # Create train op by applying gradients to variables and incrementing
    # global step.
    train_op = optimizer.apply_gradients(
        grads_and_vars=grads_and_vars,
        global_step=global_step,
        name="{}_apply_gradients".format(scope)
    )

    # Clip weights.
    if params["{}_clip_weights".format(scope)]:
        with tf.control_dependencies(control_inputs=[train_op]):
            clip_value_min = params["{}_clip_weights".format(scope)][0]
            clip_value_max = params["{}_clip_weights".format(scope)][1]

            train_op = tf.group(
                [
                    tf.assign(
                        ref=v,
                        value=tf.clip_by_value(
                            t=v,
                            clip_value_min=clip_value_min,
                            clip_value_max=clip_value_max
                        )
                    )
                    for v in tf.trainable_variables(scope=scope)
                ],
                name="{}_clip_by_value_weights".format(scope)
            )

    return loss, train_op


def get_loss_and_train_op(
        generator_total_loss, critic_total_loss, params):
    """Gets loss and train op for train mode.
    Args:
        generator_total_loss: tensor, scalar total loss of generator.
        critic_total_loss: tensor, scalar total loss of critic.
        params: dict, user passed parameters.
    Returns:
        Loss scalar tensor and train_op to be used by the EstimatorSpec.
    """
    func_name = "get_loss_and_train_op"
    # Get global step.
    global_step = tf.train.get_or_create_global_step()

    # Determine if it is time to train generator or critic.
    cycle_step = tf.mod(
        x=global_step,
        y=tf.cast(
            x=tf.add(
                x=params["critic_train_steps"],
                y=params["generator_train_steps"]
            ),
            dtype=tf.int64
        ),
        name="{}_cycle_step".format(func_name)
    )

    # Create choose critic condition.
    condition = tf.less(
        x=cycle_step, y=params["critic_train_steps"]
    )

    # Needed for batch normalization, but has no effect otherwise.
    update_ops = tf.get_collection(key=tf.GraphKeys.UPDATE_OPS)

    # Ensure update ops get updated.
    with tf.control_dependencies(control_inputs=update_ops):
        # Conditionally choose to train generator or critic subgraph.
        loss, train_op = tf.cond(
            pred=condition,
            true_fn=lambda: train_network(
                loss=critic_total_loss,
                global_step=global_step,
                params=params,
                scope="critic"
            ),
            false_fn=lambda: train_network(
                loss=generator_total_loss,
                global_step=global_step,
                params=params,
                scope="generator"
            )
        )

    return loss, train_op


## eval_metrics.py

In [None]:
%%writefile wgan_module/trainer/eval_metrics.py
import tensorflow as tf

from .print_object import print_obj


def get_eval_metric_ops(fake_logits, real_logits, params):
    """Gets eval metric ops.

    Args:
        fake_logits: tensor, shape of [cur_batch_size, 1] that came from
            critic having processed generator's output image.
        real_logits: tensor, shape of [cur_batch_size, 1] that came from
            critic having processed real image.
        params: dict, user passed parameters.

    Returns:
        Dictionary of eval metric ops.
    """
    func_name = "get_eval_metric_ops"
    # Concatenate critic logits and labels.
    critic_logits = tf.concat(
        values=[real_logits, fake_logits],
        axis=0,
        name="critic_concat_logits"
    )
    print_obj("\n" + func_name, "critic_logits", critic_logits)

    critic_labels = tf.concat(
        values=[
            tf.ones_like(tensor=real_logits),
            tf.zeros_like(tensor=fake_logits)
        ],
        axis=0,
        name="critic_concat_labels"
    )
    print_obj(func_name, "critic_labels", critic_labels)

    # Calculate critic probabilities.
    critic_probabilities = tf.nn.sigmoid(
        x=critic_logits, name="critic_probabilities"
    )
    print_obj(
        func_name, "critic_probabilities", critic_probabilities
    )

    # Create eval metric ops dictionary.
    eval_metric_ops = {
        "accuracy": tf.metrics.accuracy(
            labels=critic_labels,
            predictions=critic_probabilities,
            name="critic_accuracy"
        ),
        "precision": tf.metrics.precision(
            labels=critic_labels,
            predictions=critic_probabilities,
            name="critic_precision"
        ),
        "recall": tf.metrics.recall(
            labels=critic_labels,
            predictions=critic_probabilities,
            name="critic_recall"
        ),
        "auc_roc": tf.metrics.auc(
            labels=critic_labels,
            predictions=critic_probabilities,
            num_thresholds=200,
            curve="ROC",
            name="critic_auc_roc"
        ),
        "auc_pr": tf.metrics.auc(
            labels=critic_labels,
            predictions=critic_probabilities,
            num_thresholds=200,
            curve="PR",
            name="critic_auc_pr"
        )
    }
    print_obj(func_name, "eval_metric_ops", eval_metric_ops)

    return eval_metric_ops


## predict.py

In [None]:
%%writefile wgan_module/trainer/predict.py
import tensorflow as tf

from . import image_utils
from .print_object import print_obj


def get_predictions_and_export_outputs(features, generator, params):
    """Gets predictions and serving export outputs.

    Args:
        features: dict, feature tensors from serving input function.
        generator: instance of `Generator`.
        params: dict, user passed parameters.

    Returns:
        Predictions dictionary and export outputs dictionary.
    """
    func_name = "get_predictions_and_export_outputs"

    # Extract given latent vectors from features dictionary.
    Z = features["Z"]
    print_obj("\n" + func_name, "Z", Z)

    # Get generated images from generator using latent vector.
    generated_images = generator.get_fake_images(
        Z=Z, mode=tf.estimator.ModeKeys.PREDICT, params=params
    )
    print_obj(func_name, "generated_images", generated_images)

    # Resize generated images to match real image sizes.
    generated_images = image_utils.resize_fake_images(
        fake_images=generated_images, params=params
    )
    print_obj(func_name, "generated_images", generated_images)

    # Create predictions dictionary.
    predictions_dict = {
        "generated_images": generated_images
    }
    print_obj(func_name, "predictions_dict", predictions_dict)

    # Create export outputs.
    export_outputs = {
        "predict_export_outputs": tf.estimator.export.PredictOutput(
            outputs=predictions_dict)
    }
    print_obj(func_name, "export_outputs", export_outputs)

    return predictions_dict, export_outputs


## wgan.py

In [None]:
%%writefile wgan_module/trainer/wgan.py
import tensorflow as tf

from . import critic
from . import eval_metrics
from . import generator
from . import predict
from . import train
from . import train_and_eval
from .print_object import print_obj


def wgan_model(features, labels, mode, params):
    """Wasserstein GAN custom Estimator model function.

    Args:
        features: dict, keys are feature names and values are feature tensors.
        labels: tensor, label data.
        mode: tf.estimator.ModeKeys with values of either TRAIN, EVAL, or
            PREDICT.
        params: dict, user passed parameters.

    Returns:
        Instance of `tf.estimator.EstimatorSpec` class.
    """
    func_name = "wgan_model"
    print_obj("\n" + func_name, "features", features)
    print_obj(func_name, "labels", labels)
    print_obj(func_name, "mode", mode)
    print_obj(func_name, "params", params)

    # Loss function, training/eval ops, etc.
    predictions_dict = None
    loss = None
    train_op = None
    eval_metric_ops = None
    export_outputs = None

    # Instantiate generator.
    wgan_generator = generator.Generator(
        kernel_regularizer=tf.contrib.layers.l1_l2_regularizer(
            scale_l1=params["generator_l1_regularization_scale"],
            scale_l2=params["generator_l2_regularization_scale"]
        ),
        bias_regularizer=None,
        name="generator"
    )

    # Instantiate critic.
    wgan_critic = critic.Critic(
        kernel_regularizer=tf.contrib.layers.l1_l2_regularizer(
            scale_l1=params["critic_l1_regularization_scale"],
            scale_l2=params["critic_l2_regularization_scale"]
        ),
        bias_regularizer=None,
        name="critic"
    )

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Get predictions and export outputs.
        (predictions_dict,
         export_outputs) = predict.get_predictions_and_export_outputs(
            features=features, generator=wgan_generator, params=params
        )
    else:
        # Get logits and losses from networks for train and eval modes.
        (real_logits,
         fake_logits,
         generator_total_loss,
         critic_total_loss) = train_and_eval.get_logits_and_losses(
            features=features,
            generator=wgan_generator,
            critic=wgan_critic,
            mode=mode,
            params=params
        )

        if mode == tf.estimator.ModeKeys.TRAIN:
            # Create variable and gradient histogram summaries.
            train.create_variable_and_gradient_histogram_summaries(
                loss_dict={
                    "generator": generator_total_loss,
                    "critic": critic_total_loss
                },
                params=params
            )

            # Get loss and train op for EstimatorSpec.
            loss, train_op = train.get_loss_and_train_op(
                generator_total_loss=generator_total_loss,
                critic_total_loss=critic_total_loss,
                params=params
            )
        else:
            # Set eval loss.
            loss = critic_total_loss

            # Get eval metrics.
            eval_metric_ops = eval_metrics.get_eval_metric_ops(
                real_logits=real_logits,
                fake_logits=fake_logits,
                params=params
            )

    # Return EstimatorSpec
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions_dict,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        export_outputs=export_outputs
    )


## serving.py

In [None]:
%%writefile wgan_module/trainer/serving.py
import tensorflow as tf

from .print_object import print_obj


def serving_input_fn(params):
    """Serving input function.

    Args:
        params: dict, user passed parameters.

    Returns:
        ServingInputReceiver object containing features and receiver tensors.
    """
    func_name = "serving_input_fn"
    # Create placeholders to accept data sent to the model at serving time.
    # shape = (batch_size,)
    feature_placeholders = {
        "Z": tf.placeholder(
            dtype=tf.float32,
            shape=[None, params["latent_size"]],
            name="serving_input_placeholder_Z"
        )
    }
    print_obj("\n" + func_name, "feature_placeholders", feature_placeholders)

    # Create clones of the feature placeholder tensors so that the SavedModel
    # SignatureDef will point to the placeholder.
    features = {
        key: tf.identity(
            input=value,
            name="{}_identity_placeholder_{}".format(func_name, key)
        )
        for key, value in feature_placeholders.items()
    }
    print_obj(func_name, "features", features)

    return tf.estimator.export.ServingInputReceiver(
        features=features, receiver_tensors=feature_placeholders
    )


## model.py

In [None]:
%%writefile wgan_module/trainer/model.py
import tensorflow as tf

from . import input
from . import serving
from . import wgan
from .print_object import print_obj


def train_and_evaluate(args):
    """Trains and evaluates custom Estimator model.

    Args:
        args: dict, user passed parameters.

    Returns:
        `Estimator` object.
    """
    func_name = "train_and_evaluate"
    print_obj("\n" + func_name, "args", args)
    # Ensure filewriter cache is clear for TensorBoard events file.
    tf.summary.FileWriterCache.clear()

    # Set logging to be level of INFO.
    tf.logging.set_verbosity(tf.logging.INFO)

    # Create a RunConfig for Estimator.
    config = tf.estimator.RunConfig(
        model_dir=args["output_dir"],
        save_summary_steps=args["save_summary_steps"],
        save_checkpoints_steps=args["save_checkpoints_steps"],
        keep_checkpoint_max=args["keep_checkpoint_max"]
    )

    # Create our custom estimator using our model function.
    estimator = tf.estimator.Estimator(
        model_fn=wgan.wgan_model,
        model_dir=args["output_dir"],
        config=config,
        params=args
    )

    # Create train spec to read in our training data.
    train_spec = tf.estimator.TrainSpec(
        input_fn=input.read_dataset(
            filename=args["train_file_pattern"],
            mode=tf.estimator.ModeKeys.TRAIN,
            batch_size=args["train_batch_size"],
            params=args
        ),
        max_steps=args["train_steps"]
    )

    # Create exporter to save out the complete model to disk.
    exporter = tf.estimator.LatestExporter(
        name="exporter",
        serving_input_receiver_fn=lambda: serving.serving_input_fn(args)
    )

    # Create eval spec to read in our validation data and export our model.
    eval_spec = tf.estimator.EvalSpec(
        input_fn=input.read_dataset(
            filename=args["eval_file_pattern"],
            mode=tf.estimator.ModeKeys.EVAL,
            batch_size=args["eval_batch_size"],
            params=args
        ),
        steps=args["eval_steps"],
        start_delay_secs=args["start_delay_secs"],
        throttle_secs=args["throttle_secs"],
        exporters=exporter
    )

    # Create train and evaluate loop to train and evaluate our estimator.
    tf.estimator.train_and_evaluate(
        estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)

    return estimator


## task.py

In [None]:
%%writefile wgan_module/trainer/task.py
import argparse
import json
import os

from . import model


def convert_string_to_bool(string):
    """Converts string to bool.
    Args:
        string: str, string to convert.
    Returns:
        Boolean conversion of string.
    """
    return False if string.lower() == "false" else True


def convert_string_to_none_or_float(string):
    """Converts string to None or float.

    Args:
        string: str, string to convert.

    Returns:
        None or float conversion of string.
    """
    return None if string.lower() == "none" else float(string)


def convert_string_to_none_or_int(string):
    """Converts string to None or int.

    Args:
        string: str, string to convert.

    Returns:
        None or int conversion of string.
    """
    return None if string.lower() == "none" else int(string)


def convert_string_to_list_of_ints(string, sep):
    """Converts string to list of ints.

    Args:
        string: str, string to convert.
        sep: str, separator string.

    Returns:
        List of ints conversion of string.
    """
    if not string:
        return []
    return [int(x) for x in string.split(sep)]


def convert_string_to_list_of_floats(string, sep):
    """Converts string to list of floats.

    Args:
        string: str, string to convert.
        sep: str, separator string.

    Returns:
        List of floats conversion of string.
    """
    if not string:
        return []
    return [float(x) for x in string.split(sep)]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # File arguments.
    parser.add_argument(
        "--train_file_pattern",
        help="GCS location to read training data.",
        required=True
    )
    parser.add_argument(
        "--eval_file_pattern",
        help="GCS location to read evaluation data.",
        required=True
    )
    parser.add_argument(
        "--output_dir",
        help="GCS location to write checkpoints and export models.",
        required=True
    )
    parser.add_argument(
        "--job-dir",
        help="This model ignores this field, but it is required by gcloud.",
        default="junk"
    )

    # Training parameters.
    parser.add_argument(
        "--train_batch_size",
        help="Number of examples in training batch.",
        type=int,
        default=32
    )
    parser.add_argument(
        "--train_steps",
        help="Number of steps to train for.",
        type=int,
        default=100
    )
    parser.add_argument(
        "--save_summary_steps",
        help="How many steps to train before saving a summary.",
        type=int,
        default=100
    )
    parser.add_argument(
        "--save_checkpoints_steps",
        help="How many steps to train before saving a checkpoint.",
        type=int,
        default=100
    )
    parser.add_argument(
        "--keep_checkpoint_max",
        help="Max number of checkpoints to keep.",
        type=int,
        default=100
    )
    parser.add_argument(
        "--input_fn_autotune",
        help="Whether to autotune input function performance.",
        type=str,
        default="True"
    )

    # Eval parameters.
    parser.add_argument(
        "--eval_batch_size",
        help="Number of examples in evaluation batch.",
        type=int,
        default=32
    )
    parser.add_argument(
        "--eval_steps",
        help="Number of steps to evaluate for.",
        type=str,
        default="None"
    )
    parser.add_argument(
        "--start_delay_secs",
        help="Number of seconds to wait before first evaluation.",
        type=int,
        default=60
    )
    parser.add_argument(
        "--throttle_secs",
        help="Number of seconds to wait between evaluations.",
        type=int,
        default=120
    )

    # Image parameters.
    parser.add_argument(
        "--height",
        help="Height of image.",
        type=int,
        default=32
    )
    parser.add_argument(
        "--width",
        help="Width of image.",
        type=int,
        default=32
    )
    parser.add_argument(
        "--depth",
        help="Depth of image.",
        type=int,
        default=3
    )

    # Generator parameters.
    parser.add_argument(
        "--latent_size",
        help="The latent size of the noise vector.",
        type=int,
        default=3
    )
    parser.add_argument(
        "--generator_projection_dims",
        help="The 3D dimensions to project latent noise vector into.",
        type=str,
        default="8,8,256"
    )
    parser.add_argument(
        "--generator_num_filters",
        help="Number of filters for generator conv layers.",
        type=str,
        default="128, 64"
    )
    parser.add_argument(
        "--generator_kernel_sizes",
        help="Kernel sizes for generator conv layers.",
        type=str,
        default="5,5"
    )
    parser.add_argument(
        "--generator_strides",
        help="Strides for generator conv layers.",
        type=str,
        default="1,2"
    )
    parser.add_argument(
        "--generator_final_num_filters",
        help="Number of filters for final generator conv layer.",
        type=int,
        default=3
    )
    parser.add_argument(
        "--generator_final_kernel_size",
        help="Kernel sizes for final generator conv layer.",
        type=int,
        default=5
    )
    parser.add_argument(
        "--generator_final_stride",
        help="Strides for final generator conv layer.",
        type=int,
        default=2
    )
    parser.add_argument(
        "--generator_leaky_relu_alpha",
        help="The amount of leakyness of generator's leaky relus.",
        type=float,
        default=0.2
    )
    parser.add_argument(
        "--generator_final_activation",
        help="The final activation function of generator.",
        type=str,
        default="None"
    )
    parser.add_argument(
        "--generator_l1_regularization_scale",
        help="Scale factor for L1 regularization for generator.",
        type=float,
        default=0.0
    )
    parser.add_argument(
        "--generator_l2_regularization_scale",
        help="Scale factor for L2 regularization for generator.",
        type=float,
        default=0.0
    )
    parser.add_argument(
        "--generator_optimizer",
        help="Name of optimizer to use for generator.",
        type=str,
        default="Adam"
    )
    parser.add_argument(
        "--generator_learning_rate",
        help="How quickly we train our model by scaling the gradient for generator.",
        type=float,
        default=0.1
    )
    parser.add_argument(
        "--generator_adam_beta1",
        help="Adam optimizer's beta1 hyperparameter for first moment.",
        type=float,
        default=0.9
    )
    parser.add_argument(
        "--generator_adam_beta2",
        help="Adam optimizer's beta2 hyperparameter for second moment.",
        type=float,
        default=0.999
    )
    parser.add_argument(
        "--generator_adam_epsilon",
        help="Adam optimizer's epsilon hyperparameter for numerical stability.",
        type=float,
        default=1e-8
    )
    parser.add_argument(
        "--generator_rmsprop_decay",
        help="RMSProp optimizer's decay hyperparameter for discounting factor for the history/coming gradient.",
        type=float,
        default=0.9
    )
    parser.add_argument(
        "--generator_rmsprop_momentum",
        help="RMSProp optimizer's momentum hyperparameter for first moment.",
        type=float,
        default=0.999
    )
    parser.add_argument(
        "--generator_rmsprop_epsilon",
        help="RMSProp optimizer's epsilon hyperparameter for numerical stability.",
        type=float,
        default=1e-8
    )
    parser.add_argument(
        "--generator_clip_gradients",
        help="Global clipping to prevent gradient norm to exceed this value for generator.",
        type=str,
        default="None"
    )
    parser.add_argument(
        "--generator_clip_weights",
        help="Clip weights within this range for generator.",
        type=str,
        default="None"
    )
    parser.add_argument(
        "--generator_train_steps",
        help="Number of steps to train generator for per cycle.",
        type=int,
        default=100
    )

    # Critic parameters.
    parser.add_argument(
        "--critic_num_filters",
        help="Number of filters for critic conv layers.",
        type=str,
        default="64, 128"
    )
    parser.add_argument(
        "--critic_kernel_sizes",
        help="Kernel sizes for critic conv layers.",
        type=str,
        default="5,5"
    )
    parser.add_argument(
        "--critic_strides",
        help="Strides for critic conv layers.",
        type=str,
        default="1,2"
    )
    parser.add_argument(
        "--critic_dropout_rates",
        help="Dropout rates for critic dropout layers.",
        type=str,
        default="0.3,0.3"
    )
    parser.add_argument(
        "--critic_leaky_relu_alpha",
        help="The amount of leakyness of critic's leaky relus.",
        type=float,
        default=0.2
    )
    parser.add_argument(
        "--critic_l1_regularization_scale",
        help="Scale factor for L1 regularization for critic.",
        type=float,
        default=0.0
    )
    parser.add_argument(
        "--critic_l2_regularization_scale",
        help="Scale factor for L2 regularization for critic.",
        type=float,
        default=0.0
    )
    parser.add_argument(
        "--critic_optimizer",
        help="Name of optimizer to use for critic.",
        type=str,
        default="Adam"
    )
    parser.add_argument(
        "--critic_learning_rate",
        help="How quickly we train our model by scaling the gradient for critic.",
        type=float,
        default=0.1
    )
    parser.add_argument(
        "--critic_adam_beta1",
        help="Adam optimizer's beta1 hyperparameter for first moment.",
        type=float,
        default=0.9
    )
    parser.add_argument(
        "--critic_adam_beta2",
        help="Adam optimizer's beta2 hyperparameter for second moment.",
        type=float,
        default=0.999
    )
    parser.add_argument(
        "--critic_adam_epsilon",
        help="Adam optimizer's epsilon hyperparameter for numerical stability.",
        type=float,
        default=1e-8
    )
    parser.add_argument(
        "--critic_rmsprop_decay",
        help="RMSProp optimizer's decay hyperparameter for discounting factor for the history/coming gradient.",
        type=float,
        default=0.9
    )
    parser.add_argument(
        "--critic_rmsprop_momentum",
        help="RMSProp optimizer's momentum hyperparameter for first moment.",
        type=float,
        default=0.999
    )
    parser.add_argument(
        "--critic_rmsprop_epsilon",
        help="RMSProp optimizer's epsilon hyperparameter for numerical stability.",
        type=float,
        default=1e-8
    )
    parser.add_argument(
        "--critic_clip_gradients",
        help="Global clipping to prevent gradient norm to exceed this value for critic.",
        type=str,
        default="None"
    )
    parser.add_argument(
        "--critic_clip_weights",
        help="Clip weights within this range for critic.",
        type=str,
        default="None"
    )
    parser.add_argument(
        "--critic_train_steps",
        help="Number of steps to train critic for per cycle.",
        type=int,
        default=100
    )

    # Parse all arguments.
    args = parser.parse_args()
    arguments = args.__dict__

    # Unused args provided by service.
    arguments.pop("job_dir", None)
    arguments.pop("job-dir", None)

    # Fix input_fn_autotune.
    arguments["input_fn_autotune"] = convert_string_to_bool(
        string=arguments["input_fn_autotune"]
    )

    # Fix eval steps.
    arguments["eval_steps"] = convert_string_to_none_or_int(
        string=arguments["eval_steps"])

    # Fix generator_projection_dims.
    arguments["generator_projection_dims"] = convert_string_to_list_of_ints(
        string=arguments["generator_projection_dims"], sep=","
    )

    # Fix num_filters.
    arguments["generator_num_filters"] = convert_string_to_list_of_ints(
        string=arguments["generator_num_filters"], sep=","
    )

    arguments["critic_num_filters"] = convert_string_to_list_of_ints(
        string=arguments["critic_num_filters"], sep=","
    )

    # Fix kernel_sizes.
    arguments["generator_kernel_sizes"] = convert_string_to_list_of_ints(
        string=arguments["generator_kernel_sizes"], sep=","
    )

    arguments["critic_kernel_sizes"] = convert_string_to_list_of_ints(
        string=arguments["critic_kernel_sizes"], sep=","
    )

    # Fix strides.
    arguments["generator_strides"] = convert_string_to_list_of_ints(
        string=arguments["generator_strides"], sep=","
    )

    arguments["critic_strides"] = convert_string_to_list_of_ints(
        string=arguments["critic_strides"], sep=","
    )

    # Fix critic_dropout_rates.
    arguments["critic_dropout_rates"] = convert_string_to_list_of_floats(
        string=arguments["critic_dropout_rates"], sep=","
    )

    # Fix clip_gradients.
    arguments["generator_clip_gradients"] = convert_string_to_none_or_float(
        string=arguments["generator_clip_gradients"]
    )

    arguments["critic_clip_gradients"] = convert_string_to_none_or_float(
        string=arguments["critic_clip_gradients"]
    )

    # Fix clip_weights.
    arguments["generator_clip_weights"] = convert_string_to_list_of_floats(
        string=arguments["generator_clip_weights"], sep=","
    )

    arguments["critic_clip_weights"] = convert_string_to_list_of_floats(
        string=arguments["critic_clip_weights"], sep=","
    )

    # Append trial_id to path if we are doing hptuning.
    # This code can be removed if you are not using hyperparameter tuning.
    arguments["output_dir"] = os.path.join(
        arguments["output_dir"],
        json.loads(
            os.environ.get(
                "TF_CONFIG", "{}"
            )
        ).get("task", {}).get("trial", ""))

    # Run the training job.
    model.train_and_evaluate(arguments)
