In [1]:
# Import libraries and modules
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import shutil
import tensorflow as tf
print(tf.__version__)
print(np.__version__)
# np.set_printoptions(threshold=np.inf)

1.15.3-dlenv_tfe
1.18.5


# Local Development

## Arguments

In [2]:
def calc_generator_discriminator_conv_layer_properties(
        conv_num_filters, conv_kernel_sizes, conv_strides, depth):
    """Calculates generator and discriminator conv layer properties.

    Args:
        num_filters: list, nested list of ints of the number of filters
            for each conv layer.
        kernel_sizes: list, nested list of ints of the kernel sizes for
            each conv layer.
        strides: list, nested list of ints of the strides for each conv
            layer.
        depth: int, depth dimension of images.

    Returns:
        Nested lists of conv layer properties for both generator and
            discriminator.
    """
    def make_generator(num_filters, kernel_sizes, strides, depth):
        """Calculates generator conv layer properties.

        Args:
            num_filters: list, nested list of ints of the number of filters
                for each conv layer.
            kernel_sizes: list, nested list of ints of the kernel sizes for
                each conv layer.
            strides: list, nested list of ints of the strides for each conv
                layer.
            depth: int, depth dimension of images.

        Returns:
            Nested list of conv layer properties for generator.
        """
        # Get the number of growths.
        num_growths = len(num_filters) - 1

        # Make base block.
        in_out = num_filters[0]
        base = [
            [kernel_sizes[0][i]] * 2 + in_out + [strides[0][i]] * 2
            for i in range(len(num_filters[0]))
        ]
        blocks = [base]

        # Add growth blocks.
        for i in range(1, num_growths + 1):
            in_out = [[blocks[i - 1][-1][-3], num_filters[i][0]]]
            block = [[kernel_sizes[i][0]] * 2 + in_out[0] + [strides[i][0]] * 2]
            for j in range(1, len(num_filters[i])):
                in_out.append([block[-1][-3], num_filters[i][j]])
                block.append(
                    [kernel_sizes[i][j]] * 2 + in_out[j] + [strides[i][j]] * 2
                )
            blocks.append(block)

        # Add toRGB conv.
        blocks[-1].append([1, 1, blocks[-1][-1][-3], depth] + [1] * 2)

        return blocks

    def make_discriminator(generator):
        """Calculates discriminator conv layer properties.

        Args:
            generator: list, nested list of conv layer properties for
                generator.

        Returns:
            Nested list of conv layer properties for discriminator.
        """
        # Reverse generator.
        discriminator = generator[::-1]

        # Reverse input and output shapes.
        discriminator = [
            [
                conv[0:2] + conv[2:4][::-1] + conv[-2:]
                for conv in block[::-1]
            ]
            for block in discriminator
        ]

        return discriminator

    # Calculate conv layer properties for generator using args.
    generator = make_generator(
        conv_num_filters, conv_kernel_sizes, conv_strides, depth
    )

    # Calculate conv layer properties for discriminator using generator
    # properties.
    discriminator = make_discriminator(generator)

    return generator, discriminator


def split_up_generator_conv_layer_properties(
        generator, num_filters, strides, depth):
    """Splits up generator conv layer properties into lists.

    Args:
        generator: list, nested list of conv layer properties for
            generator.
        num_filters: list, nested list of ints of the number of filters
            for each conv layer.
        strides: list, nested list of ints of the strides for each conv
            layer.
        depth: int, depth dimension of images.

    Returns:
        Nested lists of conv layer properties for generator.
    """
    generator_base_conv_blocks = [generator[0][0:len(num_filters[0])]]

    generator_growth_conv_blocks = []
    if len(num_filters) > 1:
        generator_growth_conv_blocks = generator[1:-1] + [generator[-1][:-1]]

    generator_to_rgb_layers = [
        [[1] * 2 + [num_filters[i][0]] + [depth] + [strides[i][0]] * 2]
        for i in range(len(num_filters))
    ]

    return (generator_base_conv_blocks,
            generator_growth_conv_blocks,
            generator_to_rgb_layers)


def split_up_discriminator_conv_layer_properties(
        discriminator, num_filters, strides, depth):
    """Splits up discriminator conv layer properties into lists.

    Args:
        discriminator: list, nested list of conv layer properties for
            discriminator.
        num_filters: list, nested list of ints of the number of filters
            for each conv layer.
        strides: list, nested list of ints of the strides for each conv
            layer.
        depth: int, depth dimension of images.

    Returns:
        Nested lists of conv layer properties for discriminator.
    """
    discriminator_from_rgb_layers = [
        [[1] * 2 + [depth] + [num_filters[i][0]] + [strides[i][0]] * 2]
        for i in range(len(num_filters))
    ]

    if len(num_filters) > 1:
        discriminator_base_conv_blocks = [discriminator[-1]]
    else:
        discriminator_base_conv_blocks = [discriminator[-1][1:]]

    discriminator_growth_conv_blocks = []
    if len(num_filters) > 1:
        discriminator_growth_conv_blocks = [discriminator[0][1:]] + discriminator[1:-1]
        discriminator_growth_conv_blocks = discriminator_growth_conv_blocks[::-1]

    return (discriminator_from_rgb_layers,
            discriminator_base_conv_blocks,
            discriminator_growth_conv_blocks)


In [3]:
# Create arguments dictionary to hold all user passed parameters.
arguments = {}

# File arguments.
arguments["train_file_pattern"] = "gs://machine-learning-1234-bucket/gan/data/cifar10_car/train*.tfrecord"
arguments["eval_file_pattern"] = "gs://machine-learning-1234-bucket/gan/data/cifar10_car/test*.tfrecord"
arguments["output_dir"] = "gs://machine-learning-1234-bucket/gan/pgan/trained_model_local_cifar10_car"

# Training parameters.
arguments["dataset"] = "cifar10"
arguments["train_batch_size"] = 32
arguments["train_steps"] = 59500
arguments["use_tpu"] = False
arguments["use_estimator_train_and_evaluate"] = False
arguments["growth_idx"] = 0
arguments["previous_train_steps"] = 0
arguments["save_optimizer_metrics_to_checkpoint"] = True
arguments["save_summary_steps"] = 100
arguments["save_checkpoints_steps"] = 10000
arguments["keep_checkpoint_max"] = 100
arguments["input_fn_autotune"] = True

# Eval parameters.
arguments["eval_batch_size"] = 1
arguments["eval_steps"] = 1
arguments["start_delay_secs"] = 6000000
arguments["throttle_secs"] = 6000000
arguments["eval_on_tpu"] = True

# Serving parameters.
arguments["exports_to_keep"] = 20
arguments["export_to_tpu"] = False
arguments["export_to_cpu"] = True
arguments["predict_all_resolutions"] = True

# Image parameters.
arguments["height"] = 32
arguments["width"] = 32
arguments["depth"] = 3

# Shared parameters.
arguments["num_steps_until_growth"] = 8500
arguments["use_equalized_learning_rate"] = True

