## resnet.py

In [None]:
%%writefile pca_out_of_core_distributed_module/trainer/resnet.py
import tensorflow as tf


class ResNet(object):
    """Class that contains methods that preprocess images through ResNet.

    Attributes:
        params: dict, user passed parameters.
    """
    def __init__(self, params):
        """Initializes `ResNet` class instance.

        Args:
            params: dict, user passed parameters.
        """
        self.params = params

        self.resnet_model, self.pooling_layer = self.get_resnet_layers(
            input_shape=(
                self.params["image_height"],
                self.params["image_width"],
                self.params["image_depth"]
            )
        )

    def get_resnet_layers(self, input_shape):
        """Gets ResNet layers from ResNet50 model.

        Args:
            input_shape: tuple, input shape of images.
        """
        # Load the ResNet50 model.
        resnet50_model = tf.keras.applications.resnet50.ResNet50(
            include_top=False,
            weights=self.params["resnet_weights"],
            input_shape=input_shape
        )
        resnet50_model.trainable = False

        # Create a new Model based on original resnet50 model ended after the
        # chosen residual block.
        layer_name = self.params["resnet_layer_name"]
        resnet50 = tf.keras.Model(
            inputs=resnet50_model.input,
            outputs=resnet50_model.get_layer(layer_name).output
        )

        # Add adaptive mean-spatial pooling after the new model.
        adaptive_mean_spatial_layer = tf.keras.layers.GlobalAvgPool2D()

        return resnet50, adaptive_mean_spatial_layer

    def preprocess_image_batch(self, images):
        """Preprocesses batch of images.

        Args:
            images: tensor, rank 4 image tensor of shape
                (batch_size, image_height, image_width, image_depth).

        Returns:
            Preprocessed images tensor.
        """
        images = tf.cast(x=images, dtype=tf.float32)

        if self.params["preprocess_input"]:
            images = tf.keras.applications.resnet50.preprocess_input(x=images)

        return images

    def get_image_resnet_feature_vectors(self, images):
        """Gets image ResNet feature vectors.

        Args:
            images: tensor, rank 4 image tensor of shape
                (batch_size, image_height, image_width, image_depth).

        Returns:
            Processed ResNet feature rank 1 tensor for each image.
        """
        preprocessed_images = self.preprocess_image_batch(images=images)
        resnet_feature_image = self.resnet_model(inputs=preprocessed_images)
        resnet_feature_vector = self.pooling_layer(inputs=resnet_feature_image)

        return resnet_feature_vector


## training_inputs.py

In [None]:
%%writefile pca_out_of_core_distributed_module/trainer/training_inputs.py
import tensorflow as tf


def decode_example(protos, params):
    """Decodes TFRecord file into tensors.

    Given protobufs, decode into image and label tensors.

    Args:
        protos: protobufs from TFRecord file.
        params: dict, user passed parameters.

    Returns:
        Image and label tensors.
    """
    dtype_map = {
        "str": tf.string,
        "int": tf.int64,
        "float": tf.float32
    }

    # Create feature schema map for protos.
    tf_example_features = {
        feat["name"]: (
            tf.io.FixedLenFeature(
                shape=feat["shape"], dtype=dtype_map[feat["dtype"]]
            )
            if feat["type"] == "FixedLen"
            else tf.io.FixedLenSequenceFeature(
                shape=feat["shape"], dtype=dtype_map[feat["dtype"]]
            )
        )
        for feat in params["tf_record_example_schema"]
    }

    # Parse features from tf.Example.
    parsed_features = tf.io.parse_single_example(
        serialized=protos, features=tf_example_features
    )

    # Convert from a scalar string tensor (whose single string has
    # length height * width * depth) to a uint8 tensor with shape
    # [height * width * depth].
    if params["image_encoding"] == "raw":
        image = tf.io.decode_raw(
            input_bytes=parsed_features[params["image_feature_name"]],
            out_type=tf.uint8
        )
    elif params["image_encoding"] == "png":
        image = tf.io.decode_png(
            contents=parsed_features[params["image_feature_name"]],
            channels=params["image_depth"]
        )
    elif params["image_encoding"] == "jpeg":
        image = tf.io.decode_jpeg(
            contents=parsed_features[params["image_feature_name"]],
            channels=params["image_depth"]
        )

    # Reshape flattened image back into normal dimensions.
    image = tf.reshape(
        tensor=image,
        shape=[
            params["image_height"],
            params["image_width"],
            params["image_depth"]
        ]
    )

    return image


def read_dataset(file_pattern, batch_size, params):
    """Reads TF Record data using tf.data, doing necessary preprocessing.

    Given filename, mode, batch size, and other parameters, read TF Record
    dataset using Dataset API, apply necessary preprocessing, and return an
    input function to the Estimator API.

    Args:
        file_pattern: str, file pattern that to read into our tf.data dataset.
        batch_size: int, number of examples per batch.
        params: dict, dictionary of user passed parameters.

    Returns:
        An input function.
    """
    def fetch_dataset(filename):
        """Fetches TFRecord Dataset from given filename.

        Args:
            filename: str, name of TFRecord file.

        Returns:
            Dataset containing TFRecord Examples.
        """
        buffer_size = 8 * 1024 * 1024  # 8 MiB per file
        dataset = tf.data.TFRecordDataset(
            filenames=filename, buffer_size=buffer_size
        )

        return dataset

    def _input_fn():
        """Wrapper input function used by Estimator API to get data tensors.

        Returns:
            Batched dataset object of dictionary of feature tensors and label
                tensor.
        """
        # Create dataset to contain list of files matching pattern.
        dataset = tf.data.Dataset.list_files(
            file_pattern=file_pattern, shuffle=False
        )

        # Parallel interleaves multiple files at once with map function.
        dataset = dataset.apply(
            tf.data.experimental.parallel_interleave(
                map_func=fetch_dataset, cycle_length=64, sloppy=True
            )
        )

        # Decode TF Record Example into a features dictionary of tensors.
        dataset = dataset.map(
            map_func=lambda x: decode_example(
                protos=x, params=params
            ),
            num_parallel_calls=(
                tf.contrib.data.AUTOTUNE
                if params["input_fn_autotune"]
                else None
            )
        )

        # Batch dataset and drop remainder so there are no partial batches.
        dataset = dataset.batch(batch_size=batch_size, drop_remainder=False)

        # Prefetch data to improve latency.
        dataset = dataset.prefetch(
            buffer_size=(
                tf.data.experimental.AUTOTUNE
                if params["input_fn_autotune"]
                else 1
            )
        )

        return dataset

    return _input_fn


