## print_object.py

In [1]:
%%writefile pgan_module/trainer/print_object.py
def print_obj(function_name, object_name, object_value):
    """Prints enclosing function, object name, and object value.

    Args:
        function_name: str, name of function.
        object_name: str, name of object.
        object_value: object, value of passed object.
    """
#     pass
    print("{}: {} = {}".format(function_name, object_name, object_value))


Overwriting pgan_module/trainer/print_object.py


## input.py

In [2]:
%%writefile pgan_module/trainer/input.py
import tensorflow as tf

from .print_object import print_obj


def decode_example(protos, params):
    """Decodes TFRecord file into tensors.

    Given protobufs, decode into image and label tensors.

    Args:
        protos: protobufs from TFRecord file.
        params: dict, user passed parameters.

    Returns:
        Image and label tensors.
    """
    # Create feature schema map for protos.
    features = {
        "image_raw": tf.FixedLenFeature(shape=[], dtype=tf.string),
        "label": tf.FixedLenFeature(shape=[], dtype=tf.int64)
    }

    # Parse features from tf.Example.
    parsed_features = tf.parse_single_example(
        serialized=protos, features=features
    )
    print_obj("\ndecode_example", "features", features)

    # Convert from a scalar string tensor (whose single string has
    # length height * width * depth) to a uint8 tensor with shape
    # [height * width * depth].
    image = tf.decode_raw(
        input_bytes=parsed_features["image_raw"], out_type=tf.uint8
    )
    print_obj("decode_example", "image", image)

    # Reshape flattened image back into normal dimensions.
    image = tf.reshape(
        tensor=image,
        shape=[params["height"], params["width"], params["depth"]]
    )
    print_obj("decode_example", "image", image)

    # Convert from [0, 255] -> [-1.0, 1.0] floats.
    image = tf.cast(x=image, dtype=tf.float32) * (2. / 255) - 1.0
    print_obj("decode_example", "image", image)

    # Convert label from a scalar uint8 tensor to an int32 scalar.
    label = tf.cast(x=parsed_features["label"], dtype=tf.int32)
    print_obj("decode_example", "label", label)

    return {"image": image}, label


def read_dataset(filename, mode, batch_size, params):
    """Reads CSV time series data using tf.data, doing necessary preprocessing.

    Given filename, mode, batch size, and other parameters, read CSV dataset
    using Dataset API, apply necessary preprocessing, and return an input
    function to the Estimator API.

    Args:
        filename: str, file pattern that to read into our tf.data dataset.
        mode: The estimator ModeKeys. Can be TRAIN or EVAL.
        batch_size: int, number of examples per batch.
        params: dict, dictionary of user passed parameters.

    Returns:
        An input function.
    """
    def _input_fn():
        """Wrapper input function used by Estimator API to get data tensors.

        Returns:
            Batched dataset object of dictionary of feature tensors and label
                tensor.
        """
        # Create list of files that match pattern.
        file_list = tf.gfile.Glob(filename=filename)

        # Create dataset from file list.
        dataset = tf.data.TFRecordDataset(
            filenames=file_list, num_parallel_reads=40
        )

        # Shuffle and repeat if training with fused op.
        if mode == tf.estimator.ModeKeys.TRAIN:
            dataset = dataset.apply(
                tf.contrib.data.shuffle_and_repeat(
                    buffer_size=50 * batch_size,
                    count=None  # indefinitely
                )
            )

        # Decode CSV file into a features dictionary of tensors, then batch.
        dataset = dataset.apply(
            tf.contrib.data.map_and_batch(
                map_func=lambda x: decode_example(
                    protos=x,
                    params=params
                ),
                batch_size=batch_size,
                num_parallel_calls=4
            )
        )

        # Prefetch data to improve latency.
        dataset = dataset.prefetch(buffer_size=2)

        # Create a iterator, then get batch of features from example queue.
        batched_dataset = dataset.make_one_shot_iterator().get_next()

        return batched_dataset
    return _input_fn


Overwriting pgan_module/trainer/input.py


## generator.py

In [3]:
%%writefile pgan_module/trainer/generator.py
import tensorflow as tf

from .print_object import print_obj


def generator_projection(Z, regularizer, params):
    """Creates generator projection from noise latent vector.

    Args:
        Z: tensor, latent vectors of shape [cur_batch_size, latent_size].
        params: dict, user passed parameters.

    Returns:
        Latent vector projection tensor.
    """
    # Project latent vectors.
    projection_height = params["generator_projection_dims"][0]
    projection_width = params["generator_projection_dims"][1]
    projection_depth = params["generator_projection_dims"][2]

    with tf.variable_scope(name_or_scope="generator", reuse=tf.AUTO_REUSE):
        # shape = (
        #     cur_batch_size,
        #     projection_height * projection_width * projection_depth
        # )
        projection = tf.layers.dense(
            inputs=Z,
            units=projection_height * projection_width * projection_depth,
            activation=tf.nn.leaky_relu,
            kernel_initializer="he_normal",
            kernel_regularizer=regularizer,
            name="projection_layer"
        )
        print_obj("generator_projection", "projection", projection)

    # Reshape projection into "image".
    # shape = (
    #     cur_batch_size,
    #     projection_height,
    #     projection_width,
    #     projection_depth
    # )
    projection = tf.reshape(
        tensor=projection,
        shape=[-1, projection_height, projection_width, projection_depth],
        name="projection_reshaped"
    )
    print_obj("generator_network", "projection", projection)

    return projection


def create_generator_base_conv_layer_block(regularizer, params):
    """Creates generator base conv layer block.

    Args:
        regularizer: `l1_l2_regularizer` object, regularizar for kernel
            variables.
        params: dict, user passed parameters.

    Returns:
        List of base conv layers.
    """
    with tf.variable_scope(name_or_scope="generator", reuse=tf.AUTO_REUSE):
        # Get conv block layer properties.
        conv_block = params["generator_base_conv_blocks"][0]

        # Create list of base conv layers.
        base_conv_layers = [
            tf.layers.Conv2D(
                filters=conv_block[i][3],
                kernel_size=conv_block[i][0:2],
                strides=conv_block[i][4:6],
                padding="same",
                activation=tf.nn.leaky_relu,
                kernel_initializer="he_normal",
#                 kernel_regularizer=regularizer,
                name="generator_base_layers_conv2d_{}_{}x{}_{}_{}".format(
                    i,
                    conv_block[i][0],
                    conv_block[i][1],
                    conv_block[i][2],
                    conv_block[i][3]
                )
            )
            for i in range(len(conv_block))
        ]
        print_obj(
            "\ncreate_generator_base_conv_layer_block",
            "base_conv_layers",
            base_conv_layers
        )

    return base_conv_layers


def create_generator_growth_layer_block(block_idx, regularizer, params):
    """Creates generator growth block.

    Args:
        block_idx: int, the current growth block's index.
        regularizer: `l1_l2_regularizer` object, regularizar for kernel
            variables.
        params: dict, user passed parameters.

    Returns:
        List of growth block layers.
    """
    with tf.variable_scope(name_or_scope="generator", reuse=tf.AUTO_REUSE):
        # Get conv block layer properties.
        conv_block = params["generator_growth_conv_blocks"][block_idx]

        # Create new inner convolutional layers.
        conv_layers = [
            tf.layers.Conv2D(
                filters=conv_block[i][3],
                kernel_size=conv_block[i][0:2],
                strides=conv_block[i][4:6],
                padding="same",
                activation=tf.nn.leaky_relu,
                kernel_initializer="he_normal",
#                 kernel_regularizer=regularizer,
                name="generator_growth_layers_conv2d_{}_{}_{}x{}_{}_{}".format(
                    block_idx,
                    i,
                    conv_block[i][0],
                    conv_block[i][1],
                    conv_block[i][2],
                    conv_block[i][3]
                )
            )
            for i in range(len(conv_block))
        ]
        print_obj(
            "\ncreate_generator_growth_layer_block", "conv_layers", conv_layers
        )

    return conv_layers


def create_generator_to_rgb_layers(regularizer, params):
    """Creates generator toRGB layers of 1x1 convs.

    Args:
        regularizer: `l1_l2_regularizer` object, regularizar for kernel
            variables.
        params: dict, user passed parameters.

    Returns:
        List of toRGB 1x1 conv layers.
    """
    with tf.variable_scope(name_or_scope="generator", reuse=tf.AUTO_REUSE):
        # Get toRGB layer properties.
        to_rgb = [
            params["generator_to_rgb_layers"][i][0][:]
            for i in range(len(params["generator_to_rgb_layers"]))
        ]

        # Create list to hold toRGB 1x1 convs.
        to_rgb_conv_layers = [
            # Create base toRGB conv 1x1.
            tf.layers.Conv2D(
                filters=to_rgb[i][3],
                kernel_size=to_rgb[i][0:2],
                strides=to_rgb[i][4:6],
                padding="same",
                activation=tf.nn.leaky_relu,
                kernel_initializer="he_normal",
#                 kernel_regularizer=regularizer,
                name="generator_to_rgb_layers_conv2d_{}_{}x{}_{}_{}".format(
                    i, to_rgb[i][0], to_rgb[i][1], to_rgb[i][2], to_rgb[i][3]
                )
            )
            for i in range(len(to_rgb))
        ]
        print_obj(
            "\ncreate_generator_to_rgb_layers",
            "to_rgb_conv_layers",
            to_rgb_conv_layers
        )

    return to_rgb_conv_layers