# Full lists for full 1024x1024 network growth.
full_conv_num_filters = [[512, 512], [512, 512], [512, 512], [512, 512], [256, 256], [128, 128], [64, 64], [32, 32], [16, 16]]
full_conv_kernel_sizes = [[4, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
full_conv_strides = [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]]

# Set final image size as a multiple of 2, starting at 4.
image_size = 32
prop_list_len = max(
    min(int(math.log(image_size, 2) - 1), len(full_conv_num_filters)), 1
)

# Get slices of lists.
conv_num_filters = full_conv_num_filters[0:prop_list_len]
print("conv_num_filters = {}".format(conv_num_filters))
conv_kernel_sizes = full_conv_kernel_sizes[0:prop_list_len]
print("conv_kernel_sizes = {}".format(conv_kernel_sizes))
conv_strides = full_conv_strides[0:prop_list_len]
print("conv_strides = {}".format(conv_strides))

arguments["conv_num_filters"] = conv_num_filters
arguments["conv_kernel_sizes"] = conv_kernel_sizes
arguments["conv_strides"] = conv_strides

# Truncate lists if over the 1024x1024 current limit.
if len(arguments["conv_num_filters"]) > 9:
    arguments["conv_num_filters"] = arguments["conv_num_filters"][0:10]
    arguments["conv_kernel_sizes"] = arguments["conv_kernel_sizes"][0:10]
    arguments["conv_strides"] = arguments["conv_strides"][0:10]

# Get conv layer properties for generator and discriminator.
(generator,
 discriminator) = calc_generator_discriminator_conv_layer_properties(
    arguments["conv_num_filters"],
    arguments["conv_kernel_sizes"],
    arguments["conv_strides"],
    arguments["depth"]
)

# Split up generator properties into separate lists.
(generator_base_conv_blocks,
 generator_growth_conv_blocks,
 generator_to_rgb_layers) = split_up_generator_conv_layer_properties(
    generator,
    arguments["conv_num_filters"],
    arguments["conv_strides"],
    arguments["depth"]
)
arguments["generator_base_conv_blocks"] = generator_base_conv_blocks
arguments["generator_growth_conv_blocks"] = generator_growth_conv_blocks
arguments["generator_to_rgb_layers"] = generator_to_rgb_layers

# Split up discriminator properties into separate lists.
(discriminator_from_rgb_layers,
 discriminator_base_conv_blocks,
 discriminator_growth_conv_blocks) = split_up_discriminator_conv_layer_properties(
    discriminator,
    arguments["conv_num_filters"],
    arguments["conv_strides"],
    arguments["depth"]
)
arguments["discriminator_from_rgb_layers"] = discriminator_from_rgb_layers
arguments["discriminator_base_conv_blocks"] = discriminator_base_conv_blocks
arguments["discriminator_growth_conv_blocks"] = discriminator_growth_conv_blocks

# Generator parameters.
arguments["latent_size"] = 512
arguments["normalize_latent"] = True
arguments["use_pixel_norm"] = True
arguments["pixel_norm_epsilon"] = 1e-8
arguments["generator_projection_dims"] = [4, 4, 512]
arguments["generator_leaky_relu_alpha"] = 0.2
arguments["generator_to_rgb_activation"] = "None"
arguments["generator_l1_regularization_scale"] = 0.
arguments["generator_l2_regularization_scale"] = 0.
arguments["generator_optimizer"] = "Adam"
arguments["generator_learning_rate"] = 0.001
arguments["generator_adam_beta1"] = 0.
arguments["generator_adam_beta2"] = 0.99
arguments["generator_adam_epsilon"] = 1e-8
arguments["generator_clip_gradients"] = None
arguments["generator_train_steps"] = 1

# Discriminator hyperparameters.
arguments["use_minibatch_stddev"] = True
arguments["minibatch_stddev_group_size"] = 4
arguments["minibatch_stddev_averaging"] = True
arguments["discriminator_leaky_relu_alpha"] = 0.2
arguments["discriminator_l1_regularization_scale"] = 0.
arguments["discriminator_l2_regularization_scale"] = 0.
arguments["discriminator_optimizer"] = "Adam"
arguments["discriminator_learning_rate"] = 0.001
arguments["discriminator_adam_beta1"] = 0.
arguments["discriminator_adam_beta2"] = 0.99
arguments["discriminator_adam_epsilon"] = 1e-8
arguments["discriminator_clip_gradients"] = None
arguments["discriminator_clip_gradients"] = 2.0
arguments["discriminator_gradient_penalty_coefficient"] = 10.0
arguments["epsilon_drift"] = 0.001
arguments["discriminator_train_steps"] = 1


conv_num_filters = [[512, 512], [512, 512], [512, 512], [512, 512]]
conv_kernel_sizes = [[4, 3], [3, 3], [3, 3], [3, 3]]
conv_strides = [[1, 1], [1, 1], [1, 1], [1, 1]]


In [4]:
arguments

{'train_file_pattern': 'gs://machine-learning-1234-bucket/gan/data/cifar10_car/train*.tfrecord',
 'eval_file_pattern': 'gs://machine-learning-1234-bucket/gan/data/cifar10_car/test*.tfrecord',
 'output_dir': 'gs://machine-learning-1234-bucket/gan/pgan/trained_model_local_cifar10_car',
 'dataset': 'cifar10',
 'train_batch_size': 32,
 'train_steps': 59500,
 'use_tpu': False,
 'use_estimator_train_and_evaluate': False,
 'growth_idx': 0,
 'previous_train_steps': 0,
 'save_optimizer_metrics_to_checkpoint': True,
 'save_summary_steps': 100,
 'save_checkpoints_steps': 10000,
 'keep_checkpoint_max': 100,
 'input_fn_autotune': True,
 'eval_batch_size': 1,
 'eval_steps': 1,
 'start_delay_secs': 6000000,
 'throttle_secs': 6000000,
 'eval_on_tpu': True,
 'exports_to_keep': 20,
 'export_to_tpu': False,
 'export_to_cpu': True,
 'predict_all_resolutions': True,
 'height': 32,
 'width': 32,
 'depth': 3,
 'num_steps_until_growth': 8500,
 'use_equalized_learning_rate': True,
 'conv_num_filters': [[512, 5

## print_object.py

In [5]:
def print_obj(function_name, object_name, object_value):
    """Prints enclosing function, object name, and object value.

    Args:
        function_name: str, name of function.
        object_name: str, name of object.
        object_value: object, value of passed object.
    """
#     pass
    print("{}: {} = {}".format(function_name, object_name, object_value))


## image_utils.py

In [6]:
def preprocess_image(image, params):
    """Preprocess image tensor.

    Args:
        image: tensor, input image with shape
            [cur_batch_size, height, width, depth].
        params: dict, user passed parameters.

    Returns:
        Preprocessed image tensor with shape
            [cur_batch_size, height, width, depth].
    """
    func_name = "preprocess_image"
    # Convert from [0, 255] -> [-1.0, 1.0] floats.
    image = tf.cast(x=image, dtype=tf.float32) * (2. / 255) - 1.0
    print_obj(func_name, "image", image)

    return image


def resize_real_image(image, params, block_idx):
    """Resizes real images to match the GAN's current size.

    Args:
        image: tensor, original image.
        params: dict, user passed parameters.
        block_idx: int, index of current block.

    Returns:
        Resized image tensor.
    """
    func_name = "resize_real_image"
    print_obj("\n" + func_name, "block_idx", block_idx)
    print_obj(func_name, "image", image)

    # Resize image to match GAN size at current block index.
    resized_image = tf.image.resize(
        images=image,
        size=[
            params["generator_projection_dims"][0] * (2 ** block_idx),
            params["generator_projection_dims"][1] * (2 ** block_idx)
        ],
        method="nearest",
        name="{}_resized_image_{}".format(func_name, block_idx)
    )
    print_obj(func_name, "resized_image", resized_image)

    return resized_image


def resize_real_images(image, params):
    """Resizes real images to match the GAN's current size.

    Args:
        image: tensor, original image.
        params: dict, user passed parameters.

    Returns:
        Resized image tensor.
    """
    func_name = "resize_real_images"
    print_obj("\n" + func_name, "image", image)
    # Resize real image for each block.
    if len(params["conv_num_filters"]) == 1:
        print(
            "\n: NEVER GOING TO GROW, SKIP SWITCH CASE!".format(func_name)
        )
        # If we never are going to grow, no sense using the switch case.
        # 4x4
        resized_image = resize_real_image(
            image=image, params=params, block_idx=0
        )
    else:
        if params["growth_idx"] is not None:
            block_idx = min(
                (params["growth_idx"] - 1) // 2 + 1,
                len(params["conv_num_filters"]) - 1
            )
            resized_image = resize_real_image(
                image=image, params=params, block_idx=block_idx
            )
        else:
            # Find growth index based on global step and growth frequency.
            growth_index = tf.add(
                x=tf.floordiv(
                    x=tf.minimum(
                        x=tf.cast(
                            x=tf.floordiv(
                                x=tf.train.get_or_create_global_step() - 1,
                                y=params["num_steps_until_growth"],
                                name="{}_global_step_floordiv".format(
                                    func_name
                                )
                            ),
                            dtype=tf.int32
                        ),
                        y=(len(params["conv_num_filters"]) - 1) * 2
                    ) - 1,
                    y=2
                ),
                y=1,
                name="{}_growth_index".format(func_name)
            )

            # Switch to case based on number of steps for resized image.
            resized_image = tf.switch_case(
                branch_index=growth_index,
                branch_fns=[
                    # 4x4
                    lambda: resize_real_image(
                        image=image, params=params, block_idx=0
                    ),
                    # 8x8
                    lambda: resize_real_image(
                        image=image,
                        params=params,
                        block_idx=min(1, len(params["conv_num_filters"]) - 1)
                    ),
                    # 16x16
                    lambda: resize_real_image(
                        image=image,
                        params=params,
                        block_idx=min(2, len(params["conv_num_filters"]) - 1)
                    ),
                    # 32x32
                    lambda: resize_real_image(
                        image=image,
                        params=params,
                        block_idx=min(3, len(params["conv_num_filters"]) - 1)
                    ),
                    # 64x64
                    lambda: resize_real_image(
                        image=image,
                        params=params,
                        block_idx=min(4, len(params["conv_num_filters"]) - 1)
                    ),
                    # 128x128
                    lambda: resize_real_image(
                        image=image,
                        params=params,
                        block_idx=min(5, len(params["conv_num_filters"]) - 1)
                    ),
                    # 256x256
                    lambda: resize_real_image(
                        image=image,
                        params=params,
                        block_idx=min(6, len(params["conv_num_filters"]) - 1)
                    ),
                    # 512x512
                    lambda: resize_real_image(
                        image=image,
                        params=params,
                        block_idx=min(7, len(params["conv_num_filters"]) - 1)
                    ),
                    # 1024x1024
                    lambda: resize_real_image(
                        image=image,
                        params=params,
                        block_idx=min(8, len(params["conv_num_filters"]) - 1)
                    )
                ],
                name="{}_switch_case_resized_image".format(func_name)
            )
    print_obj(func_name, "resized_image", resized_image)

    return resized_image


## input.py

In [7]:
def decode_example(protos, params):
    """Decodes TFRecord file into tensors.

    Given protobufs, decode into image and label tensors.

    Args:
        protos: protobufs from TFRecord file.
        params: dict, user passed parameters.

    Returns:
        Image and label tensors.
    """
    func_name = "decode_example"
    # Create feature schema map for protos.
    if params["dataset"] == "cifar10":
        features = {
            "image_raw": tf.FixedLenFeature(shape=[], dtype=tf.string),
            "label": tf.FixedLenFeature(shape=[], dtype=tf.int64)
        }
    elif params["dataset"] == "celeba_hq":
        features = {
            "image_raw": tf.FixedLenFeature(shape=[], dtype=tf.string)
        }

    # Parse features from tf.Example.
    parsed_features = tf.parse_single_example(
        serialized=protos, features=features
    )
    print_obj("\n" + func_name, "features", features)

    # Convert from a scalar string tensor (whose single string has
    # length height * width * depth) to a uint8 tensor with shape
    # [height * width * depth].
    if params["dataset"] == "cifar10":
        image = tf.decode_raw(
            input_bytes=parsed_features["image_raw"], out_type=tf.uint8
        )
    elif params["dataset"] == "celeba_hq":
        image = tf.image.decode_jpeg(
            contents=parsed_features["image_raw"], channels=3
        )
    print_obj(func_name, "image", image)

    # Reshape flattened image back into normal dimensions.
    image = tf.reshape(
        tensor=image,
        shape=[params["height"], params["width"], params["depth"]]
    )
    print_obj(func_name, "image", image)

    # Preprocess image.
    image = preprocess_image(image=image, params=params)
    print_obj(func_name, "image", image)

    if params["dataset"] == "cifar10":
        # Convert label from a scalar uint8 tensor to an int32 scalar.
        label = tf.cast(x=parsed_features["label"], dtype=tf.int32)
    elif params["dataset"] == "celeba_hq":
        label = tf.zeros(shape=[], dtype=tf.int32)
    print_obj(func_name, "label", label)

    return {"image": image}, label


def set_static_shape(features, labels, batch_size, params):
    """Sets static shape of batched input tensors in dataset.

    Args:
        features: dict, keys are feature names and values are feature tensors.
        labels: tensor, label data.
        batch_size: int, number of examples per batch.

    Returns:
        Features tensor dictionary and labels tensor.
    """
    features["image"].set_shape(
        features["image"].get_shape().merge_with(
            tf.TensorShape([batch_size, None, None, None])
        )
    )
    labels.set_shape(
        labels.get_shape().merge_with(tf.TensorShape([batch_size]))
    )

    return features, labels


def read_dataset(filename, mode, batch_size, params):
    """Reads TF Record data using tf.data, doing necessary preprocessing.

    Given filename, mode, batch size, and other parameters, read TF Record
    dataset using Dataset API, apply necessary preprocessing, and return an
    input function to the Estimator API.

    Args:
        filename: str, file pattern that to read into our tf.data dataset.
        mode: The estimator ModeKeys. Can be TRAIN or EVAL.
        batch_size: int, number of examples per batch.
        params: dict, dictionary of user passed parameters.

    Returns:
        An input function.
    """
    def fetch_dataset(filename):
        """Fetches TFRecord Dataset from given filename.

        Args:
            filename: str, name of TFRecord file.
        Returns:
            Dataset containing TFRecord Examples.
        """
        buffer_size = 8 * 1024 * 1024  # 8 MiB per file
        dataset = tf.data.TFRecordDataset(
            filenames=filename, buffer_size=buffer_size
        )

        return dataset

    def _input_fn(params):
        """Wrapper input function used by Estimator API to get data tensors.

        Args:
            params: dict, created by TPU job that contains the per core batch
                size.
        Returns:
            Batched dataset object of dictionary of feature tensors and label
                tensor.
        """
        # Extract per core batch size from created dict for TPU.
        batch_size = params["batch_size"]

        # Determine if we are in train or eval mode.
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        # Create dataset to contain list of files matching pattern.
        dataset = tf.data.Dataset.list_files(
            file_pattern=filename, shuffle=is_training
        )

        # Repeat dataset files indefinitely if in training.
        if is_training:
            dataset = dataset.repeat()

        # Parallel interleaves multiple files at once with map function.
        dataset = dataset.apply(
            tf.contrib.data.parallel_interleave(
                map_func=fetch_dataset, cycle_length=64, sloppy=True
            )
        )

        # Shuffle the Dataset TFRecord Examples if in training.
        if is_training:
            dataset = dataset.shuffle(buffer_size=1024)

        # Decode CSV file into a features dictionary of tensors, then batch.
        dataset = dataset.apply(
            tf.contrib.data.map_and_batch(
                map_func=lambda x: decode_example(
                    protos=x,
                    params=params
                ),
                batch_size=batch_size,
                num_parallel_batches=8,
                drop_remainder=True,
            )
        )

        # Assign static shape, namely make the batch size axis static.
        dataset = dataset.map(
            map_func=lambda x, y: set_static_shape(
                features=x, labels=y, batch_size=batch_size, params=params
            )
        )

        # Prefetch data to improve latency.
        if params["input_fn_autotune"]:
            dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
        else:
            dataset = dataset.prefetch(buffer_size=1)

        # Create a iterator, then get batch of features from example queue.
        batched_dataset = dataset.make_one_shot_iterator().get_next()

        return batched_dataset
    return _input_fn


## equalized_learning_rate_layers.py

In [8]:
class WeightScaledDense(tf.layers.Dense):
    """Subclassing `Dense` layer to allow equalized learning rate scaling.

    Fields:
        equalized_learning_rate: bool, if want to scale layer weights to
            equalize learning rate each forward pass.
    """
    def __init__(
            self,
            units,
            activation=None,
            use_bias=True,
            kernel_initializer=None,
            bias_initializer=tf.zeros_initializer(),
            kernel_regularizer=None,
            bias_regularizer=None,
            activity_regularizer=None,
            kernel_constraint=None,
            bias_constraint=None,
            trainable=True,
            equalized_learning_rate=False,
            name=None,
            **kwargs):
        """Initializes `WeightScaledDense` layer.

        Args:
            units: Integer or Long, dimensionality of the output space.
            activation: Activation function (callable). Set it to None to maintain a
              linear activation.
            use_bias: Boolean, whether the layer uses a bias.
            kernel_initializer: Initializer function for the weight matrix.
              If `None` (default), weights are initialized using the default
              initializer used by `tf.compat.v1.get_variable`.
            bias_initializer: Initializer function for the bias.
            kernel_regularizer: Regularizer function for the weight matrix.
            bias_regularizer: Regularizer function for the bias.
            activity_regularizer: Regularizer function for the output.
            kernel_constraint: An optional projection function to be applied to the
                kernel after being updated by an `Optimizer` (e.g. used to implement
                norm constraints or value constraints for layer weights). The function
                must take as input the unprojected variable and must return the
                projected variable (which must have the same shape). Constraints are
                not safe to use when doing asynchronous distributed training.
            bias_constraint: An optional projection function to be applied to the
                bias after being updated by an `Optimizer`.
            trainable: Boolean, if `True` also add variables to the graph collection
              `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
            equalized_learning_rate: bool, if want to scale layer weights to
                equalize learning rate each forward pass.
            name: String, the name of the layer. Layers with the same name will
              share weights, but to avoid mistakes we require reuse=True in such cases.
        """
        super().__init__(
            units=units,
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            trainable=trainable,
            name=name,
            **kwargs
        )

        # Whether we will scale weights using He init every forward pass.
        self.equalized_learning_rate = equalized_learning_rate

    def call(self, inputs):
        """Calls layer and returns outputs.

        Args:
            inputs: tensor, input tensor of shape [batch_size, features].
        """
        if self.equalized_learning_rate:
            # Scale kernel weights by He init fade-in constant.
            kernel_shape = [x.value for x in self.kernel.shape]
            fan_in = kernel_shape[0]
            he_constant = tf.sqrt(x=2. / float(fan_in))
            kernel = self.kernel * he_constant
        else:
            kernel = self.kernel

        rank = len(inputs.shape)
        if rank > 2:
            # Broadcasting is required for the inputs.
            outputs = tf.tensordot(
                a=inputs, b=kernel, axes=[[rank - 1], [0]]
            )
            # Reshape the output back to the original ndim of the input.
            if not context.executing_eagerly():
                shape = inputs.shape.as_list()
                output_shape = shape[:-1] + [self.units]
                outputs.set_shape(shape=output_shape)
        else:
            inputs = tf.cast(x=inputs, dtype=self._compute_dtype)
            if isinstance(inputs, tf.SparseTensor):
                outputs = tf.sparse_tensor_dense_matmul(sp_a=inputs, b=kernel)
            else:
                outputs = tf.matmul(a=inputs, b=kernel)
        if self.use_bias:
            outputs = tf.nn.bias_add(value=outputs, bias=self.bias)
        if self.activation is not None:
            return self.activation(outputs)  # pylint: disable=not-callable
        return outputs


class WeightScaledConv2D(tf.layers.Conv2D):
    """Subclassing `WeightScaledConv2D` layer to allow equalized learning rate scaling.

    Fields:
        equalized_learning_rate: bool, if want to scale layer weights to
            equalize learning rate each forward pass.
    """
    def __init__(
            self,
            filters,
            kernel_size,
            strides=(1, 1),
            padding="valid",
            data_format="channels_last",
            dilation_rate=(1, 1),
            activation=None,
            use_bias=True,
            kernel_initializer=None,
            bias_initializer=tf.zeros_initializer(),
            kernel_regularizer=None,
            bias_regularizer=None,
            activity_regularizer=None,
            kernel_constraint=None,
            bias_constraint=None,
            trainable=True,
            equalized_learning_rate=False,
            name=None,
            **kwargs):
        """Initializes `WeightScaledConv2D` layer.

        Args:
            filters: Integer, the dimensionality of the output space (i.e. the number
              of filters in the convolution).
            kernel_size: An integer or tuple/list of 2 integers, specifying the
              height and width of the 2D convolution window.
              Can be a single integer to specify the same value for
              all spatial dimensions.
            strides: An integer or tuple/list of 2 integers,
              specifying the strides of the convolution along the height and width.
              Can be a single integer to specify the same value for
              all spatial dimensions.
              Specifying any stride value != 1 is incompatible with specifying
              any `dilation_rate` value != 1.
            padding: One of `"valid"` or `"same"` (case-insensitive).
            data_format: A string, one of `channels_last` (default) or `channels_first`.
              The ordering of the dimensions in the inputs.
              `channels_last` corresponds to inputs with shape
              `(batch, height, width, channels)` while `channels_first` corresponds to
              inputs with shape `(batch, channels, height, width)`.
            dilation_rate: An integer or tuple/list of 2 integers, specifying
              the dilation rate to use for dilated convolution.
              Can be a single integer to specify the same value for
              all spatial dimensions.
              Currently, specifying any `dilation_rate` value != 1 is
              incompatible with specifying any stride value != 1.
            activation: Activation function. Set it to None to maintain a
              linear activation.
            use_bias: Boolean, whether the layer uses a bias.
            kernel_initializer: An initializer for the convolution kernel.
            bias_initializer: An initializer for the bias vector. If None, the default
              initializer will be used.
            kernel_regularizer: Optional regularizer for the convolution kernel.
            bias_regularizer: Optional regularizer for the bias vector.
            activity_regularizer: Optional regularizer function for the output.
            kernel_constraint: Optional projection function to be applied to the
                kernel after being updated by an `Optimizer` (e.g. used to implement
                norm constraints or value constraints for layer weights). The function
                must take as input the unprojected variable and must return the
                projected variable (which must have the same shape). Constraints are
                not safe to use when doing asynchronous distributed training.
            bias_constraint: Optional projection function to be applied to the
                bias after being updated by an `Optimizer`.
            trainable: Boolean, if `True` also add variables to the graph collection
              `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
            equalized_learning_rate: bool, if want to scale layer weights to
                equalize learning rate each forward pass.
            name: A string, the name of the layer.
        """
        super().__init__(
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            data_format=data_format,
            dilation_rate=dilation_rate,
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            trainable=trainable,
            name=name,
            **kwargs
        )

        # Whether we will scale weights using He init every forward pass.
        self.equalized_learning_rate = equalized_learning_rate

    def call(self, inputs):
        """Calls layer and returns outputs.

        Args:
            inputs: tensor, input tensor of shape
                [batch_size, height, width, channels].
        """
        if self.equalized_learning_rate:
            # Scale kernel weights by He init constant.
            kernel_shape = [x.value for x in self.kernel.shape]
            fan_in = kernel_shape[0] * kernel_shape[1] * kernel_shape[2]
            he_constant = tf.sqrt(x=2. / float(fan_in))
            kernel = self.kernel * he_constant
        else:
            kernel = self.kernel

        outputs = self._convolution_op(inputs, kernel)

        if self.use_bias:
            if self.data_format == "channels_first":
                if self.rank == 1:
                    # nn.bias_add does not accept a 1D input tensor.
                    bias = tf.reshape(
                        tensor=self.bias, shape=(1, self.filters, 1)
                    )
                    outputs += bias
                else:
                    outputs = tf.nn.bias_add(
                        value=outputs, bias=self.bias, data_format="NCHW"
                    )
            else:
                outputs = tf.nn.bias_add(
                    value=outputs, bias=self.bias, data_format="NHWC"
                )

        if self.activation is not None:
            return self.activation(outputs)
        return outputs


## vector_to_image.py

In [9]:
class VectorToImage(object):
    """Convolutional network takes latent vector input and outputs image.

    Fields:
        kernel_regularizer: `l1_l2_regularizer` object, regularizar for kernel
            variables.
        bias_regularizer: `l1_l2_regularizer` object, regularizar for bias
            variables.
        projection_layer: `Dense` layer for projection of noise to image.
        conv_layer_blocks: list, lists of block layers for each block.
        to_rgb_conv_layers: list, toRGB 1x1 conv layers.
        build_vector_to_image_tensors: list, tensors used to build layer
            internals.
    """
    def __init__(
            self, kernel_regularizer, bias_regularizer, params, kind):
        """Instantiates and builds vec_to_img network.

        Args:
            kernel_regularizer: `l1_l2_regularizer` object, regularizar for
                kernel variables.
            bias_regularizer: `l1_l2_regularizer` object, regularizar for bias
                variables.
            params: dict, user passed parameters.
            kind: str, kind of `VectorToImage` instance.
        """
        # Set kind of vector to image network.
        self.kind = kind

        # Regularizer for kernel weights.
        self.kernel_regularizer = kernel_regularizer

        # Regularizer for bias weights.
        self.bias_regularizer = bias_regularizer

        # Instantiate vector to image layers.
        (self.projection_layer,
         self.conv_layer_blocks,
         self.to_rgb_conv_layers) = self.instantiate_vec_to_img_layers(params)

        # Build vector to image layer internals.
        self.build_vec_to_img_tensors = self.build_vec_to_img_layers(
            params
        )

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def instantiate_vec_to_img_projection_layer(self, params):
        """Instantiates vec_to_img projection layer.

        Projection layer projects latent noise vector into an image.

        Args:
            params: dict, user passed parameters.

        Returns:
            Latent vector projection `WeightScaledDense` layer.
        """
        func_name = "instantiate_{}_projection_layer".format(self.kind)

        # Project latent vectors.
        projection_height = params["{}_projection_dims".format(self.kind)][0]
        projection_width = params["{}_projection_dims".format(self.kind)][1]
        projection_depth = params["{}_projection_dims".format(self.kind)][2]

        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # shape = (
            #     cur_batch_size,
            #     projection_height * projection_width * projection_depth
            # )
            projection_layer = WeightScaledDense(
                units=projection_height * projection_width * projection_depth,
                activation=None,
                kernel_initializer=(
                    tf.random_normal_initializer(mean=0., stddev=1.0)
                    if params["use_equalized_learning_rate"]
                    else "he_normal"
                ),
                kernel_regularizer=self.kernel_regularizer,
                bias_regularizer=self.bias_regularizer,
                equalized_learning_rate=params["use_equalized_learning_rate"],
                name="{}_projection_layer".format(self.name)
            )

            print_obj("\n" + func_name, "projection_layer", projection_layer)

        return projection_layer

    def instantiate_vec_to_img_base_conv_layer_block(self, params):
        """Instantiates vec_to_img base conv layer block.

        Args:
            params: dict, user passed parameters.

        Returns:
            List of base block conv layers.
        """
        func_name = "instantiate_{}_base_conv_layer_block".format(self.kind)

        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Get conv block layer properties.
            conv_block = params["{}_base_conv_blocks".format(self.kind)][0]

            # Create list of base conv layers.
            base_conv_layers = [
                WeightScaledConv2D(
                    filters=conv_block[i][3],
                    kernel_size=conv_block[i][0:2],
                    strides=conv_block[i][4:6],
                    padding="same",
                    activation=None,
                    kernel_initializer=(
                        tf.random_normal_initializer(mean=0., stddev=1.0)
                        if params["use_equalized_learning_rate"]
                        else "he_normal"
                    ),
                    kernel_regularizer=self.kernel_regularizer,
                    bias_regularizer=self.bias_regularizer,
                    equalized_learning_rate=params["use_equalized_learning_rate"],
                    name="{}_base_layers_conv2d_{}_{}x{}_{}_{}".format(
                        self.name,
                        i,
                        conv_block[i][0],
                        conv_block[i][1],
                        conv_block[i][2],
                        conv_block[i][3]
                    )
                )
                for i in range(len(conv_block))
            ]
            print_obj("\n" + func_name, "base_conv_layers", base_conv_layers)

        return base_conv_layers

    def instantiate_vec_to_img_growth_layer_block(self, params, block_idx):
        """Instantiates vec_to_img growth layer block.

        Args:
            params: dict, user passed parameters.
            block_idx: int, the current growth block's index.

        Returns:
            List of growth block conv layers.
        """
        func_name = "instantiate_{}_growth_layer_block".format(self.kind)

        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Get conv block layer properties.
            conv_block = params["{}_growth_conv_blocks".format(self.kind)][block_idx]

            # Create new inner convolutional layers.
            conv_layers = [
                WeightScaledConv2D(
                    filters=conv_block[i][3],
                    kernel_size=conv_block[i][0:2],
                    strides=conv_block[i][4:6],
                    padding="same",
                    activation=None,
                    kernel_initializer=(
                        tf.random_normal_initializer(mean=0., stddev=1.0)
                        if params["use_equalized_learning_rate"]
                        else "he_normal"
                    ),
                    kernel_regularizer=self.kernel_regularizer,
                    bias_regularizer=self.bias_regularizer,
                    equalized_learning_rate=params["use_equalized_learning_rate"],
                    name="{}_growth_layers_conv2d_{}_{}_{}x{}_{}_{}".format(
                        self.name,
                        block_idx,
                        i,
                        conv_block[i][0],
                        conv_block[i][1],
                        conv_block[i][2],
                        conv_block[i][3]
                    )
                )
                for i in range(len(conv_block))
            ]
            print_obj("\n" + func_name, "conv_layers", conv_layers)

        return conv_layers

    def instantiate_vec_to_img_to_rgb_layers(self, params):
        """Instantiates vec_to_img toRGB layers of 1x1 convs.

        Args:
            params: dict, user passed parameters.

        Returns:
            List of toRGB 1x1 conv layers.
        """
        func_name = "instantiate_{}_to_rgb_layers".format(self.kind)
        
        activation_dict = {
            "sigmoid": tf.nn.sigmoid, "relu": tf.nn.relu, "tanh": tf.nn.tanh
        }

        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Get toRGB layer properties.
            to_rgb = [
                params["{}_to_rgb_layers".format(self.kind)][i][0][:]
                for i in range(
                    len(params["{}_to_rgb_layers".format(self.kind)])
                )
            ]

            # Create list to hold toRGB 1x1 convs.
            to_rgb_conv_layers = [
                WeightScaledConv2D(
                    filters=to_rgb[i][3],
                    kernel_size=to_rgb[i][0:2],
                    strides=to_rgb[i][4:6],
                    padding="same",
                    activation=activation_dict.get(
                        params["generator_to_rgb_activation"].lower(), None
                    ),
                    kernel_initializer=(
                        tf.random_normal_initializer(mean=0., stddev=1.0)
                        if params["use_equalized_learning_rate"]
                        else "he_normal"
                    ),
                    kernel_regularizer=self.kernel_regularizer,
                    bias_regularizer=self.bias_regularizer,
                    equalized_learning_rate=params["use_equalized_learning_rate"],
                    name="{}_to_rgb_layers_conv2d_{}_{}x{}_{}_{}".format(
                        self.name,
                        i,
                        to_rgb[i][0],
                        to_rgb[i][1],
                        to_rgb[i][2],
                        to_rgb[i][3]
                    )
                )
                for i in range(len(to_rgb))
            ]
            print_obj(
                "\n" + func_name, "to_rgb_conv_layers", to_rgb_conv_layers
            )

        return to_rgb_conv_layers

    def instantiate_vec_to_img_layers(self, params):
        """Instantiates layers of vec_to_img network.

        Args:
            params: dict, user passed parameters.

        Returns:
            projection_layer: `WeightScaledDense` layer for projection of noise to image.
            conv_layer_blocks: list, lists of block layers for each block.
            to_rgb_conv_layers: list, toRGB 1x1 conv layers.
        """
        func_name = "instantiate_{}_layers".format(self.kind)

        # Instantiate noise-image projection `WeightScaledDense` layer.
        projection_layer = self.instantiate_vec_to_img_projection_layer(
            params=params
        )
        print_obj("\n" + func_name, "projection_layer", projection_layer)

        # Instantiate base convolutional `WeightScaledConv2D` layers, for post-growth.
        conv_layer_blocks = [
            self.instantiate_vec_to_img_base_conv_layer_block(
                params=params
            )
        ]

        # Instantiate growth block `WeightScaledConv2D` layers.
        conv_layer_blocks.extend(
            [
                self.instantiate_vec_to_img_growth_layer_block(
                    params=params, block_idx=block_idx
                )
                for block_idx in range(
                    len(params["{}_growth_conv_blocks".format(self.kind)])
                )
            ]
        )
        print_obj(func_name, "conv_layer_blocks", conv_layer_blocks)

        # Instantiate toRGB 1x1 `WeightScaledConv2D` layers.
        to_rgb_conv_layers = self.instantiate_vec_to_img_to_rgb_layers(
            params=params
        )
        print_obj(func_name, "to_rgb_conv_layers", to_rgb_conv_layers)

        return projection_layer, conv_layer_blocks, to_rgb_conv_layers

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def build_vec_to_img_projection_layer(self, params):
        """Builds vec_to_img projection layer internals using call.

        Args:
            params: dict, user passed parameters.

        Returns:
            Latent vector projection tensor.
        """
        func_name = "build_{}_projection_layer".format(self.kind)

        # Project latent vectors.
        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # shape = (
            #     cur_batch_size,
            #     projection_height * projection_width * projection_depth
            # )
            projection_tensor = self.projection_layer(
                inputs=tf.zeros(
                    shape=[1, params["latent_size"]], dtype=tf.float32
                )
            )
            print_obj(
                "\n" + func_name, "projection_tensor", projection_tensor
            )

        return projection_tensor

    def build_vec_to_img_base_conv_layer_block(self, params):
        """Builds vec_to_img base conv layer block internals using call.

        Args:
            params: dict, user passed parameters.

        Returns:
            List of base conv tensors.
        """
        func_name = "build_{}_base_conv_layer_block".format(self.kind)

        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Get conv block layer properties.
            conv_block = params["{}_base_conv_blocks".format(self.kind)][0]

            # Create list of base conv layers.
            base_conv_tensors = [
                # The base conv block is always the 0th one.
                self.conv_layer_blocks[0][i](
                    inputs=tf.zeros(
                        shape=[1] + conv_block[i][0:3], dtype=tf.float32
                    )
                )
                for i in range(len(conv_block))
            ]
            print_obj(
                "\n" + func_name, "base_conv_tensors", base_conv_tensors
            )

        return base_conv_tensors

    def build_vec_to_img_growth_layer_block(
            self, params, growth_block_idx):
        """Builds vec_to_img growth block internals through call.

        Args:
            params: dict, user passed parameters.
            growth_block_idx: int, the current growth block's index.

        Returns:
            List of growth block tensors.
        """
        func_name = "build_{}_growth_layer_block".format(self.kind)

        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Get conv block layer properties.
            conv_block = params["{}_growth_conv_blocks".format(self.kind)][growth_block_idx]

            # Create new inner convolutional layers.
            conv_tensors = [
                self.conv_layer_blocks[1 + growth_block_idx][i](
                    inputs=tf.zeros(
                        shape=[1] + conv_block[i][0:3], dtype=tf.float32
                    )
                )
                for i in range(len(conv_block))
            ]
            print_obj("\n" + func_name, "conv_tensors", conv_tensors)

        return conv_tensors

    def build_vec_to_img_to_rgb_layers(self, params):
        """Builds vec_to_img toRGB layers of 1x1 convs internals through call.

        Args:
            params: dict, user passed parameters.

        Returns:
            List of toRGB 1x1 conv tensors.
        """
        func_name = "build_{}_to_rgb_layers".format(self.kind)

        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Get toRGB layer properties.
            to_rgb = [
                params["{}_to_rgb_layers".format(self.kind)][i][0][:]
                for i in range(
                    len(params["{}_to_rgb_layers".format(self.kind)])
                )
            ]

            # Create list to hold toRGB 1x1 convs.
            to_rgb_conv_tensors = [
                self.to_rgb_conv_layers[i](
                    inputs=tf.zeros(
                        shape=[1] + to_rgb[i][0:3], dtype=tf.float32)
                    )
                for i in range(len(to_rgb))
            ]
            print_obj(
                "\n" + func_name, "to_rgb_conv_tensors", to_rgb_conv_tensors
            )

        return to_rgb_conv_tensors

    def build_vec_to_img_layers(self, params):
        """Builds vec_to_img layer internals.

        Args:
            params: dict, user passed parameters.

        Returns:
            List of toRGB tensors.
        """
        func_name = "build_{}_layers".format(self.kind)

        # Build projection layer internals using call.
        projection_tensor = self.build_vec_to_img_projection_layer(
            params=params
        )
        print_obj("\n" + func_name, "projection_tensor", projection_tensor)

        with tf.control_dependencies(control_inputs=[projection_tensor]):
            # Build base convolutional layer block's internals using call.
            conv_block_tensors = [
                self.build_vec_to_img_base_conv_layer_block(
                    params=params
                )
            ]

            # Build growth block layer internals through call.
            conv_block_tensors.extend(
                [
                    self.build_vec_to_img_growth_layer_block(
                        params=params,
                        growth_block_idx=growth_block_idx
                    )
                    for growth_block_idx in range(
                        len(params["{}_growth_conv_blocks".format(self.kind)])
                    )
                ]
            )
            print_obj(func_name, "conv_block_tensors", conv_block_tensors)

            # Flatten block tensor lists of lists into list.
            conv_block_tensors = [
                item for sublist in conv_block_tensors for item in sublist
            ]
            print_obj(func_name, "conv_block_tensors", conv_block_tensors)

            with tf.control_dependencies(
                    control_inputs=conv_block_tensors):
                # Build toRGB 1x1 conv layer internals through call.
                to_rgb_conv_tensors = self.build_vec_to_img_to_rgb_layers(
                    params=params
                )
                print_obj(
                    func_name, "to_rgb_conv_tensors", to_rgb_conv_tensors
                )

        return to_rgb_conv_tensors

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def pixel_norm(self, X, epsilon=1e-8):
        """Normalizes the feature vector in each pixel to unit length.

        Args:
            X: tensor, image feature vectors.
            epsilon: float, small value to add to denominator for numerical
                stability.

        Returns:
            Pixel normalized feature vectors.
        """
        with tf.variable_scope("{}/pixel_norm".format(self.name)):
            return X * tf.rsqrt(
                x=tf.add(
                    x=tf.reduce_mean(
                        input_tensor=tf.square(x=X), axis=-1, keepdims=True
                    ),
                    y=epsilon
                )
            )

    def use_pixel_norm(self, X, params, epsilon=1e-8):
        """Decides based on user parameter whether to use pixel norm or not.

        Args:
            X: tensor, image feature vectors.
            params: dict, user passed parameters.
            epsilon: float, small value to add to denominator for numerical
                stability.

        Returns:
            Pixel normalized feature vectors if using pixel norm, else
                original feature vectors.
        """
        if params["use_pixel_norm"]:
            return self.pixel_norm(X=X, epsilon=epsilon)
        else:
            return X

    def use_vec_to_img_projection_layer(self, Z, params):
        """Uses projection layer to convert random noise vector into an image.

        Args:
            Z: tensor, latent vectors of shape [cur_batch_size, latent_size].
            params: dict, user passed parameters.

        Returns:
            Latent vector projection tensor.
        """
        func_name = "use_{}_projection_layer".format(self.kind)

        # Project latent vectors.
        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            if params["normalize_latent"]:
                # shape = (cur_batch_size, latent_size)
                Z = self.pixel_norm(X=Z, epsilon=params["pixel_norm_epsilon"])

            # shape = (
            #     cur_batch_size,
            #     projection_height * projection_width * projection_depth
            # )
            projection_tensor = self.projection_layer(inputs=Z)
            print_obj(
                "\n" + func_name, "projection_tensor", projection_tensor
            )

        # Reshape projection into "image".
        # shape = (
        #     cur_batch_size,
        #     projection_height,
        #     projection_width,
        #     projection_depth
        # )
        projection_tensor_reshaped = tf.reshape(
            tensor=projection_tensor,
            shape=[-1] + params["{}_projection_dims".format(self.kind)],
            name="{}_projection_reshaped".format(self.name)
        )
        print_obj(
            func_name,
            "projection_tensor_reshaped",
            projection_tensor_reshaped
        )

        # shape = (
        #     cur_batch_size,
        #     projection_height,
        #     projection_width,
        #     projection_depth
        # )
        projection_tensor_leaky = tf.nn.leaky_relu(
            features=projection_tensor_reshaped,
            alpha=params["{}_leaky_relu_alpha".format(self.kind)],
            name="{}_projection_tensor_reshaped_leaky_relu".format(self.kind)
        )
        print_obj(
            func_name, "projection_tensor_leaky", projection_tensor_leaky
        )

        # shape = (
        #     cur_batch_size,
        #     projection_height,
        #     projection_width,
        #     projection_depth
        # )
        pixel_norm_output = self.use_pixel_norm(
            X=projection_tensor_leaky,
            params=params,
            epsilon=params["pixel_norm_epsilon"]
        )
        print_obj(func_name, "pixel_norm_output", pixel_norm_output)

        return pixel_norm_output

    def fused_conv2d_pixel_norm(self, input_image, conv2d_layer, params):
        """Fused `WeightScaledConv2D` layer and pixel norm operation.

        Args:
            input_image: tensor, input image of rank 4.
            conv2d_layer: `WeightScaledConv2D` layer.
            params: dict, user passed parameters.

        Returns:
            New image tensor of rank 4.
        """
        func_name = "fused_conv2d_pixel_norm"

        conv_output = conv2d_layer(inputs=input_image)
        print_obj("\n" + func_name, "conv_output", conv_output)

        conv_output_leaky = tf.nn.leaky_relu(
            features=conv_output,
            alpha=params["{}_leaky_relu_alpha".format(self.kind)],
            name="{}_fused_conv2d_pixel_norm_leaky_relu".format(self.kind)
        )
        print_obj(func_name, "conv_output_leaky", conv_output_leaky)

        pixel_norm_output = self.use_pixel_norm(
            X=conv_output_leaky,
            params=params,
            epsilon=params["pixel_norm_epsilon"]
        )
        print_obj(func_name, "pixel_norm_output", pixel_norm_output)

        return pixel_norm_output

    def upsample_vec_to_img_image(self, image, orig_img_size, block_idx):
        """Upsamples vec_to_img image.

        Args:
            image: tensor, image created by vec_to_img conv block.
            orig_img_size: list, the height and width dimensions of the
                original image before any growth.
            block_idx: int, index of the current vec_to_img growth block.

        Returns:
            Upsampled image tensor.
        """
        func_name = "upsample_{}_image".format(self.kind)

        # Upsample from s X s to 2s X 2s image.
        upsampled_image = tf.image.resize(
            images=image,
            size=tf.convert_to_tensor(
                value=orig_img_size,
                dtype=tf.int32,
                name="{}_upsample_{}_image_orig_img_size".format(
                    self.name, self.kind
                )
            ) * 2 ** block_idx,
            method="nearest",
            name="{}_growth_upsampled_image_{}_{}x{}_{}x{}".format(
                self.name,
                block_idx,
                orig_img_size[0] * 2 ** (block_idx - 1),
                orig_img_size[1] * 2 ** (block_idx - 1),
                orig_img_size[0] * 2 ** block_idx,
                orig_img_size[1] * 2 ** block_idx
            )
        )
        print_obj("\n" + func_name, "upsampled_image", upsampled_image)

        return upsampled_image

    def create_base_vec_to_img_network(self, Z, params):
        """Creates base vec_to_img network.

        Args:
            Z: tensor, latent vectors of shape [cur_batch_size, latent_size].
            projection_layer: `WeightScaledDense` layer for projection of noise into image.
            to_rgb_conv_layers: list, toRGB 1x1 conv layers.
            blocks: list, lists of block layers for each block.
            params: dict, user passed parameters.

        Returns:
            Final network block conv tensor.
        """
        func_name = "create_base_{}_network".format(self.kind)

        print_obj("\n" + func_name, "Z", Z)
        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Project latent noise vectors into image.
            projection = self.use_vec_to_img_projection_layer(
                Z=Z, params=params
            )
            print_obj(func_name, "projection", projection)

            # Only need the first block and toRGB conv layer for base network.
            block_layers = self.conv_layer_blocks[0]
            to_rgb_conv_layer = self.to_rgb_conv_layers[0]

            # Pass inputs through layer chain.
            block_conv = projection
            for i in range(0, len(block_layers)):
                block_conv = self.fused_conv2d_pixel_norm(
                    input_image=block_conv,
                    conv2d_layer=block_layers[i],
                    params=params
                )
                print_obj(func_name, "block_conv_{}".format(i), block_conv)

            # Convert convolution to RGB image.
            to_rgb_conv = to_rgb_conv_layer(inputs=block_conv)
            print_obj(func_name, "to_rgb_conv", to_rgb_conv)

        return to_rgb_conv

    def create_growth_transition_vec_to_img_network(
            self, Z, orig_img_size, alpha_var, params, trans_idx):
        """Creates growth transition vec_to_img network.

        Args:
            Z: tensor, latent vectors of shape [cur_batch_size, latent_size].
            orig_img_size: list, the height and width dimensions of the
                original image before any growth.
            alpha_var: variable, alpha for weighted sum of fade-in of layers.
            params: dict, user passed parameters.
            trans_idx: int, index of current growth transition.

        Returns:
            Weighted sum tensor of growing and shrinking network paths.
        """
        func_name = "create_growth_transition_{}_network".format(self.kind)

        print_obj("\nEntered {}".format(func_name), "trans_idx", trans_idx)

        print_obj(func_name, "Z", Z)
        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Project latent noise vectors into image.
            projection = self.use_vec_to_img_projection_layer(
                Z=Z, params=params
            )
            print_obj(func_name, "projection", projection)

            # Permanent blocks.
            permanent_blocks = self.conv_layer_blocks[0:trans_idx + 1]

            # Base block doesn't need any upsampling so handle differently.
            base_block_conv_layers = permanent_blocks[0]

            # Pass inputs through layer chain.
            block_conv = projection
            for i in range(0, len(base_block_conv_layers)):
                block_conv = self.fused_conv2d_pixel_norm(
                    input_image=block_conv,
                    conv2d_layer=base_block_conv_layers[i],
                    params=params
                )
                print_obj(
                    func_name,
                    "base_block_conv_{}_{}".format(trans_idx, i),
                    block_conv
                )

            # Growth blocks require first prev conv layer's image upsampled.
            for i in range(1, len(permanent_blocks)):
                # Upsample previous block's image.
                block_conv = self.upsample_vec_to_img_image(
                    image=block_conv,
                    orig_img_size=orig_img_size,
                    block_idx=i
                )
                print_obj(
                    func_name,
                    "upsample_vec_to_img_image_block_conv_{}_{}".format(
                        trans_idx, i
                    ),
                    block_conv
                )

                block_conv_layers = permanent_blocks[i]
                for j in range(0, len(block_conv_layers)):
                    block_conv = self.fused_conv2d_pixel_norm(
                        input_image=block_conv,
                        conv2d_layer=block_conv_layers[j],
                        params=params
                    )
                    print_obj(
                        func_name,
                        "block_conv_{}_{}_{}".format(trans_idx, i, j),
                        block_conv
                    )

            # Upsample most recent block conv image for both side chains.
            upsampled_block_conv = self.upsample_vec_to_img_image(
                image=block_conv,
                orig_img_size=orig_img_size,
                block_idx=len(permanent_blocks)
            )
            print_obj(
                func_name,
                "upsampled_block_conv_{}".format(trans_idx),
                upsampled_block_conv
            )

            # Growing side chain.
            growing_block_layers = self.conv_layer_blocks[trans_idx + 1]
            growing_to_rgb_conv_layer = self.to_rgb_conv_layers[trans_idx + 1]

            # Pass inputs through layer chain.
            block_conv = upsampled_block_conv
            for i in range(0, len(growing_block_layers)):
                block_conv = self.fused_conv2d_pixel_norm(
                    input_image=block_conv,
                    conv2d_layer=growing_block_layers[i],
                    params=params
                )
                print_obj(
                    func_name,
                    "growing_block_conv_{}_{}".format(trans_idx, i),
                    block_conv
                )

            growing_to_rgb_conv = growing_to_rgb_conv_layer(inputs=block_conv)
            print_obj(
                func_name,
                "growing_to_rgb_conv_{}".format(trans_idx),
                growing_to_rgb_conv
            )

            # Shrinking side chain.
            shrinking_to_rgb_conv_layer = self.to_rgb_conv_layers[trans_idx]

            # Pass inputs through layer chain.
            shrinking_to_rgb_conv = shrinking_to_rgb_conv_layer(
                inputs=upsampled_block_conv
            )
            print_obj(
                func_name,
                "shrinking_to_rgb_conv_{}".format(trans_idx),
                shrinking_to_rgb_conv
            )

            # Weighted sum.
            weighted_sum = tf.add(
                x=growing_to_rgb_conv * alpha_var,
                y=shrinking_to_rgb_conv * (1.0 - alpha_var),
                name="growth_transition_weighted_sum_{}".format(trans_idx)
            )
            print_obj(
                func_name,
                "weighted_sum_{}".format(trans_idx),
                weighted_sum
            )

        return weighted_sum

    def create_growth_stable_vec_to_img_network(
            self, Z, orig_img_size, params, trans_idx):
        """Creates final vec_to_img network.

        Args:
            Z: tensor, latent vectors of shape [cur_batch_size, latent_size].
            orig_img_size: list, the height and width dimensions of the
                original image before any growth.
            params: dict, user passed parameters.
            trans_idx: int, index of current growth transition.

        Returns:
            Final network block conv tensor.
        """
        func_name = "create_growth_stable_{}_network".format(self.kind)

        print_obj("\n" + func_name, "Z", Z)
        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Project latent noise vectors into image.
            projection = self.use_vec_to_img_projection_layer(
                Z=Z, params=params
            )
            print_obj(func_name, "projection", projection)

            # Permanent blocks.
            permanent_blocks = self.conv_layer_blocks[0:trans_idx + 2]

            # Base block doesn't need any upsampling so handle differently.
            base_block_conv_layers = permanent_blocks[0]

            # Pass inputs through layer chain.
            block_conv = projection
            for i in range(len(base_block_conv_layers)):
                block_conv = self.fused_conv2d_pixel_norm(
                    input_image=block_conv,
                    conv2d_layer=base_block_conv_layers[i],
                    params=params
                )
                print_obj(
                    func_name, "base_block_conv_{}".format(i), block_conv
                )

            # Growth blocks require first prev conv layer's image upsampled.
            for i in range(1, len(permanent_blocks)):
                # Upsample previous block's image.
                block_conv = self.upsample_vec_to_img_image(
                    image=block_conv,
                    orig_img_size=orig_img_size,
                    block_idx=i
                )
                print_obj(
                    func_name,
                    "upsample_vec_to_img_image_block_conv_{}".format(i),
                    block_conv
                )

                # Get layers from ith permanent block.
                block_conv_layers = permanent_blocks[i]

                # Loop through `WeightScaledConv2D` layers now of permanent block.
                for j in range(len(block_conv_layers)):
                    block_conv = self.fused_conv2d_pixel_norm(
                        input_image=block_conv,
                        conv2d_layer=block_conv_layers[j],
                        params=params
                    )
                    print_obj(
                        func_name,
                        "block_conv_{}_{}".format(i, j),
                        block_conv
                    )

            # Get transition index toRGB conv layer.
            to_rgb_conv_layer = self.to_rgb_conv_layers[trans_idx + 1]

            # Pass inputs through layer chain.
            to_rgb_conv = to_rgb_conv_layer(inputs=block_conv)
            print_obj(func_name, "to_rgb_conv", to_rgb_conv)

        return to_rgb_conv

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def unknown_switch_case_vec_to_img_outputs(
            self, Z, orig_img_size, alpha_var, params, growth_index):
        """Uses switch case to use the correct network to generate images.

        Args:
            Z: tensor, latent vectors of shape [cur_batch_size, latent_size].
            orig_img_size: list, the height and width dimensions of the
                original image before any growth.
            alpha_var: variable, alpha for weighted sum of fade-in of layers.
            params: dict, user passed parameters.
            growth_index: tensor, current growth stage.

        Returns:
            Generated image output tensor.
        """
        func_name = "unknown_switch_case_{}_outputs".format(self.kind)
        # Switch to case based on number of steps for gen outputs.
        generated_outputs = tf.switch_case(
            branch_index=growth_index,
            branch_fns=[
                # 4x4
                lambda: self.create_base_vec_to_img_network(
                    Z=Z, params=params
                ),
                # 8x8
                lambda: self.create_growth_transition_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(0, len(params["conv_num_filters"]) - 2)
                ),
                # 8x8
                lambda: self.create_growth_stable_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    params=params,
                    trans_idx=min(0, len(params["conv_num_filters"]) - 2)
                ),
                # 16x16
                lambda: self.create_growth_transition_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(1, len(params["conv_num_filters"]) - 2)
                ),
                # 16x16
                lambda: self.create_growth_stable_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    params=params,
                    trans_idx=min(1, len(params["conv_num_filters"]) - 2)
                ),
                # 32x32
                lambda: self.create_growth_transition_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(2, len(params["conv_num_filters"]) - 2)
                ),
                # 32x32
                lambda: self.create_growth_stable_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    params=params,
                    trans_idx=min(2, len(params["conv_num_filters"]) - 2)
                ),
                # 64x64
                lambda: self.create_growth_transition_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(3, len(params["conv_num_filters"]) - 2)
                ),
                # 64x64
                lambda: self.create_growth_stable_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    params=params,
                    trans_idx=min(3, len(params["conv_num_filters"]) - 2)
                ),
                # 128x128
                lambda: self.create_growth_transition_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(4, len(params["conv_num_filters"]) - 2)
                ),
                # 128x128
                lambda: self.create_growth_stable_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    params=params,
                    trans_idx=min(4, len(params["conv_num_filters"]) - 2)
                ),
                # 256x256
                lambda: self.create_growth_transition_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(5, len(params["conv_num_filters"]) - 2)
                ),
                # 256x256
                lambda: self.create_growth_stable_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    params=params,
                    trans_idx=min(5, len(params["conv_num_filters"]) - 2)
                ),
                # 512x512
                lambda: self.create_growth_transition_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(6, len(params["conv_num_filters"]) - 2)
                ),
                # 512x512
                lambda: self.create_growth_stable_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    params=params,
                    trans_idx=min(6, len(params["conv_num_filters"]) - 2)
                ),
                # 1024x1024
                lambda: self.create_growth_transition_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(7, len(params["conv_num_filters"]) - 2)
                ),
                # 1024x1024
                lambda: self.create_growth_stable_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    params=params,
                    trans_idx=min(7, len(params["conv_num_filters"]) - 2)
                )
            ],
            name="{}_switch_case_generated_outputs".format(self.name)
        )
        print_obj(func_name, "generated_outputs", generated_outputs)

        return generated_outputs

    def known_switch_case_vec_to_img_outputs(
            self, Z, orig_img_size, alpha_var, params):
        """Uses switch case to use the correct network to generate images.

        Args:
            Z: tensor, latent vectors of shape [batch_size, latent_size].
            orig_img_size: list, the height and width dimensions of the
                original image before any growth.
            alpha_var: variable, alpha for weighted sum of fade-in of layers.
            params: dict, user passed parameters.

        Returns:
            Generated image output tensor.
        """
        func_name = "known_switch_case_{}_outputs".format(self.kind)

        # Switch to case based on number of steps for gen outputs.
        if params["growth_idx"] == 0:
            # No growth yet, just base block.
            generated_outputs = self.create_base_vec_to_img_network(
                Z=Z, params=params
            )
        else:
            # Determine which growth transition we're in.
            trans_idx = (params["growth_idx"] - 1) // 2

            # If there is more room to grow.
            if params["growth_idx"] % 2 == 1:
                # Grow network using weighted sum with smaller network.
                generated_outputs = self.create_growth_transition_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=trans_idx
                )
            else:
                # Stablize bigger network without weighted sum.
                generated_outputs = self.create_growth_stable_vec_to_img_network(
                    Z=Z,
                    orig_img_size=orig_img_size,
                    params=params,
                    trans_idx=trans_idx
                )
        print_obj(func_name, "generated_outputs", generated_outputs)

        return generated_outputs

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def get_train_eval_vec_to_img_outputs(self, Z, alpha_var, params):
        """Uses vec_to_img network and returns image for train/eval.

        Args:
            Z: tensor, latent vectors of shape [cur_batch_size, latent_size].
            alpha_var: variable, alpha for weighted sum of fade-in of layers.
            params: dict, user passed parameters.

        Returns:
            Generated image output tensor of shape
                [cur_batch_size, image_size, image_size, depth].
        """
        func_name = "get_train_eval_{}_outputs".format(self.kind)

        print_obj("\n" + func_name, "Z", Z)

        # Get vec_to_img's output image tensor.
        if len(params["conv_num_filters"]) == 1:
            print(
                "\n{}: NOT GOING TO GROW, SKIP SWITCH CASE!".format(func_name)
            )
            # If never going to grow, no sense using the switch case.
            # 4x4
            generated_outputs = self.create_base_vec_to_img_network(
                Z=Z, params=params
            )
        else:
            if params["growth_idx"] is not None:
                # Switch to case based on number of steps for gen outputs.
                generated_outputs = self.known_switch_case_vec_to_img_outputs(
                    Z=Z,
                    orig_img_size=params["{}_projection_dims".format(self.kind)][0:2],
                    alpha_var=alpha_var,
                    params=params,
                )
            else:
                # Find growth index based on global step and growth frequency.
                growth_index = tf.minimum(
                    x=tf.cast(
                        x=tf.floordiv(
                            x=tf.train.get_or_create_global_step() - 1,
                            y=params["num_steps_until_growth"],
                            name="{}_global_step_floordiv".format(self.name)
                        ),
                        dtype=tf.int32),
                    y=(len(params["conv_num_filters"]) - 1) * 2,
                    name="{}_growth_index".format(self.name)
                )

                # Switch to case based on number of steps for gen outputs.
                generated_outputs = self.unknown_switch_case_vec_to_img_outputs(
                    Z=Z,
                    orig_img_size=params["{}_projection_dims".format(self.kind)][0:2],
                    alpha_var=alpha_var,
                    params=params,
                    growth_index=growth_index
                )

        print_obj("\n" + func_name, "generated_outputs", generated_outputs)

        # Wrap generated outputs in a control dependency for the build
        # vec_to_img tensors to ensure vec_to_img internals are built.
        with tf.control_dependencies(
                control_inputs=self.build_vec_to_img_tensors):
            generated_outputs = tf.identity(
                input=generated_outputs,
                name="{}_generated_outputs_identity".format(self.name)
            )

        return generated_outputs

    def get_predict_vec_to_img_outputs(self, Z, params, block_idx):
        """Uses vec_to_img network and returns image for predict.

        Args:
            Z: tensor, latent vectors of shape [cur_batch_size, latent_size].
            params: dict, user passed parameters.
            block_idx: int, current conv layer block's index.

        Returns:
            Generated image output tensor of shape
                [cur_batch_size, image_size, image_size, depth] or list of
                them for each resolution.
        """
        func_name = "get_predict_{}_outputs".format(self.kind)

        print_obj("\n" + func_name, "Z", Z)

        # Get vec_to_img's generated image.
        if block_idx == 0:
            # 4x4
            generated_outputs = self.create_base_vec_to_img_network(
                Z=Z, params=params
            )
        else:
            # 8x8 through 1024x1024
            generated_outputs = self.create_growth_stable_vec_to_img_network(
                Z=Z,
                orig_img_size=params["{}_projection_dims".format(self.kind)][0:2],
                params=params,
                trans_idx=block_idx - 1
            )
        print_obj(func_name, "generated_outputs", generated_outputs)

        return generated_outputs