## covariance.py

In [None]:
%%writefile pca_out_of_core_distributed_module/trainer/covariance.py
import tensorflow as tf


class CovarianceMatrix(object):
    """Class that batch updates covariance matrix.

    Attributes:
        params: dict, user passed parameters.
        seen_example_count: tf.Variable, rank 0 of shape () containing
            the count of the number of examples seen so far.
        col_means_vector: tf.Variable, rank 1 of shape (num_cols,) containing
            column means.
        covariance_matrix: tf.Variable, rank 2 of shape (num_cols, num_cols)
            containing covariance matrix.
    """
    def __init__(self, params):
        """Initializes `CovarianceMatrix` class instance.

        Args:
            params: dict, user passed parameters.
        """
        self.params = params

        self.seen_example_count = tf.Variable(
            initial_value=tf.zeros(shape=(), dtype=tf.int64), trainable=False
        )
        self.col_means_vector = tf.Variable(
            initial_value=tf.zeros(
                shape=(self.params["num_cols"],), dtype=tf.float32
            ),
            trainable=False
        )
        self.covariance_matrix = tf.Variable(
            initial_value=tf.zeros(
                shape=(self.params["num_cols"], self.params["num_cols"]),
                dtype=tf.float32
            ),
            trainable=False
        )

    @tf.function
    def assign_seen_example_count(self, seen_example_count):
        """Assigns seen example count tf.Variable.

        Args:
            seen_example_count: tensor, rank 0 of shape () containing
            the count of the number of examples seen so far.
        """
        self.seen_example_count.assign(value=seen_example_count)

    @tf.function
    def assign_col_means_vector(self, col_means_vector):
        """Assigns column means vector tf.Variable.

        Args:
            col_means_vector: tensor, rank 1 of shape (num_cols,) containing
            column means.
        """
        self.col_means_vector.assign(value=col_means_vector)

    @tf.function
    def assign_covariance_matrix(self, covariance_matrix):
        """Assigns covariance matrix tf.Variable.

        Args:
            covariance_matrix: tensor, rank 2 of shape (num_cols, num_cols)
            containing covariance matrix.
        """
        self.covariance_matrix.assign(value=covariance_matrix)

    def update_example_count(self, count_a, count_b):
        """Updates the running number of examples processed.

        Given previous running total and current batch size, return new
        running total.

        Args:
            count_a: tensor, tf.int64 rank 0 tensor of previous running total
                of examples.
            count_b: tensor, tf.int64 rank 0 tensor of current batch size.

        Returns:
            A tf.int64 rank 0 tensor of new running total of examples.
        """
        return count_a + count_b

    def update_mean_incremental(self, count_a, mean_a, value_b):
        """Updates the running mean vector incrementally.

        Given previous running total, running column means, and single
            example's column values, return new running column means.

        Args:
            count_a: tensor, tf.int64 rank 0 tensor of previous running total
                of examples.
            mean_a: tensor, tf.float32 rank 1 tensor of previous running column
                means.
            value_b: tensor, tf.float32 rank 1 tensor of single example's
                column values.

        Returns:
            A tf.float32 rank 1 tensor of new running column means.
        """
        umean_a = mean_a * tf.cast(x=count_a, dtype=tf.float32)
        mean_ab_num = umean_a + tf.squeeze(input=value_b, axis=0)
        mean_ab = mean_ab_num / tf.cast(x=count_a + 1, dtype=tf.float32)

        return mean_ab

    def update_covariance_incremental(
        self, count_a, mean_a, cov_a, value_b, mean_ab, use_sample_covariance
    ):
        """Updates the running covariance matrix incrementally.

        Given previous running total, running column means, running covariance
        matrix, single example's column values, new running column means, and
        whether to use sample covariance or not, return new running covariance
        matrix.

        Args:
            count_a: tensor, tf.int64 rank 0 tensor of previous running total
                of examples.
            mean_a: tensor, tf.float32 rank 1 tensor of previous running column
                means.
            cov_a: tensor, tf.float32 rank 2 tensor of previous running
                covariance matrix.
            value_b: tensor, tf.float32 rank 1 tensor of single example's
                column values.
            mean_ab: tensor, tf.float32 rank 1 tensor of new running column
                means.
            use_sample_covariance: bool, flag on whether sample or population
                covariance is used.

        Returns:
            A tf.float32 rank 2 tensor of new covariance matrix.
        """
        mean_diff = tf.matmul(
                a=value_b - mean_a, b=value_b - mean_ab, transpose_a=True
        )

        if use_sample_covariance:
            ucov_a = cov_a * tf.cast(x=count_a - 1, dtype=tf.float32)
            cov_ab_denominator = tf.cast(x=count_a, dtype=tf.float32)
        else:
            ucov_a = cov_a * tf.cast(x=count_a, dtype=tf.float32)
            cov_ab_denominator = tf.cast(x=count_a + 1, dtype=tf.float32)
        cov_ab_numerator = ucov_a + mean_diff
        cov_ab = cov_ab_numerator / cov_ab_denominator

        return cov_ab

    def singleton_batch_update(
        self,
        X,
        running_count,
        running_mean,
        running_covariance,
        use_sample_covariance
    ):
        """Updates running tensors incrementally when batch_size equals 1.

        Given the the data vector X, the tensor tracking running example
        counts, the tensor tracking running column means, and the tensor
        tracking running covariance matrix, returns updated running example
        count tensor, column means tensor, and covariance matrix tensor.

        Args:
            X: tensor, tf.float32 rank 2 tensor of input data.
            running_count: tensor, tf.int64 rank 0 tensor tracking running
                example counts.
            running_mean: tensor, tf.float32 rank 1 tensor tracking running
                column means.
            running_covariance: tensor, tf.float32 rank 2 tensor tracking
                running covariance matrix.
            use_sample_covariance: bool, flag on whether sample or population
                covariance is used.

        Returns:
            Updated updated running example count tensor, column means tensor,
                and covariance matrix tensor.
        """
        # shape = (num_cols, num_cols)
        if running_count == 0:
            # Would produce NaNs, so rollover example for next iteration.
            self.rollover_singleton_example = X

            # Update count though so that we don't end up in this block again.
            count = self.update_example_count(
                count_a=running_count, count_b=1
            )

            # No need to update mean or covariance this iteration
            mean = running_mean
            covariance = running_covariance
        elif running_count == 1:
            # Batch update since we're combining previous & current batches.
            count, mean, covariance = self.non_singleton_batch_update(
                batch_size=2,
                X=tf.concat(
                    values=[self.rollover_singleton_example, X], axis=0
                ),
                running_count=0,
                running_mean=running_mean,
                running_covariance=running_covariance,
                use_sample_covariance=use_sample_covariance
            )
        else:
            # Calculate new combined mean for incremental covariance matrix.
            # shape = (num_cols,)
            mean = self.update_mean_incremental(
                count_a=running_count, mean_a=running_mean, value_b=X
            )

            # Update running tensors from single example
            # shape = ()
            count = self.update_example_count(
                count_a=running_count, count_b=1
            )

            # shape = (num_cols, num_cols)
            covariance = self.update_covariance_incremental(
                count_a=running_count,
                mean_a=running_mean,
                cov_a=running_covariance,
                value_b=X,
                mean_ab=mean,
                use_sample_covariance=use_sample_covariance
            )

        return count, mean, covariance

    def update_mean_batch(self, count_a, mean_a, count_b, mean_b):
        """Updates the running mean vector with a batch of data.

        Given previous running example count, running column means, current
        batch size, and batch's column means, return new running column means.

        Args:
            count_a: tensor, tf.int64 rank 0 tensor of previous running total
                of examples.
            mean_a: tensor, tf.float32 rank 1 tensor of previous running column
                means.
            count_b: tensor, tf.int64 rank 0 tensor of current batch size.
            mean_b: tensor, tf.float32 rank 1 tensor of batch's column means.

        Returns:
            A tf.float32 rank 1 tensor of new running column means.
        """
        sum_a = mean_a * tf.cast(x=count_a, dtype=tf.float32)
        sum_b = mean_b * tf.cast(x=count_b, dtype=tf.float32)
        mean_ab_denominator = tf.cast(x=count_a + count_b, dtype=tf.float32)
        mean_ab = (sum_a + sum_b) / mean_ab_denominator

        return mean_ab

    def update_covariance_batch(
        self,
        count_a,
        mean_a,
        cov_a,
        count_b,
        mean_b,
        cov_b,
        use_sample_covariance
    ):
        """Updates the running covariance matrix with batch of data.

        Given previous running example count, column means, and
        covariance matrix, current batch size, column means, and covariance
        matrix, and whether to use sample covariance or not, return new running
        covariance matrix.

        Args:
            count_a: tensor, tf.int64 rank 0 tensor of previous running total
                of examples.
            mean_a: tensor, tf.float32 rank 1 tensor of previous running column
                means.
            cov_a: tensor, tf.float32 rank 2 tensor of previous running
                covariance matrix.
            count_b: tensor, tf.int64 rank 0 tensor of current batch size.
            mean_b: tensor, tf.float32 rank 1 tensor of batch's column means.
            cov_b: tensor, tf.float32 rank 2 tensor of batch's covariance
                matrix.
            use_sample_covariance: bool, flag on whether sample or population
                covariance is used.

        Returns:
            A tf.float32 rank 2 tensor of new running covariance matrix.
        """
        mean_diff = tf.expand_dims(input=mean_a - mean_b, axis=0)

        if use_sample_covariance:
            ucov_a = cov_a * tf.cast(x=count_a - 1, dtype=tf.float32)
            ucov_b = cov_b * tf.cast(x=count_b - 1, dtype=tf.float32)
            den = tf.cast(x=count_a + count_b - 1, dtype=tf.float32)
        else:
            ucov_a = cov_a * tf.cast(x=count_a, dtype=tf.float32)
            ucov_b = cov_b * tf.cast(x=count_b, dtype=tf.float32)
            den = tf.cast(x=count_a + count_b, dtype=tf.float32)

        mean_diff = tf.matmul(a=mean_diff, b=mean_diff, transpose_a=True)
        mean_scaling_num = tf.cast(x=count_a * count_b, dtype=tf.float32)
        mean_scaling_den = tf.cast(x=count_a + count_b, dtype=tf.float32)
        mean_scaling = mean_scaling_num / mean_scaling_den
        cov_ab = (ucov_a + ucov_b + mean_diff * mean_scaling) / den

        return cov_ab

    def non_singleton_batch_update(
        self,
        batch_size,
        X,
        running_count,
        running_mean,
        running_covariance,
        use_sample_covariance
    ):
        """Updates running tensors when batch_size does NOT equal 1.

        Given the current batch size, the data matrix X, the tensor tracking
        running example counts, the tensor tracking running column means, and
        the tensor tracking running covariance matrix, returns updated running
        example count tensor, column means tensor, and covariance matrix
        tensor.

        Args:
            batch_size: int, number of examples in current batch (could be
                partial).
            X: tensor, tf.float32 rank 2 tensor of input data.
            running_count: tensor, tf.int64 rank 0 tensor tracking running
                example counts.
            running_mean: tensor, tf.float32 rank 1 tensor tracking running
                column means.
            running_covariance: tensor, tf.float32 rank 2 tensor tracking
                running covariance matrix.
            use_sample_covariance: bool, flag on whether sample or population
                covariance is used.

        Returns:
            Updated updated running example count tensor, column means tensor,
                and covariance matrix tensor.
        """
        # shape = (num_cols,)
        X_mean = tf.reduce_mean(input_tensor=X, axis=0)

        # shape = (batch_size, num_cols)
        X_centered = X - X_mean

        # shape = (num_cols, num_cols)
        X_cov = tf.matmul(
                a=X_centered,
                b=X_centered,
                transpose_a=True
        )
        X_cov /= tf.cast(x=batch_size - 1, dtype=tf.float32)

        # Update running tensors from batch statistics.
        # shape = ()
        count = self.update_example_count(
            count_a=running_count, count_b=batch_size
        )

        # shape = (num_cols,)
        mean = self.update_mean_batch(
            count_a=running_count,
            mean_a=running_mean,
            count_b=batch_size,
            mean_b=X_mean
        )

        # shape = (num_cols, num_cols)
        covariance = self.update_covariance_batch(
            count_a=running_count,
            mean_a=running_mean,
            cov_a=running_covariance,
            count_b=batch_size,
            mean_b=X_mean,
            cov_b=X_cov,
            use_sample_covariance=use_sample_covariance
        )

        return count, mean, covariance

    def calculate_data_stats(self, data):
        """Calculates statistics of data.

        Args:
            data: tensor, rank 2 tensor of shape
                (current_batch_size, num_cols) containing batch of input data.
        """
        current_batch_size = data.shape[0]

        if current_batch_size == 1:
            (seen_example_count,
             col_means_vector,
             covariance_matrix) = self.singleton_batch_update(
                X=data,
                running_count=self.seen_example_count,
                running_mean=self.col_means_vector,
                running_covariance=self.covariance_matrix,
                use_sample_covariance=self.params["use_sample_covariance"]
            )
        else:
            (seen_example_count,
             col_means_vector,
             covariance_matrix) = self.non_singleton_batch_update(
                batch_size=current_batch_size,
                X=data,
                running_count=self.seen_example_count,
                running_mean=self.col_means_vector,
                running_covariance=self.covariance_matrix,
                use_sample_covariance=self.params["use_sample_covariance"]
            )

        self.assign_seen_example_count(seen_example_count=seen_example_count)
        self.assign_col_means_vector(col_means_vector=col_means_vector)
        self.assign_covariance_matrix(covariance_matrix=covariance_matrix)