def upsample_generator_image(image, original_image_size, block_idx):
    """Upsamples generator image.

    Args:
        image: tensor, image created by generator conv block.
        original_image_size: list, the height and width dimensions of the
            original image before any growth.
        block_idx: int, index of the current generator growth block.

    Returns:
        Upsampled image tensor.
    """
    # Upsample from s X s to 2s X 2s image.
    upsampled_image = tf.image.resize(
        images=image,
        size=tf.convert_to_tensor(
            value=original_image_size,
            dtype=tf.int32,
            name="upsample_generator_image_original_image_size"
        ) * 2 ** block_idx,
        method="nearest",
        name="generator_growth_upsampled_image_{}_{}x{}_{}x{}".format(
            block_idx,
            original_image_size[0] * 2 ** (block_idx - 1),
            original_image_size[1] * 2 ** (block_idx - 1),
            original_image_size[0] * 2 ** block_idx,
            original_image_size[1] * 2 ** block_idx
        )
    )
    print_obj(
        "\nupsample_generator_image",
        "upsampled_image",
        upsampled_image
    )

    return upsampled_image


def create_base_generator_network(X, to_rgb_conv_layers, blocks):
    """Creates base generator network.

    Args:
        X: tensor, input image to generator.
        to_rgb_conv_layers: list, toRGB 1x1 conv layers.
        blocks: list, lists of block layers for each block.

    Returns:
        Final network block conv tensor.
    """
    print_obj("\ncreate_base_generator_network", "X", X)
    with tf.variable_scope(name_or_scope="generator", reuse=tf.AUTO_REUSE):
        # Only need the first block and toRGB conv layer for base network.
        block_layers = blocks[0]
        to_rgb_conv_layer = to_rgb_conv_layers[0]

        # Pass inputs through layer chain.
        block_conv = block_layers[0](inputs=X)
        print_obj("create_base_generator_network", "block_conv_0", block_conv)

        for i in range(1, len(block_layers)):
            block_conv = block_layers[i](inputs=block_conv)
            print_obj(
                "create_base_generator_network",
                "block_conv_{}".format(i),
                block_conv
            )
        to_rgb_conv = to_rgb_conv_layer(inputs=block_conv)
        print_obj("create_base_generator_network", "to_rgb_conv", to_rgb_conv)

    return to_rgb_conv


def create_growth_transition_generator_network(
        X,
        to_rgb_conv_layers,
        blocks,
        original_image_size,
        alpha_var,
        trans_idx):
    """Creates base generator network.

    Args:
        X: tensor, input image to generator.
        to_rgb_conv_layers: list, toRGB 1x1 conv layers.
        blocks: list, lists of block layers for each block.
        original_image_size: list, the height and width dimensions of the
            original image before any growth.
        alpha_var: variable, alpha for weighted sum of fade-in of layers.
        trans_idx: int, index of current growth transition.

    Returns:
        Final network block conv tensor.
    """
    print_obj(
        "\nEntered create_growth_transition_generator_network",
        "trans_idx",
        trans_idx
    )
    print_obj("create_growth_transition_generator_network", "X", X)
    with tf.variable_scope(name_or_scope="generator", reuse=tf.AUTO_REUSE):
        # Permanent blocks.
        permanent_blocks = blocks[0:trans_idx + 1]

        # Base block doesn't need any upsampling so it's handled differently.
        base_block_conv_layers = permanent_blocks[0]

        # Pass inputs through layer chain.
        block_conv = base_block_conv_layers[0](inputs=X)
        print_obj(
            "\ncreate_growth_transition_generator_network",
            "base_block_conv_{}_0".format(trans_idx),
            block_conv
        )
        for i in range(1, len(base_block_conv_layers)):
            block_conv = base_block_conv_layers[i](inputs=block_conv)
            print_obj(
                "create_growth_transition_generator_network",
                "base_block_conv_{}_{}".format(trans_idx, i),
                block_conv
            )

        # Growth blocks require first the prev conv layer's image upsampled.
        for i in range(1, len(permanent_blocks)):
            # Upsample previous block's image.
            block_conv = upsample_generator_image(
                image=block_conv,
                original_image_size=original_image_size,
                block_idx=i
            )
            print_obj(
                "create_growth_transition_generator_network",
                "upsample_generator_image_block_conv_{}_{}".format(
                    trans_idx, i
                ),
                block_conv
            )

            block_conv_layers = permanent_blocks[i]
            for j in range(0, len(block_conv_layers)):
                block_conv = block_conv_layers[j](inputs=block_conv)
                print_obj(
                    "create_growth_transition_generator_network",
                    "block_conv_{}_{}_{}".format(trans_idx, i, j),
                    block_conv
                )

        # Upsample most recent block conv image for both side chains.
        upsampled_block_conv = upsample_generator_image(
            image=block_conv,
            original_image_size=original_image_size,
            block_idx=len(permanent_blocks)
        )
        print_obj(
            "create_growth_transition_generator_network",
            "upsampled_block_conv_{}".format(trans_idx),
            upsampled_block_conv
        )

        # Growing side chain.
        growing_block_layers = blocks[trans_idx + 1]
        growing_to_rgb_conv_layer = to_rgb_conv_layers[trans_idx + 1]

        # Pass inputs through layer chain.
        block_conv = growing_block_layers[0](inputs=upsampled_block_conv)
        print_obj(
            "create_growth_transition_generator_network",
            "growing_block_conv_{}_0".format(trans_idx),
            block_conv
        )
        for i in range(1, len(growing_block_layers)):
            block_conv = growing_block_layers[i](inputs=block_conv)
            print_obj(
                "create_growth_transition_generator_network",
                "growing_block_conv_{}_{}".format(trans_idx, i),
                block_conv
            )
        growing_to_rgb_conv = growing_to_rgb_conv_layer(inputs=block_conv)
        print_obj(
            "create_growth_transition_generator_network",
            "growing_to_rgb_conv_{}".format(trans_idx),
            growing_to_rgb_conv
        )

        # Shrinking side chain.
        shrinking_to_rgb_conv_layer = to_rgb_conv_layers[trans_idx]

        # Pass inputs through layer chain.
        shrinking_to_rgb_conv = shrinking_to_rgb_conv_layer(
            inputs=upsampled_block_conv
        )
        print_obj(
            "create_growth_transition_generator_network",
            "shrinking_to_rgb_conv_{}".format(trans_idx),
            shrinking_to_rgb_conv
        )

        # Weighted sum.
        weighted_sum = tf.add(
            x=growing_to_rgb_conv * alpha_var,
            y=shrinking_to_rgb_conv * (1.0 - alpha_var),
            name="growth_transition_weighted_sum_{}".format(trans_idx)
        )
        print_obj(
            "create_growth_transition_generator_network",
            "weighted_sum_{}".format(trans_idx),
            weighted_sum
        )

    return weighted_sum


def create_final_generator_network(
        X, to_rgb_conv_layers, blocks, original_image_size):
    """Creates base generator network.

    Args:
        X: tensor, input image to generator.
        to_rgb_conv_layers: list, toRGB 1x1 conv layers.
        blocks: list, lists of block layers for each block.
        original_image_size: list, the height and width dimensions of the
            original image before any growth.

    Returns:
        Final network block conv tensor.
    """
    print_obj("\ncreate_final_generator_network", "X", X)
    with tf.variable_scope(name_or_scope="generator", reuse=tf.AUTO_REUSE):
        # Base block doesn't need any upsampling so it's handled differently.
        base_block_conv_layers = blocks[0]

        # Pass inputs through layer chain.
        block_conv = base_block_conv_layers[0](inputs=X)
        print_obj(
            "\ncreate_final_generator_network",
            "base_block_conv",
            block_conv
        )

        for i in range(1, len(base_block_conv_layers)):
            block_conv = base_block_conv_layers[i](inputs=block_conv)
            print_obj(
                "create_final_generator_network",
                "base_block_conv_{}".format(i),
                block_conv
            )

        # Growth blocks require first the prev conv layer's image upsampled.
        for i in range(1, len(blocks)):
            # Upsample previous block's image.
            block_conv = upsample_generator_image(
                image=block_conv,
                original_image_size=original_image_size,
                block_idx=i
            )
            print_obj(
                "create_final_generator_network",
                "upsample_generator_image_block_conv_{}".format(i),
                block_conv
            )

            block_conv_layers = blocks[i]
            for j in range(0, len(block_conv_layers)):
                block_conv = block_conv_layers[j](inputs=block_conv)
                print_obj(
                    "create_final_generator_network",
                    "block_conv_{}_{}".format(i, j),
                    block_conv
                )

        # Only need the last toRGB conv layer.
        to_rgb_conv_layer = to_rgb_conv_layers[-1]

        # Pass inputs through layer chain.
        to_rgb_conv = to_rgb_conv_layer(inputs=block_conv)
        print_obj(
            "create_final_generator_network", "to_rgb_conv", to_rgb_conv
        )

    return to_rgb_conv