## generator.py

In [10]:
class Generator(VectorToImage):
    """Generator that takes latent vector input and outputs image.

    Fields:
        name: str, name of `Generator`.
    """
    def __init__(self, kernel_regularizer, bias_regularizer, params, name):
        """Instantiates and builds generator network.

        Args:
            kernel_regularizer: `l1_l2_regularizer` object, regularizar for
                kernel variables.
            bias_regularizer: `l1_l2_regularizer` object, regularizar for bias
                variables.
            params: dict, user passed parameters.
            name: str, name of `Generator`.
        """
        # Set name of `Generator`.
        self.name = name

        # Set kind of `VectorToImage`.
        kind = "generator"

        # Initialize base class.
        super().__init__(kernel_regularizer, bias_regularizer, params, kind)

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def get_generator_loss(self, fake_logits, params):
        """Gets generator loss.

        Args:
            fake_logits: tensor, shape of [cur_batch_size, 1] that came from
                discriminator having processed generator's output image.
            params: dict, user passed parameters.

        Returns:
            Generator's total loss tensor of shape [].
        """
        func_name = "get_generator_loss"

        # Calculate base generator loss.
        generator_loss = -tf.reduce_mean(
            input_tensor=fake_logits,
            name="{}_loss".format(self.name)
        )
        print_obj("\n" + func_name, "generator_loss", generator_loss)

        # Get generator regularization losses.
        generator_reg_loss = get_regularization_loss(
            lambda1=params["generator_l1_regularization_scale"],
            lambda2=params["generator_l2_regularization_scale"],
            scope=self.name
        )
        print_obj(func_name, "generator_reg_loss", generator_reg_loss)

        # Combine losses for total losses.
        generator_total_loss = tf.math.add(
            x=generator_loss,
            y=generator_reg_loss,
            name="{}_total_loss".format(self.name)
        )
        print_obj(func_name, "generator_total_loss", generator_total_loss)

        if not params["use_tpu"]:
            # Add summaries for TensorBoard.
            tf.summary.scalar(
                name="generator_loss",
                tensor=generator_loss,
                family="losses"
            )
            tf.summary.scalar(
                name="generator_reg_loss",
                tensor=generator_reg_loss,
                family="losses"
            )
            tf.summary.scalar(
                name="generator_total_loss",
                tensor=generator_total_loss,
                family="total_losses"
            )

        return generator_total_loss