## pca.py

In [None]:
%%writefile pca_out_of_core_distributed_module/trainer/pca.py
import tensorflow as tf

from . import covariance


class PCA(covariance.CovarianceMatrix):
    """Class that performs PCA projection and reconstruction.

    Attributes:
        rollover_singleton_example: tensor, rank 2 tensor of shape
            (1, num_cols) containing a rollover singleton example in case the
            data batch size begins at 1. This avoids NaN covariances.
        eigenvalues: tf.Variable, rank 1 of shape (num_cols,) containing the
            eigenvalues of the covariance matrix.
        eigenvectors: tf.Variable, rank 2 of shape (num_cols, num_cols)
            containing the eigenvectors of the covariance matrix.
        top_k_eigenvectors: tensor, rank 2 tensor of shape
            (num_cols, top_k_pc) containing the eigenvectors associated with
            the top_k_pc eigenvalues.
    """
    def __init__(self, params):
        """Initializes `PCA` class instance.

        Args:
            params: dict, user passed parameters.
        """
        super().__init__(params=params)
        self.params = params

        self.rollover_singleton_example = None

        self.eigenvalues = tf.Variable(
            initial_value=tf.zeros(
                shape=(self.params["num_cols"],), dtype=tf.float32
            ),
            trainable=False
        )

        self.eigenvectors = tf.Variable(
            initial_value=tf.zeros(
                shape=(self.params["num_cols"], self.params["num_cols"]),
                dtype=tf.float32
            ),
            trainable=False
        )

        self.top_k_eigenvectors = tf.zeros(
            shape=(self.params["num_cols"], self.params["top_k_pc"])
        )

    @tf.function
    def assign_eigenvalues(self, eigenvalues):
        """Assigns covariance matrix eigenvalues tf.Variable.

        Args:
            eigenvalues: tensor, rank 1 of shape (num_cols,) containing the
            eigenvalues of the covariance matrix.
        """
        self.eigenvalues.assign(value=eigenvalues)

    @tf.function
    def assign_eigenvectors(self, eigenvectors):
        """Assigns covariance matrix eigenvectors tf.Variable.

        Args:
            eigenvectors: tensor, rank 2 of shape (num_cols, num_cols)
            containing the eigenvectors of the covariance matrix.
        """
        self.eigenvectors.assign(value=eigenvectors)

    def calculate_eigenvalues_and_eigenvectors(self):
        """Calculates eigenvalues and eigenvectors of data.
        """
        # shape = (num_cols,) & (num_cols, num_cols)
        eigenvalues, eigenvectors = tf.linalg.eigh(
            tensor=self.covariance_matrix
        )

        self.assign_eigenvalues(eigenvalues=eigenvalues)
        self.assign_eigenvectors(eigenvectors=eigenvectors)

    def pca_projection_to_top_k_pc(self, data):
        """Projects data down to top_k principal components.

        Args:
            data: tensor, rank 2 tensor of shape (num_examples, num_cols)
                containing batch of input data.

        Returns:
            Rank 2 tensor of shape (num_examples, top_k_pc) containing
                projected centered data.
        """
        # shape = (num_cols, top_k_pc)
        self.top_k_eigenvectors = (
            self.eigenvectors[:, -self.params["top_k_pc"]:]
        )

        # shape = (num_examples, num_cols)
        centered_data = data - self.col_means_vector

        # shape = (num_examples, top_k_pc)
        projected_centered_data = tf.matmul(
            a=centered_data,
            b=self.top_k_eigenvectors
        )

        return projected_centered_data

    def pca_reconstruction_from_top_k_pc(self, data):
        """Reconstructs data up from top_k principal components.

        Args:
            data: tensor, rank 2 tensor of shape (num_examples, num_cols)
                containing batch of input data.

        Returns:
            Rank 2 tensor of shape (num_examples, num_cols) containing
                lossy, reconstructed input data.
        """
        # shape = (num_examples, top_k_pc)
        projected_centered_data = self.pca_projection_to_top_k_pc(data=data)

        # shape = (num_examples, num_cols)
        unprojected_centered_data = tf.matmul(
            a=projected_centered_data,
            b=self.top_k_eigenvectors,
            transpose_b=True
        )

        # shape = (num_examples, num_cols)
        data_reconstructed = unprojected_centered_data + self.col_means_vector

        return data_reconstructed