def generator_network(Z, alpha_var, params):
    """Creates generator network and returns generated output.

    Args:
        Z: tensor, latent vectors of shape [cur_batch_size, latent_size].
        alpha_var: variable, alpha for weighted sum of fade-in of layers.
        params: dict, user passed parameters.

    Returns:
        Generated outputs tensor of shape
            [cur_batch_size, height * width * depth].
    """
    print_obj("\ngenerator_network", "Z", Z)

    # Create regularizer for layer kernel weights.
    regularizer = tf.contrib.layers.l1_l2_regularizer(
        scale_l1=params["generator_l1_regularization_scale"],
        scale_l2=params["generator_l2_regularization_scale"]
    )

    # Project latent vectors.
    projection = generator_projection(Z, regularizer, params)
    print_obj("generator_network", "projection", projection)

    # Create empty list to hold generator convolutional layer blocks.
    blocks = []

    # Create base convolutional layers, for post-growth.
    blocks.append(create_generator_base_conv_layer_block(regularizer, params))

    # Create growth layer blocks.
    for block_idx in range(len(params["generator_growth_conv_blocks"])):
        blocks.append(
            create_generator_growth_layer_block(
                block_idx, regularizer, params
            )
        )
    print_obj("generator_network", "blocks", blocks)

    # Create list of toRGB 1x1 conv layers.
    to_rgb_conv_layers = create_generator_to_rgb_layers(
        regularizer, params
    )
    print_obj("generator_network", "to_rgb_conv_layers", to_rgb_conv_layers)

    # Get generated outputs.
    if (params["train_steps"] // params["num_steps_until_growth"] <= 0 or
       len(params["conv_num_filters"]) == 1):
        print("\ngenerator_network: NEVER GOING TO GROW, SKIP SWITCH CASE")
        # If we never are going to grow, no sense using the switch case.
        # 4x4
        generated_outputs = create_base_generator_network(
            projection, to_rgb_conv_layers, blocks
        )
    else:
        # Find growth index based on global step and growth frequency.
        growth_index = tf.cast(
            x=tf.floordiv(
                x=tf.train.get_or_create_global_step(),
                y=params["num_steps_until_growth"]
            ),
            dtype=tf.int32,
            name="generator_growth_index"
        )

        # Switch to case based on number of steps for network creation.
        generated_outputs = tf.switch_case(
            branch_index=growth_index,
            branch_fns=[
                # 4x4
                lambda: create_base_generator_network(
                    projection,
                    to_rgb_conv_layers,
                    blocks
                ),
                # 8x8
                lambda: create_growth_transition_generator_network(
                    projection,
                    to_rgb_conv_layers,
                    blocks,
                    params["generator_projection_dims"][0:2],
                    alpha_var,
                    0
                ),
                # 16x16
                lambda: create_growth_transition_generator_network(
                    projection,
                    to_rgb_conv_layers,
                    blocks,
                    params["generator_projection_dims"][0:2],
                    alpha_var,
                    1
                ),
                # 32x32
                lambda: create_growth_transition_generator_network(
                    projection,
                    to_rgb_conv_layers,
                    blocks,
                    params["generator_projection_dims"][0:2],
                    alpha_var,
                    2
                ),
                # 64x64
                lambda: create_growth_transition_generator_network(
                    projection,
                    to_rgb_conv_layers,
                    blocks,
                    params["generator_projection_dims"][0:2],
                    alpha_var,
                    3
                ),
                # 128x128
                lambda: create_growth_transition_generator_network(
                    projection,
                    to_rgb_conv_layers,
                    blocks,
                    params["generator_projection_dims"][0:2],
                    alpha_var,
                    4
                ),
                # 256x256
                lambda: create_growth_transition_generator_network(
                    projection,
                    to_rgb_conv_layers,
                    blocks,
                    params["generator_projection_dims"][0:2],
                    alpha_var,
                    5
                ),
                # 512x512
                lambda: create_growth_transition_generator_network(
                    projection,
                    to_rgb_conv_layers,
                    blocks,
                    params["generator_projection_dims"][0:2],
                    alpha_var,
                    6
                ),
                # 1024x1024
                lambda: create_growth_transition_generator_network(
                    projection,
                    to_rgb_conv_layers,
                    blocks,
                    params["generator_projection_dims"][0:2],
                    alpha_var,
                    7
                ),
                # 1024x1024
                lambda: create_final_generator_network(
                    projection,
                    to_rgb_conv_layers,
                    blocks,
                    params["generator_projection_dims"][0:2]
                )
            ],
            name="generator_switch_case_generated_outputs"
        )

    print_obj("generator_network", "generated_outputs", generated_outputs)

    return generated_outputs


def get_generator_loss(generated_logits):
    """Gets generator loss.

    Args:
        generated_logits: tensor, shape of
            [cur_batch_size, height * width * depth].

    Returns:
        Tensor of generator's total loss of shape [].
    """
    # Calculate base generator loss.
    generator_loss = -tf.reduce_mean(
        input_tensor=generated_logits,
        name="generator_loss"
    )
    print_obj("\nget_generator_loss", "generator_loss", generator_loss)

    # Get regularization losses.
    generator_regularization_loss = tf.losses.get_regularization_loss(
        scope="generator",
        name="generator_regularization_loss"
    )
    print_obj(
        "get_generator_loss",
        "generator_regularization_loss",
        generator_regularization_loss
    )

    # Combine losses for total losses.
    generator_total_loss = tf.math.add(
        x=generator_loss,
        y=generator_regularization_loss,
        name="generator_total_loss"
    )
    print_obj(
        "get_generator_loss", "generator_total_loss", generator_total_loss
    )

    return generator_total_loss


Overwriting pgan_module/trainer/generator.py


## discriminator.py

In [4]:
%%writefile pgan_module/trainer/discriminator.py
import tensorflow as tf

from .print_object import print_obj


def create_discriminator_from_rgb_layers(regularizer, params):
    """Creates discriminator fromRGB layers of 1x1 convs.

    Args:
        regularizer: `l1_l2_regularizer` object, regularizar for kernel
            variables.
        params: dict, user passed parameters.

    Returns:
        List of fromRGB 1x1 conv layers.
    """
    with tf.variable_scope(name_or_scope="discriminator", reuse=tf.AUTO_REUSE):
        # Get fromRGB layer properties.
        from_rgb = [
            params["discriminator_from_rgb_layers"][i][0][:]
            for i in range(len(params["discriminator_from_rgb_layers"]))
        ]

        # Create list to hold toRGB 1x1 convs.
        from_rgb_conv_layers = [
            # Create base toRGB conv 1x1.
            tf.layers.Conv2D(
                filters=from_rgb[i][3],
                kernel_size=from_rgb[i][0:2],
                strides=from_rgb[i][4:6],
                padding="same",
                activation=tf.nn.leaky_relu,
                kernel_initializer="he_normal",
#                 kernel_regularizer=regularizer,
                name="discriminator_from_rgb_layers_conv2d_{}_{}x{}_{}_{}".format(
                    i,
                    from_rgb[i][0],
                    from_rgb[i][1],
                    from_rgb[i][2],
                    from_rgb[i][3]
                )
            )
            for i in range(len(from_rgb))
        ]
        print_obj(
            "\ncreate_discriminator_from_rgb_layers",
            "from_rgb_conv_layers",
            from_rgb_conv_layers
        )

    return from_rgb_conv_layers


def create_discriminator_base_conv_layer_block(regularizer, params):
    """Creates discriminator base conv layer block.

    Args:
        regularizer: `l1_l2_regularizer` object, regularizar for kernel
            variables.
        params: dict, user passed parameters.

    Returns:
        List of base conv layers.
    """
    with tf.variable_scope(name_or_scope="discriminator", reuse=tf.AUTO_REUSE):
        # Get conv block layer properties.
        conv_block = params["discriminator_base_conv_blocks"][0]

        # Create list of base conv layers.
        base_conv_layers = [
            tf.layers.Conv2D(
                filters=conv_block[i][3],
                kernel_size=conv_block[i][0:2],
                strides=conv_block[i][4:6],
                padding="same",
                activation=tf.nn.leaky_relu,
                kernel_initializer="he_normal",
#                 kernel_regularizer=regularizer,
                name="discriminator_base_layers_conv2d_{}_{}x{}_{}_{}".format(
                    i,
                    conv_block[i][0],
                    conv_block[i][1],
                    conv_block[i][2],
                    conv_block[i][3]
                )
            )
            for i in range(len(conv_block) - 1)
        ]

        # Have valid padding for layer just before flatten and logits.
        base_conv_layers.append(
            tf.layers.Conv2D(
                filters=conv_block[-1][3],
                kernel_size=conv_block[-1][0:2],
                strides=conv_block[-1][4:6],
                padding="valid",
                activation=tf.nn.leaky_relu,
                kernel_initializer="he_normal",
#                 kernel_regularizer=regularizer,
                name="discriminator_base_layers_conv2d_{}_{}x{}_{}_{}".format(
                    len(conv_block) - 1,
                    conv_block[-1][0],
                    conv_block[-1][1],
                    conv_block[-1][2],
                    conv_block[-1][3]
                )
            )
        )
        print_obj(
            "\ncreate_discriminator_base_conv_layer_block",
            "base_conv_layers",
            base_conv_layers
        )

    return base_conv_layers