## image_to_vector.py

In [11]:
class ImageToVector(object):
    """Convolutional network takes image input and outputs a vector.

    Fields:
        kind: str, kind of `ImageToVector` instance.
        kernel_regularizer: `l1_l2_regularizer` object, regularizar for kernel
            variables.
        bias_regularizer: `l1_l2_regularizer` object, regularizar for bias
            variables.
        projection_layer: `WeightScaledDense` layer for projection of noise to image.
        conv_layer_blocks: list, lists of block layers for each block.
        to_rgb_conv_layers: list, toRGB 1x1 conv layers.
        build_vector_to_image_tensors: list, tensors used to build layer
            internals.
    """
    def __init__(self, kernel_regularizer, bias_regularizer, params, kind):
        """Instantiates and builds vec_to_img network.

        Args:
            kernel_regularizer: `l1_l2_regularizer` object, regularizar for
                kernel variables.
            bias_regularizer: `l1_l2_regularizer` object, regularizar for bias
                variables.
            params: dict, user passed parameters.
            kind: str, kind of `ImageToVector` instance.
        """
        # Set kind of image to vector network.
        self.kind = kind

        # Regularizer for kernel weights.
        self.kernel_regularizer = kernel_regularizer

        # Regularizer for bias weights.
        self.bias_regularizer = bias_regularizer

        # Instantiate image to vector layers.
        (self.from_rgb_conv_layers,
         self.conv_layer_blocks,
         self.flatten_layer,
         self.logits_layer) = self.instantiate_img_to_vec_layers(
            params
        )

        # Build image to vector layer internals.
        self.build_img_to_vec_tensors = self.build_img_to_vec_layers(
            params
        )

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def instantiate_img_to_vec_from_rgb_layers(self, params):
        """Instantiates img_to_vec fromRGB layers of 1x1 convs.

        Args:
            params: dict, user passed parameters.

        Returns:
            List of fromRGB 1x1 WeightScaledConv2D layers.
        """
        func_name = "instantiate_{}_from_rgb_layers".format(self.kind)

        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Get fromRGB layer properties.
            from_rgb = [
                params["{}_from_rgb_layers".format(self.kind)][i][0][:]
                for i in range(
                    len(params["{}_from_rgb_layers".format(self.kind)])
                )
            ]

            # Create list to hold toRGB 1x1 convs.
            from_rgb_conv_layers = [
                WeightScaledConv2D(
                    filters=from_rgb[i][3],
                    kernel_size=from_rgb[i][0:2],
                    strides=from_rgb[i][4:6],
                    padding="same",
                    activation=None,
                    kernel_initializer=(
                        tf.random_normal_initializer(mean=0., stddev=1.0)
                        if params["use_equalized_learning_rate"]
                        else "he_normal"
                    ),
                    kernel_regularizer=self.kernel_regularizer,
                    bias_regularizer=self.bias_regularizer,
                    equalized_learning_rate=params["use_equalized_learning_rate"],
                    name="{}_from_rgb_layers_conv2d_{}_{}x{}_{}_{}".format(
                        self.name,
                        i,
                        from_rgb[i][0],
                        from_rgb[i][1],
                        from_rgb[i][2],
                        from_rgb[i][3]
                    )
                )
                for i in range(len(from_rgb))
            ]
            print_obj(
                "\n" + func_name, "from_rgb_conv_layers", from_rgb_conv_layers
            )

        return from_rgb_conv_layers

    def instantiate_img_to_vec_base_conv_layer_block(self, params):
        """Instantiates img_to_vec base conv layer block.

        Args:
            params: dict, user passed parameters.

        Returns:
            List of base conv layers.
        """
        func_name = "instantiate_{}_base_conv_layer_block".format(self.kind)

        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Get conv block layer properties.
            conv_block = params["{}_base_conv_blocks".format(self.kind)][0]

            # Create list of base conv layers.
            base_conv_layers = [
                WeightScaledConv2D(
                    filters=conv_block[i][3],
                    kernel_size=conv_block[i][0:2],
                    strides=conv_block[i][4:6],
                    padding="same",
                    activation=None,
                    kernel_initializer=(
                        tf.random_normal_initializer(mean=0., stddev=1.0)
                        if params["use_equalized_learning_rate"]
                        else "he_normal"
                    ),
                    kernel_regularizer=self.kernel_regularizer,
                    bias_regularizer=self.bias_regularizer,
                    equalized_learning_rate=params["use_equalized_learning_rate"],
                    name="{}_base_layers_conv2d_{}_{}x{}_{}_{}".format(
                        self.name,
                        i,
                        conv_block[i][0],
                        conv_block[i][1],
                        conv_block[i][2],
                        conv_block[i][3]
                    )
                )
                for i in range(len(conv_block) - 1)
            ]

            # Have valid padding for layer just before flatten and logits.
            base_conv_layers.append(
                WeightScaledConv2D(
                    filters=conv_block[-1][3],
                    kernel_size=conv_block[-1][0:2],
                    strides=conv_block[-1][4:6],
                    padding="valid",
                    activation=None,
                    kernel_initializer=(
                        tf.random_normal_initializer(mean=0., stddev=1.0)
                        if params["use_equalized_learning_rate"]
                        else "he_normal"
                    ),
                    kernel_regularizer=self.kernel_regularizer,
                    bias_regularizer=self.bias_regularizer,
                    equalized_learning_rate=params["use_equalized_learning_rate"],
                    name="{}_base_layers_conv2d_{}_{}x{}_{}_{}".format(
                        self.name,
                        len(conv_block) - 1,
                        conv_block[-1][0],
                        conv_block[-1][1],
                        conv_block[-1][2],
                        conv_block[-1][3]
                    )
                )
            )
            print_obj(
                "\n" + func_name, "base_conv_layers", base_conv_layers
            )

        return base_conv_layers

    def instantiate_img_to_vec_growth_layer_block(self, params, block_idx):
        """Instantiates img_to_vec growth block layers.

        Args:
            params: dict, user passed parameters.
            block_idx: int, the current growth block's index.

        Returns:
            List of growth block layers.
        """
        func_name = "instantiate_{}_growth_layer_block".format(self.kind)

        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Get conv block layer properties.
            conv_block = params["{}_growth_conv_blocks".format(self.kind)][block_idx]

            # Create new inner convolutional layers.
            conv_layers = [
                WeightScaledConv2D(
                    filters=conv_block[i][3],
                    kernel_size=conv_block[i][0:2],
                    strides=conv_block[i][4:6],
                    padding="same",
                    activation=None,
                    kernel_initializer=(
                        tf.random_normal_initializer(mean=0., stddev=1.0)
                        if params["use_equalized_learning_rate"]
                        else "he_normal"
                    ),
                    kernel_regularizer=self.kernel_regularizer,
                    bias_regularizer=self.bias_regularizer,
                    equalized_learning_rate=params["use_equalized_learning_rate"],
                    name="{}_growth_layers_conv2d_{}_{}_{}x{}_{}_{}".format(
                        self.name,
                        block_idx,
                        i,
                        conv_block[i][0],
                        conv_block[i][1],
                        conv_block[i][2],
                        conv_block[i][3]
                    )
                )
                for i in range(len(conv_block))
            ]
            print_obj("\n" + func_name, "conv_layers", conv_layers)

        return conv_layers

    def instantiate_img_to_vec_layers(self, params):
        """Instantiates layers of img_to_vec network.

        Args:
            params: dict, user passed parameters.

        Returns:
            from_rgb_conv_layers: list, fromRGB 1x1 `WeightScaledConv2D` layers.
            conv_layer_blocks: list, lists of `WeightScaledConv2D` block layers for each
                block.
            flatten_layer: `Flatten` layer prior to logits layer.
            logits_layer: `WeightScaledDense` layer for logits.
        """
        func_name = "instantiate_{}_layers".format(self.kind)

        # Instantiate fromRGB 1x1 `WeightScaledConv2D` layers.
        from_rgb_conv_layers = self.instantiate_img_to_vec_from_rgb_layers(
            params=params
        )
        print_obj(
            "\n" + func_name, "from_rgb_conv_layers", from_rgb_conv_layers
        )

        # Instantiate base conv block's `WeightScaledConv2D` layers, for post-growth.
        conv_layer_blocks = [
            self.instantiate_img_to_vec_base_conv_layer_block(
                params=params
            )
        ]

        # Instantiate growth `WeightScaledConv2D` layer blocks.
        conv_layer_blocks.extend(
            [
                self.instantiate_img_to_vec_growth_layer_block(
                    params=params,
                    block_idx=block_idx
                )
                for block_idx in range(
                    len(params["{}_growth_conv_blocks".format(self.kind)])
                )
            ]
        )
        print_obj(
            func_name, "conv_layer_blocks", conv_layer_blocks
        )

        # Instantiate `Flatten` and `WeightScaledDense` logits layers.
        (flatten_layer,
         logits_layer) = self.instantiate_img_to_vec_logits_layer(
            params=params
        )
        print_obj(func_name, "flatten_layer", flatten_layer)
        print_obj(func_name, "logits_layer", logits_layer)

        return (from_rgb_conv_layers,
                conv_layer_blocks,
                flatten_layer,
                logits_layer)

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def build_img_to_vec_from_rgb_layers(self, params):
        """Creates img_to_vec fromRGB layers of 1x1 convs.

        Args:
            params: dict, user passed parameters.

        Returns:
            List of tensors from fromRGB 1x1 `WeightScaledConv2D` layers.
        """
        func_name = "build_{}_from_rgb_layers".format(self.kind)

        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Get fromRGB layer properties.
            from_rgb = [
                params["{}_from_rgb_layers".format(self.kind)][i][0][:]
                for i in range(
                    len(params["{}_from_rgb_layers".format(self.kind)])
                )
            ]

            # Create list to hold fromRGB 1x1 convs.
            from_rgb_conv_tensors = [
                self.from_rgb_conv_layers[i](
                    inputs=tf.zeros(
                        shape=[1] + from_rgb[i][0:3], dtype=tf.float32
                    )
                )
                for i in range(len(from_rgb))
            ]
            print_obj(
                "\n" + func_name,
                "from_rgb_conv_tensors",
                from_rgb_conv_tensors
            )

        return from_rgb_conv_tensors

    def build_img_to_vec_growth_layer_block(self, params, block_idx):
        """Creates img_to_vec growth block.

        Args:
            params: dict, user passed parameters.
            block_idx: int, the current growth block's index.

        Returns:
            List of tensors from growth block `WeightScaledConv2D` layers.
        """
        func_name = "build_{}_growth_layer_block".format(self.kind)

        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Get conv block layer properties.
            conv_block = params["{}_growth_conv_blocks".format(self.kind)][block_idx]

            # Create new inner convolutional layers.
            conv_tensors = [
                self.conv_layer_blocks[1 + block_idx][i](
                    inputs=tf.zeros(
                        shape=[1] + conv_block[i][0:3], dtype=tf.float32
                    )
                )
                for i in range(len(conv_block))
            ]
            print_obj("\n" + func_name, "conv_tensors", conv_tensors)

        return conv_tensors

    def build_img_to_vec_logits_layer(self, params):
        """Builds flatten and logits layer internals using call.

        Args:
            params: dict, user passed parameters.

        Returns:
            Final logits tensor of img_to_vec.
        """
        func_name = "build_{}_logits_layer".format(self.kind)

        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            block_conv_size = params["{}_base_conv_blocks".format(self.kind)][-1][-1][3]

            # Flatten final block conv tensor.
            block_conv_flat = self.flatten_layer(
                inputs=tf.zeros(
                    shape=[1, 1, 1, block_conv_size],
                    dtype=tf.float32
                )
            )
            print_obj("\n" + func_name, "block_conv_flat", block_conv_flat)

            # Final linear layer for logits.
            logits = self.logits_layer(inputs=block_conv_flat)
            print_obj(func_name, "logits", logits)

        return logits

    def build_img_to_vec_layers(self, params):
        """Builds img_to_vec layer internals.

        Args:
            params: dict, user passed parameters.

        Returns:
            Logits tensor.
        """
        func_name = "build_{}_layers".format(self.kind)

        # Build fromRGB 1x1 `WeightScaledConv2D` layers internals through call.
        from_rgb_conv_tensors = self.build_img_to_vec_from_rgb_layers(
            params=params
        )
        print_obj(
            "\n" + func_name, "from_rgb_conv_tensors", from_rgb_conv_tensors
        )

        with tf.control_dependencies(control_inputs=from_rgb_conv_tensors):
            # Create base convolutional block's layer internals using call.
            conv_block_tensors = [
                self.build_img_to_vec_base_conv_layer_block(
                    params=params
                )
            ]

            # Build growth `WeightScaledConv2D` layer block internals through call.
            conv_block_tensors.extend(
                [
                    self.build_img_to_vec_growth_layer_block(
                        params=params, block_idx=block_idx
                    )
                    for block_idx in range(
                       len(params["{}_growth_conv_blocks".format(self.kind)])
                    )
                ]
            )

            # Flatten conv block tensor lists of lists into list.
            conv_block_tensors = [
                item for sublist in conv_block_tensors for item in sublist
            ]
            print_obj(func_name, "conv_block_tensors", conv_block_tensors)

            with tf.control_dependencies(control_inputs=conv_block_tensors):
                # Build logits layer internals using call.
                logits_tensor = self.build_img_to_vec_logits_layer(
                    params=params
                )
                print_obj(func_name, "logits_tensor", logits_tensor)

        return logits_tensor

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def use_img_to_vec_logits_layer(self, block_conv, params):
        """Uses flatten and logits layers to get logits tensor.

        Args:
            block_conv: tensor, output of last conv layer of img_to_vec.
            params: dict, user passed parameters.

        Returns:
            Final logits tensor of img_to_vec.
        """
        func_name = "use_{}_logits_layer".format(self.kind)

        print_obj("\n" + func_name, "block_conv", block_conv)
        # Set shape to remove ambiguity for dense layer.
        height, width =  params["generator_projection_dims"][0:2]
        valid_kernel_size = (
            params["discriminator_base_conv_blocks"][0][-1][0]
        )
        block_conv.set_shape(
            [
                block_conv.get_shape()[0],
                height - valid_kernel_size + 1,
                width - valid_kernel_size + 1,
                block_conv.get_shape()[-1]]
        )
        print_obj(func_name, "block_conv", block_conv)

        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Flatten final block conv tensor.
            block_conv_flat = self.flatten_layer(inputs=block_conv)
            print_obj(func_name, "block_conv_flat", block_conv_flat)

            # Final linear layer for logits.
            logits = self.logits_layer(inputs=block_conv_flat)
            print_obj(func_name, "logits", logits)

        return logits

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def create_base_img_to_vec_block_and_logits(self, block_conv, params):
        """Creates base img_to_vec block and logits.

        Args:
            block_conv: tensor, output of previous `WeightScaledConv2D` block's layer.
            params: dict, user passed parameters.
        Returns:
            Final logits tensor of img_to_vec.
        """
        func_name = "create_base_{}_block_and_logits".format(self.kind)
        print_obj("\n" + func_name, "block_conv", block_conv)
        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Only need the first conv layer block for base network.
            block_layers = self.conv_layer_blocks[0]

            for i in range(len(block_layers)):
                block_conv = block_layers[i](inputs=block_conv)
                print_obj(func_name, "block_conv", block_conv)

                block_conv = tf.nn.leaky_relu(
                    features=block_conv,
                    alpha=params["{}_leaky_relu_alpha".format(self.kind)],
                    name="{}_base_layers_conv2d_{}_leaky_relu".format(
                        self.kind, i
                    )
                )
                print_obj(func_name, "block_conv_leaky", block_conv)

            # Get logits now.
            logits = self.use_img_to_vec_logits_layer(
                block_conv=block_conv,
                params=params
            )
            print_obj(func_name, "logits", logits)

        return logits

    def create_growth_transition_img_to_vec_weighted_sum(
            self, X, alpha_var, params, trans_idx):
        """Creates growth transition img_to_vec weighted_sum

        Args:
            X: tensor, input image to img_to_vec.
            alpha_var: variable, alpha for weighted sum of fade-in of layers.
            params: dict, user passed parameters.
            trans_idx: int, index of current growth transition.

        Returns:
            Tensor of weighted sum between shrinking and growing block paths.
        """
        func_name = "create_growth_transition_{}_weighted_sum".format(
            self.kind
        )

        print_obj("\nEntered {}".format(func_name), "trans_idx", trans_idx)
        print_obj(func_name, "X", X)
        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Growing side chain.
            growing_from_rgb_conv_layer = self.from_rgb_conv_layers[trans_idx + 1]
            growing_block_layers = self.conv_layer_blocks[trans_idx + 1]

            # Pass inputs through layer chain.
            growing_block_conv = growing_from_rgb_conv_layer(inputs=X)
            print_obj(
                "\n" + func_name, "growing_block_conv", growing_block_conv
            )

            growing_block_conv = tf.nn.leaky_relu(
                features=growing_block_conv,
                alpha=params["{}_leaky_relu_alpha".format(self.kind)],
                name="{}_growth_growing_from_rgb_{}_leaky_relu".format(
                    self.kind, trans_idx
                )
            )
            print_obj(func_name, "growing_block_conv_leaky", growing_block_conv)

            for i in range(len(growing_block_layers)):
                growing_block_conv = growing_block_layers[i](
                    inputs=growing_block_conv
                )
                print_obj(
                    func_name, "growing_block_conv", growing_block_conv
                )

                growing_block_conv = tf.nn.leaky_relu(
                    features=growing_block_conv,
                    alpha=params["{}_leaky_relu_alpha".format(self.kind)],
                    name="{}_growth_conv_2d_{}_{}_leaky_relu".format(
                        self.kind, trans_idx, i
                    )
                )
                print_obj(func_name, "growing_block_conv_leaky", growing_block_conv)

            # Down sample from 2s X 2s to s X s image.
            growing_block_conv_downsampled = tf.layers.AveragePooling2D(
                pool_size=(2, 2),
                strides=(2, 2),
                name="{}_growing_downsampled_image_{}".format(
                    self.name,
                    trans_idx
                )
            )(inputs=growing_block_conv)
            print_obj(
                func_name,
                "growing_block_conv_downsampled",
                growing_block_conv_downsampled
            )

            # Shrinking side chain.
            shrinking_from_rgb_conv_layer = self.from_rgb_conv_layers[trans_idx]

            # Pass inputs through layer chain.
            # Down sample from 2s X 2s to s X s image.
            X_downsampled = tf.layers.AveragePooling2D(
                pool_size=(2, 2),
                strides=(2, 2),
                name="{}_shrinking_downsampled_image_{}".format(
                    self.name,
                    trans_idx
                )
            )(inputs=X)
            print_obj(func_name, "X_downsampled", X_downsampled)

            shrinking_from_rgb_conv = shrinking_from_rgb_conv_layer(
                inputs=X_downsampled
            )
            print_obj(
                func_name, "shrinking_from_rgb_conv", shrinking_from_rgb_conv
            )

            shrinking_from_rgb_conv = tf.nn.leaky_relu(
                features=shrinking_from_rgb_conv,
                alpha=params["{}_leaky_relu_alpha".format(self.kind)],
                name="{}_growth_shrinking_from_rgb_{}_leaky_relu".format(
                    self.kind, trans_idx
                )
            )
            print_obj(
                func_name,
                "shrinking_from_rgb_conv_leaky",
                shrinking_from_rgb_conv
            )

            # Weighted sum.
            weighted_sum = tf.add(
                x=growing_block_conv_downsampled * alpha_var,
                y=shrinking_from_rgb_conv * (1.0 - alpha_var),
                name="{}_growth_transition_weighted_sum_{}".format(
                    self.name, trans_idx
                )
            )
            print_obj(func_name, "weighted_sum", weighted_sum)

        return weighted_sum

    def create_img_to_vec_perm_growth_block_network(
            self, block_conv, params, trans_idx):
        """Creates img_to_vec permanent block network.

        Args:
            block_conv: tensor, output of previous block's layer.
            params: dict, user passed parameters.
            trans_idx: int, index of current growth transition.

        Returns:
            Tensor from final permanent block `WeightScaledConv2D` layer.
        """
        func_name = "create_{}_perm_block_network".format(self.kind)

        print_obj("\nEntered {}".format(func_name), "trans_idx", trans_idx)
        print_obj(func_name, "block_conv", block_conv)
        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Get permanent growth blocks, so skip the base block.
            permanent_blocks = self.conv_layer_blocks[1:trans_idx + 1]

            # Reverse order of blocks.
            permanent_blocks = permanent_blocks[::-1]

            # Pass inputs through layer chain.

            # Loop through the permanent growth blocks.
            for i in range(len(permanent_blocks)):
                # Get layers from ith permanent block.
                permanent_block_layers = permanent_blocks[i]

                # Loop through layers of ith permanent block.
                for j in range(len(permanent_block_layers)):
                    block_conv = permanent_block_layers[j](inputs=block_conv)
                    print_obj(func_name, "block_conv_{}".format(i), block_conv)

                    block_conv = tf.nn.leaky_relu(
                        features=block_conv,
                        alpha=params["{}_leaky_relu_alpha".format(self.kind)],
                        name="{}_perm_conv_2d_{}_{}_{}_leaky_relu".format(
                            self.kind, trans_idx, i, j
                        )
                    )
                    print_obj(func_name, "block_conv_leaky", block_conv)

                # Down sample from 2s X 2s to s X s image.
                block_conv = tf.layers.AveragePooling2D(
                    pool_size=(2, 2),
                    strides=(2, 2),
                    name="{}_perm_conv_downsample_{}_{}".format(
                        self.name, trans_idx, i
                    )
                )(inputs=block_conv)
                print_obj(func_name, "block_conv_downsampled", block_conv)

        return block_conv

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def unknown_switch_case_img_to_vec_logits(
            self, X, alpha_var, params, growth_index):
        """Uses switch case to use the correct network to get logits.

        Args:
            X: tensor, image tensors of shape
                [cur_batch_size, image_size, image_size, depth].
            alpha_var: variable, alpha for weighted sum of fade-in of layers.
            params: dict, user passed parameters.
            growth_index: tensor, current growth stage.

        Returns:
            Logits tensor of shape [cur_batch_size, 1].
        """
        func_name = "unknown_switch_case_{}_logits".format(self.kind)
        # Switch to case based on number of steps to get logits.
        logits = tf.switch_case(
            branch_index=growth_index,
            branch_fns=[
                # 4x4
                lambda: self.create_base_img_to_vec_network(
                    X=X, params=params
                ),
                # 8x8
                lambda: self.create_growth_transition_img_to_vec_network(
                    X=X,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(0, len(params["conv_num_filters"]) - 2)
                ),
                # 8x8
                lambda: self.create_growth_stable_img_to_vec_network(
                    X=X,
                    params=params,
                    trans_idx=min(0, len(params["conv_num_filters"]) - 2)
                ),
                # 16x16
                lambda: self.create_growth_transition_img_to_vec_network(
                    X=X,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(1, len(params["conv_num_filters"]) - 2)
                ),
                # 16x16
                lambda: self.create_growth_stable_img_to_vec_network(
                    X=X,
                    params=params,
                    trans_idx=min(1, len(params["conv_num_filters"]) - 2)
                ),
                # 32x32
                lambda: self.create_growth_transition_img_to_vec_network(
                    X=X,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(2, len(params["conv_num_filters"]) - 2)
                ),
                # 32x32
                lambda: self.create_growth_stable_img_to_vec_network(
                    X=X,
                    params=params,
                    trans_idx=min(2, len(params["conv_num_filters"]) - 2)
                ),
                # 64x64
                lambda: self.create_growth_transition_img_to_vec_network(
                    X=X,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(3, len(params["conv_num_filters"]) - 2)
                ),
                # 64x64
                lambda: self.create_growth_stable_img_to_vec_network(
                    X=X,
                    params=params,
                    trans_idx=min(3, len(params["conv_num_filters"]) - 2)
                ),
                # 128x128
                lambda: self.create_growth_transition_img_to_vec_network(
                    X=X,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(4, len(params["conv_num_filters"]) - 2)
                ),
                # 128x128
                lambda: self.create_growth_stable_img_to_vec_network(
                    X=X,
                    params=params,
                    trans_idx=min(4, len(params["conv_num_filters"]) - 2)
                ),
                # 256x256
                lambda: self.create_growth_transition_img_to_vec_network(
                    X=X,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(5, len(params["conv_num_filters"]) - 2)
                ),
                # 256x256
                lambda: self.create_growth_stable_img_to_vec_network(
                    X=X,
                    params=params,
                    trans_idx=min(5, len(params["conv_num_filters"]) - 2)
                ),
                # 512x512
                lambda: self.create_growth_transition_img_to_vec_network(
                    X=X,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(6, len(params["conv_num_filters"]) - 2)
                ),
                # 512x512
                lambda: self.create_growth_stable_img_to_vec_network(
                    X=X,
                    params=params,
                    trans_idx=min(6, len(params["conv_num_filters"]) - 2)
                ),
                # 1024x1024
                lambda: self.create_growth_transition_img_to_vec_network(
                    X=X,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=min(7, len(params["conv_num_filters"]) - 2)
                ),
                # 1024x1024
                lambda: self.create_growth_stable_img_to_vec_network(
                    X=X,
                    params=params,
                    trans_idx=min(7, len(params["conv_num_filters"]) - 2)
                )
            ],
            name="{}_switch_case_logits".format(self.name)
        )
        print_obj("\n" + func_name, "logits", logits)

        return logits

    def known_switch_case_img_to_vec_logits(self, X, alpha_var, params):
        """Uses switch case to use the correct network to get logits.

        Args:
            X: tensor, image tensors of shape
                [batch_size, image_size, image_size, depth].
            alpha_var: variable, alpha for weighted sum of fade-in of layers.
            params: dict, user passed parameters.

        Returns:
            Logits tensor of shape [batch_size, 1].
        """
        func_name = "switch_case_{}_logits".format(self.kind)

        # Switch to case based on number of steps to get logits.
        if params["growth_idx"] == 0:
            # No growth yet, just base block.
            logits = self.create_base_img_to_vec_network(X=X, params=params)
        else:
            # Determine which growth transition we're in.
            trans_idx = (params["growth_idx"] - 1) // 2

            # If there is more room to grow.
            if params["growth_idx"] % 2 == 1:
                # Grow network using weighted sum with smaller network.
                logits = self.create_growth_transition_img_to_vec_network(
                    X=X,
                    alpha_var=alpha_var,
                    params=params,
                    trans_idx=trans_idx
                )
            else:
                # Stablize bigger network without weighted sum.
                logits = self.create_growth_stable_img_to_vec_network(
                    X=X, params=params, trans_idx=trans_idx
                )
        print_obj("\n" + func_name, "logits", logits)

        return logits

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def get_train_eval_img_to_vec_logits(self, X, alpha_var, params):
        """Uses generator network and returns generated output for train/eval.

        Args:
            X: tensor, image tensors of shape
                [cur_batch_size, image_size, image_size, depth].
            alpha_var: variable, alpha for weighted sum of fade-in of layers.
            params: dict, user passed parameters.

        Returns:
            Logits tensor of shape [cur_batch_size, 1].
        """
        func_name = "get_train_eval_{}_logits".format(self.kind)

        print_obj("\n" + func_name, "X", X)

        # Get img_to_vec's logits tensor.
        if len(params["conv_num_filters"]) == 1:
            print(
                "\n {}: NOT GOING TO GROW, SKIP SWITCH CASE!".format(
                    func_name
                )
            )
            # If never going to grow, no sense using the switch case.
            # 4x4
            logits = self.create_base_img_to_vec_network(X=X, params=params)
        else:
            if params["growth_idx"] is not None:
                logits = self.known_switch_case_img_to_vec_logits(
                    X=X, alpha_var=alpha_var, params=params
                )
            else:
                # Find growth index based on global step and growth frequency.
                growth_index = tf.minimum(
                    x=tf.cast(
                        x=tf.floordiv(
                            x=tf.train.get_or_create_global_step() - 1,
                            y=params["num_steps_until_growth"],
                            name="{}_global_step_floordiv".format(self.name)
                        ),
                        dtype=tf.int32),
                    y=(len(params["conv_num_filters"]) - 1) * 2,
                    name="{}_growth_index".format(self.name)
                )

                # Switch to case based on number of steps for logits.
                logits = self.unknown_switch_case_img_to_vec_logits(
                    X=X,
                    alpha_var=alpha_var,
                    params=params,
                    growth_index=growth_index
                )
        print_obj("\n" + func_name, "logits", logits)

        # Wrap logits in a control dependency for the build img_to_vec
        # tensors to ensure img_to_vec internals are built.
        with tf.control_dependencies(
                control_inputs=[self.build_img_to_vec_tensors]):
            logits = tf.identity(
                input=logits, name="{}_logits_identity".format(self.name)
            )

        return logits