## checkpoints.py

In [None]:
%%writefile pca_out_of_core_distributed_module/trainer/checkpoints.py
import os
import tensorflow as tf


class Checkpoints(object):
    """Class that contains methods used for training checkpoints.
    """
    def __init__(self):
        """Instantiate instance of `Checkpoints`.
        """
        pass

    def create_checkpoint_manager(self):
        """Creates checkpoint manager for reading and writing checkpoints.
        """
        self.checkpoint_manager = tf.train.CheckpointManager(
            checkpoint=self.checkpoint,
            directory=os.path.join(
                self.params["output_dir"], "checkpoints"
            ),
            max_to_keep=self.params["keep_checkpoint_max"],
            checkpoint_name="ckpt",
            step_counter=self.global_step,
            checkpoint_interval=self.params["save_checkpoints_steps"]
        )

    def create_checkpoint_machinery(self):
        """Creates checkpoint machinery needed to save & restore checkpoints.
        """
        # Create checkpoint instance.
        self.checkpoint = tf.train.Checkpoint(
            global_step=self.global_step,
            seen_example_count=self.pca_model.seen_example_count,
            col_means_vector=self.pca_model.col_means_vector,
            covariance_matrix=self.pca_model.covariance_matrix,
            eigenvalues=self.pca_model.eigenvalues,
            eigenvectors=self.pca_model.eigenvectors
        )

        # Create initial checkpoint manager.
        self.create_checkpoint_manager()

        # Restore any prior checkpoints.
        print(
            "Loading latest checkpoint: {}".format(
                self.checkpoint_manager.latest_checkpoint
            )
        )
        status = self.checkpoint.restore(
            save_path=self.checkpoint_manager.latest_checkpoint
        )

        if self.checkpoint_manager.latest_checkpoint:
            status.assert_consumed()