def create_discriminator_growth_layer_block(
        block_idx, regularizer, params):
    """Creates discriminator growth block.

    Args:
        block_idx: int, the current growth block's index.
        regularizer: `l1_l2_regularizer` object, regularizar for kernel
            variables.
        params: dict, user passed parameters.

    Returns:
        List of growth block layers.
    """
    with tf.variable_scope(name_or_scope="discriminator", reuse=tf.AUTO_REUSE):
        # Get conv block layer properties.
        conv_block = params["discriminator_growth_conv_blocks"][block_idx]

        # Create new inner convolutional layers.
        conv_layers = [
            tf.layers.Conv2D(
                filters=conv_block[i][3],
                kernel_size=conv_block[i][0:2],
                strides=conv_block[i][4:6],
                padding="same",
                activation=tf.nn.leaky_relu,
                kernel_initializer="he_normal",
#                 kernel_regularizer=regularizer,
                name="discriminator_growth_layers_conv2d_{}_{}_{}x{}_{}_{}".format(
                    block_idx,
                    i,
                    conv_block[i][0],
                    conv_block[i][1],
                    conv_block[i][2],
                    conv_block[i][3]
                )
            )
            for i in range(len(conv_block))
        ]
        print_obj(
            "\ncreate_discriminator_growth_layer_block",
            "conv_layers",
            conv_layers
        )

        # Down sample from 2s X 2s to s X s image.
        downsampled_image_layer = tf.layers.AveragePooling2D(
            pool_size=(2, 2),
            strides=(2, 2),
            name="discriminator_growth_downsampled_image_{}".format(
                block_idx
            )
        )
        print_obj(
            "create_discriminator_growth_layer_block",
            "downsampled_image_layer",
            downsampled_image_layer
        )

    return conv_layers + [downsampled_image_layer]


def create_discriminator_growth_transition_downsample_layers(params):
    """Creates discriminator growth transition downsample layers.

    Args:
        params: dict, user passed parameters.

    Returns:
        List of growth transition downsample layers.
    """
    with tf.variable_scope(name_or_scope="discriminator", reuse=tf.AUTO_REUSE):
        # Down sample from 2s X 2s to s X s image.
        downsample_layers = [
            tf.layers.AveragePooling2D(
                pool_size=(2, 2),
                strides=(2, 2),
                name="discriminator_growth_transition_downsample_layer_{}".format(
                    layer_idx
                )
            )
            for layer_idx in range(
                1 + len(params["discriminator_growth_conv_blocks"])
            )
        ]
        print_obj(
            "\ncreate_discriminator_growth_transition_downsample_layers",
            "downsample_layers",
            downsample_layers
        )

    return downsample_layers


def create_base_discriminator_network(
        X, from_rgb_conv_layers, blocks, params):
    """Creates base discriminator network.

    Args:
        X: tensor, input image to discriminator.
        from_rgb_conv_layers: list, fromRGB 1x1 conv layers.
        blocks: list, lists of block layers for each block.
        params: dict, user passed parameters.

    Returns:
        Last block's last conv layer's tensor.
    """
    with tf.variable_scope(name_or_scope="discriminator", reuse=tf.AUTO_REUSE):
        # Only need the first fromRGB conv layer and block for base network.
        from_rgb_conv_layer = from_rgb_conv_layers[0]
        block_layers = blocks[0]

        # Pass inputs through layer chain.
        block_conv = from_rgb_conv_layer(inputs=X)
        print_obj(
            "\ncreate_base_discriminator_network",
            "block_conv",
            block_conv
        )

        for i in range(len(block_layers)):
            block_conv = block_layers[i](inputs=block_conv)
            print_obj(
                "create_base_discriminator_network", "block_conv", block_conv
            )

    return block_conv


def create_growth_transition_discriminator_network(
        X,
        from_rgb_conv_layers,
        blocks,
        transition_downsample_layers,
        alpha_var,
        params,
        trans_idx):
    """Creates base discriminator network.

    Args:
        X: tensor, input image to discriminator.
        from_rgb_conv_layers: list, fromRGB 1x1 conv layers.
        blocks: list, lists of block layers for each block.
        transition_downsample_layers: list, downsample layers for transition.
        alpha_var: variable, alpha for weighted sum of fade-in of layers.
        params: dict, user passed parameters.
        trans_idx: int, index of current growth transition.

    Returns:
        Last block's last conv layer's tensor.
    """
    with tf.variable_scope(name_or_scope="discriminator", reuse=tf.AUTO_REUSE):
        # Growing side chain.
        growing_from_rgb_conv_layer = from_rgb_conv_layers[trans_idx + 1]
        growing_block_layers = blocks[trans_idx + 1]

        # Pass inputs through layer chain.
        growing_block_conv = growing_from_rgb_conv_layer(inputs=X)
        print_obj(
            "\ncreate_base_discriminator_network",
            "growing_block_conv",
            growing_block_conv
        )
        for i in range(len(growing_block_layers)):
            growing_block_conv = growing_block_layers[i](
                inputs=growing_block_conv
            )
            print_obj(
                "create_base_discriminator_network",
                "growing_block_conv",
                growing_block_conv
            )

        # Shrinking side chain.
        transition_downsample_layer = transition_downsample_layers[trans_idx]
        shrinking_from_rgb_conv_layer = from_rgb_conv_layers[trans_idx]

        # Pass inputs through layer chain.
        transition_downsample = transition_downsample_layer(inputs=X)
        print_obj(
            "create_base_discriminator_network",
            "transition_downsample",
            transition_downsample
        )
        shrinking_from_rgb_conv = shrinking_from_rgb_conv_layer(
            inputs=transition_downsample
        )
        print_obj(
            "create_base_discriminator_network",
            "shrinking_from_rgb_conv",
            shrinking_from_rgb_conv
        )

        # Weighted sum.
        weighted_sum = tf.add(
            x=growing_block_conv * alpha_var,
            y=shrinking_from_rgb_conv * (1.0 - alpha_var),
            name="growth_transition_weighted_sum_{}".format(trans_idx)
        )
        print_obj(
            "create_base_discriminator_network",
            "weighted_sum",
            weighted_sum
        )

        # Permanent blocks.
        permanent_blocks = blocks[0:trans_idx + 1]

        # Reverse order of blocks and flatten.
        permanent_block_layers = [
            item for sublist in permanent_blocks[::-1] for item in sublist
        ]

        # Pass inputs through layer chain.
        block_conv = weighted_sum
        for i in range(len(permanent_block_layers)):
            block_conv = permanent_block_layers[i](inputs=block_conv)
            print_obj(
                "create_growth_transition_discriminator_network",
                "block_conv",
                block_conv
            )

    return block_conv


def create_final_discriminator_network(
        X, from_rgb_conv_layers, blocks, params):
    """Creates base discriminator network.

    Args:
        X: tensor, input image to discriminator.
        from_rgb_conv_layers: list, fromRGB 1x1 conv layers.
        blocks: list, lists of block layers for each block.
        params: dict, user passed parameters.

    Returns:
        Last block's last conv layer's tensor.
    """
    with tf.variable_scope(name_or_scope="discriminator", reuse=tf.AUTO_REUSE):
        # Only need the last fromRGB conv layer.
        from_rgb_conv_layer = from_rgb_conv_layers[-1]

        # Reverse order of blocks and flatten.
        block_layers = [item for sublist in blocks[::-1] for item in sublist]

        # Pass inputs through layer chain.
        block_conv = from_rgb_conv_layer(inputs=X)
        print_obj(
            "\ncreate_final_discriminator_network",
            "block_conv",
            block_conv
        )

        for i in range(len(block_layers)):
            block_conv = block_layers[i](inputs=block_conv)
            print_obj(
                "create_final_discriminator_network", "block_conv", block_conv
            )

    return block_conv


def discriminator_logits(block_conv, regularizer):
    """Finds logits from discriminator's last conv layer.

    Args:
        block_conv: tensor, output of last conv layer of discriminator.
        regularizer: `l1_l2_regularizer` object, regularizar for kernel
            variables.

    Returns:
        Final logits tensor of discriminator.
    """
    with tf.variable_scope(name_or_scope="discriminator", reuse=tf.AUTO_REUSE):
        # Flatten final block conv tensor.
        block_conv_flat = tf.layers.Flatten()(inputs=block_conv)
        print_obj(
            "discriminator_network",
            "block_conv_flat",
            block_conv_flat
        )

        # Final linear layer for logits.
        logits = tf.layers.Dense(
            units=1,
            activation=None,
#             kernel_regularizer=regularizer,
            name="layers_dense_logits"
        )(inputs=block_conv_flat)
        print_obj(
            "create_growth_transition_discriminator_network", "logits", logits
        )

    return logits