## discriminator.py

In [12]:
class Discriminator(ImageToVector):
    """Discriminator that takes image input and outputs logits.

    Fields:
        name: str, name of `Discriminator`.
    """
    def __init__(self, kernel_regularizer, bias_regularizer, params, name):
        """Instantiates and builds discriminator network.

        Args:
            kernel_regularizer: `l1_l2_regularizer` object, regularizar for
                kernel variables.
            bias_regularizer: `l1_l2_regularizer` object, regularizar for bias
                variables.
            params: dict, user passed parameters.
            name: str, name of `Discriminator`.
        """
        # Set name of discriminator.
        self.name = name

        # Set kind of `ImageToVector`.
        kind = "discriminator"

        # Initialize base class.
        super().__init__(kernel_regularizer, bias_regularizer, params, kind)

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def instantiate_img_to_vec_logits_layer(self, params):
        """Instantiates discriminator flatten and logits layers.

        Args:
            params: dict, user passed parameters.
        Returns:
            Flatten and logits layers of discriminator.
        """
        func_name = "instantiate_img_to_vec_logits_layer"
        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Flatten layer to ready final block conv tensor for dense layer.
            flatten_layer = tf.layers.Flatten(
                name="{}_flatten_layer".format(self.name)
            )
            print_obj(func_name, "flatten_layer", flatten_layer)

            # Final linear layer for logits.
            logits_layer = WeightScaledDense(
                units=1,
                activation=None,
                kernel_initializer=(
                    tf.random_normal_initializer(mean=0., stddev=1.0)
                    if params["use_equalized_learning_rate"]
                    else "he_normal"
                ),
                kernel_regularizer=self.kernel_regularizer,
                bias_regularizer=self.bias_regularizer,
                equalized_learning_rate=params["use_equalized_learning_rate"],
                name="{}_layers_dense_logits".format(self.name)
            )
            print_obj(func_name, "logits_layer", logits_layer)

        return flatten_layer, logits_layer

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def build_img_to_vec_base_conv_layer_block(self, params):
        """Creates discriminator base conv layer block.

        Args:
            params: dict, user passed parameters.

        Returns:
            List of tensors from base `WeightScaledConv2D` layers.
        """
        func_name = "build_{}_base_conv_layer_block".format(self.kind)

        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Get conv block layer properties.
            conv_block = params["{}_base_conv_blocks".format(self.kind)][0]

            # The base conv block is always the 0th one.
            base_conv_layer_block = self.conv_layer_blocks[0]

            # batch_batch stddev comes before first base conv layer,
            # creating 1 extra feature map.
            if params["use_minibatch_stddev"]:
                # Therefore, the number of input channels will be 1 higher
                # for first base conv block.
                num_in_channels = conv_block[0][3] + 1
            else:
                num_in_channels = conv_block[0][3]

            # Get first base conv layer from list.
            first_base_conv_layer = base_conv_layer_block[0]

            # Build first layer with bigger tensor.
            base_conv_tensors = [
                first_base_conv_layer(
                    inputs=tf.zeros(
                        shape=[1] + conv_block[0][0:2] + [num_in_channels],
                        dtype=tf.float32
                    )
                )
            ]

            # Now build the rest of the base conv block layers, store in list.
            base_conv_tensors.extend(
                [
                    base_conv_layer_block[i](
                        inputs=tf.zeros(
                            shape=[1] + conv_block[i][0:3], dtype=tf.float32
                        )
                    )
                    for i in range(1, len(conv_block))
                ]
            )
            print_obj(
                "\n" + func_name, "base_conv_tensors", base_conv_tensors
            )

        return base_conv_tensors

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def minibatch_stddev_common(
            self,
            variance,
            tile_multiples,
            params,
            caller):
        """Adds minibatch stddev feature map to image using grouping.

        This is the code that is common between the grouped and ungroup
        minibatch stddev functions.

        Args:
            variance: tensor, variance of minibatch or minibatch groups.
            tile_multiples: list, length 4, used to tile input to final shape
                input_dims[i] * mutliples[i].
            params: dict, user passed parameters.
            caller: str, name of the calling function.

        Returns:
            Minibatch standard deviation feature map image added to
                channels of shape
                [cur_batch_size, image_size, image_size, 1].
        """
        func_name = "minibatch_stddev_common".format(self.kind)

        with tf.variable_scope(
                "{}/{}_minibatch_stddev".format(self.name, caller)):
            # Calculate standard deviation over the group plus small epsilon.
            # shape = (
            #     {"grouped": cur_batch_size / group_size, "ungrouped": 1},
            #     image_size,
            #     image_size,
            #     num_channels
            # )
            stddev = tf.sqrt(
                x=variance + 1e-8, name="{}_stddev".format(caller)
            )
            print_obj(func_name, "{}_stddev".format(caller), stddev)

            # Take average over feature maps and pixels.
            if params["minibatch_stddev_averaging"]:
                # grouped shape = (cur_batch_size / group_size, 1, 1, 1)
                # ungrouped shape = (1, 1, 1, 1)
                stddev = tf.reduce_mean(
                    input_tensor=stddev,
                    axis=[1, 2, 3],
                    keepdims=True,
                    name="{}_stddev_average".format(caller)
                )
                print_obj(
                    func_name, "{}_stddev_average".format(caller), stddev
                )

            # Replicate over group and pixels.
            # shape = (
            #     cur_batch_size,
            #     image_size,
            #     image_size,
            #     1
            # )
            stddev_feature_map = tf.tile(
                input=stddev,
                multiples=tile_multiples,
                name="{}_stddev_feature_map".format(caller)
            )
            print_obj(
                func_name,
                "{}_stddev_feature_map".format(caller),
                stddev_feature_map
            )

        return stddev_feature_map

    def grouped_minibatch_stddev(
            self,
            X,
            cur_batch_size,
            static_image_shape,
            params,
            group_size):
        """Adds minibatch stddev feature map to image using grouping.

        Args:
            X: tf.float32 tensor, image of shape
                [cur_batch_size, image_size, image_size, num_channels].
            cur_batch_size: tf.int64 tensor, the dynamic batch size (in case
                of partial batch).
            static_image_shape: list, the static shape of each image.
            params: dict, user passed parameters.
            group_size: int, size of image groups.

        Returns:
            Minibatch standard deviation feature map image added to
                channels of shape
                [cur_batch_size, image_size, image_size, 1].
        """
        func_name = "grouped_minibatch_stddev".format(self.kind)

        with tf.variable_scope(
                "{}/grouped_minibatch_stddev".format(self.name)):
            # The group size should be less than or equal to the batch size.
            if params["growth_idx"] is not None:
                group_size = min(group_size, cur_batch_size)
            else:
                # shape = ()
                group_size = tf.minimum(
                    x=group_size, y=cur_batch_size, name="group_size"
                )
            print_obj("\n" + func_name, "group_size", group_size)

            # Split minibatch into M groups of size group_size, rank 5 tensor.
            # shape = (
            #     group_size,
            #     cur_batch_size / group_size,
            #     image_size,
            #     image_size,
            #     num_channels
            # )
            grouped_image = tf.reshape(
                tensor=X,
                shape=[group_size, -1] + static_image_shape,
                name="grouped_image"
            )
            print_obj(func_name, "grouped_image", grouped_image)

            # Find the mean of each group.
            # shape = (
            #     1,
            #     cur_batch_size / group_size,
            #     image_size,
            #     image_size,
            #     num_channels
            # )
            grouped_mean = tf.reduce_mean(
                input_tensor=grouped_image,
                axis=0,
                keepdims=True,
                name="grouped_mean"
            )
            print_obj(func_name, "grouped_mean", grouped_mean)

            # Center each group using the mean.
            # shape = (
            #     group_size,
            #     cur_batch_size / group_size,
            #     image_size,
            #     image_size,
            #     num_channels
            # )
            centered_grouped_image = tf.subtract(
                x=grouped_image, y=grouped_mean, name="centered_grouped_image"
            )
            print_obj(
                func_name, "centered_grouped_image", centered_grouped_image
            )

            # Calculate variance over group.
            # shape = (
            #     cur_batch_size / group_size,
            #     image_size,
            #     image_size,
            #     num_channels
            # )
            grouped_variance = tf.reduce_mean(
                input_tensor=tf.square(x=centered_grouped_image),
                axis=0,
                name="grouped_variance"
            )
            print_obj(func_name, "grouped_variance", grouped_variance)

            # Get stddev image using ops common to both grouped & ungrouped.
            stddev_feature_map = self.minibatch_stddev_common(
                variance=grouped_variance,
                tile_multiples=[group_size] + static_image_shape[0:2] + [1],
                params=params,
                caller="grouped"
            )
            print_obj(func_name, "stddev_feature_map", stddev_feature_map)

        return stddev_feature_map

    def ungrouped_minibatch_stddev(
            self,
            X,
            cur_batch_size,
            static_image_shape,
            params):
        """Adds minibatch stddev feature map added to image channels.

        Args:
            X: tensor, image of shape
                [cur_batch_size, image_size, image_size, num_channels].
            cur_batch_size: tf.int64 tensor, the dynamic batch size (in case
                of partial batch).
            static_image_shape: list, the static shape of each image.
            params: dict, user passed parameters.

        Returns:
            Minibatch standard deviation feature map image added to
                channels of shape
                [cur_batch_size, image_size, image_size, 1].
        """
        func_name = "ungrouped_minibatch_stddev".format(self.kind)

        with tf.variable_scope(
                "{}/ungrouped_minibatch_stddev".format(self.name)):
            # Find the mean of each group.
            # shape = (
            #     1,
            #     image_size,
            #     image_size,
            #     num_channels
            # )
            mean = tf.reduce_mean(
                input_tensor=X, axis=0, keepdims=True, name="mean"
            )
            print_obj("\n" + func_name, "mean", mean)

            # Center each group using the mean.
            # shape = (
            #     cur_batch_size,
            #     image_size,
            #     image_size,
            #     num_channels
            # )
            centered_image = tf.subtract(
                x=X, y=mean, name="centered_image"
            )
            print_obj(func_name, "centered_image", centered_image)

            # Calculate variance over group.
            # shape = (
            #     1,
            #     image_size,
            #     image_size,
            #     num_channels
            # )
            variance = tf.reduce_mean(
                input_tensor=tf.square(x=centered_image),
                axis=0,
                keepdims=True,
                name="variance"
            )
            print_obj(func_name, "variance", variance)

            # Get stddev image using ops common to both grouped & ungrouped.
            stddev_feature_map = self.minibatch_stddev_common(
                variance=variance,
                tile_multiples=[cur_batch_size] + static_image_shape[0:2] + [1],
                params=params,
                caller="ungrouped"
            )
            print_obj(func_name, "stddev_feature_map", stddev_feature_map)

        return stddev_feature_map

    def minibatch_stddev(self, X, params, group_size=4):
        """Adds minibatch stddev feature map added to image.

        Args:
            X: tensor, image of shape
                [cur_batch_size, image_size, image_size, num_channels].
            params: dict, user passed parameters.
            group_size: int, size of image groups.

        Returns:
            Image with minibatch standard deviation feature map added to
                channels of shape
                [cur_batch_size, image_size, image_size, num_channels + 1].
        """
        func_name = "minibatch_stddev".format(self.kind)

        with tf.variable_scope("{}/minibatch_stddev".format(self.name)):
            # Get static shape of image.
            # shape = (3,)
            static_image_shape = params["generator_projection_dims"]
            print_obj(
                "\n" + func_name, "static_image_shape", static_image_shape
            )

            if params["growth_idx"] is not None:
                if (params["batch_size"] % group_size == 0 or
                   params["batch_size"] < group_size):
                    stddev_feature_map = self.grouped_minibatch_stddev(
                        X=X,
                        cur_batch_size=params["batch_size"],
                        static_image_shape=static_image_shape,
                        params=params,
                        group_size=group_size
                    )
                else:
                    stddev_feature_map = self.ungrouped_minibatch_stddev(
                        X=X,
                        cur_batch_size=params["batch_size"],
                        static_image_shape=static_image_shape,
                        params=params
                    )
            else:
                # Get dynamic shape of image.
                # shape = (4,)
                dynamic_image_shape = tf.shape(
                    input=X, name="dynamic_image_shape"
                )
                print_obj(
                    func_name, "dynamic_image_shape", dynamic_image_shape
                )

                # Extract current batch size (in case this is a partial batch).
                cur_batch_size = dynamic_image_shape[0]

                # batch_size must be divisible by or smaller than group_size.
                divisbility_condition = tf.equal(
                    x=tf.mod(x=cur_batch_size, y=group_size),
                    y=0,
                    name="divisbility_condition"
                )

                less_than_condition = tf.less(
                    x=cur_batch_size, y=group_size, name="less_than_condition"
                )

                or_condition = tf.logical_or(
                    x=divisbility_condition,
                    y=less_than_condition,
                    name="or_condition"
                )

                # Get minibatch stddev feature map image from grouped or
                # ungrouped branch.
                stddev_feature_map = tf.cond(
                    pred=or_condition,
                    true_fn=lambda: self.grouped_minibatch_stddev(
                        X=X,
                        cur_batch_size=cur_batch_size,
                        static_image_shape=static_image_shape,
                        params=params,
                        group_size=group_size
                    ),
                    false_fn=lambda: self.ungrouped_minibatch_stddev(
                        X=X,
                        cur_batch_size=cur_batch_size,
                        static_image_shape=static_image_shape,
                        params=params
                    ),
                    name="stddev_feature_map_cond"
                )
            print_obj(func_name, "stddev_feature_map", stddev_feature_map)

            # Append to image as new feature map.
            # shape = (
            #     cur_batch_size,
            #     image_size,
            #     image_size,
            #     num_channels + 1
            # )
            appended_image = tf.concat(
                values=[X, stddev_feature_map],
                axis=-1,
                name="appended_image"
            )
            print_obj(func_name, "appended_image", appended_image)

        return appended_image

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def create_base_img_to_vec_network(self, X, params):
        """Creates base discriminator network.

        Args:
            X: tensor, input image to discriminator.
            params: dict, user passed parameters.

        Returns:
            Final logits tensor of discriminator.
        """
        func_name = "create_base_discriminator_network"

        print_obj("\n" + func_name, "X", X)
        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Only need the first fromRGB conv layer & block for base network.
            from_rgb_conv_layer = self.from_rgb_conv_layers[0]

            # Pass inputs through layer chain.
            from_rgb_conv = from_rgb_conv_layer(inputs=X)
            print_obj(func_name, "from_rgb_conv", from_rgb_conv)

            from_rgb_conv = tf.nn.leaky_relu(
                features=from_rgb_conv,
                alpha=params["{}_leaky_relu_alpha".format(self.kind)],
                name="{}_from_rgb_conv_2d_leaky_relu".format(self.kind)
            )
            print_obj(func_name, "from_rgb_conv_leaky", from_rgb_conv)

            if params["use_minibatch_stddev"]:
                block_conv = self.minibatch_stddev(
                    X=from_rgb_conv,
                    params=params,
                    group_size=params["minibatch_stddev_group_size"]
                )
            else:
                block_conv = from_rgb_conv

            # Get logits after continuing through base conv block.
            logits = self.create_base_img_to_vec_block_and_logits(
                block_conv=block_conv, params=params
            )

        return logits

    def create_growth_transition_img_to_vec_network(
            self, X, alpha_var, params, trans_idx):
        """Creates growth transition discriminator network.

        Args:
            X: tensor, input image to img_to_vec.
            alpha_var: variable, alpha for weighted sum of fade-in of layers.
            params: dict, user passed parameters.
            trans_idx: int, index of current growth transition.

        Returns:
            Final logits tensor of discriminator.
        """
        func_name = "create_growth_transition_discriminator_network"

        print_obj("\nEntered {}".format(func_name), "trans_idx", trans_idx)
        print_obj(func_name, "X", X)
        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Get weighted sum between shrinking and growing block paths.
            weighted_sum = self.create_growth_transition_img_to_vec_weighted_sum(
                X=X, alpha_var=alpha_var, params=params, trans_idx=trans_idx)
            print_obj(func_name, "weighted_sum", weighted_sum)

            # Get output of final permanent growth block's last `WeightScaledConv2D` layer.
            block_conv = self.create_img_to_vec_perm_growth_block_network(
                block_conv=weighted_sum, params=params, trans_idx=trans_idx
            )
            print_obj(func_name, "block_conv", block_conv)

            # Conditionally add minibatch stddev as an additional feature map.
            if params["use_minibatch_stddev"]:
                block_conv = self.minibatch_stddev(
                    X=block_conv,
                    params=params,
                    group_size=params["minibatch_stddev_group_size"]
                )
                print_obj(func_name, "minibatch_stddev_block_conv", block_conv)

            # Get logits after continuing through base conv block.
            logits = self.create_base_img_to_vec_block_and_logits(
                block_conv=block_conv, params=params
            )
            print_obj(func_name, "logits", logits)

        return logits

    def create_growth_stable_img_to_vec_network(self, X, params, trans_idx):
        """Creates stable growth discriminator network.

        Args:
            X: tensor, input image to discriminator.
            params: dict, user passed parameters.
            trans_idx: int, index of current growth transition.

        Returns:
            Final logits tensor of discriminator.
        """
        func_name = "create_growth_stable_discriminator_network"

        print_obj("\n" + func_name, "X", X)
        with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE):
            # Get transition index fromRGB conv layer.
            from_rgb_conv_layer = self.from_rgb_conv_layers[trans_idx + 1]

            # Pass inputs through layer chain.
            from_rgb_conv = from_rgb_conv_layer(inputs=X)
            print_obj(func_name, "from_rgb_conv", from_rgb_conv)

            block_conv = tf.nn.leaky_relu(
                features=from_rgb_conv,
                alpha=params["{}_leaky_relu_alpha".format(self.kind)],
                name="{}_final_from_rgb_conv_2d_leaky_relu".format(self.kind)
            )
            print_obj(func_name, "from_rgb_conv_leaky", block_conv)

            # Get output of final permanent growth block's last `WeightScaledConv2D` layer.
            block_conv = self.create_img_to_vec_perm_growth_block_network(
                block_conv=block_conv, params=params, trans_idx=trans_idx + 1
            )
            print_obj(func_name, "block_conv", block_conv)

            if params["use_minibatch_stddev"]:
                block_conv = self.minibatch_stddev(
                    X=block_conv,
                    params=params,
                    group_size=params["minibatch_stddev_group_size"]
                )
                print_obj(
                    func_name, "minibatch_stddev_block_conv", block_conv
                )

            # Get logits after continuing through base conv block.
            logits = self.create_base_img_to_vec_block_and_logits(
                block_conv=block_conv, params=params
            )

        return logits

    ##########################################################################
    ##########################################################################
    ##########################################################################

    def get_gradient_penalty_loss(
            self,
            cur_batch_size,
            fake_images,
            real_images,
            alpha_var,
            params):
        """Gets discriminator gradient penalty loss.

        Args:
            cur_batch_size: tensor, in case of a partial batch instead of
                using the user passed int.
            fake_images: tensor, images generated by the generator from random
                noise of shape [cur_batch_size, image_size, image_size, 3].
            real_images: tensor, real images from input of shape
                [cur_batch_size, image_size, image_size, 3].
            alpha_var: variable, alpha for weighted sum of fade-in of layers.
            params: dict, user passed parameters.

        Returns:
            Discriminator's gradient penalty loss of shape [].
        """
        func_name = "get_gradient_penalty_loss"

        with tf.name_scope(name="{}/gradient_penalty".format(self.name)):
            # Get a random uniform number rank 4 tensor.
            random_uniform_num = tf.random.uniform(
                shape=[cur_batch_size, 1, 1, 1],
                minval=0., maxval=1.,
                dtype=tf.float32,
                name="random_uniform_num"
            )
            print_obj(
                "\n" + func_name, "random_uniform_num", random_uniform_num
            )

            # Find the element-wise difference between images.
            image_difference = fake_images - real_images
            print_obj(func_name, "image_difference", image_difference)

            # Get random samples from this mixed image distribution.
            mixed_images = random_uniform_num * image_difference
            mixed_images += real_images
            print_obj(func_name, "mixed_images", mixed_images)

            # Send to the discriminator to get logits.
            mixed_logits = self.get_train_eval_img_to_vec_logits(
                X=mixed_images, alpha_var=alpha_var, params=params
            )
            print_obj(func_name, "mixed_logits", mixed_logits)

            # Get the mixed loss.
            mixed_loss = tf.reduce_sum(
                input_tensor=mixed_logits,
                name="mixed_loss"
            )
            print_obj(func_name, "mixed_loss", mixed_loss)

            # Get gradient from returned list of length 1.
            mixed_gradients = tf.gradients(
                ys=mixed_loss,
                xs=[mixed_images],
                name="gradients"
            )[0]
            print_obj(func_name, "mixed_gradients", mixed_gradients)

            # Get gradient's L2 norm.
            mixed_norms = tf.sqrt(
                x=tf.reduce_sum(
                    input_tensor=tf.square(
                        x=mixed_gradients,
                        name="squared_grads"
                    ),
                    axis=[1, 2, 3]
                ) + 1e-8
            )
            print_obj(func_name, "mixed_norms", mixed_norms)

            # Get squared difference from target of 1.0.
            squared_difference = tf.square(
                x=mixed_norms - 1.0,
                name="squared_difference"
            )
            print_obj(func_name, "squared_difference", squared_difference)

            # Get gradient penalty scalar.
            gradient_penalty = tf.reduce_mean(
                input_tensor=squared_difference, name="gradient_penalty"
            )
            print_obj(func_name, "gradient_penalty", gradient_penalty)

            # Multiply with lambda to get gradient penalty loss.
            gradient_penalty_loss = tf.multiply(
                x=params["discriminator_gradient_penalty_coefficient"],
                y=gradient_penalty,
                name="gradient_penalty_loss"
            )

        return gradient_penalty_loss

    def get_discriminator_loss(
            self,
            cur_batch_size,
            fake_images,
            real_images,
            fake_logits,
            real_logits,
            alpha_var,
            params):
        """Gets discriminator loss.

        Args:
            cur_batch_size: tensor, in case of a partial batch instead of
                using the user passed int.
            fake_images: tensor, images generated by the generator from random
                noise of shape [cur_batch_size, image_size, image_size, 3].
            real_images: tensor, real images from input of shape
                [cur_batch_size, image_size, image_size, 3].
            fake_logits: tensor, shape of [cur_batch_size, 1] that came from
                discriminator having processed generator's output image.
            real_logits: tensor, shape of [cur_batch_size, 1] that came from
                discriminator having processed real image.
            alpha_var: variable, alpha for weighted sum of fade-in of layers.
            params: dict, user passed parameters.

        Returns:
            Discriminator's total loss tensor of shape [].
        """
        func_name = "get_discriminator_loss"

        # Calculate base discriminator loss.
        discriminator_real_loss = tf.reduce_mean(
            input_tensor=real_logits,
            name="{}_real_loss".format(self.name)
        )
        print_obj(
            "\n" + func_name,
            "discriminator_real_loss",
            discriminator_real_loss
        )

        discriminator_generated_loss = tf.reduce_mean(
            input_tensor=fake_logits,
            name="{}_generated_loss".format(self.name)
        )
        print_obj(
            func_name,
            "discriminator_generated_loss",
            discriminator_generated_loss
        )

        discriminator_loss = tf.subtract(
            x=discriminator_generated_loss, y=discriminator_real_loss,
            name="{}_loss".format(self.name)
        )
        print_obj(
            func_name, "discriminator_loss", discriminator_loss
        )

        # Get discriminator gradient penalty loss.
        discriminator_gradient_penalty = self.get_gradient_penalty_loss(
            cur_batch_size=cur_batch_size,
            fake_images=fake_images,
            real_images=real_images,
            alpha_var=alpha_var,
            params=params
        )
        print_obj(
            func_name,
            "discriminator_gradient_penalty",
            discriminator_gradient_penalty
        )

        # Get discriminator epsilon drift penalty.
        epsilon_drift_penalty = tf.multiply(
            x=params["epsilon_drift"],
            y=tf.reduce_mean(input_tensor=tf.square(x=real_logits)),
            name="epsilon_drift_penalty"
        )
        print_obj(
            func_name, "epsilon_drift_penalty", epsilon_drift_penalty
        )

        # Get discriminator Wasserstein GP loss.
        discriminator_wasserstein_gp_loss = tf.add_n(
            inputs=[
                discriminator_loss,
                discriminator_gradient_penalty,
                epsilon_drift_penalty
            ],
            name="{}_wasserstein_gp_loss".format(self.name)
        )
        print_obj(
            func_name,
            "discriminator_wasserstein_gp_loss",
            discriminator_wasserstein_gp_loss
        )

        # Get discriminator regularization losses.
        discriminator_reg_loss = get_regularization_loss(
            lambda1=params["discriminator_l1_regularization_scale"],
            lambda2=params["discriminator_l2_regularization_scale"],
            scope=self.name
        )
        print_obj(
            func_name, "discriminator_reg_loss", discriminator_reg_loss
        )

        # Combine losses for total losses.
        discriminator_total_loss = tf.add(
            x=discriminator_wasserstein_gp_loss,
            y=discriminator_reg_loss,
            name="{}_total_loss".format(self.name)
        )
        print_obj(
            func_name, "discriminator_total_loss", discriminator_total_loss
        )

        if not params["use_tpu"]:
            # Add summaries for TensorBoard.
            tf.summary.scalar(
                name="discriminator_real_loss",
                tensor=discriminator_real_loss,
                family="losses"
            )
            tf.summary.scalar(
                name="discriminator_generated_loss",
                tensor=discriminator_generated_loss,
                family="losses"
            )
            tf.summary.scalar(
                name="discriminator_loss",
                tensor=discriminator_loss,
                family="losses"
            )
            tf.summary.scalar(
                name="discriminator_gradient_penalty",
                tensor=discriminator_gradient_penalty,
                family="losses"
            )
            tf.summary.scalar(
                name="epsilon_drift_penalty",
                tensor=epsilon_drift_penalty,
                family="losses"
            )
            tf.summary.scalar(
                name="discriminator_wasserstein_gp_loss",
                tensor=discriminator_wasserstein_gp_loss,
                family="losses"
            )
            tf.summary.scalar(
                name="discriminator_reg_loss",
                tensor=discriminator_reg_loss,
                family="losses"
            )
            tf.summary.scalar(
                name="discriminator_total_loss",
                tensor=discriminator_total_loss,
                family="total_losses"
            )

        return discriminator_total_loss