## train_step.py

In [None]:
%%writefile pca_out_of_core_distributed_module/trainer/train_step.py
import tensorflow as tf


class TrainStep(object):
    """Class that contains methods concerning train steps.
    """
    def __init__(self):
        """Instantiate instance of `TrainStep`.
        """
        pass

    def train_batch(self, features):
        """Trains model with a batch of feature data.

        Args:
            features: tensor, rank 2 tensor of feature data.

        Returns:
            Scalar loss tensor.
        """
        # Pass images through ResNet to get feature vectors.
        resnet_feature_vectors = (
            self.resnet_instance.get_image_resnet_feature_vectors(
                images=features
            )
        )

        # Train PCA model.
        self.pca_model.calculate_data_stats(data=resnet_feature_vectors)

        return tf.zeros(shape=(), dtype=tf.float32)

    def distributed_eager_train_step(self, features):
        """Perform one distributed, eager train step.

        Args:
            features: dict, feature tensors from input function.

        Returns:
            Scalar loss of model.
        """
        if self.params["tf_version"] > 2.1:
            run_function = self.strategy.run
        else:
            run_function = self.strategy.experimental_run_v2

        per_replica_losses = run_function(
            fn=self.train_batch,
            kwargs={"features": features}
        )

        return self.strategy.reduce(
            reduce_op=tf.distribute.ReduceOp.SUM,
            value=per_replica_losses,
            axis=None
        )

    def non_distributed_eager_train_step(self, features):
        """Perform one non-distributed, eager train step.

        Args:
            features: dict, feature tensors from input function.

        Returns:
            Scalar loss of model.
        """
        return self.train_batch(features=features)

    @tf.function
    def distributed_graph_train_step(self, features):
        """Perform one distributed, graph train step.

        Args:
            features: dict, feature tensors from input function.

        Returns:
            Scalar loss of model.
        """
        if self.params["tf_version"] > 2.1:
            run_function = self.strategy.run
        else:
            run_function = self.strategy.experimental_run_v2

        per_replica_losses = run_function(
            fn=self.train_batch,
            kwargs={"features": features}
        )

        return self.strategy.reduce(
            reduce_op=tf.distribute.ReduceOp.SUM,
            value=per_replica_losses,
            axis=None
        )

    @tf.function
    def non_distributed_graph_train_step(self, features):
        """Perform one non-distributed, graph train step.

        Args:
            features: dict, feature tensors from input function.

        Returns:
            Scalar loss of model.
        """
        return self.train_batch(features=features)

    def get_train_step_functions(self):
        """Gets model train step functions for strategy and mode.
        """
        if self.strategy:
            if self.params["use_graph_mode"]:
                self.train_step_fn = (
                    self.distributed_graph_train_step
                )
            else:
                self.train_step_fn = (
                    self.distributed_eager_train_step
                )
        else:
            if self.params["use_graph_mode"]:
                self.train_step_fn = (
                    self.non_distributed_graph_train_step
                )
            else:
                self.train_step_fn = (
                    self.non_distributed_eager_train_step
                )

    @tf.function
    def increment_global_step_var(self):
        """Increments global step variable.
        """
        self.global_step.assign_add(
            delta=tf.ones(shape=(), dtype=tf.int64)
        )

    def perform_training_step(self, train_dataset_iterator, train_step_fn):
        """Performs one training step of model.

        Args:
            train_dataset_iterator: iterator, iterator of instance of
                `Dataset` for training data.
            train_step_fn: unbound function, trains the given model
                with a given set of features.
        """
        # Train model on batch of features and get loss.
        features = next(train_dataset_iterator)

        # Train for a step and get loss.
        self.loss = train_step_fn(features=features)

        # Checkpoint model every save_checkpoints_steps steps.
        checkpoint_saved = self.checkpoint_manager.save(
            checkpoint_number=self.global_step, check_interval=True
        )

        if checkpoint_saved:
            print("Checkpoint saved at {}".format(checkpoint_saved))

        # Increment steps.
        self.increment_global_step_var()