def discriminator_network(X, alpha_var, params):
    """Creates discriminator network and returns logits.

    Args:
        X: tensor, image tensors of shape
            [cur_batch_size, height, width, depth].
        alpha_var: variable, alpha for weighted sum of fade-in of layers.
        params: dict, user passed parameters.

    Returns:
        Logits tensor of shape [cur_batch_size, 1].
    """
    print_obj("\ndiscriminator_network", "X", X)

    # Create regularizer for layer kernel weights.
    regularizer = tf.contrib.layers.l1_l2_regularizer(
        scale_l1=params["discriminator_l1_regularization_scale"],
        scale_l2=params["discriminator_l2_regularization_scale"]
    )

    # Create list of fromRGB 1x1 conv layers.
    from_rgb_conv_layers = create_discriminator_from_rgb_layers(
        regularizer, params
    )
    print_obj(
        "discriminator_network",
        "from_rgb_conv_layers",
        from_rgb_conv_layers
    )

    # Create empty list to hold discriminator convolutional layer blocks.
    blocks = []

    # Create base convolutional layers, for post-growth.
    blocks.append(
        create_discriminator_base_conv_layer_block(regularizer, params)
    )

    # Create growth layer blocks.
    for block_idx in range(len(params["discriminator_growth_conv_blocks"])):
        blocks.append(
            create_discriminator_growth_layer_block(
                block_idx, regularizer, params
            )
        )
    print_obj("discriminator_network", "blocks", blocks)

    # Create list of transition downsample layers.
    transition_downsample_layers = (
        create_discriminator_growth_transition_downsample_layers(params)
    )
    print_obj(
        "discriminator_network",
        "transition_downsample_layers",
        transition_downsample_layers
    )

    # Get final convolutional block's final layer output.
    if (params["train_steps"] // params["num_steps_until_growth"] <= 0 or
       len(params["conv_num_filters"]) == 1):
        print("\ndiscriminator_network: NEVER GOING TO GROW, SKIP SWITCH CASE")
        # If we never are going to grow, no sense using the switch case.
        # 4x4
        block_conv = create_base_discriminator_network(
            X, from_rgb_conv_layers, blocks, params
        )
    else:
        # Find growth index based on global step and growth frequency.
        growth_index = tf.cast(
            x=tf.floordiv(
                x=tf.train.get_or_create_global_step(),
                y=params["num_steps_until_growth"]
            ),
            dtype=tf.int32,
            name="discriminator_growth_index"
        )

        # Switch to case based on number of steps for network creation.
        block_conv = tf.switch_case(
            branch_index=growth_index,
            branch_fns=[
                # 4x4
                lambda: create_base_discriminator_network(
                    X, from_rgb_conv_layers, blocks, params
                ),
                # 8x8
                lambda: create_growth_transition_discriminator_network(
                    X,
                    from_rgb_conv_layers,
                    blocks,
                    transition_downsample_layers,
                    alpha_var,
                    params,
                    0
                ),
                # 16x16
                lambda: create_growth_transition_discriminator_network(
                    X,
                    from_rgb_conv_layers,
                    blocks,
                    transition_downsample_layers,
                    alpha_var,
                    params,
                    1
                ),
                # 32x32
                lambda: create_growth_transition_discriminator_network(
                    X,
                    from_rgb_conv_layers,
                    blocks,
                    transition_downsample_layers,
                    alpha_var,
                    params,
                    2
                ),
                # 64x64
                lambda: create_growth_transition_discriminator_network(
                    X,
                    from_rgb_conv_layers,
                    blocks,
                    transition_downsample_layers,
                    alpha_var,
                    params,
                    3
                ),
                # 128x128
                lambda: create_growth_transition_discriminator_network(
                    X,
                    from_rgb_conv_layers,
                    blocks,
                    transition_downsample_layers,
                    alpha_var,
                    params,
                    4
                ),
                # 256x256
                lambda: create_growth_transition_discriminator_network(
                    X,
                    from_rgb_conv_layers,
                    blocks,
                    transition_downsample_layers,
                    alpha_var,
                    params,
                    5
                ),
                # 512x512
                lambda: create_growth_transition_discriminator_network(
                    X,
                    from_rgb_conv_layers,
                    blocks,
                    transition_downsample_layers,
                    alpha_var,
                    params,
                    6
                ),
                # 1024x1024
                lambda: create_growth_transition_discriminator_network(
                    X,
                    from_rgb_conv_layers,
                    blocks,
                    transition_downsample_layers,
                    alpha_var,
                    params,
                    7
                ),
                # 1024x1024
                lambda: create_final_discriminator_network(
                    X, from_rgb_conv_layers, blocks, params
                )
            ],
            name="discriminator_switch_case_block_conv"
        )

    # Set shape to remove ambiguity for dense layer.
    block_conv.set_shape(
        [
            block_conv.get_shape()[0],
            params["generator_projection_dims"][0] / 4,
            params["generator_projection_dims"][1] / 4,
            block_conv.get_shape()[-1]]
    )
    print_obj(
        "discriminator_network",
        "block_conv",
        block_conv
    )

    # Get final logits.
    logits = discriminator_logits(block_conv, regularizer)

    return logits


def get_discriminator_loss(generated_logits, real_logits, params):
    """Gets discriminator loss.

    Args:
        generated_logits: tensor, shape of
            [cur_batch_size, height * width * depth].
        real_logits: tensor, shape of
            [cur_batch_size, height * width * depth].
        params: dict, user passed parameters.

    Returns:
        Tensor of discriminator's total loss of shape [].
    """
    # Calculate base discriminator loss.
    discriminator_real_loss = tf.reduce_mean(
        input_tensor=real_logits,
        name="discriminator_real_loss"
    )
    print_obj(
        "\nget_discriminator_loss",
        "discriminator_real_loss",
        discriminator_real_loss
    )

    discriminator_generated_loss = tf.reduce_mean(
        input_tensor=generated_logits,
        name="discriminator_generated_loss"
    )
    print_obj(
        "get_discriminator_loss",
        "discriminator_generated_loss",
        discriminator_generated_loss
    )

    discriminator_loss = tf.add(
        x=discriminator_real_loss, y=-discriminator_generated_loss,
        name="discriminator_loss"
    )
    print_obj(
        "get_discriminator_loss",
        "discriminator_loss",
        discriminator_loss
    )

    # Get discriminator gradient penalty.
    discriminator_gradients = tf.gradients(
        ys=discriminator_loss,
        xs=tf.trainable_variables(scope="discriminator"),
        name="discriminator_gradients_for_penalty"
    )

    discriminator_gradient_penalty = tf.square(
        x=tf.multiply(
            x=params["discriminator_gradient_penalty_coefficient"],
            y=tf.linalg.global_norm(
                t_list=discriminator_gradients,
                name="discriminator_gradients_global_norm"
            ) - 1.0
        ),
        name="discriminator_gradient_penalty"
    )

    discriminator_wasserstein_gp_loss = tf.add(
        x=discriminator_loss,
        y=discriminator_gradient_penalty,
        name="discriminator_wasserstein_gp_loss"
    )

    # Get regularization losses.
    discriminator_regularization_loss = tf.losses.get_regularization_loss(
        scope="discriminator",
        name="discriminator_regularization_loss"
    )
    print_obj(
        "get_discriminator_loss",
        "discriminator_regularization_loss",
        discriminator_regularization_loss
    )

    # Combine losses for total losses.
    discriminator_total_loss = tf.math.add(
        x=discriminator_wasserstein_gp_loss,
        y=discriminator_regularization_loss,
        name="discriminator_total_loss"
    )
    print_obj(
        "get_discriminator_loss",
        "discriminator_total_loss",
        discriminator_total_loss
    )

    return discriminator_total_loss


Overwriting pgan_module/trainer/discriminator.py


## pgan.py

In [5]:
%%writefile pgan_module/trainer/pgan.py
import tensorflow as tf

from . import discriminator
from . import generator
from .print_object import print_obj