## regularization.py

In [13]:
def get_regularization_loss(lambda1=0., lambda2=0., scope=None):
    """Gets regularization losses from variables attached to a regularizer.

    Args:
        lambda1: float, L1 regularization scale parameter.
        lambda2: float, L2 regularization scale parameter.
        scope: str, the name of the variable scope.

    Returns:
        Scalar regularization loss tensor.
    """
    def sum_nd_tensor_list_to_scalar_tensor(t_list):
        """Sums different shape tensors into a scalar tensor.

        Args:
            t_list: list, tensors of varying shapes.

        Returns:
            Scalar tensor.
        """
        func_name = "sum_nd_tensor_list_to_scalar_tensor"
        # Sum list of tensors into a list of scalars.
        t_reduce_sum_list = [
            tf.reduce_sum(
                # Remove the :0 from the end of the name.
                input_tensor=t, name="{}_reduce_sum".format(t.name[:-2])
            )
            for t in t_list
        ]
        print_obj("\n" + func_name, "t_reduce_sum_list", t_reduce_sum_list)

        # Add all scalars together into one scalar.
        t_scalar_sum_tensor = tf.add_n(
            inputs=t_reduce_sum_list,
            name="{}_t_scalar_sum_tensor".format(scope)
        )
        print_obj(func_name, "t_scalar_sum_tensor", t_scalar_sum_tensor)

        return t_scalar_sum_tensor

    func_name = "get_regularization_loss"
    print_obj("\n" + func_name, "scope", scope)
    if lambda1 <= 0. and lambda2 <= 0.:
        # No regularization so return zero.
        return tf.zeros(shape=[], dtype=tf.float32)

    # Get list of trainable variables with a regularizer attached in scope.
    trainable_reg_vars_list = tf.get_collection(
        tf.GraphKeys.REGULARIZATION_LOSSES, scope=scope)
    print_obj(
        func_name, "trainable_reg_vars_list", trainable_reg_vars_list
    )

    for var in trainable_reg_vars_list:
        print_obj(
            "{}_{}".format(func_name, scope), "{}".format(var.name), var.graph
        )

    l1_loss = 0.
    if lambda1 > 0.:
        # For L1 regularization, take the absolute value element-wise of each.
        trainable_reg_vars_abs_list = [
            tf.abs(
                x=var,
                # Clean up regularizer scopes in variable names.
                name="{}_abs".format(("/").join(var.name.split("/")[0:3]))
            )
            for var in trainable_reg_vars_list
        ]

        # Get L1 loss
        l1_loss = tf.multiply(
            x=lambda1,
            y=sum_nd_tensor_list_to_scalar_tensor(
                t_list=trainable_reg_vars_abs_list
            ),
            name="{}_l1_loss".format(scope)
        )

    l2_loss = 0.
    if lambda2 > 0.:
        # For L2 regularization, square all variables element-wise.
        trainable_reg_vars_squared_list = [
            tf.square(
                x=var,
                # Clean up regularizer scopes in variable names.
                name="{}_squared".format(("/").join(var.name.split("/")[0:3]))
            )
            for var in trainable_reg_vars_list
        ]
        print_obj(
            func_name,
            "trainable_reg_vars_squared_list",
            trainable_reg_vars_squared_list
        )

        # Get L2 loss
        l2_loss = tf.multiply(
            x=lambda2,
            y=sum_nd_tensor_list_to_scalar_tensor(
                t_list=trainable_reg_vars_squared_list
            ),
            name="{}_l2_loss".format(scope)
        )

    l1_l2_loss = tf.add(
        x=l1_loss, y=l2_loss, name="{}_l1_l2_loss".format(scope)
    )

    return l1_l2_loss