## training_loop.py

In [None]:
%%writefile pca_out_of_core_distributed_module/trainer/training_loop.py
import tensorflow as tf


class TrainingLoop(object):
    """Class that contains methods for training loop.
    """
    def __init__(self):
        """Instantiate instance of `TrainStep`.
        """
        pass

    def training_loop(self):
        """Loops through training dataset to train model.
        """
        # Get correct train function based on parameters.
        self.get_train_step_functions()

        num_steps = (
            self.params["train_dataset_length"] // self.global_batch_size
        )

        while self.global_step.numpy() < num_steps:
            # Train model.
            self.perform_training_step(
                train_dataset_iterator=self.train_dataset_iterator,
                train_step_fn=self.train_step_fn
            )

        self.training_loop_end_save_model()


## export.py

In [None]:
%%writefile pca_out_of_core_distributed_module/trainer/export.py
import datetime
import os
import tensorflow as tf


class Export(object):
    """Class that contains methods used for exporting model objects.
    """
    def __init__(self):
        """Instantiate instance of `Export`.
        """
        pass

    def create_serving_model(self):
        """Creates Keras `Model` for serving.

        Returns:
            `tf.Keras.Model` for serving predictions.
        """
        # Create input layer for raw images.
        input_layer = tf.keras.Input(
            shape=(
                self.params["image_height"],
                self.params["image_width"],
                self.params["image_depth"]
            ),
            name="serving_inputs",
            dtype=tf.uint8
        )

        # Pass images through ResNet to get feature vectors.
        resnet_feature_vectors = (
            self.resnet_instance.get_image_resnet_feature_vectors(
                images=input_layer
            )
        )

        # Project ResNet feature vectors using PCA eigenvectors.
        pca_projections = tf.identity(
            input=self.pca_model.pca_projection_to_top_k_pc(
                data=resnet_feature_vectors
            ),
            name="pca_projections"
        )

        return tf.keras.Model(
            inputs=input_layer,
            outputs=pca_projections,
            name="serving_model"
        )

    def export_saved_model(self):
        """Exports SavedModel to output directory for serving.
        """
        # Build export path.
        export_path = os.path.join(
            self.params["output_dir"],
            "export",
            datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        )

        # Create serving models.
        serving_model = self.create_serving_model()

        # Signature will be serving_default.
        tf.saved_model.save(
            obj=serving_model,
            export_dir=export_path
        )

    def training_loop_end_save_model(self):
        """Saving model when training loop ends.
        """
        # Write final checkpoint.
        checkpoint_saved = self.checkpoint_manager.save(
            checkpoint_number=self.global_step, check_interval=False
        )

        if checkpoint_saved:
            print("Checkpoint saved at {}".format(checkpoint_saved))

        # Export SavedModel for serving.
        self.export_saved_model()


