diff --git a/docs/advanced_topics.rst b/docs/advanced_topics.rst index 3051598ad2..7506d2d0a2 100644 --- a/docs/advanced_topics.rst +++ b/docs/advanced_topics.rst @@ -11,5 +11,6 @@ Advanced Topics :maxdepth: 4 multiple_plans + compression_settings diff --git a/docs/compression_settings.rst b/docs/compression_settings.rst new file mode 100644 index 0000000000..f3d19e8870 --- /dev/null +++ b/docs/compression_settings.rst @@ -0,0 +1,27 @@ +.. # Copyright (C) 2021 Intel Corporation +.. # Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +.. _compression_settings: + +*************** +Compression Settings +*************** + +Federated Learning can enable tens to thousands of participants to work together on the same model, but with this scaling comes increased communication cost. Furthermore, large models exacerbate this problem. For this reason we make compression is a core capability of |productName|, and our framework supports several lossless and lossy compression pipelines out of the box. In general, the weights of a model are typically not robust to information loss, so no compression is applied by default to the model weights sent bidirectionally; however, the deltas between the model weights for each round are inherently more sparse and better suited for lossy compression. The following is the list of compression pipelines that |productName| currently supports: + +* ``NoCompressionPipeline``: This is the default option applied to model weights +* ``RandomShiftPipeline``: A **lossless** pipeline that randomly shifts the weights during transport +* ``STCPipeline``: A **lossy** pipeline consisting of three transformations: *Sparsity Transform* (p_sparsity=0.1), which by default retains only the (p*100)% absolute values of greatest magnitude. *Ternary Transform*, which discretizes the sparse array into three buckets, and finally a *GZIP Transform*. +* ``SKCPipeline``: A **lossy** pipeline consisting of three transformations: *Sparsity Transform* (p=0.1), which by default retains only the(p*100)% absolute values of greatest magnitude. *KMeans Transform* (k=6), which applies the KMeans algorithm to the sparse array with *k* centroids, and finally a *GZIP Transform*. +* ``KCPipeline``: A **lossy** pipeline consisting of two transformations: *KMeans Transform* (k=6), which applies the KMeans algorithm to the original weight array with *k* centroids, and finally a *GZIP Transform*. + +We provide an example template, **keras_cnn_with_compression**, that utilizes the *KCPipeline* with 6 centroids for KMeans. To gain a better understanding of how experiments perform with greater or fewer centroids, you can modify the *n_clusters* parameter in the template's plan.yaml: + + .. code-block:: console + + compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml + template : openfl.pipelines.KCPipeline + settings : + n_clusters : 6 + diff --git a/openfl-workspace/default/plan/plan.yaml b/openfl-workspace/default/plan/plan.yaml index 66c7f15d19..ccfd00904f 100644 --- a/openfl-workspace/default/plan/plan.yaml +++ b/openfl-workspace/default/plan/plan.yaml @@ -31,3 +31,6 @@ assigner : tasks : defaults : plan/defaults/tasks_fast_estimator.yaml + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml diff --git a/openfl-workspace/fe_tf_adversarial_cifar/plan/plan.yaml b/openfl-workspace/fe_tf_adversarial_cifar/plan/plan.yaml index 4125f061bf..73d75f03a2 100644 --- a/openfl-workspace/fe_tf_adversarial_cifar/plan/plan.yaml +++ b/openfl-workspace/fe_tf_adversarial_cifar/plan/plan.yaml @@ -37,3 +37,6 @@ assigner : tasks : defaults : plan/defaults/tasks_fast_estimator.yaml + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml diff --git a/openfl-workspace/fe_torch_adversarial_cifar/plan/plan.yaml b/openfl-workspace/fe_torch_adversarial_cifar/plan/plan.yaml index 4125f061bf..73d75f03a2 100644 --- a/openfl-workspace/fe_torch_adversarial_cifar/plan/plan.yaml +++ b/openfl-workspace/fe_torch_adversarial_cifar/plan/plan.yaml @@ -37,3 +37,6 @@ assigner : tasks : defaults : plan/defaults/tasks_fast_estimator.yaml + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml diff --git a/openfl-workspace/keras_cnn_mnist/plan/plan.yaml b/openfl-workspace/keras_cnn_mnist/plan/plan.yaml index eee8497df3..c177e6c3ad 100644 --- a/openfl-workspace/keras_cnn_mnist/plan/plan.yaml +++ b/openfl-workspace/keras_cnn_mnist/plan/plan.yaml @@ -37,3 +37,6 @@ assigner : tasks : defaults : plan/defaults/tasks_keras.yaml + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml diff --git a/openfl-workspace/keras_cnn_with_compression/.workspace b/openfl-workspace/keras_cnn_with_compression/.workspace new file mode 100644 index 0000000000..3c2c5d08b4 --- /dev/null +++ b/openfl-workspace/keras_cnn_with_compression/.workspace @@ -0,0 +1,2 @@ +current_plan_name: default + diff --git a/openfl-workspace/keras_cnn_with_compression/code/__init__.py b/openfl-workspace/keras_cnn_with_compression/code/__init__.py new file mode 100644 index 0000000000..f1410b1298 --- /dev/null +++ b/openfl-workspace/keras_cnn_with_compression/code/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""You may copy this file as the starting point of your own model.""" diff --git a/openfl-workspace/keras_cnn_with_compression/code/keras_cnn.py b/openfl-workspace/keras_cnn_with_compression/code/keras_cnn.py new file mode 100644 index 0000000000..abbe8b8320 --- /dev/null +++ b/openfl-workspace/keras_cnn_with_compression/code/keras_cnn.py @@ -0,0 +1,85 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" + +import tensorflow.keras as ke + +from tensorflow.keras import Sequential +from tensorflow.keras.layers import Conv2D, Flatten, Dense + +from openfl.federated import KerasTaskRunner + + +class KerasCNN(KerasTaskRunner): + """A basic convolutional neural network model.""" + + def __init__(self, **kwargs): + """ + Initialize. + + Args: + **kwargs: Additional parameters to pass to the function + """ + super().__init__(**kwargs) + + self.model = self.build_model(self.feature_shape, self.data_loader.num_classes, **kwargs) + + self.initialize_tensorkeys_for_functions() + + self.model.summary(print_fn=self.logger.info) + + if self.data_loader is not None: + self.logger.info(f'Train Set Size : {self.get_train_data_size()}') + self.logger.info(f'Valid Set Size : {self.get_valid_data_size()}') + + def build_model(self, + input_shape, + num_classes, + conv_kernel_size=(4, 4), + conv_strides=(2, 2), + conv1_channels_out=16, + conv2_channels_out=32, + final_dense_inputsize=100, + **kwargs): + """ + Define the model architecture. + + Args: + input_shape (numpy.ndarray): The shape of the data + num_classes (int): The number of classes of the dataset + + Returns: + tensorflow.python.keras.engine.sequential.Sequential: The model defined in Keras + + """ + model = Sequential() + + model.add(Conv2D(conv1_channels_out, + kernel_size=conv_kernel_size, + strides=conv_strides, + activation='relu', + input_shape=input_shape)) + + model.add(Conv2D(conv2_channels_out, + kernel_size=conv_kernel_size, + strides=conv_strides, + activation='relu')) + + model.add(Flatten()) + + model.add(Dense(final_dense_inputsize, activation='relu')) + + model.add(Dense(num_classes, activation='softmax')) + + model.compile(loss=ke.losses.categorical_crossentropy, + optimizer=ke.optimizers.Adam(), + metrics=['accuracy']) + + # initialize the optimizer variables + opt_vars = model.optimizer.variables() + + for v in opt_vars: + v.initializer.run(session=self.sess) + + return model diff --git a/openfl-workspace/keras_cnn_with_compression/code/mnist_utils.py b/openfl-workspace/keras_cnn_with_compression/code/mnist_utils.py new file mode 100644 index 0000000000..295c173bcd --- /dev/null +++ b/openfl-workspace/keras_cnn_with_compression/code/mnist_utils.py @@ -0,0 +1,118 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" + +import numpy as np + +from logging import getLogger +from tensorflow.python.keras.utils.data_utils import get_file + +logger = getLogger(__name__) + + +def one_hot(labels, classes): + """ + One Hot encode a vector. + + Args: + labels (list): List of labels to onehot encode + classes (int): Total number of categorical classes + + Returns: + np.array: Matrix of one-hot encoded labels + """ + return np.eye(classes)[labels] + + +def _load_raw_datashards(shard_num, collaborator_count): + """ + Load the raw data by shard. + + Returns tuples of the dataset shard divided into training and validation. + + Args: + shard_num (int): The shard number to use + collaborator_count (int): The number of collaborators in the federation + + Returns: + 2 tuples: (image, label) of the training, validation dataset + """ + origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' + path = get_file('mnist.npz', + origin=origin_folder + 'mnist.npz', + file_hash='731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1') + + with np.load(path) as f: + # get all of mnist + X_train_tot = f['x_train'] + y_train_tot = f['y_train'] + + X_valid_tot = f['x_test'] + y_valid_tot = f['y_test'] + + # create the shards + shard_num = int(shard_num) + X_train = X_train_tot[shard_num::collaborator_count] + y_train = y_train_tot[shard_num::collaborator_count] + + X_valid = X_valid_tot[shard_num::collaborator_count] + y_valid = y_valid_tot[shard_num::collaborator_count] + + return (X_train, y_train), (X_valid, y_valid) + + +def load_mnist_shard(shard_num, collaborator_count, categorical=True, + channels_last=True, **kwargs): + """ + Load the MNIST dataset. + + Args: + shard_num (int): The shard to use from the dataset + collaborator_count (int): The number of collaborators in the federation + categorical (bool): True = convert the labels to one-hot encoded + vectors (Default = True) + channels_last (bool): True = The input images have the channels + last (Default = True) + **kwargs: Additional parameters to pass to the function + + Returns: + list: The input shape + int: The number of classes + numpy.ndarray: The training data + numpy.ndarray: The training labels + numpy.ndarray: The validation data + numpy.ndarray: The validation labels + """ + img_rows, img_cols = 28, 28 + num_classes = 10 + + (X_train, y_train), (X_valid, y_valid) = _load_raw_datashards( + shard_num, collaborator_count + ) + + if channels_last: + X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1) + X_valid = X_valid.reshape(X_valid.shape[0], img_rows, img_cols, 1) + input_shape = (img_rows, img_cols, 1) + else: + X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols) + X_valid = X_valid.reshape(X_valid.shape[0], 1, img_rows, img_cols) + input_shape = (1, img_rows, img_cols) + + X_train = X_train.astype('float32') + X_valid = X_valid.astype('float32') + X_train /= 255 + X_valid /= 255 + + logger.info(f'MNIST > X_train Shape : {X_train.shape}') + logger.info(f'MNIST > y_train Shape : {y_train.shape}') + logger.info(f'MNIST > Train Samples : {X_train.shape[0]}') + logger.info(f'MNIST > Valid Samples : {X_valid.shape[0]}') + + if categorical: + # convert class vectors to binary class matrices + y_train = one_hot(y_train, num_classes) + y_valid = one_hot(y_valid, num_classes) + + return input_shape, num_classes, X_train, y_train, X_valid, y_valid diff --git a/openfl-workspace/keras_cnn_with_compression/code/tfmnist_inmemory.py b/openfl-workspace/keras_cnn_with_compression/code/tfmnist_inmemory.py new file mode 100644 index 0000000000..e4fad049cb --- /dev/null +++ b/openfl-workspace/keras_cnn_with_compression/code/tfmnist_inmemory.py @@ -0,0 +1,40 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" + +from openfl.federated import TensorFlowDataLoader + +from .mnist_utils import load_mnist_shard + + +class TensorFlowMNISTInMemory(TensorFlowDataLoader): + """TensorFlow Data Loader for MNIST Dataset.""" + + def __init__(self, data_path, batch_size, **kwargs): + """ + Initialize. + + Args: + data_path: File path for the dataset + batch_size (int): The batch size for the data loader + **kwargs: Additional arguments, passed to super init and load_mnist_shard + """ + super().__init__(batch_size, **kwargs) + + # TODO: We should be downloading the dataset shard into a directory + # TODO: There needs to be a method to ask how many collaborators and + # what index/rank is this collaborator. + # Then we have a way to automatically shard based on rank and size of + # collaborator list. + + _, num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard( + shard_num=int(data_path), **kwargs + ) + + self.X_train = X_train + self.y_train = y_train + self.X_valid = X_valid + self.y_valid = y_valid + + self.num_classes = num_classes diff --git a/openfl-workspace/keras_cnn_with_compression/plan/cols.yaml b/openfl-workspace/keras_cnn_with_compression/plan/cols.yaml new file mode 100644 index 0000000000..61ffdd470b --- /dev/null +++ b/openfl-workspace/keras_cnn_with_compression/plan/cols.yaml @@ -0,0 +1,5 @@ +# Copyright (C) 2020 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +collaborators: + \ No newline at end of file diff --git a/openfl-workspace/keras_cnn_with_compression/plan/data.yaml b/openfl-workspace/keras_cnn_with_compression/plan/data.yaml new file mode 100644 index 0000000000..8d07ca8b3b --- /dev/null +++ b/openfl-workspace/keras_cnn_with_compression/plan/data.yaml @@ -0,0 +1,7 @@ +# Copyright (C) 2020 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +# collaborator_name,data_directory_path +one,1 + + diff --git a/openfl-workspace/keras_cnn_with_compression/plan/defaults b/openfl-workspace/keras_cnn_with_compression/plan/defaults new file mode 100644 index 0000000000..fb82f9c5b6 --- /dev/null +++ b/openfl-workspace/keras_cnn_with_compression/plan/defaults @@ -0,0 +1,2 @@ +../../workspace/plan/defaults + diff --git a/openfl-workspace/keras_cnn_with_compression/plan/plan.yaml b/openfl-workspace/keras_cnn_with_compression/plan/plan.yaml new file mode 100644 index 0000000000..ffe3b835c5 --- /dev/null +++ b/openfl-workspace/keras_cnn_with_compression/plan/plan.yaml @@ -0,0 +1,47 @@ +# Copyright (C) 2020 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +aggregator : + defaults : plan/defaults/aggregator.yaml + template : openfl.component.Aggregator + settings : + init_state_path : save/keras_cnn_mnist_init.pbuf + best_state_path : save/keras_cnn_mnist_best.pbuf + last_state_path : save/keras_cnn_mnist_last.pbuf + db_store_rounds: 2 + rounds_to_train : 10 + +collaborator : + defaults : plan/defaults/collaborator.yaml + template : openfl.component.Collaborator + settings : + db_store_rounds: 2 + delta_updates : true + opt_treatment : RESET + +data_loader : + defaults : plan/defaults/data_loader.yaml + template : code.tfmnist_inmemory.TensorFlowMNISTInMemory + settings : + collaborator_count : 2 + data_group_name : mnist + batch_size : 256 + +task_runner : + defaults : plan/defaults/task_runner.yaml + template : code.keras_cnn.KerasCNN + +network : + defaults : plan/defaults/network.yaml + +assigner : + defaults : plan/defaults/assigner.yaml + +tasks : + defaults : plan/defaults/tasks_keras.yaml + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml + template : openfl.pipelines.KCPipeline + settings : + n_clusters : 6 diff --git a/openfl-workspace/keras_cnn_with_compression/requirements.txt b/openfl-workspace/keras_cnn_with_compression/requirements.txt new file mode 100644 index 0000000000..60f6e6c9fa --- /dev/null +++ b/openfl-workspace/keras_cnn_with_compression/requirements.txt @@ -0,0 +1 @@ +tensorflow==2.3.1 diff --git a/openfl-workspace/keras_nlp/plan/plan.yaml b/openfl-workspace/keras_nlp/plan/plan.yaml index 939830b8ea..a21bc92e93 100644 --- a/openfl-workspace/keras_nlp/plan/plan.yaml +++ b/openfl-workspace/keras_nlp/plan/plan.yaml @@ -42,3 +42,6 @@ assigner : tasks : defaults : plan/defaults/tasks_keras.yaml + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml diff --git a/openfl-workspace/tf_2dunet/plan/plan.yaml b/openfl-workspace/tf_2dunet/plan/plan.yaml index da653a23c6..c4cb9a2bd0 100644 --- a/openfl-workspace/tf_2dunet/plan/plan.yaml +++ b/openfl-workspace/tf_2dunet/plan/plan.yaml @@ -38,3 +38,6 @@ assigner : tasks : defaults : plan/defaults/tasks_tensorflow.yaml + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml diff --git a/openfl-workspace/tf_cnn_histology/plan/plan.yaml b/openfl-workspace/tf_cnn_histology/plan/plan.yaml index 6faeb738fb..5cea6d60dc 100644 --- a/openfl-workspace/tf_cnn_histology/plan/plan.yaml +++ b/openfl-workspace/tf_cnn_histology/plan/plan.yaml @@ -8,6 +8,7 @@ aggregator : init_state_path : save/tf_cnn_histology_init.pbuf last_state_path : save/tf_cnn_histology_latest.pbuf best_state_path : save/tf_cnn_histology_best.pbuf + db_store_rounds: 2 rounds_to_train : 10 collaborator : @@ -15,6 +16,7 @@ collaborator : template : openfl.component.Collaborator settings : delta_updates : true + db_store_rounds: 2 opt_treatment : RESET data_loader : diff --git a/openfl-workspace/torch_cnn_histology/plan/plan.yaml b/openfl-workspace/torch_cnn_histology/plan/plan.yaml index 14e05b668b..14af8f5f81 100644 --- a/openfl-workspace/torch_cnn_histology/plan/plan.yaml +++ b/openfl-workspace/torch_cnn_histology/plan/plan.yaml @@ -36,3 +36,6 @@ tasks: assigner: defaults: plan/defaults/assigner.yaml + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml diff --git a/openfl-workspace/torch_cnn_mnist/plan/plan.yaml b/openfl-workspace/torch_cnn_mnist/plan/plan.yaml index 326130f45a..a91ca0fedc 100644 --- a/openfl-workspace/torch_cnn_mnist/plan/plan.yaml +++ b/openfl-workspace/torch_cnn_mnist/plan/plan.yaml @@ -38,3 +38,5 @@ assigner : tasks : defaults : plan/defaults/tasks_torch.yaml +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml diff --git a/openfl-workspace/torch_unet_kvasir/plan/plan.yaml b/openfl-workspace/torch_unet_kvasir/plan/plan.yaml index bb8f652efc..583b23eba9 100644 --- a/openfl-workspace/torch_unet_kvasir/plan/plan.yaml +++ b/openfl-workspace/torch_unet_kvasir/plan/plan.yaml @@ -55,3 +55,6 @@ tasks : apply : local metrics : - dice_coef + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml diff --git a/openfl-workspace/workspace/plan/defaults/compression_pipeline.yaml b/openfl-workspace/workspace/plan/defaults/compression_pipeline.yaml new file mode 100644 index 0000000000..a508f94fd2 --- /dev/null +++ b/openfl-workspace/workspace/plan/defaults/compression_pipeline.yaml @@ -0,0 +1 @@ +template: openfl.pipelines.NoCompressionPipeline diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index d190ac7c81..9efc5ec159 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -324,27 +324,27 @@ def get_aggregated_tensor(self, collaborator_name, tensor_name, named_tensor : protobuf NamedTensor the tensor requested by the collaborator """ - self.logger.debug( - 'Retrieving aggregated tensor {} for collaborator {}'.format( - tensor_name, collaborator_name)) + self.logger.debug(f'Retrieving aggregated tensor {tensor_name},{round_number},{tags} \ + for collaborator {collaborator_name}') if 'compressed' in tags or require_lossless: compress_lossless = True + else: + compress_lossless = False # TODO the TensorDB doesn't support compressed data yet. # The returned tensor will # be recompressed anyway. if 'compressed' in tags: tags.remove('compressed') + if 'lossy_compressed' in tags: + tags.remove('lossy_compressed') tensor_key = TensorKey( tensor_name, self.uuid, round_number, report, tuple(tags) ) tensor_name, origin, round_number, report, tags = tensor_key - # send_model_deltas = False - compress_lossless = False - if 'aggregated' in tags and 'delta' in tags and round_number != 0: # send_model_deltas = True agg_tensor_key = TensorKey( @@ -353,7 +353,7 @@ def get_aggregated_tensor(self, collaborator_name, tensor_name, else: agg_tensor_key = tensor_key - nparray = self.tensor_db.get_tensor_from_cache(tensor_key) + nparray = self.tensor_db.get_tensor_from_cache(agg_tensor_key) if nparray is None: raise ValueError("Aggregator does not have an aggregated tensor" @@ -380,14 +380,16 @@ def _nparray_to_named_tensor(self, tensor_key, nparray, send_model_deltas, """ tensor_name, origin, round_number, report, tags = tensor_key # if we have an aggregated tensor, we can make a delta - if 'aggregated' in tensor_name and send_model_deltas: + if 'aggregated' in tags and send_model_deltas: # Should get the pretrained model to create the delta. If training # has happened, Model should already be stored in the TensorDB - model_nparray = self.tensor_db.get_tensor_from_cache( - TensorKey(tensor_name, - origin, - round_number - 1, - ('model',))) + model_tk = TensorKey(tensor_name, + origin, + round_number - 1, + report, + ('model',)) + + model_nparray = self.tensor_db.get_tensor_from_cache(model_tk) assert (model_nparray is not None), ( "The original model layer should be present if the latest " @@ -535,7 +537,7 @@ def _process_named_tensor(self, named_tensor, collaborator_name): tuple(named_tensor.tags) ) tensor_name, origin, round_number, report, tags = tensor_key - assert ('compressed' in tags or 'lossy_decompressed' in tags), ( + assert ('compressed' in tags or 'lossy_compressed' in tags), ( 'Named tensor {} is not compressed'.format(tensor_key)) if 'compressed' in tags: dec_tk, decompressed_nparray = self.tensor_codec.decompress( @@ -558,7 +560,8 @@ def _process_named_tensor(self, named_tensor, collaborator_name): dec_tk, decompressed_nparray = self.tensor_codec.decompress( tensor_key, data=raw_bytes, - transformer_metadata=metadata + transformer_metadata=metadata, + require_lossless=False ) dec_name, dec_origin, dec_round_num, dec_report, dec_tags = dec_tk if type(dec_tags) == str: @@ -654,7 +657,6 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result agg_results, base_model_nparray ) - self.tensor_db.cache_tensor({delta_tk: delta_nparray}) else: # This condition is possible for base model # optimizer states (i.e. Adam/iter:0, SGD, etc.) @@ -679,8 +681,11 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result metadata ) + self.tensor_db.cache_tensor({decompressed_delta_tk: decompressed_delta_nparray}) + # Apply delta (unless delta couldn't be created) if base_model_nparray is not None: + self.logger.debug(f'Applying delta for layer {decompressed_delta_tk[0]}') new_model_tk, new_model_nparray = self.tensor_codec.apply_delta( decompressed_delta_tk, decompressed_delta_nparray, @@ -704,9 +709,6 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result # Finally, cache the updated model tensor self.tensor_db.cache_tensor({final_model_tk: new_model_nparray}) - # self.logger.debug('TensorDB contents after - # training round {}: - # {}'.format(self.round_number,self.tensor_db)) def _compute_validation_related_task_metrics(self, task_name): """ diff --git a/openfl/component/collaborator/collaborator.py b/openfl/component/collaborator/collaborator.py index d535febc03..a0fb4b62c5 100644 --- a/openfl/component/collaborator/collaborator.py +++ b/openfl/component/collaborator/collaborator.py @@ -65,10 +65,10 @@ def __init__(self, federation_uuid, client, task_runner, - tensor_pipe, task_config, opt_treatment=OptTreatment.RESET, delta_updates=False, + compression_pipeline=None, db_store_rounds=1, **kwargs): """Initialize.""" @@ -82,8 +82,8 @@ def __init__(self, self.aggregator_uuid = aggregator_uuid self.federation_uuid = federation_uuid - self.tensor_pipe = tensor_pipe or NoCompressionPipeline() - self.tensor_codec = TensorCodec(self.tensor_pipe) + self.compression_pipeline = compression_pipeline or NoCompressionPipeline() + self.tensor_codec = TensorCodec(self.compression_pipeline) self.tensor_db = TensorDB() self.db_store_rounds = db_store_rounds @@ -266,8 +266,6 @@ def get_data_for_tensorkey(self, tensor_key): tensor_dependencies = self.tensor_codec.find_dependencies( tensor_key, self.delta_updates ) - # self.logger.info('tensor_dependencies = {}'.format( - # tensor_dependencies)) if len(tensor_dependencies) > 0: # Resolve dependencies # tensor_dependencies[0] corresponds to the prior version @@ -285,18 +283,20 @@ def get_data_for_tensorkey(self, tensor_key): new_model_tk, nparray = self.tensor_codec.apply_delta( tensor_dependencies[1], uncompressed_delta, - prior_model_layer - ) - self.logger.debug('Applied delta to tensor {}'.format( - tensor_dependencies[0][0]) + prior_model_layer, + creates_model=True, ) + self.tensor_db.cache_tensor({new_model_tk: nparray}) else: + self.logger.info('Count not find previous model layer.' + 'Fetching latest layer from aggregator') # The original model tensor should be fetched from client nparray = self.get_aggregated_tensor_from_aggregator( - tensor_key + tensor_key, + require_lossless=True ) elif 'model' in tags: - # Pulling the model for the first time or + # Pulling the model for the first time nparray = self.get_aggregated_tensor_from_aggregator( tensor_key, require_lossless=True @@ -344,8 +344,6 @@ def get_aggregated_tensor_from_aggregator(self, tensor_key, # cache this tensor self.tensor_db.cache_tensor({tensor_key: nparray}) - # self.logger.info('Printing updated TensorDB: {}'.format( - # self.tensor_db)) return nparray @@ -413,6 +411,7 @@ def nparray_to_named_tensor(self, tensor_key, nparray): ) delta_comp_tensor_key, delta_comp_nparray, metadata = \ self.tensor_codec.compress(delta_tensor_key, delta_nparray) + named_tensor = utils.construct_named_tensor( delta_comp_tensor_key, delta_comp_nparray, diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index dfa2e6a285..0a0b27a1da 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -255,6 +255,7 @@ def get_aggregator(self, tensor_dict=None): defaults[SETTINGS]['federation_uuid'] = self.federation_uuid defaults[SETTINGS]['authorized_cols'] = self.authorized_cols defaults[SETTINGS]['assigner'] = self.get_assigner() + defaults[SETTINGS]['compression_pipeline'] = self.get_tensor_pipe() if self.aggregator_ is None: self.aggregator_ = Plan.Build(**defaults, initial_tensor_dict=tensor_dict) @@ -385,7 +386,7 @@ def get_collaborator(self, collaborator_name, data_loader = self.get_data_loader(collaborator_name) defaults[SETTINGS]['task_runner'] = self.get_task_runner(data_loader) - defaults[SETTINGS]['tensor_pipe'] = self.get_tensor_pipe() + defaults[SETTINGS]['compression_pipeline'] = self.get_tensor_pipe() defaults[SETTINGS]['task_config'] = self.config.get('tasks', {}) if client is not None: defaults[SETTINGS]['client'] = client diff --git a/openfl/pipelines/kc_pipeline.py b/openfl/pipelines/kc_pipeline.py index 54c437603e..912c9bada3 100644 --- a/openfl/pipelines/kc_pipeline.py +++ b/openfl/pipelines/kc_pipeline.py @@ -38,10 +38,16 @@ def forward(self, data, **kwargs): # clustering k_means = cluster.KMeans(n_clusters=self.n_cluster, n_init=self.n_cluster) data = data.reshape((-1, 1)) - k_means.fit(data) - quantized_values = k_means.cluster_centers_.squeeze() - indices = k_means.labels_ - quant_array = np.choose(indices, quantized_values) + if data.shape[0] >= self.n_cluster: + k_means = cluster.KMeans( + n_clusters=self.n_cluster, n_init=self.n_cluster) + k_means.fit(data) + quantized_values = k_means.cluster_centers_.squeeze() + indices = k_means.labels_ + quant_array = np.choose(indices, quantized_values) + else: + quant_array = data + int_array, int2float_map = self._float_to_int(quant_array) metadata['int_to_float'] = int2float_map diff --git a/openfl/pipelines/skc_pipeline.py b/openfl/pipelines/skc_pipeline.py index 3e77378f61..ce6b62bc56 100644 --- a/openfl/pipelines/skc_pipeline.py +++ b/openfl/pipelines/skc_pipeline.py @@ -32,10 +32,9 @@ def forward(self, data, **kwargs): data: an numpy array from the model tensor_dict. Returns: - condensed_data: an numpy array being sparsified. + sparse_data: a flattened, sparse representation of the input tensor metadata: dictionary to store a list of meta information. """ - self.p = 1 metadata = {'int_list': list(data.shape)} # sparsification data = data.astype(np.float32) @@ -43,13 +42,9 @@ def forward(self, data, **kwargs): n_elements = flatten_data.shape[0] k_op = int(np.ceil(n_elements * self.p)) topk, topk_indices = self._topk_func(flatten_data, k_op) - # - condensed_data = topk sparse_data = np.zeros(flatten_data.shape) sparse_data[topk_indices] = topk - nonzero_element_bool_indices = sparse_data != 0.0 - metadata['bool_list'] = list(nonzero_element_bool_indices) - return condensed_data, metadata + return sparse_data, metadata def backward(self, data, metadata, **kwargs): """Recover data array with the right shape and numerical type. @@ -64,10 +59,7 @@ def backward(self, data, metadata, **kwargs): """ data = data.astype(np.float32) data_shape = metadata['int_list'] - nonzero_element_bool_indices = list(metadata['bool_list']) - recovered_data = np.zeros(data_shape).reshape(-1).astype(np.float32) - recovered_data[nonzero_element_bool_indices] = data - recovered_data = recovered_data.reshape(data_shape) + recovered_data = data.reshape(data_shape) return recovered_data @staticmethod @@ -115,12 +107,15 @@ def forward(self, data, **kwargs): """ # clustering data = data.reshape((-1, 1)) - k_means = cluster.KMeans( - n_clusters=self.n_cluster, n_init=self.n_cluster) - k_means.fit(data) - quantized_values = k_means.cluster_centers_.squeeze() - indices = k_means.labels_ - quant_array = np.choose(indices, quantized_values) + if data.shape[0] >= self.n_cluster: + k_means = cluster.KMeans( + n_clusters=self.n_cluster, n_init=self.n_cluster) + k_means.fit(data) + quantized_values = k_means.cluster_centers_.squeeze() + indices = k_means.labels_ + quant_array = np.choose(indices, quantized_values) + else: + quant_array = data int_array, int2float_map = self._float_to_int(quant_array) metadata = {'int_to_float': int2float_map} int_array = int_array.reshape(-1) @@ -209,11 +204,11 @@ def backward(self, data, metadata, **kwargs): class SKCPipeline(TransformationPipeline): """A pipeline class to compress data lossly using sparsity and k-means methods.""" - def __init__(self, p_sparsity=0.01, n_clusters=6, **kwargs): + def __init__(self, p_sparsity=0.1, n_clusters=6, **kwargs): """Initialize a pipeline of transformers. Args: - p_sparsity (float): Sparsity factor (Default=0.01) + p_sparsity (float): Sparsity factor (Default=0.1) n_cluster (int): Number of K-Means clusters (Default=6) Returns: diff --git a/openfl/pipelines/stc_pipeline.py b/openfl/pipelines/stc_pipeline.py index 0eb8efd298..a5f1f97914 100644 --- a/openfl/pipelines/stc_pipeline.py +++ b/openfl/pipelines/stc_pipeline.py @@ -22,13 +22,14 @@ def __init__(self, p=0.01): self.p = p def forward(self, data, **kwargs): - """Sparsify data and pass over only non-sparsified elements by reducing the array size. + """ + Sparsify data and pass over only non-sparsified elements by reducing the array size. Args: - data: an numpy array from the model tensor_dict + data: an numpy array from the model tensor_dict. Returns: - condensed_data: an numpy array being sparsified. + sparse_data: a flattened, sparse representation of the input tensor metadata: dictionary to store a list of meta information. """ metadata = {'int_list': list(data.shape)} @@ -38,31 +39,24 @@ def forward(self, data, **kwargs): n_elements = flatten_data.shape[0] k_op = int(np.ceil(n_elements * self.p)) topk, topk_indices = self._topk_func(flatten_data, k_op) - # - condensed_data = topk sparse_data = np.zeros(flatten_data.shape) sparse_data[topk_indices] = topk - nonzero_element_bool_indices = sparse_data != 0.0 - metadata['bool_list'] = list(nonzero_element_bool_indices) - return condensed_data, metadata - # return sparse_data, metadata + return sparse_data, metadata def backward(self, data, metadata, **kwargs): """Recover data array with the right shape and numerical type. Args: data: an numpy array with non-zero values. - metadata: dictionary to contain information for recovering back to original data array. + metadata: dictionary to contain information for recovering back + to original data array. Returns: recovered_data: an numpy array with original shape. """ data = data.astype(np.float32) data_shape = metadata['int_list'] - nonzero_element_bool_indices = list(metadata['bool_list']) - recovered_data = np.zeros(data_shape).reshape(-1).astype(np.float32) - recovered_data[nonzero_element_bool_indices] = data - recovered_data = recovered_data.reshape(data_shape) + recovered_data = data.reshape(data_shape) return recovered_data @staticmethod @@ -203,7 +197,7 @@ def backward(self, data, metadata, **kwargs): class STCPipeline(TransformationPipeline): """A pipeline class to compress data lossly using sparsity and ternerization methods.""" - def __init__(self, p_sparsity=0.01, n_clusters=6, **kwargs): + def __init__(self, p_sparsity=0.1, n_clusters=6, **kwargs): """Initialize a pipeline of transformers. Args: diff --git a/openfl/pipelines/tensor_codec.py b/openfl/pipelines/tensor_codec.py index fb226c5d4b..f1faef658b 100644 --- a/openfl/pipelines/tensor_codec.py +++ b/openfl/pipelines/tensor_codec.py @@ -179,7 +179,7 @@ def generate_delta(tensor_key, nparray, base_model_nparray): return delta_tensor_key, nparray - base_model_nparray @staticmethod - def apply_delta(tensor_key, delta, base_model_nparray): + def apply_delta(tensor_key, delta, base_model_nparray, creates_model=False): """ Add delta to the nparray. @@ -191,6 +191,8 @@ def apply_delta(tensor_key, delta, base_model_nparray): old model base_model_nparray: The nparray that corresponds to the prior weights + creates_model: If flag is set, the tensorkey returned + will correspond to the aggregator model Returns: new_model_tensor_key: Latest model layer tensorkey @@ -205,7 +207,7 @@ def apply_delta(tensor_key, delta, base_model_nparray): # assert('model' in tensor_key[3]), 'The tensorkey should be provided # from the base model' # Aggregator UUID has the prefix 'aggregator' - if 'aggregator' in origin: + if 'aggregator' in origin and not creates_model: tags = list(tags) tags.remove('delta') new_tags = tuple(tags) diff --git a/tests/openfl/component/collaborator/test_collaborator.py b/tests/openfl/component/collaborator/test_collaborator.py index 2e7badd59c..5f67b2b04b 100644 --- a/tests/openfl/component/collaborator/test_collaborator.py +++ b/tests/openfl/component/collaborator/test_collaborator.py @@ -15,7 +15,7 @@ def collaborator_mock(): """Initialize the collaborator mock.""" col = Collaborator('col1', 'some_uuid', 'federation_uuid', - mock.Mock(), mock.Mock(), None, mock.Mock(), opt_treatment='RESET') + mock.Mock(), mock.Mock(), mock.Mock(), opt_treatment='RESET') col.tensor_db = mock.Mock() return col