def train_network(loss, global_step, alpha_var, params, scope):
    """Trains network and returns loss and train op.

    Args:
        loss: tensor, shape of [].
        global_step: tensor, the current training step or batch in the
            training loop.
        alpha_var: variable, alpha for weighted sum of fade-in of layers.
        params: dict, user passed parameters.
        scope: str, the variables that to train.

    Returns:
        Loss tensor and training op.
    """
    # Create optimizer map.
    optimizers = {
        "Adam": tf.train.AdamOptimizer,
        "Adadelta": tf.train.AdadeltaOptimizer,
        "AdagradDA": tf.train.AdagradDAOptimizer,
        "Adagrad": tf.train.AdagradOptimizer,
        "Ftrl": tf.train.FtrlOptimizer,
        "GradientDescent": tf.train.GradientDescentOptimizer,
        "Momentum": tf.train.MomentumOptimizer,
        "ProximalAdagrad": tf.train.ProximalAdagradOptimizer,
        "ProximalGradientDescent": tf.train.ProximalGradientDescentOptimizer,
        "RMSProp": tf.train.RMSPropOptimizer
    }

    # Get gradients.
    gradients = tf.gradients(
        ys=loss,
        xs=tf.trainable_variables(scope=scope),
        name="{}_gradients".format(scope)
    )

    # Clip gradients.
    if params["{}_clip_gradients".format(scope)]:
        gradients, _ = tf.clip_by_global_norm(
            t_list=gradients,
            clip_norm=params["{}_clip_gradients".format(scope)],
            name="{}_clip_by_global_norm_gradients".format(scope)
        )

    # Zip back together gradients and variables.
    grads_and_vars = zip(gradients, tf.trainable_variables(scope=scope))

    # Get optimizer and instantiate it.
    optimizer = optimizers[params["{}_optimizer".format(scope)]](
        learning_rate=params["{}_learning_rate".format(scope)]
    )

    # Create train op by applying gradients to variables and incrementing
    # global step.
    train_op = optimizer.apply_gradients(
        grads_and_vars=grads_and_vars,
        global_step=global_step,
        name="{}_apply_gradients".format(scope)
    )

    # Update alpha variable to linearly scale from 0 to 1 based on steps.
    alpha_var_update_op = tf.assign(
        ref=alpha_var,
        value=tf.divide(
            x=tf.cast(
                x=tf.mod(x=global_step, y=params["num_steps_until_growth"]),
                dtype=tf.float32
            ),
            y=params["num_steps_until_growth"]
        )
    )

    # Ensure alpha variable gets updated.
    with tf.control_dependencies(control_inputs=[alpha_var_update_op]):
        return loss, train_op


def resize_real_image(block_idx, image, params):
    """Resizes real images to match the GAN's current size.

    Args:
        block_idx: int, index of current block.
        image: tensor, original image.
        params: dict, user passed parameters.

    Returns:
        Resized image tensor.
    """
    print_obj("\nresize_real_image", "block_idx", block_idx)
    print_obj("resize_real_image", "image", image)

    # Resize image to match GAN size at current block index.
    resized_image = tf.image.resize(
        images=image,
        size=[
            params["generator_projection_dims"][0] * (2 ** block_idx),
            params["generator_projection_dims"][1] * (2 ** block_idx)
        ],
        method="nearest",
        name="resize_real_images_resized_image_{}".format(block_idx)
    )
    print_obj("resize_real_images", "resized_image", resized_image)

    return resized_image


def resize_real_images(image, params):
    """Resizes real images to match the GAN's current size.

    Args:
        image: tensor, original image.
        params: dict, user passed parameters.

    Returns:
        Resized image tensor.
    """
    print_obj("\nresize_real_images", "image", image)
    # Resize real image for each block.
    if (params["train_steps"] // params["num_steps_until_growth"] <= 0 or
       len(params["conv_num_filters"]) == 1):
        # If we never are going to grow, no sense using the switch case.
        # 4x4
        resized_image = resize_real_image(0, image, params)
        print_obj(
            "resize_real_images", "slipped resized_image", resized_image
        )
    else:
        # Find growth index based on global step and growth frequency.
        growth_index = tf.cast(
            x=tf.floordiv(
                x=tf.train.get_or_create_global_step(),
                y=params["num_steps_until_growth"]
            ),
            dtype=tf.int32,
            name="resize_real_images_growth_index"
        )

        # Switch to case based on number of steps for resized image.
        resized_image = tf.switch_case(
            branch_index=growth_index,
            branch_fns=[
                lambda: resize_real_image(0, image, params),  # 4x4
                lambda: resize_real_image(1, image, params),  # 8x8
                lambda: resize_real_image(2, image, params),  # 16x16
                lambda: resize_real_image(3, image, params),  # 32x32
                lambda: resize_real_image(4, image, params),  # 64x64
                lambda: resize_real_image(5, image, params),  # 128x128
                lambda: resize_real_image(6, image, params),  # 256x256
                lambda: resize_real_image(7, image, params),  # 512x512
                lambda: resize_real_image(8, image, params),  # 1024x1024
            ],
            name="resize_real_images_switch_case_resized_image"
        )
        print_obj(
            "resize_real_images", "selected resized_image", resized_image
        )

    return resized_image


def pgan_model(features, labels, mode, params):
    """Progressively Growing GAN custom Estimator model function.

    Args:
        features: dict, keys are feature names and values are feature tensors.
        labels: tensor, label data.
        mode: tf.estimator.ModeKeys with values of either TRAIN, EVAL, or
            PREDICT.
        params: dict, user passed parameters.

    Returns:
        Instance of `tf.estimator.EstimatorSpec` class.
    """
    print_obj("\npgan_model", "features", features)
    print_obj("pgan_model", "labels", labels)
    print_obj("pgan_model", "mode", mode)
    print_obj("pgan_model", "params", params)

    # Loss function, training/eval ops, etc.
    predictions_dict = None
    loss = None
    train_op = None
    eval_metric_ops = None
    export_outputs = None

    # Create alpha variable to use for weighted sum for smooth fade-in.
    alpha_var = tf.get_variable(
        name="alpha_var",
        dtype=tf.float32,
        initializer=tf.zeros(shape=[], dtype=tf.float32),
        trainable=False
    )
    print_obj("pgan_model", "alpha_var", alpha_var)

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Extract given latent vectors from features dictionary.
        Z = tf.cast(x=features["Z"], dtype=tf.float32)

        # Get predictions from generator.
        generated_images = generator.generator_network(Z, alpha_var, params)

        # Create predictions dictionary.
        predictions_dict = {
            "generated_images": generated_images
        }

        # Create export outputs.
        export_outputs = {
            "predict_export_outputs": tf.estimator.export.PredictOutput(
                outputs=predictions_dict)
        }
    else:
        # Extract image from features dictionary.
        X = features["image"]

        # Get dynamic batch size in case of partial batch.
        cur_batch_size = tf.shape(
            input=X,
            out_type=tf.int32,
            name="pgan_model_cur_batch_size"
        )[0]

        # Create random noise latent vector for each batch example.
        Z = tf.random.normal(
            shape=[cur_batch_size, params["latent_size"]],
            mean=0.0,
            stddev=1.0,
            dtype=tf.float32
        )

        # Establish generator network subgraph with gaussian noise.
        print("\nCall generator with Z = {}.".format(Z))
        generator_outputs = generator.generator_network(Z, alpha_var, params)

        # Resize real images based on the current size of the GAN.
        real_image = resize_real_images(X, params)

        # Establish discriminator network subgraph with real data.
        print("\nCall discriminator with real_image = {}.".format(
            real_image
        ))

        real_logits = discriminator.discriminator_network(
            real_image, alpha_var, params
        )

        # Get generated logits too.
        print("\nCall discriminator with generator_outputs = {}.".format(
            generator_outputs
        ))

        generated_logits = discriminator.discriminator_network(
            generator_outputs, alpha_var, params
        )

        # Get generator total loss.
        generator_total_loss = generator.get_generator_loss(generated_logits)

        # Get discriminator total loss.
        discriminator_total_loss = discriminator.get_discriminator_loss(
            generated_logits, real_logits, params
        )

        if mode == tf.estimator.ModeKeys.TRAIN:
            # Get global step.
            global_step = tf.train.get_or_create_global_step()

            # Determine if it is time to train generator or discriminator.
            cycle_step = tf.mod(
                x=global_step,
                y=tf.cast(
                    x=tf.add(
                        x=params["generator_train_steps"],
                        y=params["discriminator_train_steps"]
                    ),
                    dtype=tf.int64
                )
            )

            # Create choose generator condition.
            condition = tf.less(
                x=cycle_step, y=params["generator_train_steps"]
            )

            # Needed for batch normalization, but has no effect otherwise.
            update_ops = tf.get_collection(key=tf.GraphKeys.UPDATE_OPS)

            with tf.control_dependencies(control_inputs=update_ops):
                # Conditionally choose to train generator or discriminator.
                loss, train_op = tf.cond(
                    pred=condition,
                    true_fn=lambda: train_network(
                        loss=generator_total_loss,
                        global_step=global_step,
                        alpha_var=alpha_var,
                        params=params,
                        scope="generator"
                    ),
                    false_fn=lambda: train_network(
                        loss=discriminator_total_loss,
                        global_step=global_step,
                        alpha_var=alpha_var,
                        params=params,
                        scope="discriminator"
                    )
                )
        else:
            loss = discriminator_total_loss

            # Concatenate discriminator logits and labels.
            discriminator_logits = tf.concat(
                values=[real_logits, generated_logits],
                axis=0,
                name="discriminator_concat_logits"
            )

            discriminator_labels = tf.concat(
                values=[
                    tf.ones_like(tensor=real_logits),
                    tf.zeros_like(tensor=generated_logits)
                ],
                axis=0,
                name="discriminator_concat_labels"
            )

            # Calculate discriminator probabilities.
            discriminator_probabilities = tf.nn.sigmoid(
                x=discriminator_logits, name="discriminator_probabilities"
            )

            # Create eval metric ops dictionary.
            eval_metric_ops = {
                "accuracy": tf.metrics.accuracy(
                    labels=discriminator_labels,
                    predictions=discriminator_probabilities,
                    name="pgan_model_accuracy"
                ),
                "precision": tf.metrics.precision(
                    labels=discriminator_labels,
                    predictions=discriminator_probabilities,
                    name="pgan_model_precision"
                ),
                "recall": tf.metrics.recall(
                    labels=discriminator_labels,
                    predictions=discriminator_probabilities,
                    name="pgan_model_recall"
                ),
                "auc_roc": tf.metrics.auc(
                    labels=discriminator_labels,
                    predictions=discriminator_probabilities,
                    num_thresholds=200,
                    curve="ROC",
                    name="pgan_model_auc_roc"
                ),
                "auc_pr": tf.metrics.auc(
                    labels=discriminator_labels,
                    predictions=discriminator_probabilities,
                    num_thresholds=200,
                    curve="PR",
                    name="pgan_model_auc_pr"
                )
            }

    # Return EstimatorSpec
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions_dict,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        export_outputs=export_outputs
    )