## model.py

In [None]:
%%writefile pca_out_of_core_distributed_module/trainer/model.py
import os
import tensorflow as tf

from . import checkpoints
from . import export
from . import pca
from . import resnet
from . import train_step
from . import training_inputs
from . import training_loop


class TrainModel(
    checkpoints.Checkpoints,
    train_step.TrainStep,
    training_loop.TrainingLoop,
    export.Export
):
    """Class that trains a model.

    Attributes:
        params: dict, user passed parameters.
        resnet_instance: instance or `ResNet` class.
        pca_model: instance of `PCA` class.
        strategy: instance of tf.distribute.strategy.
        global_batch_size: int, global batch size after summing batch sizes
            across replicas.
        train_dataset_iterator: iterator, iterator of instance of `Dataset`
            for training data.
        train_step_fn: unbound function, function for a train step using
            correct strategy and mode.
        global_step: tf.Variable, the global step counter.
        checkpoint: instance of tf.train.Checkpoint, for saving and restoring
            checkpoints.
        checkpoint_manager: instance of tf.train.CheckpointManager, for
            managing checkpoint path, how often to write, etc.
    """
    def __init__(self, params):
        """Instantiate trainer.

        Args:
            params: dict, user passed parameters.
        """
        super().__init__()
        self.params = params

        self.resnet_instance = resnet.ResNet(
            params={
                "image_height": self.params["image_height"],
                "image_width": self.params["image_width"],
                "image_depth": self.params["image_depth"],
                "resnet_weights": self.params["resnet_weights"],
                "resnet_layer_name": self.params["resnet_layer_name"],
                "preprocess_input": self.params["preprocess_input"]
            }
        )

        self.pca_model = pca.PCA(
            params={
                "num_cols": self.params["num_cols"],
                "use_sample_covariance": self.params["use_sample_covariance"],
                "top_k_pc": self.params["top_k_pc"]
            }
        )

        self.strategy = None
        self.global_batch_size = []

        self.train_dataset_iterator = None

        self.train_step_fn = None

        self.global_step = tf.Variable(
            initial_value=tf.zeros(shape=[], dtype=tf.int64),
            trainable=False,
            name="global_step"
        )

        self.checkpoint = None
        self.checkpoint_manager = None

    def get_train_dataset(self, num_replicas):
        """Gets train dataset.

        Args:
            num_replicas: int, number of device replicas.

        Returns:
            `tf.data.Dataset` for training data.
        """
        return training_inputs.read_dataset(
            file_pattern=self.params["train_file_pattern"],
            batch_size=self.params["train_batch_size"] * num_replicas,
            params=self.params
        )()

    def train_block(self, train_dataset):
        """Training block setups training, then loops through datasets.

        Args:
            train_dataset: instance of `Dataset` for training data.
        """
        # Create iterators of datasets.
        self.train_dataset_iterator = iter(train_dataset)

        # Create checkpoint machinery to save/restore checkpoints.
        self.create_checkpoint_machinery()

        # Run training loop.
        self.training_loop()

    def train_model(self):
        """Trains Keras model.

        Args:
            args: dict, user passed parameters.
        """
        if self.params["distribution_strategy"]:
            # If the list of devices is not specified in the
            # Strategy constructor, it will be auto-detected.
            if self.params["distribution_strategy"] == "Mirrored":
                self.strategy = tf.distribute.MirroredStrategy()
            print(
                "Number of devices = {}".format(
                    self.strategy.num_replicas_in_sync
                )
            )

            # Set global batch size for training.
            self.global_batch_size = (
                self.params["train_batch_size"] * self.strategy.num_replicas_in_sync
            )

            # Get input dataset. Batch size is split evenly between replicas.
            train_dataset = self.get_train_dataset(
                num_replicas=self.strategy.num_replicas_in_sync
            )

            with self.strategy.scope():
                # Create distributed datasets.
                train_dataset = (
                    self.strategy.experimental_distribute_dataset(
                        dataset=train_dataset
                    )
                )

                # Training block setups training, then loops through datasets.
                self.train_block(train_dataset=train_dataset)
        else:
            # Set global batch size for training.
            self.global_batch_size = self.params["train_batch_size"]

            # Get input datasets.
            train_dataset = self.get_train_dataset(num_replicas=1)

            # Training block setups training, then loops through datasets.
            self.train_block(train_dataset=train_dataset)


## cli_parser.py

In [None]:
%%writefile pca_out_of_core_distributed_module/trainer/cli_parser.py

import argparse


def parse_file_arguments(parser):
    """Parses command line file arguments.

    Args:
        parser: instance of `argparse.ArgumentParser`.
    """
    parser.add_argument(
        "--train_file_pattern",
        help="GCS location to read training data.",
        type=str,
        required=True
    )
    parser.add_argument(
        "--output_dir",
        help="GCS location to write checkpoints and export models.",
        type=str,
        required=True
    )
    parser.add_argument(
        "--job-dir",
        help="This model ignores this field, but it is required by gcloud.",
        type=str,
        default="junk"
    )


def parse_data_arguments(parser):
    """Parses command line data arguments.

    Args:
        parser: instance of `argparse.ArgumentParser`.
    """
    parser.add_argument(
        "--tf_record_example_schema",
        help="Serialized TF Record Example schema.",
        type=str,
        required=True
    )
    parser.add_argument(
        "--image_feature_name",
        help="Name of image feature.",
        type=str,
        default="image"
    )
    parser.add_argument(
        "--image_encoding",
        help="Encoding of image: raw, png, or jpeg.",
        type=str,
        default="raw"
    )
    parser.add_argument(
        "--image_height",
        help="Height of image.",
        type=int,
        default=32
    )
    parser.add_argument(
        "--image_width",
        help="Width of image.",
        type=int,
        default=32
    )
    parser.add_argument(
        "--image_depth",
        help="Depth of image.",
        type=int,
        default=3
    )
    parser.add_argument(
        "--label_feature_name",
        help="Name of label feature.",
        type=str,
        default="label"
    )