## train_and_eval.py

In [14]:
def get_logits_and_losses(
        features, generator, discriminator, alpha_var, mode, params):
    """Gets logits and losses for both train and eval modes.

    Args:
        features: dict, feature tensors from input function.
        generator: instance of generator.`Generator`.
        discriminator: instance of discriminator.`Discriminator`.
        alpha_var: variable, alpha for weighted sum of fade-in of layers.
        mode: tf.estimator.ModeKeys with values of either TRAIN or EVAL.
        params: dict, user passed parameters.

    Returns:
        Real and fake logits and generator and discriminator losses.
    """
    func_name = "get_logits_and_losses"
    # Extract image from features dictionary.
    X = features["image"]
    print_obj("\n" + func_name, "X", X)

    if params["growth_idx"] is not None:
        cur_batch_size = X.shape[0]
        # Create random noise latent vector for each batch example.
        Z = tf.random.normal(
            shape=[cur_batch_size, params["latent_size"]],
            mean=0.0,
            stddev=1.0,
            dtype=tf.float32
        )
    else:
        # Get dynamic batch size in case of partial batch.
        cur_batch_size = tf.shape(
            input=X,
            out_type=tf.int32,
            name="{}_cur_batch_size".format(func_name)
        )[0]

        # Create random noise latent vector for each batch example.
        Z = tf.random.normal(
            shape=[cur_batch_size, params["latent_size"]],
            mean=0.0,
            stddev=1.0,
            dtype=tf.float32
        )
    print_obj(func_name, "Z", Z)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Update alpha variable for fade-in.
        alpha_var = update_alpha(
            global_step=tf.train.get_or_create_global_step(),
            alpha_var=alpha_var,
            params=params
        )
    print_obj(func_name, "alpha_var", alpha_var)

    if not params["use_tpu"]:
        # Add summaries for TensorBoard.
        tf.summary.scalar(
            name="alpha_var",
            tensor=alpha_var,
            family="alpha_var"
        )

    # Get generated image from generator network from gaussian noise.
    print("\nCall generator with Z = {}.".format(Z))
    generator_outputs = generator.get_train_eval_vec_to_img_outputs(
        Z=Z, alpha_var=alpha_var, params=params
    )

    if not params["use_tpu"]:
        # Add summaries for TensorBoard.
        tf.summary.image(
            name="generator_outputs",
            tensor=generator_outputs,
            max_outputs=5,
        )

    # Get fake logits from discriminator using generator's output image.
    print(
        "\nCall discriminator with generator_outputs = {}.".format(
            generator_outputs
        )
    )
    fake_logits = discriminator.get_train_eval_img_to_vec_logits(
        X=generator_outputs, alpha_var=alpha_var, params=params
    )

    # Resize real images based on the current size of the GAN.
    real_images = resize_real_images(image=X, params=params)

    # Get real logits from discriminator using real image.
    print(
        "\nCall discriminator with real_images = {}.".format(real_images)
    )
    real_logits = discriminator.get_train_eval_img_to_vec_logits(
        X=real_images, alpha_var=alpha_var, params=params
    )

    # Get generator total loss.
    generator_total_loss = generator.get_generator_loss(
        fake_logits=fake_logits, params=params
    )

    # Get discriminator total loss.
    discriminator_total_loss = discriminator.get_discriminator_loss(
        cur_batch_size=cur_batch_size,
        fake_images=generator_outputs,
        real_images=real_images,
        fake_logits=fake_logits,
        real_logits=real_logits,
        alpha_var=alpha_var,
        params=params
    )

    return (real_logits,
            fake_logits,
            generator_total_loss,
            discriminator_total_loss)


## train.py

In [15]:
def get_variables_and_gradients(loss, scope):
    """Gets variables and their gradients wrt. loss.
    Args:
        loss: tensor, shape of [].
        scope: str, the network's name to find its variables to train.
    Returns:
        Lists of variables and their gradients.
    """
    func_name = "get_variables_and_gradients"
    # Get trainable variables.
    variables = tf.trainable_variables(scope=scope)
    print_obj("\n{}_{}".format(func_name, scope), "variables", variables)

    # Get gradients.
    gradients = tf.gradients(
        ys=loss,
        xs=variables,
        name="{}_gradients".format(scope)
    )
    print_obj("\n{}_{}".format(func_name, scope), "gradients", gradients)

    # Add variable names back in for identification.
    gradients = [
        tf.identity(
            input=g,
            name="{}_{}_gradients".format(func_name, v.name[:-2])
        )
        if tf.is_tensor(x=g) else g
        for g, v in zip(gradients, variables)
    ]
    print_obj("\n{}_{}".format(func_name, scope), "gradients", gradients)

    return variables, gradients


def create_variable_and_gradient_histogram_summaries(loss_dict, params):
    """Creates variable and gradient histogram summaries.
    Args:
        loss_dict: dict, keys are scopes and values are scalar loss tensors
            for each network kind.
        params: dict, user passed parameters.
    """
    if not params["use_tpu"]:
        for scope, loss in loss_dict.items():
            # Get variables and their gradients wrt. loss.
            variables, gradients = get_variables_and_gradients(loss, scope)

            # Add summaries for TensorBoard.
            for g, v in zip(gradients, variables):
                tf.summary.histogram(
                    name="{}".format(v.name[:-2]),
                    values=v,
                    family="{}_variables".format(scope)
                )
                if tf.is_tensor(x=g):
                    tf.summary.histogram(
                        name="{}".format(v.name[:-2]),
                        values=g,
                        family="{}_gradients".format(scope)
                    )


def instantiate_optimizer_slots(optimizer, variables, params, scope):
    """Instantiates optimizer slots for all parameters ahead of time.
    Args:
        optimizer: instance of `Optimizer`.
        variables: list, list of scoped trainable variables.
        params: dict, user passed parameters.
        scope: str, the network's name to find its variables to train.
    Returns:
        Apply gradients op to instantiate all optimizer slots and add to
            collection op for optimizer slot metric variables.
    """
    func_name = "instantiate_optimizer_slots"
    # Create zero gradients for every scoped trainable variable.
    zero_gradients = [
        tf.zeros_like(
            tensor=v,
            dtype=tf.float32,
            name="{}_{}_{}_zeros_like".format(func_name, scope, v.name[:-2])
        )
        for v in variables
    ]
    print_obj(
        "{}_{}".format(func_name, scope), "zero_gradients", zero_gradients
    )

    # Zip together gradients and variables.
    grads_and_vars = zip(zero_gradients, variables)
    print_obj(
        "{}_{}".format(func_name, scope), "grads_and_vars", grads_and_vars
    )

    # Apply zero gradients to create all optimizer slots ahead of time. Since
    # this is when global_step is zero, it won't change the parameters or the
    # moment accumulators.
    instantiate_optimizer_op = optimizer.apply_gradients(
        grads_and_vars=grads_and_vars,
        global_step=None,
        name="{}_{}_apply_gradients".format(func_name, scope)
    )
    print_obj(
        "{}_{}".format(func_name, scope),
        "instantiate_optimizer_op",
        instantiate_optimizer_op
    )

    if params["save_optimizer_metrics_to_checkpoint"]:
        optimizer_name = "{}_{}_optimizer".format(
            scope, params["{}_optimizer".format(scope)]
        )
        # Add optimizer slot metric variables to global collection so that they
        # will be written to checkpoints.
        add_to_collection_ops = [
            tf.add_to_collection(name=tf.GraphKeys.GLOBAL_VARIABLES, value=v)
            for v in tf.get_collection(
                key=tf.GraphKeys.METRIC_VARIABLES, scope=optimizer_name
            )
        ]
    else:
        add_to_collection_ops = []
    print_obj(
        "{}_{}".format(func_name, scope),
        "add_to_collection_ops",
        add_to_collection_ops
    )

    return instantiate_optimizer_op, add_to_collection_ops


def dont_instantiate_optimizer_slots(scope):
    """Wrapper for not instantiating optimizer slots for tf.cond.
    Args:
        scope: str, the network's name to find its variables to train.
    Returns:
        Apply gradients no op to instantiate all optimizer slots and add to
            collection no op for optimizer slot metric variables.
    """
    instantiate_optimizer_no_op = tf.no_op(
        name="{}_instantiate_optimizer_no_op".format(scope)
    )

    return instantiate_optimizer_no_op, []


def train_network(
        loss, global_step, alpha_var, params, scope, increment_global_step):
    """Trains network and returns loss and train op.
    Args:
        loss: tensor, shape of [].
        global_step: tensor, the current training step or batch in the
            training loop.
        alpha_var: variable, alpha for weighted sum of fade-in of layers.
        params: dict, user passed parameters.
        scope: str, the network's name to find its variables to train.
        increment_global_step: int, whether to increment global step or not.
    Returns:
        Loss tensor and training op.
    """
    func_name = "train_network"
    print_obj("\n" + func_name, "loss", loss)
    print_obj(func_name, "global_step", global_step)
    print_obj(func_name, "alpha_var", alpha_var)
    print_obj(func_name, "scope", scope)

    # Create optimizer map.
    optimizers = {
        "Adam": tf.train.AdamOptimizer,
        "Adadelta": tf.train.AdadeltaOptimizer,
        "AdagradDA": tf.train.AdagradDAOptimizer,
        "Adagrad": tf.train.AdagradOptimizer,
        "Ftrl": tf.train.FtrlOptimizer,
        "GradientDescent": tf.train.GradientDescentOptimizer,
        "Momentum": tf.train.MomentumOptimizer,
        "ProximalAdagrad": tf.train.ProximalAdagradOptimizer,
        "ProximalGradientDescent": tf.train.ProximalGradientDescentOptimizer,
        "RMSProp": tf.train.RMSPropOptimizer
    }

    # Get optimizer and instantiate it.
    if params["{}_optimizer".format(scope)] == "Adam":
        optimizer = optimizers[params["{}_optimizer".format(scope)]](
            learning_rate=params["{}_learning_rate".format(scope)],
            beta1=params["{}_adam_beta1".format(scope)],
            beta2=params["{}_adam_beta2".format(scope)],
            epsilon=params["{}_adam_epsilon".format(scope)],
            name="{}_{}_optimizer".format(
                scope, params["{}_optimizer".format(scope)].lower()
            )
        )
    else:
        optimizer = optimizers[params["{}_optimizer".format(scope)]](
            learning_rate=params["{}_learning_rate".format(scope)],
            name="{}_{}_optimizer".format(
                scope, params["{}_optimizer".format(scope)].lower()
            )
        )
    print_obj("{}_{}".format(func_name, scope), "optimizer", optimizer)

    # If using TPU, wrap optimizer to use an allreduce to aggregate gradients
    # and broadcast the result to each shard.
    if params["use_tpu"]:
        optimizer = tf.contrib.tpu.CrossShardOptimizer(opt=optimizer)
        print_obj("{}_{}".format(func_name, scope), "optimizer", optimizer)

    # Get variables and their gradients wrt. loss.
    variables, gradients = get_variables_and_gradients(loss, scope)

    # Clip gradients.
    if params["{}_clip_gradients".format(scope)]:
        gradients, _ = tf.clip_by_global_norm(
            t_list=gradients,
            clip_norm=params["{}_clip_gradients".format(scope)],
            name="{}_clip_by_global_norm_gradients".format(scope)
        )
        print_obj("\n{}_{}".format(func_name, scope), "gradients", gradients)

        # Add variable names back in for identification.
        gradients = [
            tf.identity(
                input=g,
                name="{}_{}_clip_gradients".format(func_name, v.name[:-2])
            )
            if tf.is_tensor(x=g) else g
            for g, v in zip(gradients, variables)
        ]
        print_obj("\n{}_{}".format(func_name, scope), "gradients", gradients)

    # Zip back together gradients and variables.
    grads_and_vars = zip(gradients, variables)
    print_obj(
        "{}_{}".format(func_name, scope), "grads_and_vars", grads_and_vars
    )

    if params["{}_optimizer".format(scope)] != "GradientDescent":
        # Instantiate ALL optimizer slots, not just for ones without None grad.
        instantiate_optimizer_op, add_to_collection_ops = tf.cond(
            pred=tf.equal(
                x=global_step, y=0, name="instantiate_optimizer_op_pred"
            ),
            true_fn=lambda: instantiate_optimizer_slots(
                optimizer=optimizer,
                variables=variables,
                params=params,
                scope=scope
            ),
            false_fn=lambda: dont_instantiate_optimizer_slots(scope),
            name="instantiate_optimizer_op_cond"
        )

        with tf.control_dependencies(
                control_inputs=[instantiate_optimizer_op]):
            with tf.control_dependencies(
                    control_inputs=add_to_collection_ops):
                loss = tf.identity(
                    input=loss,
                    name="{}_{}_loss_identity".format(func_name, scope)
                )

    # Create train op by applying gradients to variables and possibly
    # incrementing global step.
    train_op = optimizer.apply_gradients(
        grads_and_vars=grads_and_vars,
        global_step=global_step if increment_global_step else None,
        name="{}_apply_gradients".format(scope)
    )
    print_obj("{}_{}".format(func_name, scope), "train_op", train_op)

    return loss, train_op


def train_discriminator(
        discriminator_loss,
        global_step,
        alpha_var,
        params,
        discriminator_scope):
    """Wrapper that trains discriminator network & returns loss and train op.
    Args:
        discriminator_loss: tensor, discriminator's loss with shape [].
        global_step: tensor, the current training step or batch in the
            training loop.
        alpha_var: variable, alpha for weighted sum of fade-in of layers.
        params: dict, user passed parameters.
        discriminator_scope: str, the discriminator's name to find its
            variables.
    Returns:
        Loss tensor and training op.
    """
    # Get loss and train_op for discriminator.
    loss, train_op = train_network(
        loss=discriminator_loss,
        global_step=global_step,
        alpha_var=alpha_var,
        params=params,
        scope=discriminator_scope,
        increment_global_step=True
    )

    return loss, train_op


def train_generator(
        generator_loss,
        global_step,
        alpha_var,
        params,
        generator_scope):
    """Wrapper that trains generator network & returns loss and train op.
    Args:
        generator_loss: tensor, generator's loss with shape [].
        global_step: tensor, the current training step or batch in the
            training loop.
        alpha_var: variable, alpha for weighted sum of fade-in of layers.
        params: dict, user passed parameters.
        generator_scope: str, the generator's name to find its variables.
    Returns:
        Loss tensor and training op.
    """
    # Get loss and train_op for generator.
    loss, train_op = train_network(
        loss=generator_loss,
        global_step=global_step,
        alpha_var=alpha_var,
        params=params,
        scope=generator_scope,
        increment_global_step=True
    )

    return loss, train_op


def known_update_alpha(global_step, alpha_var, params):
    """Returns ref for updated alpha variable.
    Args:
        global_step: tensor, the current training step or batch in the
            training loop.
        alpha_var: variable, alpha for weighted sum of fade-in of layers.
        params: dict, user passed parameters.
    Returns:
        Ref for updated alpha variable.
    """
    func_name = "known_update_alpha"
    # If never grow, then no need to update alpha since it is not used.
    if len(params["conv_num_filters"]) > 1 and params["growth_idx"] > 0:
        if params["growth_idx"] % 2 == 1:
            # Update alpha var to linearly scale from 0 to 1 based on steps.
            alpha_var = tf.assign(
                ref=alpha_var,
                value=tf.divide(
                    x=tf.cast(
                        # Add 1 since it trains on global step 0, so off by 1.
                        x=tf.add(
                            x=tf.mod(
                                x=tf.subtract(
                                    x=global_step,
                                    y=params["previous_train_steps"]
                                ),
                                y=params["num_steps_until_growth"]
                            ),
                            y=1
                        ),
                        dtype=tf.float32
                    ),
                    y=params["num_steps_until_growth"]
                ),
                name="update_alpha_assign_linear"
            )
        else:
            alpha_var = tf.assign(
                ref=alpha_var,
                value=tf.ones(shape=[], dtype=tf.float32),
                name="update_alpha_assign_ones"
            )
    print_obj(func_name, "alpha_var", alpha_var)

    return alpha_var


def unknown_update_alpha_transition(global_step, alpha_var, params):
    """Returns ref for updated alpha variable.
    Args:
        global_step: tensor, the current training step or batch in the
            training loop.
        alpha_var: variable, alpha for weighted sum of fade-in of layers.
        params: dict, user passed parameters.
    Returns:
        Ref for updated alpha variable.
    """
    alpha_var = tf.assign(
        ref=alpha_var,
        value=tf.divide(
            x=tf.cast(
                # Add 1 since it trains on global step 0, so off by 1.
                x=tf.add(
                    x=tf.mod(
                        x=tf.subtract(
                            x=global_step,
                            y=params["previous_train_steps"]
                        ),
                        y=params["num_steps_until_growth"]
                    ),
                    y=1
                ),
                dtype=tf.float32
            ),
            y=params["num_steps_until_growth"]
        ),
        name="update_alpha_assign_linear"
    )

    return alpha_var


def unknown_update_alpha_stable(global_step, alpha_var, params):
    """Returns ref for updated alpha variable.
    Args:
        global_step: tensor, the current training step or batch in the
            training loop.
        alpha_var: variable, alpha for weighted sum of fade-in of layers.
        params: dict, user passed parameters.
    Returns:
        Ref for updated alpha variable.
    """
    alpha_var = tf.assign(
        ref=alpha_var,
        value=tf.ones(shape=[], dtype=tf.float32),
        name="update_alpha_assign_ones"
    )

    return alpha_var


def unknown_update_alpha(global_step, alpha_var, params):
    """Returns ref for updated alpha variable.
    Args:
        global_step: tensor, the current training step or batch in the
            training loop.
        alpha_var: variable, alpha for weighted sum of fade-in of layers.
        params: dict, user passed parameters.
    Returns:
        Ref for updated alpha variable.
    """
    func_name = "unknown_update_alpha"

    # If never grow, then no need to update alpha since it is not used.
    if len(params["conv_num_filters"]) > 1:
        # Find growth index based on global step and growth frequency.
        growth_index = tf.minimum(
            x=tf.cast(
                x=tf.add(
                    x=tf.floordiv(
                        x=tf.subtract(
                            x=global_step,
                            y=params["previous_train_steps"]
                        ),
                        y=params["num_steps_until_growth"]
                    ),
                    y=0 if not params["growth_idx"] else params["growth_idx"]
                ),
                dtype=tf.int32),
            y=(len(params["conv_num_filters"]) - 1) * 2,
            name="update_alpha_growth_index"
        )

        # True if this is a transition stage, False if this is a stable stage.
        condition = tf.equal(
            x=tf.mod(x=growth_index, y=2),
            y=1,
            name="{}_condition".format(func_name)
        )

        # Conditionally update alpha.
        alpha_var = tf.cond(
            pred=condition,
            true_fn=lambda: unknown_update_alpha_transition(
                global_step, alpha_var, params
            ),
            false_fn=lambda: unknown_update_alpha_stable(
                global_step, alpha_var, params
            ),
            name="{}_cond".format(func_name)
        )
    print_obj(func_name, "alpha_var", alpha_var)

    return alpha_var