Overwriting pgan_module/trainer/pgan.py


## serving.py

In [6]:
%%writefile pgan_module/trainer/serving.py
import tensorflow as tf

from .print_object import print_obj


def serving_input_fn(params):
    """Serving input function.

    Args:
        params: dict, user passed parameters.

    Returns:
        ServingInputReceiver object containing features and receiver tensors.
    """
    # Create placeholders to accept data sent to the model at serving time.
    # shape = (batch_size,)
    feature_placeholders = {
        "Z": tf.placeholder(
            dtype=tf.float32,
            shape=[None, params["latent_size"]],
            name="serving_input_placeholder_Z"
        )
    }

    print_obj(
        "serving_input_fn",
        "feature_placeholders",
        feature_placeholders
    )

    # Create clones of the feature placeholder tensors so that the SavedModel
    # SignatureDef will point to the placeholder.
    features = {
        key: tf.identity(
            input=value,
            name="serving_input_fn_identity_placeholder_{}".format(key)
        )
        for key, value in feature_placeholders.items()
    }

    print_obj(
        "serving_input_fn",
        "features",
        features
    )

    return tf.estimator.export.ServingInputReceiver(
        features=features, receiver_tensors=feature_placeholders
    )


Overwriting pgan_module/trainer/serving.py


## model.py

In [7]:
%%writefile pgan_module/trainer/model.py
import tensorflow as tf

from . import input
from . import serving
from . import pgan


def train_and_evaluate(args):
    """Trains and evaluates custom Estimator model.

    Args:
        args: dict, user passed parameters.

    Returns:
        `Estimator` object.
    """
    # Set logging to be level of INFO.
    tf.logging.set_verbosity(tf.logging.INFO)

    # Create our custom estimator using our model function.
    estimator = tf.estimator.Estimator(
        model_fn=pgan.pgan_model,
        model_dir=args["output_dir"],
        params=args
    )

    # Create train spec to read in our training data.
    train_spec = tf.estimator.TrainSpec(
        input_fn=input.read_dataset(
            filename=args["train_file_pattern"],
            mode=tf.estimator.ModeKeys.TRAIN,
            batch_size=args["train_batch_size"],
            params=args
        ),
        max_steps=args["train_steps"]
    )

    # Create exporter to save out the complete model to disk.
    exporter = tf.estimator.LatestExporter(
        name="exporter",
        serving_input_receiver_fn=lambda: serving.serving_input_fn(args)
    )

    # Create eval spec to read in our validation data and export our model.
    eval_spec = tf.estimator.EvalSpec(
        input_fn=input.read_dataset(
            filename=args["eval_file_pattern"],
            mode=tf.estimator.ModeKeys.EVAL,
            batch_size=args["eval_batch_size"],
            params=args
        ),
        steps=args["eval_steps"],
        start_delay_secs=args["start_delay_secs"],
        throttle_secs=args["throttle_secs"],
        exporters=exporter
    )

    # Create train and evaluate loop to train and evaluate our estimator.
    tf.estimator.train_and_evaluate(
        estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)


Overwriting pgan_module/trainer/model.py


## task.py

In [8]:
%%writefile pgan_module/trainer/task.py
import argparse
import json
import os

from . import model


def calc_generator_discriminator_conv_layer_properties(
        conv_num_filters, conv_kernel_sizes, conv_strides, depth):
    """Calculates generator and discriminator conv layer properties.

    Args:
        num_filters: list, nested list of ints of the number of filters
            for each conv layer.
        kernel_sizes: list, nested list of ints of the kernel sizes for
            each conv layer.
        strides: list, nested list of ints of the strides for each conv
            layer.
        depth: int, depth dimension of images.

    Returns:
        Nested lists of conv layer properties for both generator and
            discriminator.
    """
    def make_generator(num_filters, kernel_sizes, strides, depth):
        """Calculates generator conv layer properties.

        Args:
            num_filters: list, nested list of ints of the number of filters
                for each conv layer.
            kernel_sizes: list, nested list of ints of the kernel sizes for
                each conv layer.
            strides: list, nested list of ints of the strides for each conv
                layer.
            depth: int, depth dimension of images.

        Returns:
            Nested list of conv layer properties for generator.
        """
        # Get the number of growths.
        num_growths = len(num_filters) - 1

        # Make base block.
        in_out = num_filters[0]
        base = [
            [kernel_sizes[0][i]] * 2 + in_out + [strides[0][i]] * 2
            for i in range(len(num_filters[0]))
        ]
        blocks = [base]

        # Add growth blocks.
        for i in range(1, num_growths + 1):
            in_out = [[blocks[i - 1][-1][-3], num_filters[i][0]]]
            block = [[kernel_sizes[i][0]] * 2 + in_out[0] + [strides[i][0]] * 2]
            for j in range(1, len(num_filters[i])):
                in_out.append([block[-1][-3], num_filters[i][j]])
                block.append(
                    [kernel_sizes[i][j]] * 2 + in_out[j] + [strides[i][j]] * 2
                )
            blocks.append(block)

        # Add toRBG conv.
        blocks[-1].append([1, 1, blocks[-1][-1][-3], depth] + [1] * 2)

        return blocks

    def make_discriminator(generator):
        """Calculates discriminator conv layer properties.

        Args:
            generator: list, nested list of conv layer properties for
                generator.

        Returns:
            Nested list of conv layer properties for discriminator.
        """
        # Reverse generator.
        discriminator = generator[::-1]

        # Reverse input and output shapes.
        discriminator = [
            [
                conv[0:2] + conv[2:4][::-1] + conv[-2:]
                for conv in block[::-1]
            ]
            for block in discriminator
        ]

        return discriminator

    # Calculate conv layer properties for generator using args.
    generator = make_generator(
        conv_num_filters, conv_kernel_sizes, conv_strides, depth
    )

    # Calculate conv layer properties for discriminator using generator
    # properties.
    discriminator = make_discriminator(generator)

    return generator, discriminator


def split_up_generator_conv_layer_properties(
        generator, num_filters, strides, depth):
    """Splits up generator conv layer properties into lists.

    Args:
        generator: list, nested list of conv layer properties for
            generator.
        num_filters: list, nested list of ints of the number of filters
            for each conv layer.
        strides: list, nested list of ints of the strides for each conv
            layer.
        depth: int, depth dimension of images.

    Returns:
        Nested lists of conv layer properties for generator.
    """
    generator_base_conv_blocks = [generator[0]]

    generator_growth_conv_blocks = generator[1:-1] + [generator[-1][:-1]]

    generator_to_rgb_layers = [
        [[1] * 2 + [num_filters[i][0]] + [depth] + [strides[i][0]] * 2]
        for i in range(len(num_filters))
    ]

    return (generator_base_conv_blocks,
            generator_growth_conv_blocks,
            generator_to_rgb_layers)