def parse_training_arguments(parser):
    """Parses command line training arguments.

    Args:
        parser: instance of `argparse.ArgumentParser`.
    """
    parser.add_argument(
        "--tf_version",
        help="Version of TensorFlow",
        type=float,
        default=2.3
    )
    parser.add_argument(
        "--use_graph_mode",
        help="Whether to use graph mode or not (eager).",
        type=str,
        default="True"
    )
    parser.add_argument(
        "--distribution_strategy",
        help="Which distribution strategy to use, if any.",
        type=str,
        default=""
    )
    parser.add_argument(
        "--train_dataset_length",
        help="Number of examples in one epoch of training set.",
        type=int,
        default=100
    )
    parser.add_argument(
        "--train_batch_size",
        help="Number of examples in training batch.",
        type=int,
        default=32
    )
    parser.add_argument(
        "--input_fn_autotune",
        help="Whether to autotune input function performance.",
        type=str,
        default="True"
    )
    parser.add_argument(
        "--save_checkpoints_steps",
        help="How many steps to train before saving a checkpoint.",
        type=int,
        default=100
    )
    parser.add_argument(
        "--keep_checkpoint_max",
        help="Max number of checkpoints to keep.",
        type=int,
        default=100
    )


def parse_resnet_arguments(parser):
    """Parses command line ResNet arguments.

    Args:
        parser: instance of `argparse.ArgumentParser`.
    """
    parser.add_argument(
        "--resnet_weights",
        help="The type of weights to use in Resnet, i.e. imagenet.",
        type=str,
        default="imagenet"
    )
    parser.add_argument(
        "--resnet_layer_name",
        help="Number of top principal components to keep.",
        type=str,
        default="conv4_block1_0_conv"
    )
    parser.add_argument(
        "--preprocess_input",
        help="Whether to preprocess input for ResNet.",
        type=str,
        default="True"
    )


def parse_pca_arguments(parser):
    """Parses command line PCA arguments.

    Args:
        parser: instance of `argparse.ArgumentParser`.
    """
    parser.add_argument(
        "--num_cols",
        help="Number of dimensions for each data instance.",
        type=int,
        default=1
    )
    parser.add_argument(
        "--use_sample_covariance",
        help="Whether using sample or population covariance.",
        type=str,
        default="True"
    )
    parser.add_argument(
        "--top_k_pc",
        help="Number of top principal components to keep.",
        type=int,
        default=1
    )


def parse_command_line_arguments():
    """Parses command line arguments and returns dictionary.

    Returns:
        Dictionary containing command line arguments.
    """
    parser = argparse.ArgumentParser()

    # Add various arguments to parser.
    parse_file_arguments(parser)
    parse_data_arguments(parser)
    parse_training_arguments(parser)
    parse_resnet_arguments(parser)
    parse_pca_arguments(parser)

    # Parse all arguments.
    args = parser.parse_args()
    arguments = args.__dict__

    return arguments


## cli_argument_reformat.py

In [None]:
%%writefile pca_out_of_core_distributed_module/trainer/cli_argument_reformat.py
import json


def convert_string_to_bool(string):
    """Converts string to bool.

    Args:
        string: str, string to convert.

    Returns:
        Boolean conversion of string.
    """
    return False if string.lower() == "false" else True


def fix_arguments(arguments):
    """Fixes command line arguments dictionary in place.
    """
    # Fix tf_record_example_schema.
    arguments["tf_record_example_schema"] = json.loads(
        arguments["tf_record_example_schema"].replace(";", " ")
    )

    # Fix use_graph_mode.
    arguments["use_graph_mode"] = convert_string_to_bool(
        string=arguments["use_graph_mode"]
    )

    # Fix input_fn_autotune.
    arguments["input_fn_autotune"] = convert_string_to_bool(
        string=arguments["input_fn_autotune"]
    )

    # Fix preprocess_input.
    arguments["preprocess_input"] = convert_string_to_bool(
        string=arguments["preprocess_input"]
    )

    # Fix use_sample_covariance.
    arguments["use_sample_covariance"] = convert_string_to_bool(
        string=arguments["use_sample_covariance"]
    )


## task.py

In [None]:
%%writefile pca_out_of_core_distributed_module/trainer/task.py
import json
import os

from trainer import cli_argument_reformat
from trainer import cli_parser
from trainer import model


if __name__ == "__main__":
    # Parse command line arguments.
    arguments = cli_parser.parse_command_line_arguments()

    # Unused args provided by service.
    arguments.pop("job_dir", None)
    arguments.pop("job-dir", None)

    # Fix formatting of command line arguments.
    cli_argument_reformat.fix_arguments(arguments)

    # Append trial_id to path if we are doing hptuning.
    # This code can be removed if you are not using hyperparameter tuning.
    arguments["output_dir"] = os.path.join(
        arguments["output_dir"],
        json.loads(
            os.environ.get(
                "TF_CONFIG", "{}"
            )
        ).get("task", {}).get("trial", ""))

    print(arguments)

    # Instantiate instance of model trainer.
    trainer = model.TrainModel(params=arguments)

    # Run the training job.
    trainer.train_model()