def update_alpha(global_step, alpha_var, params):
    """Returns ref for updated alpha variable.
    Args:
        global_step: tensor, the current training step or batch in the
            training loop.
        alpha_var: variable, alpha for weighted sum of fade-in of layers.
        params: dict, user passed parameters.
    Returns:
        Ref for updated alpha variable.
    """
    func_name = "update_alpha"
    # If never grow, then no need to update alpha since it is not used.
    if params["growth_idx"] is not None:
        alpha_var = known_update_alpha(global_step, alpha_var, params)
    else:
        alpha_var = unknown_update_alpha(global_step, alpha_var, params)
    print_obj(func_name, "alpha_var", alpha_var)

    return alpha_var


def get_loss_and_train_op(
        generator_total_loss,
        discriminator_total_loss,
        alpha_var,
        params):
    """Gets loss and train op for train mode.
    Args:
        generator_total_loss: tensor, scalar total loss of generator.
        discriminator_total_loss: tensor, scalar total loss of discriminator.
        alpha_var: variable, alpha for weighted sum of fade-in of layers.
        params: dict, user passed parameters.
    Returns:
        Loss scalar tensor and train_op to be used by the EstimatorSpec.
    """
    func_name = "get_loss_and_train_op"
    # Get global step.
    global_step = tf.train.get_or_create_global_step()

    # Determine if it is time to train generator or discriminator.
    cycle_step = tf.mod(
        x=global_step,
        y=tf.cast(
            x=tf.add(
                x=params["discriminator_train_steps"],
                y=params["generator_train_steps"]
            ),
            dtype=tf.int64
        ),
        name="{}_cycle_step".format(func_name)
    )

    # Create choose discriminator condition.
    condition = tf.less(
        x=cycle_step, y=params["discriminator_train_steps"]
    )

    # Needed for batch normalization, but has no effect otherwise.
    update_ops = tf.get_collection(key=tf.GraphKeys.UPDATE_OPS)

    # Ensure update ops get updated.
    with tf.control_dependencies(
            control_inputs=update_ops):
        # Conditionally choose to train generator or discriminator.
        loss, train_op = tf.cond(
            pred=condition,
            true_fn=lambda: train_discriminator(
                discriminator_loss=discriminator_total_loss,
                global_step=global_step,
                alpha_var=alpha_var,
                params=params,
                discriminator_scope="discriminator"
            ),
            false_fn=lambda: train_generator(
                generator_loss=generator_total_loss,
                global_step=global_step,
                alpha_var=alpha_var,
                params=params,
                generator_scope="generator"
            ),
            name="{}_cond".format(func_name)
        )

    return loss, train_op


## eval_metrics.py

In [16]:
def get_eval_metric_ops(fake_logits, real_logits):
    """Gets eval metric ops.

    Args:
        fake_logits: tensor, shape of [cur_batch_size, 1] that came from
            discriminator having processed generator's output image.
        real_logits: tensor, shape of [cur_batch_size, 1] that came from
            discriminator having processed real image.

    Returns:
        Dictionary of eval metric ops.
    """
    func_name = "get_eval_metric_ops"
    # Concatenate discriminator logits and labels.
    discriminator_logits = tf.concat(
        values=[real_logits, fake_logits],
        axis=0,
        name="discriminator_concat_logits"
    )
    print_obj("\n" + func_name, "discriminator_logits", discriminator_logits)

    discriminator_labels = tf.concat(
        values=[
            tf.ones_like(tensor=real_logits),
            tf.zeros_like(tensor=fake_logits)
        ],
        axis=0,
        name="discriminator_concat_labels"
    )
    print_obj(func_name, "discriminator_labels", discriminator_labels)

    # Calculate discriminator probabilities.
    discriminator_probabilities = tf.nn.sigmoid(
        x=discriminator_logits, name="discriminator_probabilities"
    )
    print_obj(
        func_name, "discriminator_probabilities", discriminator_probabilities
    )

    # Create eval metric ops dictionary.
    eval_metric_ops = {
        "accuracy": tf.metrics.accuracy(
            labels=discriminator_labels,
            predictions=discriminator_probabilities,
            name="discriminator_accuracy"
        ),
        "precision": tf.metrics.precision(
            labels=discriminator_labels,
            predictions=discriminator_probabilities,
            name="discriminator_precision"
        ),
        "recall": tf.metrics.recall(
            labels=discriminator_labels,
            predictions=discriminator_probabilities,
            name="discriminator_recall"
        ),
        "auc_roc": tf.metrics.auc(
            labels=discriminator_labels,
            predictions=discriminator_probabilities,
            num_thresholds=200,
            curve="ROC",
            name="discriminator_auc_roc"
        ),
        "auc_pr": tf.metrics.auc(
            labels=discriminator_labels,
            predictions=discriminator_probabilities,
            num_thresholds=200,
            curve="PR",
            name="discriminator_auc_pr"
        )
    }
    print_obj(func_name, "eval_metric_ops", eval_metric_ops)

    return eval_metric_ops


## predict.py

In [17]:
def get_predictions(Z, generator, params, block_idx):
    """Gets predictions from latent vectors Z.

    Args:
        Z: tensor, latent vectors of shape [cur_batch_size, latent_size].
        generator: instance of generator.`Generator`.
        params: dict, user passed parameters.
        block_idx: int, current conv layer block's index.

    Returns:
        Predictions dictionary of generated images from generator.
    """
    func_name = "get_predictions"
    print_obj("\n" + func_name, "Z", Z)

    # Get predictions from generator.
    generated_images = generator.get_predict_vec_to_img_outputs(
        Z=Z, params=params, block_idx=block_idx
    )
    print_obj("\n" + func_name, "generated_images", generated_images)

    # Calculate image size for returned dict keys.
    image_dim = 4 * 2 ** block_idx
    image_size = "{}x{}".format(image_dim, image_dim)

    return {
        "generated_images_{}".format(image_size): generated_images
    }


def get_predictions_and_export_outputs(features, generator, params):
    """Gets predictions and serving export outputs.

    Args:
        features: dict, feature tensors from serving input function.
        generator: instance of `Generator`.
        params: dict, user passed parameters.

    Returns:
        Predictions dictionary and export outputs dictionary.
    """
    func_name = "get_predictions_and_export_outputs"

    # Extract given latent vectors from features dictionary.
    Z = features["Z"]
    print_obj("\n" + func_name, "Z", Z)

    loop_end = len(params["conv_num_filters"])
    loop_start = 0 if params["predict_all_resolutions"] else loop_end - 1
    print_obj(func_name, "loop_start", loop_start)
    print_obj(func_name, "loop_end", loop_end)

    # Create predictions dictionary.
    predictions_dict = {}
    for i in range(loop_start, loop_end):
        predictions = get_predictions(
            Z=Z,
            generator=generator,
            params=params,
            block_idx=i
        )
        predictions_dict.update(predictions)
    print_obj(func_name, "predictions_dict", predictions_dict)

    # Create export outputs.
    export_outputs = {
        "predict_export_outputs": tf.estimator.export.PredictOutput(
            outputs=predictions_dict)
    }
    print_obj(func_name, "export_outputs", export_outputs)

    return predictions_dict, export_outputs


## pgan.py

In [18]:
def pgan_model(features, labels, mode, params):
    """Progressively Growing GAN custom Estimator model function.

    Args:
        features: dict, keys are feature names and values are feature tensors.
        labels: tensor, label data.
        mode: tf.estimator.ModeKeys with values of either TRAIN, EVAL, or
            PREDICT.
        params: dict, user passed parameters.

    Returns:
        Instance of `tf.estimator.EstimatorSpec` class.
    """
    func_name = "pgan_model"
    print_obj("\n" + func_name, "features", features)
    print_obj(func_name, "labels", labels)
    print_obj(func_name, "mode", mode)
    print_obj(func_name, "params", params)

    # Loss function, training/eval ops, etc.
    predictions_dict = None
    loss = None
    train_op = None
    eval_metric_ops = None
    export_outputs = None

    # Instantiate generator.
    pgan_generator = Generator(
        kernel_regularizer=tf.contrib.layers.l1_l2_regularizer(
            scale_l1=params["generator_l1_regularization_scale"],
            scale_l2=params["generator_l2_regularization_scale"]
        ),
        bias_regularizer=None,
        params=params,
        name="generator"
    )

    # Instantiate discriminator.
    pgan_discriminator = Discriminator(
        kernel_regularizer=tf.contrib.layers.l1_l2_regularizer(
            scale_l1=params["discriminator_l1_regularization_scale"],
            scale_l2=params["discriminator_l2_regularization_scale"]
        ),
        bias_regularizer=None,
        params=params,
        name="discriminator"
    )

    # Create alpha variable to use for weighted sum for smooth fade-in.
    alpha_var = tf.get_variable(
        name="alpha_var",
        dtype=tf.float32,
        # When the initializer is a function, tensorflow can place it
        # "outside of the control flow context" to make sure it always runs.
        initializer=lambda: tf.zeros(shape=[], dtype=tf.float32),
        trainable=False
    )
    print_obj(func_name, "alpha_var", alpha_var)

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Get predictions and export outputs.
        (predictions_dict,
         export_outputs) = get_predictions_and_export_outputs(
            features=features, generator=pgan_generator, params=params
        )
    else:
        # Get logits and losses from networks for train and eval modes.
        (real_logits,
         fake_logits,
         generator_total_loss,
         discriminator_total_loss) = get_logits_and_losses(
            features=features,
            generator=pgan_generator,
            discriminator=pgan_discriminator,
            alpha_var=alpha_var,
            mode=mode,
            params=params
        )

        if mode == tf.estimator.ModeKeys.TRAIN:
            # Create variable and gradient histogram summaries.
            create_variable_and_gradient_histogram_summaries(
                loss_dict = {
                    "generator": generator_total_loss,
                    "discriminator": discriminator_total_loss
                },
                params=params
            )

            # Get loss and train op for EstimatorSpec.
            loss, train_op = get_loss_and_train_op(
                generator_total_loss=generator_total_loss,
                discriminator_total_loss=discriminator_total_loss,
                alpha_var=alpha_var,
                params=params
            )
        else:
            loss = discriminator_total_loss

            if params["use_tpu"]:
                eval_metric_ops = (
                    get_eval_metric_ops,
                    {"real_logits": real_logits, "fake_logits": fake_logits}
                )
            else:
                eval_metric_ops = get_eval_metric_ops(
                    real_logits, fake_logits
                )

    if params["eval_on_tpu"]:
        # Return TPUEstimatorSpec
        return tf.estimator.tpu.TPUEstimatorSpec(
            mode=mode,
            predictions=predictions_dict,
            loss=loss,
            train_op=train_op,
            eval_metrics=eval_metric_ops,
            export_outputs=export_outputs
        )
    else:
        # Return EstimatorSpec
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions_dict,
            loss=loss,
            train_op=train_op,
            eval_metric_ops=eval_metric_ops,
            export_outputs=export_outputs
        )


## serving.py

In [19]:
def serving_input_fn(params):
    """Serving input function.

    Args:
        params: dict, user passed parameters.

    Returns:
        ServingInputReceiver object containing features and receiver tensors.
    """
    func_name = "serving_input_fn"
    # Create placeholders to accept data sent to the model at serving time.
    # shape = (batch_size, latent_size)
    feature_placeholders = {
        "Z": tf.placeholder(
            dtype=tf.float32,
            shape=[None, params["latent_size"]],
            name="{}_placeholder_Z".format(func_name)
        )
    }
    print_obj("\n" + func_name, "feature_placeholders", feature_placeholders)

    # Create clones of the feature placeholder tensors so that the SavedModel
    # SignatureDef will point to the placeholder.
    features = {
        key: tf.identity(
            input=value,
            name="{}_identity_placeholder_{}".format(func_name, key)
        )
        for key, value in feature_placeholders.items()
    }
    print_obj(func_name, "features", features)

    return tf.estimator.export.ServingInputReceiver(
        features=features, receiver_tensors=feature_placeholders
    )


## model.py

In [20]:
def instantiate_estimator(args, config):
    """Instantiates `TPUEstimator`.

    Args:
        args: dict, user passed parameters.
        config: instance of `tf.contrib.tpu.RunConfig`.

    Returns:
        `TPUEstimator` object.
    """
    # Create our custom estimator using our model function.
    estimator = tf.estimator.tpu.TPUEstimator(
        model_fn=pgan_model,
        model_dir=args["output_dir"],
        config=config,
        params=args,
        use_tpu=args["use_tpu"],
        train_batch_size=args["train_batch_size"],
        eval_batch_size=args["eval_batch_size"],
        eval_on_tpu=args["eval_on_tpu"],
        export_to_tpu=args["export_to_tpu"],
        export_to_cpu=args["export_to_cpu"]
    )

    return estimator


def train_estimator(args, estimator, steps):
    """Trains custom Estimator model.

    Args:
        args: dict, user passed parameters.
        estimator: instance of `TPUEstimator`.
        steps: int, number of steps to train for.
    """
    print(
        "CALLING TRAIN WITH GROWTH_IDX {}".format(args["growth_idx"])
    )
    estimator.train(
        input_fn=read_dataset(
            filename=args["train_file_pattern"],
            mode=tf.estimator.ModeKeys.TRAIN,
            batch_size=args["train_batch_size"],
            params=args
        ),
        steps=steps
    )


def export_saved_model(args, estimator):
    """Exports SavedModel.

    Args:
        args: dict, user passed parameters.
        estimator: instance of `TPUEstimator`.
    """
    tf.logging.info("Starting to export model.")
    estimator.export_savedmodel(
        export_dir_base=os.path.join(
            args["output_dir"], "export/exporter"
        ),
        serving_input_receiver_fn=lambda: serving_input_fn(
            args
        )
    )


def train_loop_iteration(args, config, steps):
    """Performs one training loop iteration.

    Args:
        args: dict, user passed parameters.
        config: instance of `tf.contrib.tpu.RunConfig`.
        steps: int, number of steps to train for.
    """
    # Instantiate new `TPUEstimator` instance.
    estimator = instantiate_estimator(args, config)

    # Train estimator.
    train_estimator(args, estimator, steps)

    # Export SavedModel.
    export_saved_model(args, estimator)

    return estimator


def progressive_train_loop(args, config):
    """Progressively trains model in a loop.

    Args:
        args: dict, user passed parameters.
        config: instance of `tf.contrib.tpu.RunConfig`.
    """
    func_name = "progressive_train_loop"

    # Detrmine number of stages.
    args["growth_idx"] = 0 if not args["growth_idx"] else args["growth_idx"]
    new_stages = ((args["train_steps"] - 1) // args["num_steps_until_growth"])
    min_potential_stages = min(
        args["growth_idx"] + new_stages + 1,
        17
    )
    print_obj("\n" + func_name, "min_potential_stages", min_potential_stages)

    min_possible_stages = min(
        min_potential_stages, len(args["conv_num_filters"]) * 2 - 1
    )
    print_obj(func_name, "min_possible_stages", min_possible_stages)

    num_stages = min_possible_stages - 1
    print_obj(func_name, "num_stages", num_stages)
    # Growth phases.
    for i in range(num_stages):
        # Perfom one training loop iteration.
        estimator = train_loop_iteration(
            args, config, steps=args["num_steps_until_growth"]
        )

        args["growth_idx"] += 1

    # Steady phase for any remaining steps.
    growth_steps = num_stages * args["num_steps_until_growth"]
    print_obj(func_name, "growth_steps", growth_steps)
    remaining_steps = args["train_steps"] - growth_steps
    print_obj(func_name, "remaining_steps", remaining_steps)
    if remaining_steps > 0:
        # Perfom one training loop iteration.
        estimator = train_loop_iteration(args, config, steps=remaining_steps)

    return estimator


def train_and_evaluate(args):
    """Trains and evaluates custom Estimator model.

    Args:
        args: dict, user passed parameters.
    """
    print_obj("train_and_evaluate", "args", args)
    # Ensure filewriter cache is clear for TensorBoard events file.
    tf.summary.FileWriterCache.clear()

    # Set logging to be level of INFO.
    tf.logging.set_verbosity(tf.logging.INFO)

    # Create TPU config.
    if args["use_tpu"]:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
        # This is the TPU initialization code that has to be at the beginning.
        tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)

        # Create TPU RunConfig.
        config = tf.contrib.tpu.RunConfig(
            tpu_config=tf.contrib.tpu.TPUConfig(
                iterations_per_loop=args["num_steps_until_growth"],
                per_host_input_for_training=True
            ),
            cluster=tpu_cluster_resolver,
            model_dir=args["output_dir"],
            save_summary_steps=args["save_summary_steps"],
            save_checkpoints_steps=args["save_checkpoints_steps"],
            keep_checkpoint_max=args["keep_checkpoint_max"]
        )

        # Run training loop.
        estimator = progressive_train_loop(args, config)
    else:
        # Create TPU RunConfig.
        config = tf.contrib.tpu.RunConfig(
            model_dir=args["output_dir"],
            save_summary_steps=args["save_summary_steps"],
            save_checkpoints_steps=args["save_checkpoints_steps"],
            keep_checkpoint_max=args["keep_checkpoint_max"]
        )

        if args["use_estimator_train_and_evaluate"]:
            # Create our custom estimator using our model function.
            estimator = tf.estimator.tpu.TPUEstimator(
                model_fn=pgan_model,
                model_dir=args["output_dir"],
                config=config,
                params=args,
                use_tpu=False,
                train_batch_size=args["train_batch_size"],
                eval_batch_size=args["eval_batch_size"],
                eval_on_tpu=False,
                export_to_tpu=False,
                export_to_cpu=True
            )

            # Create train spec to read in our training data.
            train_spec = tf.estimator.TrainSpec(
                input_fn=read_dataset(
                    filename=args["train_file_pattern"],
                    mode=tf.estimator.ModeKeys.TRAIN,
                    batch_size=args["train_batch_size"],
                    params=args
                ),
                max_steps=args["train_steps"]
            )

            # Create exporter to save out the complete model to disk.
            exporter = tf.estimator.LatestExporter(
                name="exporter",
                serving_input_receiver_fn=lambda: serving_input_fn(
                    args
                ),
                exports_to_keep=args["exports_to_keep"]
            )

            # Create eval spec to read validation data and export our model.
            eval_spec = tf.estimator.EvalSpec(
                input_fn=read_dataset(
                    filename=args["eval_file_pattern"],
                    mode=tf.estimator.ModeKeys.EVAL,
                    batch_size=args["eval_batch_size"],
                    params=args
                ),
                steps=args["eval_steps"],
                start_delay_secs=args["start_delay_secs"],
                throttle_secs=args["throttle_secs"],
                exporters=exporter
            )

            # Create train and evaluate loop to train & evaluate our estimator.
            tf.estimator.train_and_evaluate(
                estimator=estimator,
                train_spec=train_spec,
                eval_spec=eval_spec
            )
        else:
            estimator = progressive_train_loop(args, config)

    return estimator


## Run the training!

In [21]:
os.environ["OUTPUT_DIR"] = arguments["output_dir"]

In [None]:
%%bash
gsutil -m rm -rf ${OUTPUT_DIR}

In [None]:
estimator = train_and_evaluate(arguments)

train_and_evaluate: args = {'train_file_pattern': 'gs://machine-learning-1234-bucket/gan/data/cifar10_car/train*.tfrecord', 'eval_file_pattern': 'gs://machine-learning-1234-bucket/gan/data/cifar10_car/test*.tfrecord', 'output_dir': 'gs://machine-learning-1234-bucket/gan/pgan/trained_model_local_cifar10_car', 'dataset': 'cifar10', 'train_batch_size': 32, 'train_steps': 59500, 'use_tpu': False, 'use_estimator_train_and_evaluate': False, 'growth_idx': 0, 'previous_train_steps': 0, 'save_optimizer_metrics_to_checkpoint': True, 'save_summary_steps': 100, 'save_checkpoints_steps': 10000, 'keep_checkpoint_max': 100, 'input_fn_autotune': True, 'eval_batch_size': 1, 'eval_steps': 1, 'start_delay_secs': 6000000, 'throttle_secs': 6000000, 'eval_on_tpu': True, 'exports_to_keep': 20, 'export_to_tpu': False, 'export_to_cpu': True, 'predict_all_resolutions': True, 'height': 32, 'width': 32, 'depth': 3, 'num_steps_until_growth': 8500, 'use_equalized_learning_rate': True, 'conv_num_filters': [[512, 512