def split_up_discriminator_conv_layer_properties(
        discriminator, num_filters, strides, depth):
    """Splits up discriminator conv layer properties into lists.

    Args:
        discriminator: list, nested list of conv layer properties for
            discriminator.
        num_filters: list, nested list of ints of the number of filters
            for each conv layer.
        strides: list, nested list of ints of the strides for each conv
            layer.
        depth: int, depth dimension of images.

    Returns:
        Nested lists of conv layer properties for discriminator.
    """
    discriminator_from_rgb_layers = [
        [[1] * 2 + [depth] + [num_filters[i][0]] + [strides[i][0]] * 2]
        for i in range(len(num_filters))
    ]

    discriminator_base_conv_blocks = [discriminator[-1]]

    discriminator_growth_conv_blocks = [discriminator[0][1:]] + discriminator[1:-1]
    discriminator_growth_conv_blocks = discriminator_growth_conv_blocks[::-1]

    return (discriminator_from_rgb_layers,
            discriminator_base_conv_blocks,
            discriminator_growth_conv_blocks)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # File arguments.
    parser.add_argument(
        "--train_file_pattern",
        help="GCS location to read training data.",
        required=True
    )
    parser.add_argument(
        "--eval_file_pattern",
        help="GCS location to read evaluation data.",
        required=True
    )
    parser.add_argument(
        "--output_dir",
        help="GCS location to write checkpoints and export models.",
        required=True
    )
    parser.add_argument(
        "--job-dir",
        help="This model ignores this field, but it is required by gcloud.",
        default="junk"
    )

    # Training parameters.
    parser.add_argument(
        "--train_batch_size",
        help="Number of examples in training batch.",
        type=int,
        default=32
    )
    parser.add_argument(
        "--train_steps",
        help="Number of steps to train for.",
        type=int,
        default=100
    )

    # Eval parameters.
    parser.add_argument(
        "--eval_batch_size",
        help="Number of examples in evaluation batch.",
        type=int,
        default=32
    )
    parser.add_argument(
        "--eval_steps",
        help="Number of steps to evaluate for.",
        type=str,
        default="None"
    )
    parser.add_argument(
        "--start_delay_secs",
        help="Number of seconds to wait before first evaluation.",
        type=int,
        default=60
    )
    parser.add_argument(
        "--throttle_secs",
        help="Number of seconds to wait between evaluations.",
        type=int,
        default=120
    )

    # Image parameters.
    parser.add_argument(
        "--height",
        help="Height of image.",
        type=int,
        default=32
    )
    parser.add_argument(
        "--width",
        help="Width of image.",
        type=int,
        default=32
    )
    parser.add_argument(
        "--depth",
        help="Depth of image.",
        type=int,
        default=3
    )

    # Shared parameters.
    parser.add_argument(
        "--num_steps_until_growth",
        help="Number of steps until layer added to generator & discriminator.",
        type=int,
        default=100
    )
    parser.add_argument(
        "--conv_num_filters",
        help="Number of filters for growth conv layers.",
        type=str,
        default="512,512;512,512"
    )
    parser.add_argument(
        "--conv_kernel_sizes",
        help="Kernel sizes for growth conv layers.",
        type=str,
        default="3,3;3,3"
    )
    parser.add_argument(
        "--conv_strides",
        help="Strides for growth conv layers.",
        type=str,
        default="1,1;1,1"
    )

    # Generator parameters.
    parser.add_argument(
        "--latent_size",
        help="The latent size of the noise vector.",
        type=int,
        default=3
    )
    parser.add_argument(
        "--generator_projection_dims",
        help="The 3D dimensions to project latent noise vector into.",
        type=str,
        default="8,8,256"
    )
    parser.add_argument(
        "--generator_l1_regularization_scale",
        help="Scale factor for L1 regularization for generator.",
        type=float,
        default=0.0
    )
    parser.add_argument(
        "--generator_l2_regularization_scale",
        help="Scale factor for L2 regularization for generator.",
        type=float,
        default=0.0
    )
    parser.add_argument(
        "--generator_optimizer",
        help="Name of optimizer to use for generator.",
        type=str,
        default="Adam"
    )
    parser.add_argument(
        "--generator_learning_rate",
        help="How quickly we train our model by scaling the gradient for generator.",
        type=float,
        default=0.1
    )
    parser.add_argument(
        "--generator_clip_gradients",
        help="Global clipping to prevent gradient norm to exceed this value for generator.",
        type=str,
        default="None"
    )
    parser.add_argument(
        "--generator_train_steps",
        help="Number of steps to train generator for per cycle.",
        type=int,
        default=100
    )

    # Discriminator parameters.
    parser.add_argument(
        "--discriminator_l1_regularization_scale",
        help="Scale factor for L1 regularization for discriminator.",
        type=float,
        default=0.0
    )
    parser.add_argument(
        "--discriminator_l2_regularization_scale",
        help="Scale factor for L2 regularization for discriminator.",
        type=float,
        default=0.0
    )
    parser.add_argument(
        "--discriminator_optimizer",
        help="Name of optimizer to use for discriminator.",
        type=str,
        default="Adam"
    )
    parser.add_argument(
        "--discriminator_learning_rate",
        help="How quickly we train our model by scaling the gradient for discriminator.",
        type=float,
        default=0.1
    )
    parser.add_argument(
        "--discriminator_clip_gradients",
        help="Global clipping to prevent gradient norm to exceed this value for discriminator.",
        type=str,
        default="None"
    )
    parser.add_argument(
        "--discriminator_gradient_penalty_coefficient",
        help="Coefficient of gradient penalty for discriminator.",
        type=float,
        default=10.0
    )
    parser.add_argument(
        "--discriminator_train_steps",
        help="Number of steps to train discriminator for per cycle.",
        type=int,
        default=100
    )

    # Parse all arguments.
    args = parser.parse_args()
    arguments = args.__dict__

    # Unused args provided by service.
    arguments.pop("job_dir", None)
    arguments.pop("job-dir", None)

    # Fix eval steps.
    if arguments["eval_steps"] == "None":
        arguments["eval_steps"] = None
    else:
        arguments["eval_steps"] = int(arguments["eval_steps"])

    # Fix generator_projection_dims.
    arguments["generator_projection_dims"] = [
        int(x)
        for x in arguments["generator_projection_dims"].split(",")
    ]

    # Fix conv layer property parameters.
    arguments["conv_num_filters"] = [
        [int(y) for y in x.split(",")]
        for x in arguments["conv_num_filters"].split(";")
    ]

    arguments["conv_kernel_sizes"] = [
        [int(y) for y in x.split(",")]
        for x in arguments["conv_kernel_sizes"].split(";")
    ]

    arguments["conv_strides"] = [
        [int(y) for y in x.split(",")]
        for x in arguments["conv_strides"].split(";")
    ]

    # Make some assertions.
    assert len(arguments["conv_num_filters"]) > 0
    assert len(arguments["conv_num_filters"]) == len(arguments["conv_kernel_sizes"])
    assert len(arguments["conv_num_filters"]) == len(arguments["conv_strides"])

    # Truncate lists if over the 1024x1024 current limit.
    if len(arguments["conv_num_filters"]) > 9:
        arguments["conv_num_filters"] = arguments["conv_num_filters"][0:10]
        arguments["conv_kernel_sizes"] = arguments["conv_kernel_sizes"][0:10]
        arguments["conv_strides"] = arguments["conv_strides"][0:10]

    # Get conv layer properties for generator and discriminator.
    (generator,
     discriminator) = calc_generator_discriminator_conv_layer_properties(
        arguments["conv_num_filters"],
        arguments["conv_kernel_sizes"],
        arguments["conv_strides"],
        arguments["depth"]
    )

    # Split up generator properties into separate lists.
    (generator_base_conv_blocks,
     generator_growth_conv_blocks,
     generator_to_rgb_layers) = split_up_generator_conv_layer_properties(
        generator,
        arguments["conv_num_filters"],
        arguments["conv_strides"],
        arguments["depth"]
    )
    arguments["generator_base_conv_blocks"] = generator_base_conv_blocks
    arguments["generator_growth_conv_blocks"] = generator_growth_conv_blocks
    arguments["generator_to_rgb_layers"] = generator_to_rgb_layers

    # Split up discriminator properties into separate lists.
    (discriminator_from_rgb_layers,
     discriminator_base_conv_blocks,
     discriminator_growth_conv_blocks) = split_up_discriminator_conv_layer_properties(
        discriminator,
        arguments["conv_num_filters"],
        arguments["conv_strides"],
        arguments["depth"]
    )
    arguments["discriminator_from_rgb_layers"] = discriminator_from_rgb_layers
    arguments["discriminator_base_conv_blocks"] = discriminator_base_conv_blocks
    arguments["discriminator_growth_conv_blocks"] = discriminator_growth_conv_blocks

    # Fix clip_gradients.
    if arguments["generator_clip_gradients"] == "None":
        arguments["generator_clip_gradients"] = None
    else:
        arguments["generator_clip_gradients"] = float(
            arguments["generator_clip_gradients"]
        )

    if arguments["discriminator_clip_gradients"] == "None":
        arguments["discriminator_clip_gradients"] = None
    else:
        arguments["discriminator_clip_gradients"] = float(
            arguments["discriminator_clip_gradients"]
        )

    # Append trial_id to path if we are doing hptuning.
    # This code can be removed if you are not using hyperparameter tuning.
    arguments["output_dir"] = os.path.join(
        arguments["output_dir"],
        json.loads(
            os.environ.get(
                "TF_CONFIG", "{}"
            )
        ).get("task", {}).get("trial", ""))

    # Run the training job.
    model.train_and_evaluate(arguments)


Overwriting pgan_module/trainer/task.py
